1 /* 2 * Copyright (c) 2021, 2022, Oracle and/or its affiliates. All rights reserved. 3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. 4 * 5 * This code is free software; you can redistribute it and/or modify it 6 * under the terms of the GNU General Public License version 2 only, as 7 * published by the Free Software Foundation. 8 * 9 * This code is distributed in the hope that it will be useful, but WITHOUT 10 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or 11 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License 12 * version 2 for more details (a copy is included in the LICENSE file that 13 * accompanied this code). 14 * 15 * You should have received a copy of the GNU General Public License version 16 * 2 along with this work; if not, write to the Free Software Foundation, 17 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. 18 * 19 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA 20 * or visit www.oracle.com if you need additional information or have any 21 * questions. 22 */ 23 24 package compiler.vectorapi.reshape.utils; 25 26 import compiler.lib.ir_framework.ForceInline; 27 import compiler.lib.ir_framework.IRNode; 28 import compiler.lib.ir_framework.TestFramework; 29 import java.lang.invoke.MethodHandles; 30 import java.lang.invoke.MethodType; 31 import java.lang.invoke.VarHandle; 32 import java.lang.reflect.Array; 33 import java.nio.ByteOrder; 34 import java.util.List; 35 import java.util.random.RandomGenerator; 36 import java.util.stream.Collectors; 37 import java.util.stream.Stream; 38 import jdk.incubator.vector.*; 39 import jdk.test.lib.Asserts; 40 import jdk.test.lib.Utils; 41 42 public class VectorReshapeHelper { 43 public static final int INVOCATIONS = 10_000; 44 45 public static final VectorSpecies<Byte> BSPEC64 = ByteVector.SPECIES_64; 46 public static final VectorSpecies<Short> SSPEC64 = ShortVector.SPECIES_64; 47 public static final VectorSpecies<Integer> ISPEC64 = IntVector.SPECIES_64; 48 public static final VectorSpecies<Long> LSPEC64 = LongVector.SPECIES_64; 49 public static final VectorSpecies<Float> FSPEC64 = FloatVector.SPECIES_64; 50 public static final VectorSpecies<Double> DSPEC64 = DoubleVector.SPECIES_64; 51 52 public static final VectorSpecies<Byte> BSPEC128 = ByteVector.SPECIES_128; 53 public static final VectorSpecies<Short> SSPEC128 = ShortVector.SPECIES_128; 54 public static final VectorSpecies<Integer> ISPEC128 = IntVector.SPECIES_128; 55 public static final VectorSpecies<Long> LSPEC128 = LongVector.SPECIES_128; 56 public static final VectorSpecies<Float> FSPEC128 = FloatVector.SPECIES_128; 57 public static final VectorSpecies<Double> DSPEC128 = DoubleVector.SPECIES_128; 58 59 public static final VectorSpecies<Byte> BSPEC256 = ByteVector.SPECIES_256; 60 public static final VectorSpecies<Short> SSPEC256 = ShortVector.SPECIES_256; 61 public static final VectorSpecies<Integer> ISPEC256 = IntVector.SPECIES_256; 62 public static final VectorSpecies<Long> LSPEC256 = LongVector.SPECIES_256; 63 public static final VectorSpecies<Float> FSPEC256 = FloatVector.SPECIES_256; 64 public static final VectorSpecies<Double> DSPEC256 = DoubleVector.SPECIES_256; 65 66 public static final VectorSpecies<Byte> BSPEC512 = ByteVector.SPECIES_512; 67 public static final VectorSpecies<Short> SSPEC512 = ShortVector.SPECIES_512; 68 public static final VectorSpecies<Integer> ISPEC512 = IntVector.SPECIES_512; 69 public static final VectorSpecies<Long> LSPEC512 = LongVector.SPECIES_512; 70 public static final VectorSpecies<Float> FSPEC512 = FloatVector.SPECIES_512; 71 public static final VectorSpecies<Double> DSPEC512 = DoubleVector.SPECIES_512; 72 73 public static final String B2X_NODE = IRNode.VECTOR_CAST_B2X; 74 public static final String S2X_NODE = IRNode.VECTOR_CAST_S2X; 75 public static final String I2X_NODE = IRNode.VECTOR_CAST_I2X; 76 public static final String L2X_NODE = IRNode.VECTOR_CAST_L2X; 77 public static final String F2X_NODE = IRNode.VECTOR_CAST_F2X; 78 public static final String D2X_NODE = IRNode.VECTOR_CAST_D2X; 79 public static final String UB2X_NODE = IRNode.VECTOR_UCAST_B2X; 80 public static final String US2X_NODE = IRNode.VECTOR_UCAST_S2X; 81 public static final String UI2X_NODE = IRNode.VECTOR_UCAST_I2X; 82 public static final String REINTERPRET_NODE = IRNode.VECTOR_REINTERPRET; 83 84 public static void runMainHelper(Class<?> testClass, Stream<VectorSpeciesPair> testMethods, String... flags) { 85 var test = new TestFramework(testClass); 86 test.setDefaultWarmup(1); 87 test.addHelperClasses(VectorReshapeHelper.class); 88 test.addFlags("--add-modules=jdk.incubator.vector", "--add-exports=java.base/jdk.internal.misc=ALL-UNNAMED"); 89 test.addFlags(flags); 90 String testMethodNames = testMethods 91 .filter(p -> p.isp().length() <= VectorSpecies.ofLargestShape(p.isp().elementType()).length()) 92 .filter(p -> p.osp().length() <= VectorSpecies.ofLargestShape(p.osp().elementType()).length()) 93 .map(VectorSpeciesPair::format) 94 .collect(Collectors.joining(",")); 95 test.addFlags("-DTest=" + testMethodNames); 96 test.start(); 97 } 98 99 @ForceInline 100 public static <T, U> void vectorCast(VectorOperators.Conversion<T, U> cop, 101 VectorSpecies<T> isp, VectorSpecies<U> osp, Object input, Object output) { 102 var outputVector = readVector(isp, input) 103 .convertShape(cop, osp, 0); 104 writeVector(osp, outputVector, output); 105 } 106 107 public static <T, U> void runCastHelper(VectorOperators.Conversion<T, U> castOp, 108 VectorSpecies<T> isp, VectorSpecies<U> osp) throws Throwable { 109 var random = Utils.getRandomInstance(); 110 boolean isUnsignedCast = castOp.name().startsWith("ZERO"); 111 String testMethodName = VectorSpeciesPair.makePair(isp, osp, isUnsignedCast).format(); 112 var caller = StackWalker.getInstance(StackWalker.Option.RETAIN_CLASS_REFERENCE).getCallerClass(); 113 var testMethod = MethodHandles.lookup().findStatic(caller, 114 testMethodName, 115 MethodType.methodType(void.class, isp.elementType().arrayType(), osp.elementType().arrayType())) 116 .asType(MethodType.methodType(void.class, Object.class, Object.class)); 117 Object input = Array.newInstance(isp.elementType(), isp.length()); 118 Object output = Array.newInstance(osp.elementType(), osp.length()); 119 long ibase = UnsafeUtils.arrayBase(isp.elementType()); 120 long obase = UnsafeUtils.arrayBase(osp.elementType()); 121 for (int iter = 0; iter < INVOCATIONS; iter++) { 122 // We need to generate arrays with NaN or very large values occasionally 123 boolean normalArray = random.nextBoolean(); 124 var abnormalValue = List.of(Double.NaN, Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY, -1e30, 1e30); 125 for (int i = 0; i < isp.length(); i++) { 126 switch (isp.elementType().getName()) { 127 case "byte" -> UnsafeUtils.putByte(input, ibase, i, (byte)random.nextInt()); 128 case "short" -> UnsafeUtils.putShort(input, ibase, i, (short)random.nextInt()); 129 case "int" -> UnsafeUtils.putInt(input, ibase, i, random.nextInt()); 130 case "long" -> UnsafeUtils.putLong(input, ibase, i, random.nextLong()); 131 case "float" -> { 132 if (normalArray || random.nextBoolean()) { 133 UnsafeUtils.putFloat(input, ibase, i, random.nextFloat(Byte.MIN_VALUE, Byte.MAX_VALUE)); 134 } else { 135 UnsafeUtils.putFloat(input, ibase, i, abnormalValue.get(random.nextInt(abnormalValue.size())).floatValue()); 136 } 137 } 138 case "double" -> { 139 if (normalArray || random.nextBoolean()) { 140 UnsafeUtils.putDouble(input, ibase, i, random.nextDouble(Byte.MIN_VALUE, Byte.MAX_VALUE)); 141 } else { 142 UnsafeUtils.putDouble(input, ibase, i, abnormalValue.get(random.nextInt(abnormalValue.size()))); 143 } 144 } 145 default -> throw new AssertionError(); 146 } 147 } 148 149 testMethod.invokeExact(input, output); 150 151 for (int i = 0; i < osp.length(); i++) { 152 Number expected, actual; 153 if (i < isp.length()) { 154 Number initial = switch (isp.elementType().getName()) { 155 case "byte" -> UnsafeUtils.getByte(input, ibase, i); 156 case "short" -> UnsafeUtils.getShort(input, ibase, i); 157 case "int" -> UnsafeUtils.getInt(input, ibase, i); 158 case "long" -> UnsafeUtils.getLong(input, ibase, i); 159 case "float" -> UnsafeUtils.getFloat(input, ibase, i); 160 case "double" -> UnsafeUtils.getDouble(input, ibase, i); 161 default -> throw new AssertionError(); 162 }; 163 expected = switch (osp.elementType().getName()) { 164 case "byte" -> initial.byteValue(); 165 case "short" -> { 166 if (isUnsignedCast) { 167 yield (short) (initial.longValue() & ((1L << isp.elementSize()) - 1)); 168 } else { 169 yield initial.shortValue(); 170 } 171 } 172 case "int" -> { 173 if (isUnsignedCast) { 174 yield (int) (initial.longValue() & ((1L << isp.elementSize()) - 1)); 175 } else { 176 yield initial.intValue(); 177 } 178 } 179 case "long" -> { 180 if (isUnsignedCast) { 181 yield (long) (initial.longValue() & ((1L << isp.elementSize()) - 1)); 182 } else { 183 yield initial.longValue(); 184 } 185 } 186 case "float" -> initial.floatValue(); 187 case "double" -> initial.doubleValue(); 188 default -> throw new AssertionError(); 189 }; 190 } else { 191 expected = switch (osp.elementType().getName()) { 192 case "byte" -> (byte)0; 193 case "short" -> (short)0; 194 case "int" -> (int)0; 195 case "long" -> (long)0; 196 case "float" -> (float)0; 197 case "double" -> (double)0; 198 default -> throw new AssertionError(); 199 }; 200 } 201 actual = switch (osp.elementType().getName()) { 202 case "byte" -> UnsafeUtils.getByte(output, obase, i); 203 case "short" -> UnsafeUtils.getShort(output, obase, i); 204 case "int" -> UnsafeUtils.getInt(output, obase, i); 205 case "long" -> UnsafeUtils.getLong(output, obase, i); 206 case "float" -> UnsafeUtils.getFloat(output, obase, i); 207 case "double" -> UnsafeUtils.getDouble(output, obase, i); 208 default -> throw new AssertionError(); 209 }; 210 Asserts.assertEquals(expected, actual); 211 } 212 } 213 } 214 215 @ForceInline 216 public static void vectorExpandShrink(VectorSpecies<Byte> isp, VectorSpecies<Byte> osp, byte[] input, byte[] output) { 217 isp.fromByteArray(input, 0, ByteOrder.nativeOrder()) 218 .reinterpretShape(osp, 0) 219 .intoByteArray(output, 0, ByteOrder.nativeOrder()); 220 } 221 222 public static void runExpandShrinkHelper(VectorSpecies<Byte> isp, VectorSpecies<Byte> osp) throws Throwable { 223 var random = Utils.getRandomInstance(); 224 String testMethodName = VectorSpeciesPair.makePair(isp, osp).format(); 225 var caller = StackWalker.getInstance(StackWalker.Option.RETAIN_CLASS_REFERENCE).getCallerClass(); 226 var testMethod = MethodHandles.lookup().findStatic(caller, 227 testMethodName, 228 MethodType.methodType(void.class, byte.class.arrayType(), byte.class.arrayType())); 229 byte[] input = new byte[isp.vectorByteSize()]; 230 byte[] output = new byte[osp.vectorByteSize()]; 231 for (int iter = 0; iter < INVOCATIONS; iter++) { 232 random.nextBytes(input); 233 234 testMethod.invokeExact(input, output); 235 236 for (int i = 0; i < osp.vectorByteSize(); i++) { 237 int expected = i < isp.vectorByteSize() ? input[i] : 0; 238 int actual = output[i]; 239 Asserts.assertEquals(expected, actual); 240 } 241 } 242 } 243 244 @ForceInline 245 public static void vectorDoubleExpandShrink(VectorSpecies<Byte> isp, VectorSpecies<Byte> osp, byte[] input, byte[] output) { 246 isp.fromByteArray(input, 0, ByteOrder.nativeOrder()) 247 .reinterpretShape(osp, 0) 248 .reinterpretShape(isp, 0) 249 .intoByteArray(output, 0, ByteOrder.nativeOrder()); 250 } 251 252 public static void runDoubleExpandShrinkHelper(VectorSpecies<Byte> isp, VectorSpecies<Byte> osp) throws Throwable { 253 var random = Utils.getRandomInstance(); 254 String testMethodName = VectorSpeciesPair.makePair(isp, osp).format(); 255 var caller = StackWalker.getInstance(StackWalker.Option.RETAIN_CLASS_REFERENCE).getCallerClass(); 256 var testMethod = MethodHandles.lookup().findStatic(caller, 257 testMethodName, 258 MethodType.methodType(void.class, byte.class.arrayType(), byte.class.arrayType())); 259 byte[] input = new byte[isp.vectorByteSize()]; 260 byte[] output = new byte[isp.vectorByteSize()]; 261 for (int iter = 0; iter < INVOCATIONS; iter++) { 262 random.nextBytes(input); 263 264 testMethod.invokeExact(input, output); 265 266 for (int i = 0; i < isp.vectorByteSize(); i++) { 267 int expected = i < osp.vectorByteSize() ? input[i] : 0; 268 int actual = output[i]; 269 Asserts.assertEquals(expected, actual); 270 } 271 } 272 } 273 274 @ForceInline 275 public static <T, U> void vectorRebracket(VectorSpecies<T> isp, VectorSpecies<U> osp, Object input, Object output) { 276 var outputVector = readVector(isp, input) 277 .reinterpretShape(osp, 0); 278 writeVector(osp, outputVector, output); 279 } 280 281 public static <T, U> void runRebracketHelper(VectorSpecies<T> isp, VectorSpecies<U> osp) throws Throwable { 282 var random = Utils.getRandomInstance(); 283 String testMethodName = VectorSpeciesPair.makePair(isp, osp).format(); 284 var caller = StackWalker.getInstance(StackWalker.Option.RETAIN_CLASS_REFERENCE).getCallerClass(); 285 var testMethod = MethodHandles.lookup().findStatic(caller, 286 testMethodName, 287 MethodType.methodType(void.class, isp.elementType().arrayType(), osp.elementType().arrayType())) 288 .asType(MethodType.methodType(void.class, Object.class, Object.class)); 289 Object input = Array.newInstance(isp.elementType(), isp.length()); 290 Object output = Array.newInstance(osp.elementType(), osp.length()); 291 long ibase = UnsafeUtils.arrayBase(isp.elementType()); 292 long obase = UnsafeUtils.arrayBase(osp.elementType()); 293 for (int iter = 0; iter < INVOCATIONS; iter++) { 294 for (int i = 0; i < isp.vectorByteSize(); i++) { 295 UnsafeUtils.putByte(input, ibase, i, (byte)random.nextInt()); 296 } 297 298 testMethod.invokeExact(input, output); 299 300 for (int i = 0; i < osp.vectorByteSize(); i++) { 301 int expected = i < isp.vectorByteSize() ? UnsafeUtils.getByte(input, ibase, i) : 0; 302 int actual = UnsafeUtils.getByte(output, obase, i); 303 Asserts.assertEquals(expected, actual); 304 } 305 } 306 } 307 308 @ForceInline 309 private static <T> Vector<T> readVector(VectorSpecies<T> isp, Object input) { 310 return isp.fromArray(input, 0); 311 } 312 313 @ForceInline 314 private static <U> void writeVector(VectorSpecies<U> osp, Vector<U> vector, Object output) { 315 var otype = osp.elementType(); 316 if (otype == byte.class) { 317 ((ByteVector)vector).intoArray((byte[])output, 0); 318 } else if (otype == short.class) { 319 ((ShortVector)vector).intoArray((short[])output, 0); 320 } else if (otype == int.class) { 321 ((IntVector)vector).intoArray((int[])output, 0); 322 } else if (otype == long.class) { 323 ((LongVector)vector).intoArray((long[])output, 0); 324 } else if (otype == float.class) { 325 ((FloatVector)vector).intoArray((float[])output, 0); 326 } else if (otype == double.class) { 327 ((DoubleVector)vector).intoArray((double[])output, 0); 328 } else { 329 throw new AssertionError(); 330 } 331 } 332 }