1 /*
  2  * Copyright (c) 2021, 2022, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.
  8  *
  9  * This code is distributed in the hope that it will be useful, but WITHOUT
 10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 12  * version 2 for more details (a copy is included in the LICENSE file that
 13  * accompanied this code).
 14  *
 15  * You should have received a copy of the GNU General Public License version
 16  * 2 along with this work; if not, write to the Free Software Foundation,
 17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 18  *
 19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 20  * or visit www.oracle.com if you need additional information or have any
 21  * questions.
 22  */
 23 
 24 package compiler.vectorapi.reshape.utils;
 25 
 26 import compiler.lib.ir_framework.ForceInline;
 27 import compiler.lib.ir_framework.IRNode;
 28 import compiler.lib.ir_framework.TestFramework;
 29 import java.lang.invoke.MethodHandles;
 30 import java.lang.invoke.MethodType;
 31 import java.lang.invoke.VarHandle;
 32 import java.lang.reflect.Array;
 33 import java.nio.ByteOrder;
 34 import java.util.List;
 35 import java.util.random.RandomGenerator;
 36 import java.util.stream.Collectors;
 37 import java.util.stream.Stream;
 38 import jdk.incubator.foreign.MemorySegment;
 39 import jdk.incubator.vector.*;
 40 import jdk.test.lib.Asserts;
 41 import jdk.test.lib.Utils;
 42 
 43 public class VectorReshapeHelper {
 44     public static final int INVOCATIONS = 10_000;
 45 
 46     public static final VectorSpecies<Byte>    BSPEC64  =   ByteVector.SPECIES_64;
 47     public static final VectorSpecies<Short>   SSPEC64  =  ShortVector.SPECIES_64;
 48     public static final VectorSpecies<Integer> ISPEC64  =    IntVector.SPECIES_64;
 49     public static final VectorSpecies<Long>    LSPEC64  =   LongVector.SPECIES_64;
 50     public static final VectorSpecies<Float>   FSPEC64  =  FloatVector.SPECIES_64;
 51     public static final VectorSpecies<Double>  DSPEC64  = DoubleVector.SPECIES_64;
 52 
 53     public static final VectorSpecies<Byte>    BSPEC128 =   ByteVector.SPECIES_128;
 54     public static final VectorSpecies<Short>   SSPEC128 =  ShortVector.SPECIES_128;
 55     public static final VectorSpecies<Integer> ISPEC128 =    IntVector.SPECIES_128;
 56     public static final VectorSpecies<Long>    LSPEC128 =   LongVector.SPECIES_128;
 57     public static final VectorSpecies<Float>   FSPEC128 =  FloatVector.SPECIES_128;
 58     public static final VectorSpecies<Double>  DSPEC128 = DoubleVector.SPECIES_128;
 59 
 60     public static final VectorSpecies<Byte>    BSPEC256 =   ByteVector.SPECIES_256;
 61     public static final VectorSpecies<Short>   SSPEC256 =  ShortVector.SPECIES_256;
 62     public static final VectorSpecies<Integer> ISPEC256 =    IntVector.SPECIES_256;
 63     public static final VectorSpecies<Long>    LSPEC256 =   LongVector.SPECIES_256;
 64     public static final VectorSpecies<Float>   FSPEC256 =  FloatVector.SPECIES_256;
 65     public static final VectorSpecies<Double>  DSPEC256 = DoubleVector.SPECIES_256;
 66 
 67     public static final VectorSpecies<Byte>    BSPEC512 =   ByteVector.SPECIES_512;
 68     public static final VectorSpecies<Short>   SSPEC512 =  ShortVector.SPECIES_512;
 69     public static final VectorSpecies<Integer> ISPEC512 =    IntVector.SPECIES_512;
 70     public static final VectorSpecies<Long>    LSPEC512 =   LongVector.SPECIES_512;
 71     public static final VectorSpecies<Float>   FSPEC512 =  FloatVector.SPECIES_512;
 72     public static final VectorSpecies<Double>  DSPEC512 = DoubleVector.SPECIES_512;
 73 
 74     public static final String B2X_NODE  = IRNode.VECTOR_CAST_B2X;
 75     public static final String S2X_NODE  = IRNode.VECTOR_CAST_S2X;
 76     public static final String I2X_NODE  = IRNode.VECTOR_CAST_I2X;
 77     public static final String L2X_NODE  = IRNode.VECTOR_CAST_L2X;
 78     public static final String F2X_NODE  = IRNode.VECTOR_CAST_F2X;
 79     public static final String D2X_NODE  = IRNode.VECTOR_CAST_D2X;
 80     public static final String UB2X_NODE = IRNode.VECTOR_UCAST_B2X;
 81     public static final String US2X_NODE = IRNode.VECTOR_UCAST_S2X;
 82     public static final String UI2X_NODE = IRNode.VECTOR_UCAST_I2X;
 83     public static final String REINTERPRET_NODE = IRNode.VECTOR_REINTERPRET;
 84 
 85     public static void runMainHelper(Class<?> testClass, Stream<VectorSpeciesPair> testMethods, String... flags) {
 86         var test = new TestFramework(testClass);
 87         test.setDefaultWarmup(1);
 88         test.addHelperClasses(VectorReshapeHelper.class);
 89         test.addFlags("--add-modules=jdk.incubator.vector", "--add-exports=java.base/jdk.internal.misc=ALL-UNNAMED");
 90         test.addFlags(flags);
 91         String testMethodNames = testMethods
 92                 .filter(p -> p.isp().length() <= VectorSpecies.ofLargestShape(p.isp().elementType()).length())
 93                 .filter(p -> p.osp().length() <= VectorSpecies.ofLargestShape(p.osp().elementType()).length())
 94                 .map(VectorSpeciesPair::format)
 95                 .collect(Collectors.joining(","));
 96         test.addFlags("-DTest=" + testMethodNames);
 97         test.start();
 98     }
 99 
100     @ForceInline
101     public static <T, U> void vectorCast(VectorOperators.Conversion<T, U> cop,
102                                          VectorSpecies<T> isp, VectorSpecies<U> osp, Object input, Object output) {
103         var outputVector = readVector(isp, input)
104                 .convertShape(cop, osp, 0);
105         writeVector(osp, outputVector, output);
106     }
107 
108     public static <T, U> void runCastHelper(VectorOperators.Conversion<T, U> castOp,
109                                             VectorSpecies<T> isp, VectorSpecies<U> osp) throws Throwable {
110         var random = Utils.getRandomInstance();
111         boolean isUnsignedCast = castOp.name().startsWith("ZERO");
112         String testMethodName = VectorSpeciesPair.makePair(isp, osp, isUnsignedCast).format();
113         var caller = StackWalker.getInstance(StackWalker.Option.RETAIN_CLASS_REFERENCE).getCallerClass();
114         var testMethod = MethodHandles.lookup().findStatic(caller,
115                     testMethodName,
116                     MethodType.methodType(void.class, isp.elementType().arrayType(), osp.elementType().arrayType()))
117                 .asType(MethodType.methodType(void.class, Object.class, Object.class));
118         Object input = Array.newInstance(isp.elementType(), isp.length());
119         Object output = Array.newInstance(osp.elementType(), osp.length());
120         long ibase = UnsafeUtils.arrayBase(isp.elementType());
121         long obase = UnsafeUtils.arrayBase(osp.elementType());
122         for (int iter = 0; iter < INVOCATIONS; iter++) {
123             // We need to generate arrays with NaN or very large values occasionally
124             boolean normalArray = random.nextBoolean();
125             var abnormalValue = List.of(Double.NaN, Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY, -1e30, 1e30);
126             for (int i = 0; i < isp.length(); i++) {
127                 switch (isp.elementType().getName()) {
128                     case "byte"   -> UnsafeUtils.putByte(input, ibase, i, (byte)random.nextInt());
129                     case "short"  -> UnsafeUtils.putShort(input, ibase, i, (short)random.nextInt());
130                     case "int"    -> UnsafeUtils.putInt(input, ibase, i, random.nextInt());
131                     case "long"   -> UnsafeUtils.putLong(input, ibase, i, random.nextLong());
132                     case "float"  -> {
133                         if (normalArray || random.nextBoolean()) {
134                             UnsafeUtils.putFloat(input, ibase, i, random.nextFloat(Byte.MIN_VALUE, Byte.MAX_VALUE));
135                         } else {
136                             UnsafeUtils.putFloat(input, ibase, i, abnormalValue.get(random.nextInt(abnormalValue.size())).floatValue());
137                         }
138                     }
139                     case "double" -> {
140                         if (normalArray || random.nextBoolean()) {
141                             UnsafeUtils.putDouble(input, ibase, i, random.nextDouble(Byte.MIN_VALUE, Byte.MAX_VALUE));
142                         } else {
143                             UnsafeUtils.putDouble(input, ibase, i, abnormalValue.get(random.nextInt(abnormalValue.size())));
144                         }
145                     }
146                     default -> throw new AssertionError();
147                 }
148             }
149 
150             testMethod.invokeExact(input, output);
151 
152             for (int i = 0; i < osp.length(); i++) {
153                 Number expected, actual;
154                 if (i < isp.length()) {
155                     Number initial = switch (isp.elementType().getName()) {
156                         case "byte"   -> UnsafeUtils.getByte(input, ibase, i);
157                         case "short"  -> UnsafeUtils.getShort(input, ibase, i);
158                         case "int"    -> UnsafeUtils.getInt(input, ibase, i);
159                         case "long"   -> UnsafeUtils.getLong(input, ibase, i);
160                         case "float"  -> UnsafeUtils.getFloat(input, ibase, i);
161                         case "double" -> UnsafeUtils.getDouble(input, ibase, i);
162                         default -> throw new AssertionError();
163                     };
164                     expected = switch (osp.elementType().getName()) {
165                         case "byte" -> initial.byteValue();
166                         case "short" -> {
167                             if (isUnsignedCast) {
168                                 yield (short) (initial.longValue() & ((1L << isp.elementSize()) - 1));
169                             } else {
170                                 yield initial.shortValue();
171                             }
172                         }
173                         case "int" -> {
174                             if (isUnsignedCast) {
175                                 yield (int) (initial.longValue() & ((1L << isp.elementSize()) - 1));
176                             } else {
177                                 yield initial.intValue();
178                             }
179                         }
180                         case "long" -> {
181                             if (isUnsignedCast) {
182                                 yield (long) (initial.longValue() & ((1L << isp.elementSize()) - 1));
183                             } else {
184                                 yield initial.longValue();
185                             }
186                         }
187                         case "float" -> initial.floatValue();
188                         case "double" -> initial.doubleValue();
189                         default -> throw new AssertionError();
190                     };
191                 } else {
192                     expected = switch (osp.elementType().getName()) {
193                         case "byte"   -> (byte)0;
194                         case "short"  -> (short)0;
195                         case "int"    -> (int)0;
196                         case "long"   -> (long)0;
197                         case "float"  -> (float)0;
198                         case "double" -> (double)0;
199                         default -> throw new AssertionError();
200                     };
201                 }
202                 actual = switch (osp.elementType().getName()) {
203                     case "byte"   -> UnsafeUtils.getByte(output, obase, i);
204                     case "short"  -> UnsafeUtils.getShort(output, obase, i);
205                     case "int"    -> UnsafeUtils.getInt(output, obase, i);
206                     case "long"   -> UnsafeUtils.getLong(output, obase, i);
207                     case "float"  -> UnsafeUtils.getFloat(output, obase, i);
208                     case "double" -> UnsafeUtils.getDouble(output, obase, i);
209                     default -> throw new AssertionError();
210                 };
211                 Asserts.assertEquals(expected, actual);
212             }
213         }
214     }
215 
216     @ForceInline
217     public static void vectorExpandShrink(VectorSpecies<Byte> isp, VectorSpecies<Byte> osp,
218                                           MemorySegment input, MemorySegment output) {
219         isp.fromMemorySegment(input, 0, ByteOrder.nativeOrder())
220                 .reinterpretShape(osp, 0)
221                 .intoMemorySegment(output, 0, ByteOrder.nativeOrder());
222     }
223 
224     public static void runExpandShrinkHelper(VectorSpecies<Byte> isp, VectorSpecies<Byte> osp) throws Throwable {
225         var random = Utils.getRandomInstance();
226         String testMethodName = VectorSpeciesPair.makePair(isp, osp).format();
227         var caller = StackWalker.getInstance(StackWalker.Option.RETAIN_CLASS_REFERENCE).getCallerClass();
228         var testMethod = MethodHandles.lookup().findStatic(caller,
229                 testMethodName,
230                 MethodType.methodType(void.class, MemorySegment.class, MemorySegment.class));
231         byte[] input = new byte[isp.vectorByteSize()];
232         byte[] output = new byte[osp.vectorByteSize()];
233         MemorySegment msInput = MemorySegment.ofArray(input);
234         MemorySegment msOutput = MemorySegment.ofArray(output);
235         for (int iter = 0; iter < INVOCATIONS; iter++) {
236             random.nextBytes(input);
237 
238             testMethod.invokeExact(msInput, msOutput);
239 
240             for (int i = 0; i < osp.vectorByteSize(); i++) {
241                 int expected = i < isp.vectorByteSize() ? input[i] : 0;
242                 int actual = output[i];
243                 Asserts.assertEquals(expected, actual);
244             }
245         }
246     }
247 
248     @ForceInline
249     public static void vectorDoubleExpandShrink(VectorSpecies<Byte> isp, VectorSpecies<Byte> osp,
250                                                 MemorySegment input, MemorySegment output) {
251         isp.fromMemorySegment(input, 0, ByteOrder.nativeOrder())
252                 .reinterpretShape(osp, 0)
253                 .reinterpretShape(isp, 0)
254                 .intoMemorySegment(output, 0, ByteOrder.nativeOrder());
255     }
256 
257     public static void runDoubleExpandShrinkHelper(VectorSpecies<Byte> isp, VectorSpecies<Byte> osp) throws Throwable {
258         var random = Utils.getRandomInstance();
259         String testMethodName = VectorSpeciesPair.makePair(isp, osp).format();
260         var caller = StackWalker.getInstance(StackWalker.Option.RETAIN_CLASS_REFERENCE).getCallerClass();
261         var testMethod = MethodHandles.lookup().findStatic(caller,
262                 testMethodName,
263                 MethodType.methodType(void.class, MemorySegment.class, MemorySegment.class));
264         byte[] input = new byte[isp.vectorByteSize()];
265         byte[] output = new byte[isp.vectorByteSize()];
266         MemorySegment msInput = MemorySegment.ofArray(input);
267         MemorySegment msOutput = MemorySegment.ofArray(output);
268         for (int iter = 0; iter < INVOCATIONS; iter++) {
269             random.nextBytes(input);
270 
271             testMethod.invokeExact(msInput, msOutput);
272 
273             for (int i = 0; i < isp.vectorByteSize(); i++) {
274                 int expected = i < osp.vectorByteSize() ? input[i] : 0;
275                 int actual = output[i];
276                 Asserts.assertEquals(expected, actual);
277             }
278         }
279     }
280 
281     @ForceInline
282     public static <T, U> void vectorRebracket(VectorSpecies<T> isp, VectorSpecies<U> osp, Object input, Object output) {
283         var outputVector = readVector(isp, input)
284                 .reinterpretShape(osp, 0);
285         writeVector(osp, outputVector, output);
286     }
287 
288     public static <T, U> void runRebracketHelper(VectorSpecies<T> isp, VectorSpecies<U> osp) throws Throwable {
289         var random = Utils.getRandomInstance();
290         String testMethodName = VectorSpeciesPair.makePair(isp, osp).format();
291         var caller = StackWalker.getInstance(StackWalker.Option.RETAIN_CLASS_REFERENCE).getCallerClass();
292         var testMethod = MethodHandles.lookup().findStatic(caller,
293                     testMethodName,
294                     MethodType.methodType(void.class, isp.elementType().arrayType(), osp.elementType().arrayType()))
295                 .asType(MethodType.methodType(void.class, Object.class, Object.class));
296         Object input = Array.newInstance(isp.elementType(), isp.length());
297         Object output = Array.newInstance(osp.elementType(), osp.length());
298         long ibase = UnsafeUtils.arrayBase(isp.elementType());
299         long obase = UnsafeUtils.arrayBase(osp.elementType());
300         for (int iter = 0; iter < INVOCATIONS; iter++) {
301             for (int i = 0; i < isp.vectorByteSize(); i++) {
302                 UnsafeUtils.putByte(input, ibase, i, (byte)random.nextInt());
303             }
304 
305             testMethod.invokeExact(input, output);
306 
307             for (int i = 0; i < osp.vectorByteSize(); i++) {
308                 int expected = i < isp.vectorByteSize() ? UnsafeUtils.getByte(input, ibase, i) : 0;
309                 int actual = UnsafeUtils.getByte(output, obase, i);
310                 Asserts.assertEquals(expected, actual);
311             }
312         }
313     }
314 
315     @ForceInline
316     private static <T> Vector<T> readVector(VectorSpecies<T> isp, Object input) {
317         return isp.fromArray(input, 0);
318     }
319 
320     @ForceInline
321     private static <U> void writeVector(VectorSpecies<U> osp, Vector<U> vector, Object output) {
322         var otype = osp.elementType();
323         if (otype == byte.class) {
324             ((ByteVector)vector).intoArray((byte[])output, 0);
325         } else if (otype == short.class) {
326             ((ShortVector)vector).intoArray((short[])output, 0);
327         } else if (otype == int.class) {
328             ((IntVector)vector).intoArray((int[])output, 0);
329         } else if (otype == long.class) {
330             ((LongVector)vector).intoArray((long[])output, 0);
331         } else if (otype == float.class) {
332             ((FloatVector)vector).intoArray((float[])output, 0);
333         } else if (otype == double.class) {
334             ((DoubleVector)vector).intoArray((double[])output, 0);
335         } else {
336             throw new AssertionError();
337         }
338     }
339 }
--- EOF ---