18 *
19 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
20 * or visit www.oracle.com if you need additional information or have any
21 * questions.
22 */
23
24 package compiler.vectorapi.reshape.utils;
25
26 import compiler.lib.ir_framework.ForceInline;
27 import compiler.lib.ir_framework.IRNode;
28 import compiler.lib.ir_framework.TestFramework;
29 import java.lang.invoke.MethodHandles;
30 import java.lang.invoke.MethodType;
31 import java.lang.invoke.VarHandle;
32 import java.lang.reflect.Array;
33 import java.nio.ByteOrder;
34 import java.util.List;
35 import java.util.random.RandomGenerator;
36 import java.util.stream.Collectors;
37 import java.util.stream.Stream;
38 import jdk.incubator.vector.*;
39 import jdk.test.lib.Asserts;
40 import jdk.test.lib.Utils;
41
42 public class VectorReshapeHelper {
43 public static final int INVOCATIONS = 10_000;
44
45 public static final VectorSpecies<Byte> BSPEC64 = ByteVector.SPECIES_64;
46 public static final VectorSpecies<Short> SSPEC64 = ShortVector.SPECIES_64;
47 public static final VectorSpecies<Integer> ISPEC64 = IntVector.SPECIES_64;
48 public static final VectorSpecies<Long> LSPEC64 = LongVector.SPECIES_64;
49 public static final VectorSpecies<Float> FSPEC64 = FloatVector.SPECIES_64;
50 public static final VectorSpecies<Double> DSPEC64 = DoubleVector.SPECIES_64;
51
52 public static final VectorSpecies<Byte> BSPEC128 = ByteVector.SPECIES_128;
53 public static final VectorSpecies<Short> SSPEC128 = ShortVector.SPECIES_128;
54 public static final VectorSpecies<Integer> ISPEC128 = IntVector.SPECIES_128;
55 public static final VectorSpecies<Long> LSPEC128 = LongVector.SPECIES_128;
56 public static final VectorSpecies<Float> FSPEC128 = FloatVector.SPECIES_128;
57 public static final VectorSpecies<Double> DSPEC128 = DoubleVector.SPECIES_128;
196 case "float" -> (float)0;
197 case "double" -> (double)0;
198 default -> throw new AssertionError();
199 };
200 }
201 actual = switch (osp.elementType().getName()) {
202 case "byte" -> UnsafeUtils.getByte(output, obase, i);
203 case "short" -> UnsafeUtils.getShort(output, obase, i);
204 case "int" -> UnsafeUtils.getInt(output, obase, i);
205 case "long" -> UnsafeUtils.getLong(output, obase, i);
206 case "float" -> UnsafeUtils.getFloat(output, obase, i);
207 case "double" -> UnsafeUtils.getDouble(output, obase, i);
208 default -> throw new AssertionError();
209 };
210 Asserts.assertEquals(expected, actual);
211 }
212 }
213 }
214
215 @ForceInline
216 public static void vectorExpandShrink(VectorSpecies<Byte> isp, VectorSpecies<Byte> osp, byte[] input, byte[] output) {
217 isp.fromByteArray(input, 0, ByteOrder.nativeOrder())
218 .reinterpretShape(osp, 0)
219 .intoByteArray(output, 0, ByteOrder.nativeOrder());
220 }
221
222 public static void runExpandShrinkHelper(VectorSpecies<Byte> isp, VectorSpecies<Byte> osp) throws Throwable {
223 var random = Utils.getRandomInstance();
224 String testMethodName = VectorSpeciesPair.makePair(isp, osp).format();
225 var caller = StackWalker.getInstance(StackWalker.Option.RETAIN_CLASS_REFERENCE).getCallerClass();
226 var testMethod = MethodHandles.lookup().findStatic(caller,
227 testMethodName,
228 MethodType.methodType(void.class, byte.class.arrayType(), byte.class.arrayType()));
229 byte[] input = new byte[isp.vectorByteSize()];
230 byte[] output = new byte[osp.vectorByteSize()];
231 for (int iter = 0; iter < INVOCATIONS; iter++) {
232 random.nextBytes(input);
233
234 testMethod.invokeExact(input, output);
235
236 for (int i = 0; i < osp.vectorByteSize(); i++) {
237 int expected = i < isp.vectorByteSize() ? input[i] : 0;
238 int actual = output[i];
239 Asserts.assertEquals(expected, actual);
240 }
241 }
242 }
243
244 @ForceInline
245 public static void vectorDoubleExpandShrink(VectorSpecies<Byte> isp, VectorSpecies<Byte> osp, byte[] input, byte[] output) {
246 isp.fromByteArray(input, 0, ByteOrder.nativeOrder())
247 .reinterpretShape(osp, 0)
248 .reinterpretShape(isp, 0)
249 .intoByteArray(output, 0, ByteOrder.nativeOrder());
250 }
251
252 public static void runDoubleExpandShrinkHelper(VectorSpecies<Byte> isp, VectorSpecies<Byte> osp) throws Throwable {
253 var random = Utils.getRandomInstance();
254 String testMethodName = VectorSpeciesPair.makePair(isp, osp).format();
255 var caller = StackWalker.getInstance(StackWalker.Option.RETAIN_CLASS_REFERENCE).getCallerClass();
256 var testMethod = MethodHandles.lookup().findStatic(caller,
257 testMethodName,
258 MethodType.methodType(void.class, byte.class.arrayType(), byte.class.arrayType()));
259 byte[] input = new byte[isp.vectorByteSize()];
260 byte[] output = new byte[isp.vectorByteSize()];
261 for (int iter = 0; iter < INVOCATIONS; iter++) {
262 random.nextBytes(input);
263
264 testMethod.invokeExact(input, output);
265
266 for (int i = 0; i < isp.vectorByteSize(); i++) {
267 int expected = i < osp.vectorByteSize() ? input[i] : 0;
268 int actual = output[i];
269 Asserts.assertEquals(expected, actual);
270 }
271 }
272 }
273
274 @ForceInline
275 public static <T, U> void vectorRebracket(VectorSpecies<T> isp, VectorSpecies<U> osp, Object input, Object output) {
276 var outputVector = readVector(isp, input)
277 .reinterpretShape(osp, 0);
278 writeVector(osp, outputVector, output);
279 }
280
281 public static <T, U> void runRebracketHelper(VectorSpecies<T> isp, VectorSpecies<U> osp) throws Throwable {
282 var random = Utils.getRandomInstance();
283 String testMethodName = VectorSpeciesPair.makePair(isp, osp).format();
284 var caller = StackWalker.getInstance(StackWalker.Option.RETAIN_CLASS_REFERENCE).getCallerClass();
|
18 *
19 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
20 * or visit www.oracle.com if you need additional information or have any
21 * questions.
22 */
23
24 package compiler.vectorapi.reshape.utils;
25
26 import compiler.lib.ir_framework.ForceInline;
27 import compiler.lib.ir_framework.IRNode;
28 import compiler.lib.ir_framework.TestFramework;
29 import java.lang.invoke.MethodHandles;
30 import java.lang.invoke.MethodType;
31 import java.lang.invoke.VarHandle;
32 import java.lang.reflect.Array;
33 import java.nio.ByteOrder;
34 import java.util.List;
35 import java.util.random.RandomGenerator;
36 import java.util.stream.Collectors;
37 import java.util.stream.Stream;
38 import jdk.incubator.foreign.MemorySegment;
39 import jdk.incubator.vector.*;
40 import jdk.test.lib.Asserts;
41 import jdk.test.lib.Utils;
42
43 public class VectorReshapeHelper {
44 public static final int INVOCATIONS = 10_000;
45
46 public static final VectorSpecies<Byte> BSPEC64 = ByteVector.SPECIES_64;
47 public static final VectorSpecies<Short> SSPEC64 = ShortVector.SPECIES_64;
48 public static final VectorSpecies<Integer> ISPEC64 = IntVector.SPECIES_64;
49 public static final VectorSpecies<Long> LSPEC64 = LongVector.SPECIES_64;
50 public static final VectorSpecies<Float> FSPEC64 = FloatVector.SPECIES_64;
51 public static final VectorSpecies<Double> DSPEC64 = DoubleVector.SPECIES_64;
52
53 public static final VectorSpecies<Byte> BSPEC128 = ByteVector.SPECIES_128;
54 public static final VectorSpecies<Short> SSPEC128 = ShortVector.SPECIES_128;
55 public static final VectorSpecies<Integer> ISPEC128 = IntVector.SPECIES_128;
56 public static final VectorSpecies<Long> LSPEC128 = LongVector.SPECIES_128;
57 public static final VectorSpecies<Float> FSPEC128 = FloatVector.SPECIES_128;
58 public static final VectorSpecies<Double> DSPEC128 = DoubleVector.SPECIES_128;
197 case "float" -> (float)0;
198 case "double" -> (double)0;
199 default -> throw new AssertionError();
200 };
201 }
202 actual = switch (osp.elementType().getName()) {
203 case "byte" -> UnsafeUtils.getByte(output, obase, i);
204 case "short" -> UnsafeUtils.getShort(output, obase, i);
205 case "int" -> UnsafeUtils.getInt(output, obase, i);
206 case "long" -> UnsafeUtils.getLong(output, obase, i);
207 case "float" -> UnsafeUtils.getFloat(output, obase, i);
208 case "double" -> UnsafeUtils.getDouble(output, obase, i);
209 default -> throw new AssertionError();
210 };
211 Asserts.assertEquals(expected, actual);
212 }
213 }
214 }
215
216 @ForceInline
217 public static void vectorExpandShrink(VectorSpecies<Byte> isp, VectorSpecies<Byte> osp,
218 MemorySegment input, MemorySegment output) {
219 isp.fromMemorySegment(input, 0, ByteOrder.nativeOrder())
220 .reinterpretShape(osp, 0)
221 .intoMemorySegment(output, 0, ByteOrder.nativeOrder());
222 }
223
224 public static void runExpandShrinkHelper(VectorSpecies<Byte> isp, VectorSpecies<Byte> osp) throws Throwable {
225 var random = Utils.getRandomInstance();
226 String testMethodName = VectorSpeciesPair.makePair(isp, osp).format();
227 var caller = StackWalker.getInstance(StackWalker.Option.RETAIN_CLASS_REFERENCE).getCallerClass();
228 var testMethod = MethodHandles.lookup().findStatic(caller,
229 testMethodName,
230 MethodType.methodType(void.class, MemorySegment.class, MemorySegment.class));
231 byte[] input = new byte[isp.vectorByteSize()];
232 byte[] output = new byte[osp.vectorByteSize()];
233 MemorySegment msInput = MemorySegment.ofArray(input);
234 MemorySegment msOutput = MemorySegment.ofArray(output);
235 for (int iter = 0; iter < INVOCATIONS; iter++) {
236 random.nextBytes(input);
237
238 testMethod.invokeExact(msInput, msOutput);
239
240 for (int i = 0; i < osp.vectorByteSize(); i++) {
241 int expected = i < isp.vectorByteSize() ? input[i] : 0;
242 int actual = output[i];
243 Asserts.assertEquals(expected, actual);
244 }
245 }
246 }
247
248 @ForceInline
249 public static void vectorDoubleExpandShrink(VectorSpecies<Byte> isp, VectorSpecies<Byte> osp,
250 MemorySegment input, MemorySegment output) {
251 isp.fromMemorySegment(input, 0, ByteOrder.nativeOrder())
252 .reinterpretShape(osp, 0)
253 .reinterpretShape(isp, 0)
254 .intoMemorySegment(output, 0, ByteOrder.nativeOrder());
255 }
256
257 public static void runDoubleExpandShrinkHelper(VectorSpecies<Byte> isp, VectorSpecies<Byte> osp) throws Throwable {
258 var random = Utils.getRandomInstance();
259 String testMethodName = VectorSpeciesPair.makePair(isp, osp).format();
260 var caller = StackWalker.getInstance(StackWalker.Option.RETAIN_CLASS_REFERENCE).getCallerClass();
261 var testMethod = MethodHandles.lookup().findStatic(caller,
262 testMethodName,
263 MethodType.methodType(void.class, MemorySegment.class, MemorySegment.class));
264 byte[] input = new byte[isp.vectorByteSize()];
265 byte[] output = new byte[isp.vectorByteSize()];
266 MemorySegment msInput = MemorySegment.ofArray(input);
267 MemorySegment msOutput = MemorySegment.ofArray(output);
268 for (int iter = 0; iter < INVOCATIONS; iter++) {
269 random.nextBytes(input);
270
271 testMethod.invokeExact(msInput, msOutput);
272
273 for (int i = 0; i < isp.vectorByteSize(); i++) {
274 int expected = i < osp.vectorByteSize() ? input[i] : 0;
275 int actual = output[i];
276 Asserts.assertEquals(expected, actual);
277 }
278 }
279 }
280
281 @ForceInline
282 public static <T, U> void vectorRebracket(VectorSpecies<T> isp, VectorSpecies<U> osp, Object input, Object output) {
283 var outputVector = readVector(isp, input)
284 .reinterpretShape(osp, 0);
285 writeVector(osp, outputVector, output);
286 }
287
288 public static <T, U> void runRebracketHelper(VectorSpecies<T> isp, VectorSpecies<U> osp) throws Throwable {
289 var random = Utils.getRandomInstance();
290 String testMethodName = VectorSpeciesPair.makePair(isp, osp).format();
291 var caller = StackWalker.getInstance(StackWalker.Option.RETAIN_CLASS_REFERENCE).getCallerClass();
|