1 /*
  2  * Copyright (c) 2021, 2022, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.
  8  *
  9  * This code is distributed in the hope that it will be useful, but WITHOUT
 10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 12  * version 2 for more details (a copy is included in the LICENSE file that
 13  * accompanied this code).
 14  *
 15  * You should have received a copy of the GNU General Public License version
 16  * 2 along with this work; if not, write to the Free Software Foundation,
 17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 18  *
 19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 20  * or visit www.oracle.com if you need additional information or have any
 21  * questions.
 22  */
 23 
 24 package compiler.vectorapi.reshape.utils;
 25 
 26 import compiler.lib.ir_framework.ForceInline;
 27 import compiler.lib.ir_framework.IRNode;
 28 import compiler.lib.ir_framework.TestFramework;
 29 import java.lang.invoke.MethodHandles;
 30 import java.lang.invoke.MethodType;
 31 import java.lang.invoke.VarHandle;
 32 import java.lang.reflect.Array;
 33 import java.nio.ByteOrder;
 34 import java.util.List;
 35 import java.util.random.RandomGenerator;
 36 import java.util.stream.Collectors;
 37 import java.util.stream.Stream;
 38 import jdk.incubator.vector.*;
 39 import jdk.test.lib.Asserts;
 40 import jdk.test.lib.Utils;
 41 
 42 public class VectorReshapeHelper {
 43     public static final int INVOCATIONS = 10_000;
 44 
 45     public static final VectorSpecies<Byte>    BSPEC64  =   ByteVector.SPECIES_64;
 46     public static final VectorSpecies<Short>   SSPEC64  =  ShortVector.SPECIES_64;
 47     public static final VectorSpecies<Integer> ISPEC64  =    IntVector.SPECIES_64;
 48     public static final VectorSpecies<Long>    LSPEC64  =   LongVector.SPECIES_64;
 49     public static final VectorSpecies<Float>   FSPEC64  =  FloatVector.SPECIES_64;
 50     public static final VectorSpecies<Double>  DSPEC64  = DoubleVector.SPECIES_64;
 51 
 52     public static final VectorSpecies<Byte>    BSPEC128 =   ByteVector.SPECIES_128;
 53     public static final VectorSpecies<Short>   SSPEC128 =  ShortVector.SPECIES_128;
 54     public static final VectorSpecies<Integer> ISPEC128 =    IntVector.SPECIES_128;
 55     public static final VectorSpecies<Long>    LSPEC128 =   LongVector.SPECIES_128;
 56     public static final VectorSpecies<Float>   FSPEC128 =  FloatVector.SPECIES_128;
 57     public static final VectorSpecies<Double>  DSPEC128 = DoubleVector.SPECIES_128;
 58 
 59     public static final VectorSpecies<Byte>    BSPEC256 =   ByteVector.SPECIES_256;
 60     public static final VectorSpecies<Short>   SSPEC256 =  ShortVector.SPECIES_256;
 61     public static final VectorSpecies<Integer> ISPEC256 =    IntVector.SPECIES_256;
 62     public static final VectorSpecies<Long>    LSPEC256 =   LongVector.SPECIES_256;
 63     public static final VectorSpecies<Float>   FSPEC256 =  FloatVector.SPECIES_256;
 64     public static final VectorSpecies<Double>  DSPEC256 = DoubleVector.SPECIES_256;
 65 
 66     public static final VectorSpecies<Byte>    BSPEC512 =   ByteVector.SPECIES_512;
 67     public static final VectorSpecies<Short>   SSPEC512 =  ShortVector.SPECIES_512;
 68     public static final VectorSpecies<Integer> ISPEC512 =    IntVector.SPECIES_512;
 69     public static final VectorSpecies<Long>    LSPEC512 =   LongVector.SPECIES_512;
 70     public static final VectorSpecies<Float>   FSPEC512 =  FloatVector.SPECIES_512;
 71     public static final VectorSpecies<Double>  DSPEC512 = DoubleVector.SPECIES_512;
 72 
 73     public static final String B2X_NODE  = IRNode.VECTOR_CAST_B2X;
 74     public static final String S2X_NODE  = IRNode.VECTOR_CAST_S2X;
 75     public static final String I2X_NODE  = IRNode.VECTOR_CAST_I2X;
 76     public static final String L2X_NODE  = IRNode.VECTOR_CAST_L2X;
 77     public static final String F2X_NODE  = IRNode.VECTOR_CAST_F2X;
 78     public static final String D2X_NODE  = IRNode.VECTOR_CAST_D2X;
 79     public static final String UB2X_NODE = IRNode.VECTOR_UCAST_B2X;
 80     public static final String US2X_NODE = IRNode.VECTOR_UCAST_S2X;
 81     public static final String UI2X_NODE = IRNode.VECTOR_UCAST_I2X;
 82     public static final String REINTERPRET_NODE = IRNode.VECTOR_REINTERPRET;
 83 
 84     public static void runMainHelper(Class<?> testClass, Stream<VectorSpeciesPair> testMethods, String... flags) {
 85         var test = new TestFramework(testClass);
 86         test.setDefaultWarmup(1);
 87         test.addHelperClasses(VectorReshapeHelper.class);
 88         test.addFlags("--add-modules=jdk.incubator.vector", "--add-exports=java.base/jdk.internal.misc=ALL-UNNAMED");
 89         test.addFlags(flags);
 90         String testMethodNames = testMethods
 91                 .filter(p -> p.isp().length() <= VectorSpecies.ofLargestShape(p.isp().elementType()).length())
 92                 .filter(p -> p.osp().length() <= VectorSpecies.ofLargestShape(p.osp().elementType()).length())
 93                 .map(VectorSpeciesPair::format)
 94                 .collect(Collectors.joining(","));
 95         test.addFlags("-DTest=" + testMethodNames);
 96         test.start();
 97     }
 98 
 99     @ForceInline
100     public static <T, U> void vectorCast(VectorOperators.Conversion<T, U> cop,
101                                          VectorSpecies<T> isp, VectorSpecies<U> osp, Object input, Object output) {
102         var outputVector = readVector(isp, input)
103                 .convertShape(cop, osp, 0);
104         writeVector(osp, outputVector, output);
105     }
106 
107     public static <T, U> void runCastHelper(VectorOperators.Conversion<T, U> castOp,
108                                             VectorSpecies<T> isp, VectorSpecies<U> osp) throws Throwable {
109         var random = Utils.getRandomInstance();
110         boolean isUnsignedCast = castOp.name().startsWith("ZERO");
111         String testMethodName = VectorSpeciesPair.makePair(isp, osp, isUnsignedCast).format();
112         var caller = StackWalker.getInstance(StackWalker.Option.RETAIN_CLASS_REFERENCE).getCallerClass();
113         var testMethod = MethodHandles.lookup().findStatic(caller,
114                     testMethodName,
115                     MethodType.methodType(void.class, isp.elementType().arrayType(), osp.elementType().arrayType()))
116                 .asType(MethodType.methodType(void.class, Object.class, Object.class));
117         Object input = Array.newInstance(isp.elementType(), isp.length());
118         Object output = Array.newInstance(osp.elementType(), osp.length());
119         long ibase = UnsafeUtils.arrayBase(isp.elementType());
120         long obase = UnsafeUtils.arrayBase(osp.elementType());
121         for (int iter = 0; iter < INVOCATIONS; iter++) {
122             // We need to generate arrays with NaN or very large values occasionally
123             boolean normalArray = random.nextBoolean();
124             var abnormalValue = List.of(Double.NaN, Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY, -1e30, 1e30);
125             for (int i = 0; i < isp.length(); i++) {
126                 switch (isp.elementType().getName()) {
127                     case "byte"   -> UnsafeUtils.putByte(input, ibase, i, (byte)random.nextInt());
128                     case "short"  -> UnsafeUtils.putShort(input, ibase, i, (short)random.nextInt());
129                     case "int"    -> UnsafeUtils.putInt(input, ibase, i, random.nextInt());
130                     case "long"   -> UnsafeUtils.putLong(input, ibase, i, random.nextLong());
131                     case "float"  -> {
132                         if (normalArray || random.nextBoolean()) {
133                             UnsafeUtils.putFloat(input, ibase, i, random.nextFloat(Byte.MIN_VALUE, Byte.MAX_VALUE));
134                         } else {
135                             UnsafeUtils.putFloat(input, ibase, i, abnormalValue.get(random.nextInt(abnormalValue.size())).floatValue());
136                         }
137                     }
138                     case "double" -> {
139                         if (normalArray || random.nextBoolean()) {
140                             UnsafeUtils.putDouble(input, ibase, i, random.nextDouble(Byte.MIN_VALUE, Byte.MAX_VALUE));
141                         } else {
142                             UnsafeUtils.putDouble(input, ibase, i, abnormalValue.get(random.nextInt(abnormalValue.size())));
143                         }
144                     }
145                     default -> throw new AssertionError();
146                 }
147             }
148 
149             testMethod.invokeExact(input, output);
150 
151             for (int i = 0; i < osp.length(); i++) {
152                 Number expected, actual;
153                 if (i < isp.length()) {
154                     Number initial = switch (isp.elementType().getName()) {
155                         case "byte"   -> UnsafeUtils.getByte(input, ibase, i);
156                         case "short"  -> UnsafeUtils.getShort(input, ibase, i);
157                         case "int"    -> UnsafeUtils.getInt(input, ibase, i);
158                         case "long"   -> UnsafeUtils.getLong(input, ibase, i);
159                         case "float"  -> UnsafeUtils.getFloat(input, ibase, i);
160                         case "double" -> UnsafeUtils.getDouble(input, ibase, i);
161                         default -> throw new AssertionError();
162                     };
163                     expected = switch (osp.elementType().getName()) {
164                         case "byte" -> initial.byteValue();
165                         case "short" -> {
166                             if (isUnsignedCast) {
167                                 yield (short) (initial.longValue() & ((1L << isp.elementSize()) - 1));
168                             } else {
169                                 yield initial.shortValue();
170                             }
171                         }
172                         case "int" -> {
173                             if (isUnsignedCast) {
174                                 yield (int) (initial.longValue() & ((1L << isp.elementSize()) - 1));
175                             } else {
176                                 yield initial.intValue();
177                             }
178                         }
179                         case "long" -> {
180                             if (isUnsignedCast) {
181                                 yield (long) (initial.longValue() & ((1L << isp.elementSize()) - 1));
182                             } else {
183                                 yield initial.longValue();
184                             }
185                         }
186                         case "float" -> initial.floatValue();
187                         case "double" -> initial.doubleValue();
188                         default -> throw new AssertionError();
189                     };
190                 } else {
191                     expected = switch (osp.elementType().getName()) {
192                         case "byte"   -> (byte)0;
193                         case "short"  -> (short)0;
194                         case "int"    -> (int)0;
195                         case "long"   -> (long)0;
196                         case "float"  -> (float)0;
197                         case "double" -> (double)0;
198                         default -> throw new AssertionError();
199                     };
200                 }
201                 actual = switch (osp.elementType().getName()) {
202                     case "byte"   -> UnsafeUtils.getByte(output, obase, i);
203                     case "short"  -> UnsafeUtils.getShort(output, obase, i);
204                     case "int"    -> UnsafeUtils.getInt(output, obase, i);
205                     case "long"   -> UnsafeUtils.getLong(output, obase, i);
206                     case "float"  -> UnsafeUtils.getFloat(output, obase, i);
207                     case "double" -> UnsafeUtils.getDouble(output, obase, i);
208                     default -> throw new AssertionError();
209                 };
210                 Asserts.assertEquals(expected, actual);
211             }
212         }
213     }
214 
215     @ForceInline
216     public static void vectorExpandShrink(VectorSpecies<Byte> isp, VectorSpecies<Byte> osp, byte[] input, byte[] output) {
217         isp.fromByteArray(input, 0, ByteOrder.nativeOrder())
218                 .reinterpretShape(osp, 0)
219                 .intoByteArray(output, 0, ByteOrder.nativeOrder());
220     }
221 
222     public static void runExpandShrinkHelper(VectorSpecies<Byte> isp, VectorSpecies<Byte> osp) throws Throwable {
223         var random = Utils.getRandomInstance();
224         String testMethodName = VectorSpeciesPair.makePair(isp, osp).format();
225         var caller = StackWalker.getInstance(StackWalker.Option.RETAIN_CLASS_REFERENCE).getCallerClass();
226         var testMethod = MethodHandles.lookup().findStatic(caller,
227                 testMethodName,
228                 MethodType.methodType(void.class, byte.class.arrayType(), byte.class.arrayType()));
229         byte[] input = new byte[isp.vectorByteSize()];
230         byte[] output = new byte[osp.vectorByteSize()];
231         for (int iter = 0; iter < INVOCATIONS; iter++) {
232             random.nextBytes(input);
233 
234             testMethod.invokeExact(input, output);
235 
236             for (int i = 0; i < osp.vectorByteSize(); i++) {
237                 int expected = i < isp.vectorByteSize() ? input[i] : 0;
238                 int actual = output[i];
239                 Asserts.assertEquals(expected, actual);
240             }
241         }
242     }
243 
244     @ForceInline
245     public static void vectorDoubleExpandShrink(VectorSpecies<Byte> isp, VectorSpecies<Byte> osp, byte[] input, byte[] output) {
246         isp.fromByteArray(input, 0, ByteOrder.nativeOrder())
247                 .reinterpretShape(osp, 0)
248                 .reinterpretShape(isp, 0)
249                 .intoByteArray(output, 0, ByteOrder.nativeOrder());
250     }
251 
252     public static void runDoubleExpandShrinkHelper(VectorSpecies<Byte> isp, VectorSpecies<Byte> osp) throws Throwable {
253         var random = Utils.getRandomInstance();
254         String testMethodName = VectorSpeciesPair.makePair(isp, osp).format();
255         var caller = StackWalker.getInstance(StackWalker.Option.RETAIN_CLASS_REFERENCE).getCallerClass();
256         var testMethod = MethodHandles.lookup().findStatic(caller,
257                 testMethodName,
258                 MethodType.methodType(void.class, byte.class.arrayType(), byte.class.arrayType()));
259         byte[] input = new byte[isp.vectorByteSize()];
260         byte[] output = new byte[isp.vectorByteSize()];
261         for (int iter = 0; iter < INVOCATIONS; iter++) {
262             random.nextBytes(input);
263 
264             testMethod.invokeExact(input, output);
265 
266             for (int i = 0; i < isp.vectorByteSize(); i++) {
267                 int expected = i < osp.vectorByteSize() ? input[i] : 0;
268                 int actual = output[i];
269                 Asserts.assertEquals(expected, actual);
270             }
271         }
272     }
273 
274     @ForceInline
275     public static <T, U> void vectorRebracket(VectorSpecies<T> isp, VectorSpecies<U> osp, Object input, Object output) {
276         var outputVector = readVector(isp, input)
277                 .reinterpretShape(osp, 0);
278         writeVector(osp, outputVector, output);
279     }
280 
281     public static <T, U> void runRebracketHelper(VectorSpecies<T> isp, VectorSpecies<U> osp) throws Throwable {
282         var random = Utils.getRandomInstance();
283         String testMethodName = VectorSpeciesPair.makePair(isp, osp).format();
284         var caller = StackWalker.getInstance(StackWalker.Option.RETAIN_CLASS_REFERENCE).getCallerClass();
285         var testMethod = MethodHandles.lookup().findStatic(caller,
286                     testMethodName,
287                     MethodType.methodType(void.class, isp.elementType().arrayType(), osp.elementType().arrayType()))
288                 .asType(MethodType.methodType(void.class, Object.class, Object.class));
289         Object input = Array.newInstance(isp.elementType(), isp.length());
290         Object output = Array.newInstance(osp.elementType(), osp.length());
291         long ibase = UnsafeUtils.arrayBase(isp.elementType());
292         long obase = UnsafeUtils.arrayBase(osp.elementType());
293         for (int iter = 0; iter < INVOCATIONS; iter++) {
294             for (int i = 0; i < isp.vectorByteSize(); i++) {
295                 UnsafeUtils.putByte(input, ibase, i, (byte)random.nextInt());
296             }
297 
298             testMethod.invokeExact(input, output);
299 
300             for (int i = 0; i < osp.vectorByteSize(); i++) {
301                 int expected = i < isp.vectorByteSize() ? UnsafeUtils.getByte(input, ibase, i) : 0;
302                 int actual = UnsafeUtils.getByte(output, obase, i);
303                 Asserts.assertEquals(expected, actual);
304             }
305         }
306     }
307 
308     @ForceInline
309     private static <T> Vector<T> readVector(VectorSpecies<T> isp, Object input) {
310         return isp.fromArray(input, 0);
311     }
312 
313     @ForceInline
314     private static <U> void writeVector(VectorSpecies<U> osp, Vector<U> vector, Object output) {
315         var otype = osp.elementType();
316         if (otype == byte.class) {
317             ((ByteVector)vector).intoArray((byte[])output, 0);
318         } else if (otype == short.class) {
319             ((ShortVector)vector).intoArray((short[])output, 0);
320         } else if (otype == int.class) {
321             ((IntVector)vector).intoArray((int[])output, 0);
322         } else if (otype == long.class) {
323             ((LongVector)vector).intoArray((long[])output, 0);
324         } else if (otype == float.class) {
325             ((FloatVector)vector).intoArray((float[])output, 0);
326         } else if (otype == double.class) {
327             ((DoubleVector)vector).intoArray((double[])output, 0);
328         } else {
329             throw new AssertionError();
330         }
331     }
332 }