1 /*
  2  * Copyright (c) 2021, 2022, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.
  8  *
  9  * This code is distributed in the hope that it will be useful, but WITHOUT
 10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 12  * version 2 for more details (a copy is included in the LICENSE file that
 13  * accompanied this code).
 14  *
 15  * You should have received a copy of the GNU General Public License version
 16  * 2 along with this work; if not, write to the Free Software Foundation,
 17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 18  *
 19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 20  * or visit www.oracle.com if you need additional information or have any
 21  * questions.
 22  */
 23 
 24 package compiler.vectorapi.reshape.utils;
 25 
 26 import compiler.lib.ir_framework.ForceInline;
 27 import compiler.lib.ir_framework.IRNode;
 28 import compiler.lib.ir_framework.TestFramework;
 29 import java.lang.invoke.MethodHandles;
 30 import java.lang.invoke.MethodType;
 31 import java.lang.invoke.VarHandle;
 32 import java.lang.reflect.Array;
 33 import java.nio.ByteOrder;
 34 import java.util.List;
 35 import java.util.random.RandomGenerator;
 36 import java.util.stream.Collectors;
 37 import java.util.stream.Stream;
 38 import java.lang.foreign.MemorySegment;
 39 import jdk.incubator.vector.*;
 40 import jdk.test.lib.Asserts;
 41 import jdk.test.lib.Utils;
 42 
 43 public class VectorReshapeHelper {
 44     public static final int INVOCATIONS = 10_000;
 45 
 46     public static final VectorSpecies<Byte>    BSPEC64  =   ByteVector.SPECIES_64;
 47     public static final VectorSpecies<Short>   SSPEC64  =  ShortVector.SPECIES_64;
 48     public static final VectorSpecies<Integer> ISPEC64  =    IntVector.SPECIES_64;
 49     public static final VectorSpecies<Long>    LSPEC64  =   LongVector.SPECIES_64;
 50     public static final VectorSpecies<Float>   FSPEC64  =  FloatVector.SPECIES_64;
 51     public static final VectorSpecies<Double>  DSPEC64  = DoubleVector.SPECIES_64;
 52 
 53     public static final VectorSpecies<Byte>    BSPEC128 =   ByteVector.SPECIES_128;
 54     public static final VectorSpecies<Short>   SSPEC128 =  ShortVector.SPECIES_128;
 55     public static final VectorSpecies<Integer> ISPEC128 =    IntVector.SPECIES_128;
 56     public static final VectorSpecies<Long>    LSPEC128 =   LongVector.SPECIES_128;
 57     public static final VectorSpecies<Float>   FSPEC128 =  FloatVector.SPECIES_128;
 58     public static final VectorSpecies<Double>  DSPEC128 = DoubleVector.SPECIES_128;
 59 
 60     public static final VectorSpecies<Byte>    BSPEC256 =   ByteVector.SPECIES_256;
 61     public static final VectorSpecies<Short>   SSPEC256 =  ShortVector.SPECIES_256;
 62     public static final VectorSpecies<Integer> ISPEC256 =    IntVector.SPECIES_256;
 63     public static final VectorSpecies<Long>    LSPEC256 =   LongVector.SPECIES_256;
 64     public static final VectorSpecies<Float>   FSPEC256 =  FloatVector.SPECIES_256;
 65     public static final VectorSpecies<Double>  DSPEC256 = DoubleVector.SPECIES_256;
 66 
 67     public static final VectorSpecies<Byte>    BSPEC512 =   ByteVector.SPECIES_512;
 68     public static final VectorSpecies<Short>   SSPEC512 =  ShortVector.SPECIES_512;
 69     public static final VectorSpecies<Integer> ISPEC512 =    IntVector.SPECIES_512;
 70     public static final VectorSpecies<Long>    LSPEC512 =   LongVector.SPECIES_512;
 71     public static final VectorSpecies<Float>   FSPEC512 =  FloatVector.SPECIES_512;
 72     public static final VectorSpecies<Double>  DSPEC512 = DoubleVector.SPECIES_512;
 73 
 74     public static final String REINTERPRET_NODE = IRNode.VECTOR_REINTERPRET;
 75 
 76     public static void runMainHelper(Class<?> testClass, Stream<VectorSpeciesPair> testMethods, String... flags) {
 77         var test = new TestFramework(testClass);
 78         test.setDefaultWarmup(1);
 79         test.addHelperClasses(VectorReshapeHelper.class);
 80         test.addFlags("--add-modules=jdk.incubator.vector", "--add-exports=java.base/jdk.internal.misc=ALL-UNNAMED", "--enable-preview");
 81         test.addFlags(flags);
 82         String testMethodNames = testMethods
 83                 .filter(p -> p.isp().length() <= VectorSpecies.ofLargestShape(p.isp().elementType()).length())
 84                 .filter(p -> p.osp().length() <= VectorSpecies.ofLargestShape(p.osp().elementType()).length())
 85                 .map(VectorSpeciesPair::format)
 86                 .collect(Collectors.joining(","));
 87         test.addFlags("-DTest=" + testMethodNames);
 88         test.start();
 89     }
 90 
 91     @ForceInline
 92     public static <T, U> void vectorCast(VectorOperators.Conversion<T, U> cop,
 93                                          VectorSpecies<T> isp, VectorSpecies<U> osp, Object input, Object output) {
 94         var outputVector = readVector(isp, input)
 95                 .convertShape(cop, osp, 0);
 96         writeVector(osp, outputVector, output);
 97     }
 98 
 99     public static <T, U> void runCastHelper(VectorOperators.Conversion<T, U> castOp,
100                                             VectorSpecies<T> isp, VectorSpecies<U> osp) throws Throwable {
101         var random = Utils.getRandomInstance();
102         boolean isUnsignedCast = castOp.name().startsWith("ZERO");
103         String testMethodName = VectorSpeciesPair.makePair(isp, osp, isUnsignedCast).format();
104         var caller = StackWalker.getInstance(StackWalker.Option.RETAIN_CLASS_REFERENCE).getCallerClass();
105         var testMethod = MethodHandles.lookup().findStatic(caller,
106                     testMethodName,
107                     MethodType.methodType(void.class, isp.elementType().arrayType(), osp.elementType().arrayType()))
108                 .asType(MethodType.methodType(void.class, Object.class, Object.class));
109         Object input = Array.newInstance(isp.elementType(), isp.length());
110         Object output = Array.newInstance(osp.elementType(), osp.length());
111         long ibase = UnsafeUtils.arrayBase(isp.elementType());
112         long obase = UnsafeUtils.arrayBase(osp.elementType());
113         for (int iter = 0; iter < INVOCATIONS; iter++) {
114             // We need to generate arrays with NaN or very large values occasionally
115             boolean normalArray = random.nextBoolean();
116             var abnormalValue = List.of(Double.NaN, Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY, -1e30, 1e30);
117             for (int i = 0; i < isp.length(); i++) {
118                 switch (isp.elementType().getName()) {
119                     case "byte"   -> UnsafeUtils.putByte(input, ibase, i, (byte)random.nextInt());
120                     case "short"  -> UnsafeUtils.putShort(input, ibase, i, (short)random.nextInt());
121                     case "int"    -> UnsafeUtils.putInt(input, ibase, i, random.nextInt());
122                     case "long"   -> UnsafeUtils.putLong(input, ibase, i, random.nextLong());
123                     case "float"  -> {
124                         if (normalArray || random.nextBoolean()) {
125                             UnsafeUtils.putFloat(input, ibase, i, random.nextFloat(Byte.MIN_VALUE, Byte.MAX_VALUE));
126                         } else {
127                             UnsafeUtils.putFloat(input, ibase, i, abnormalValue.get(random.nextInt(abnormalValue.size())).floatValue());
128                         }
129                     }
130                     case "double" -> {
131                         if (normalArray || random.nextBoolean()) {
132                             UnsafeUtils.putDouble(input, ibase, i, random.nextDouble(Byte.MIN_VALUE, Byte.MAX_VALUE));
133                         } else {
134                             UnsafeUtils.putDouble(input, ibase, i, abnormalValue.get(random.nextInt(abnormalValue.size())));
135                         }
136                     }
137                     default -> throw new AssertionError();
138                 }
139             }
140 
141             testMethod.invokeExact(input, output);
142 
143             for (int i = 0; i < osp.length(); i++) {
144                 Number expected, actual;
145                 if (i < isp.length()) {
146                     Number initial = switch (isp.elementType().getName()) {
147                         case "byte"   -> UnsafeUtils.getByte(input, ibase, i);
148                         case "short"  -> UnsafeUtils.getShort(input, ibase, i);
149                         case "int"    -> UnsafeUtils.getInt(input, ibase, i);
150                         case "long"   -> UnsafeUtils.getLong(input, ibase, i);
151                         case "float"  -> UnsafeUtils.getFloat(input, ibase, i);
152                         case "double" -> UnsafeUtils.getDouble(input, ibase, i);
153                         default -> throw new AssertionError();
154                     };
155                     expected = switch (osp.elementType().getName()) {
156                         case "byte" -> initial.byteValue();
157                         case "short" -> {
158                             if (isUnsignedCast) {
159                                 yield (short) (initial.longValue() & ((1L << isp.elementSize()) - 1));
160                             } else {
161                                 yield initial.shortValue();
162                             }
163                         }
164                         case "int" -> {
165                             if (isUnsignedCast) {
166                                 yield (int) (initial.longValue() & ((1L << isp.elementSize()) - 1));
167                             } else {
168                                 yield initial.intValue();
169                             }
170                         }
171                         case "long" -> {
172                             if (isUnsignedCast) {
173                                 yield (long) (initial.longValue() & ((1L << isp.elementSize()) - 1));
174                             } else {
175                                 yield initial.longValue();
176                             }
177                         }
178                         case "float" -> initial.floatValue();
179                         case "double" -> initial.doubleValue();
180                         default -> throw new AssertionError();
181                     };
182                 } else {
183                     expected = switch (osp.elementType().getName()) {
184                         case "byte"   -> (byte)0;
185                         case "short"  -> (short)0;
186                         case "int"    -> (int)0;
187                         case "long"   -> (long)0;
188                         case "float"  -> (float)0;
189                         case "double" -> (double)0;
190                         default -> throw new AssertionError();
191                     };
192                 }
193                 actual = switch (osp.elementType().getName()) {
194                     case "byte"   -> UnsafeUtils.getByte(output, obase, i);
195                     case "short"  -> UnsafeUtils.getShort(output, obase, i);
196                     case "int"    -> UnsafeUtils.getInt(output, obase, i);
197                     case "long"   -> UnsafeUtils.getLong(output, obase, i);
198                     case "float"  -> UnsafeUtils.getFloat(output, obase, i);
199                     case "double" -> UnsafeUtils.getDouble(output, obase, i);
200                     default -> throw new AssertionError();
201                 };
202                 Asserts.assertEquals(expected, actual);
203             }
204         }
205     }
206 
207     @ForceInline
208     public static void vectorExpandShrink(VectorSpecies<Byte> isp, VectorSpecies<Byte> osp,
209                                           MemorySegment input, MemorySegment output) {
210         isp.fromMemorySegment(input, 0, ByteOrder.nativeOrder())
211                 .reinterpretShape(osp, 0)
212                 .intoMemorySegment(output, 0, ByteOrder.nativeOrder());
213     }
214 
215     public static void runExpandShrinkHelper(VectorSpecies<Byte> isp, VectorSpecies<Byte> osp) throws Throwable {
216         var random = Utils.getRandomInstance();
217         String testMethodName = VectorSpeciesPair.makePair(isp, osp).format();
218         var caller = StackWalker.getInstance(StackWalker.Option.RETAIN_CLASS_REFERENCE).getCallerClass();
219         var testMethod = MethodHandles.lookup().findStatic(caller,
220                 testMethodName,
221                 MethodType.methodType(void.class, MemorySegment.class, MemorySegment.class));
222         byte[] input = new byte[isp.vectorByteSize()];
223         byte[] output = new byte[osp.vectorByteSize()];
224         MemorySegment msInput = MemorySegment.ofArray(input);
225         MemorySegment msOutput = MemorySegment.ofArray(output);
226         for (int iter = 0; iter < INVOCATIONS; iter++) {
227             random.nextBytes(input);
228 
229             testMethod.invokeExact(msInput, msOutput);
230 
231             for (int i = 0; i < osp.vectorByteSize(); i++) {
232                 int expected = i < isp.vectorByteSize() ? input[i] : 0;
233                 int actual = output[i];
234                 Asserts.assertEquals(expected, actual);
235             }
236         }
237     }
238 
239     @ForceInline
240     public static void vectorDoubleExpandShrink(VectorSpecies<Byte> isp, VectorSpecies<Byte> osp,
241                                                 MemorySegment input, MemorySegment output) {
242         isp.fromMemorySegment(input, 0, ByteOrder.nativeOrder())
243                 .reinterpretShape(osp, 0)
244                 .reinterpretShape(isp, 0)
245                 .intoMemorySegment(output, 0, ByteOrder.nativeOrder());
246     }
247 
248     public static void runDoubleExpandShrinkHelper(VectorSpecies<Byte> isp, VectorSpecies<Byte> osp) throws Throwable {
249         var random = Utils.getRandomInstance();
250         String testMethodName = VectorSpeciesPair.makePair(isp, osp).format();
251         var caller = StackWalker.getInstance(StackWalker.Option.RETAIN_CLASS_REFERENCE).getCallerClass();
252         var testMethod = MethodHandles.lookup().findStatic(caller,
253                 testMethodName,
254                 MethodType.methodType(void.class, MemorySegment.class, MemorySegment.class));
255         byte[] input = new byte[isp.vectorByteSize()];
256         byte[] output = new byte[isp.vectorByteSize()];
257         MemorySegment msInput = MemorySegment.ofArray(input);
258         MemorySegment msOutput = MemorySegment.ofArray(output);
259         for (int iter = 0; iter < INVOCATIONS; iter++) {
260             random.nextBytes(input);
261 
262             testMethod.invokeExact(msInput, msOutput);
263 
264             for (int i = 0; i < isp.vectorByteSize(); i++) {
265                 int expected = i < osp.vectorByteSize() ? input[i] : 0;
266                 int actual = output[i];
267                 Asserts.assertEquals(expected, actual);
268             }
269         }
270     }
271 
272     @ForceInline
273     public static <T, U> void vectorRebracket(VectorSpecies<T> isp, VectorSpecies<U> osp, Object input, Object output) {
274         var outputVector = readVector(isp, input)
275                 .reinterpretShape(osp, 0);
276         writeVector(osp, outputVector, output);
277     }
278 
279     public static <T, U> void runRebracketHelper(VectorSpecies<T> isp, VectorSpecies<U> osp) throws Throwable {
280         var random = Utils.getRandomInstance();
281         String testMethodName = VectorSpeciesPair.makePair(isp, osp).format();
282         var caller = StackWalker.getInstance(StackWalker.Option.RETAIN_CLASS_REFERENCE).getCallerClass();
283         var testMethod = MethodHandles.lookup().findStatic(caller,
284                     testMethodName,
285                     MethodType.methodType(void.class, isp.elementType().arrayType(), osp.elementType().arrayType()))
286                 .asType(MethodType.methodType(void.class, Object.class, Object.class));
287         Object input = Array.newInstance(isp.elementType(), isp.length());
288         Object output = Array.newInstance(osp.elementType(), osp.length());
289         long ibase = UnsafeUtils.arrayBase(isp.elementType());
290         long obase = UnsafeUtils.arrayBase(osp.elementType());
291         for (int iter = 0; iter < INVOCATIONS; iter++) {
292             for (int i = 0; i < isp.vectorByteSize(); i++) {
293                 UnsafeUtils.putByte(input, ibase, i, (byte)random.nextInt());
294             }
295 
296             testMethod.invokeExact(input, output);
297 
298             for (int i = 0; i < osp.vectorByteSize(); i++) {
299                 int expected = i < isp.vectorByteSize() ? UnsafeUtils.getByte(input, ibase, i) : 0;
300                 int actual = UnsafeUtils.getByte(output, obase, i);
301                 Asserts.assertEquals(expected, actual);
302             }
303         }
304     }
305 
306     @ForceInline
307     private static <T> Vector<T> readVector(VectorSpecies<T> isp, Object input) {
308         return isp.fromArray(input, 0);
309     }
310 
311     @ForceInline
312     private static <U> void writeVector(VectorSpecies<U> osp, Vector<U> vector, Object output) {
313         var otype = osp.elementType();
314         if (otype == byte.class) {
315             ((ByteVector)vector).intoArray((byte[])output, 0);
316         } else if (otype == short.class) {
317             ((ShortVector)vector).intoArray((short[])output, 0);
318         } else if (otype == int.class) {
319             ((IntVector)vector).intoArray((int[])output, 0);
320         } else if (otype == long.class) {
321             ((LongVector)vector).intoArray((long[])output, 0);
322         } else if (otype == float.class) {
323             ((FloatVector)vector).intoArray((float[])output, 0);
324         } else if (otype == double.class) {
325             ((DoubleVector)vector).intoArray((double[])output, 0);
326         } else {
327             throw new AssertionError();
328         }
329     }
330 }