1 /* 2 * Copyright (c) 2021, 2022, Oracle and/or its affiliates. All rights reserved. 3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. 4 * 5 * This code is free software; you can redistribute it and/or modify it 6 * under the terms of the GNU General Public License version 2 only, as 7 * published by the Free Software Foundation. 8 * 9 * This code is distributed in the hope that it will be useful, but WITHOUT 10 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or 11 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License 12 * version 2 for more details (a copy is included in the LICENSE file that 13 * accompanied this code). 14 * 15 * You should have received a copy of the GNU General Public License version 16 * 2 along with this work; if not, write to the Free Software Foundation, 17 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. 18 * 19 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA 20 * or visit www.oracle.com if you need additional information or have any 21 * questions. 22 */ 23 24 package compiler.vectorapi.reshape.utils; 25 26 import compiler.lib.ir_framework.ForceInline; 27 import compiler.lib.ir_framework.IRNode; 28 import compiler.lib.ir_framework.TestFramework; 29 import java.lang.invoke.MethodHandles; 30 import java.lang.invoke.MethodType; 31 import java.lang.invoke.VarHandle; 32 import java.lang.reflect.Array; 33 import java.nio.ByteOrder; 34 import java.util.List; 35 import java.util.random.RandomGenerator; 36 import java.util.stream.Collectors; 37 import java.util.stream.Stream; 38 import jdk.incubator.foreign.MemorySegment; 39 import jdk.incubator.vector.*; 40 import jdk.test.lib.Asserts; 41 import jdk.test.lib.Utils; 42 43 public class VectorReshapeHelper { 44 public static final int INVOCATIONS = 10_000; 45 46 public static final VectorSpecies<Byte> BSPEC64 = ByteVector.SPECIES_64; 47 public static final VectorSpecies<Short> SSPEC64 = ShortVector.SPECIES_64; 48 public static final VectorSpecies<Integer> ISPEC64 = IntVector.SPECIES_64; 49 public static final VectorSpecies<Long> LSPEC64 = LongVector.SPECIES_64; 50 public static final VectorSpecies<Float> FSPEC64 = FloatVector.SPECIES_64; 51 public static final VectorSpecies<Double> DSPEC64 = DoubleVector.SPECIES_64; 52 53 public static final VectorSpecies<Byte> BSPEC128 = ByteVector.SPECIES_128; 54 public static final VectorSpecies<Short> SSPEC128 = ShortVector.SPECIES_128; 55 public static final VectorSpecies<Integer> ISPEC128 = IntVector.SPECIES_128; 56 public static final VectorSpecies<Long> LSPEC128 = LongVector.SPECIES_128; 57 public static final VectorSpecies<Float> FSPEC128 = FloatVector.SPECIES_128; 58 public static final VectorSpecies<Double> DSPEC128 = DoubleVector.SPECIES_128; 59 60 public static final VectorSpecies<Byte> BSPEC256 = ByteVector.SPECIES_256; 61 public static final VectorSpecies<Short> SSPEC256 = ShortVector.SPECIES_256; 62 public static final VectorSpecies<Integer> ISPEC256 = IntVector.SPECIES_256; 63 public static final VectorSpecies<Long> LSPEC256 = LongVector.SPECIES_256; 64 public static final VectorSpecies<Float> FSPEC256 = FloatVector.SPECIES_256; 65 public static final VectorSpecies<Double> DSPEC256 = DoubleVector.SPECIES_256; 66 67 public static final VectorSpecies<Byte> BSPEC512 = ByteVector.SPECIES_512; 68 public static final VectorSpecies<Short> SSPEC512 = ShortVector.SPECIES_512; 69 public static final VectorSpecies<Integer> ISPEC512 = IntVector.SPECIES_512; 70 public static final VectorSpecies<Long> LSPEC512 = LongVector.SPECIES_512; 71 public static final VectorSpecies<Float> FSPEC512 = FloatVector.SPECIES_512; 72 public static final VectorSpecies<Double> DSPEC512 = DoubleVector.SPECIES_512; 73 74 public static final String B2X_NODE = IRNode.VECTOR_CAST_B2X; 75 public static final String S2X_NODE = IRNode.VECTOR_CAST_S2X; 76 public static final String I2X_NODE = IRNode.VECTOR_CAST_I2X; 77 public static final String L2X_NODE = IRNode.VECTOR_CAST_L2X; 78 public static final String F2X_NODE = IRNode.VECTOR_CAST_F2X; 79 public static final String D2X_NODE = IRNode.VECTOR_CAST_D2X; 80 public static final String UB2X_NODE = IRNode.VECTOR_UCAST_B2X; 81 public static final String US2X_NODE = IRNode.VECTOR_UCAST_S2X; 82 public static final String UI2X_NODE = IRNode.VECTOR_UCAST_I2X; 83 public static final String REINTERPRET_NODE = IRNode.VECTOR_REINTERPRET; 84 85 public static void runMainHelper(Class<?> testClass, Stream<VectorSpeciesPair> testMethods, String... flags) { 86 var test = new TestFramework(testClass); 87 test.setDefaultWarmup(1); 88 test.addHelperClasses(VectorReshapeHelper.class); 89 test.addFlags("--add-modules=jdk.incubator.vector", "--add-exports=java.base/jdk.internal.misc=ALL-UNNAMED"); 90 test.addFlags(flags); 91 String testMethodNames = testMethods 92 .filter(p -> p.isp().length() <= VectorSpecies.ofLargestShape(p.isp().elementType()).length()) 93 .filter(p -> p.osp().length() <= VectorSpecies.ofLargestShape(p.osp().elementType()).length()) 94 .map(VectorSpeciesPair::format) 95 .collect(Collectors.joining(",")); 96 test.addFlags("-DTest=" + testMethodNames); 97 test.start(); 98 } 99 100 @ForceInline 101 public static <T, U> void vectorCast(VectorOperators.Conversion<T, U> cop, 102 VectorSpecies<T> isp, VectorSpecies<U> osp, Object input, Object output) { 103 var outputVector = readVector(isp, input) 104 .convertShape(cop, osp, 0); 105 writeVector(osp, outputVector, output); 106 } 107 108 public static <T, U> void runCastHelper(VectorOperators.Conversion<T, U> castOp, 109 VectorSpecies<T> isp, VectorSpecies<U> osp) throws Throwable { 110 var random = Utils.getRandomInstance(); 111 boolean isUnsignedCast = castOp.name().startsWith("ZERO"); 112 String testMethodName = VectorSpeciesPair.makePair(isp, osp, isUnsignedCast).format(); 113 var caller = StackWalker.getInstance(StackWalker.Option.RETAIN_CLASS_REFERENCE).getCallerClass(); 114 var testMethod = MethodHandles.lookup().findStatic(caller, 115 testMethodName, 116 MethodType.methodType(void.class, isp.elementType().arrayType(), osp.elementType().arrayType())) 117 .asType(MethodType.methodType(void.class, Object.class, Object.class)); 118 Object input = Array.newInstance(isp.elementType(), isp.length()); 119 Object output = Array.newInstance(osp.elementType(), osp.length()); 120 long ibase = UnsafeUtils.arrayBase(isp.elementType()); 121 long obase = UnsafeUtils.arrayBase(osp.elementType()); 122 for (int iter = 0; iter < INVOCATIONS; iter++) { 123 // We need to generate arrays with NaN or very large values occasionally 124 boolean normalArray = random.nextBoolean(); 125 var abnormalValue = List.of(Double.NaN, Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY, -1e30, 1e30); 126 for (int i = 0; i < isp.length(); i++) { 127 switch (isp.elementType().getName()) { 128 case "byte" -> UnsafeUtils.putByte(input, ibase, i, (byte)random.nextInt()); 129 case "short" -> UnsafeUtils.putShort(input, ibase, i, (short)random.nextInt()); 130 case "int" -> UnsafeUtils.putInt(input, ibase, i, random.nextInt()); 131 case "long" -> UnsafeUtils.putLong(input, ibase, i, random.nextLong()); 132 case "float" -> { 133 if (normalArray || random.nextBoolean()) { 134 UnsafeUtils.putFloat(input, ibase, i, random.nextFloat(Byte.MIN_VALUE, Byte.MAX_VALUE)); 135 } else { 136 UnsafeUtils.putFloat(input, ibase, i, abnormalValue.get(random.nextInt(abnormalValue.size())).floatValue()); 137 } 138 } 139 case "double" -> { 140 if (normalArray || random.nextBoolean()) { 141 UnsafeUtils.putDouble(input, ibase, i, random.nextDouble(Byte.MIN_VALUE, Byte.MAX_VALUE)); 142 } else { 143 UnsafeUtils.putDouble(input, ibase, i, abnormalValue.get(random.nextInt(abnormalValue.size()))); 144 } 145 } 146 default -> throw new AssertionError(); 147 } 148 } 149 150 testMethod.invokeExact(input, output); 151 152 for (int i = 0; i < osp.length(); i++) { 153 Number expected, actual; 154 if (i < isp.length()) { 155 Number initial = switch (isp.elementType().getName()) { 156 case "byte" -> UnsafeUtils.getByte(input, ibase, i); 157 case "short" -> UnsafeUtils.getShort(input, ibase, i); 158 case "int" -> UnsafeUtils.getInt(input, ibase, i); 159 case "long" -> UnsafeUtils.getLong(input, ibase, i); 160 case "float" -> UnsafeUtils.getFloat(input, ibase, i); 161 case "double" -> UnsafeUtils.getDouble(input, ibase, i); 162 default -> throw new AssertionError(); 163 }; 164 expected = switch (osp.elementType().getName()) { 165 case "byte" -> initial.byteValue(); 166 case "short" -> { 167 if (isUnsignedCast) { 168 yield (short) (initial.longValue() & ((1L << isp.elementSize()) - 1)); 169 } else { 170 yield initial.shortValue(); 171 } 172 } 173 case "int" -> { 174 if (isUnsignedCast) { 175 yield (int) (initial.longValue() & ((1L << isp.elementSize()) - 1)); 176 } else { 177 yield initial.intValue(); 178 } 179 } 180 case "long" -> { 181 if (isUnsignedCast) { 182 yield (long) (initial.longValue() & ((1L << isp.elementSize()) - 1)); 183 } else { 184 yield initial.longValue(); 185 } 186 } 187 case "float" -> initial.floatValue(); 188 case "double" -> initial.doubleValue(); 189 default -> throw new AssertionError(); 190 }; 191 } else { 192 expected = switch (osp.elementType().getName()) { 193 case "byte" -> (byte)0; 194 case "short" -> (short)0; 195 case "int" -> (int)0; 196 case "long" -> (long)0; 197 case "float" -> (float)0; 198 case "double" -> (double)0; 199 default -> throw new AssertionError(); 200 }; 201 } 202 actual = switch (osp.elementType().getName()) { 203 case "byte" -> UnsafeUtils.getByte(output, obase, i); 204 case "short" -> UnsafeUtils.getShort(output, obase, i); 205 case "int" -> UnsafeUtils.getInt(output, obase, i); 206 case "long" -> UnsafeUtils.getLong(output, obase, i); 207 case "float" -> UnsafeUtils.getFloat(output, obase, i); 208 case "double" -> UnsafeUtils.getDouble(output, obase, i); 209 default -> throw new AssertionError(); 210 }; 211 Asserts.assertEquals(expected, actual); 212 } 213 } 214 } 215 216 @ForceInline 217 public static void vectorExpandShrink(VectorSpecies<Byte> isp, VectorSpecies<Byte> osp, 218 MemorySegment input, MemorySegment output) { 219 isp.fromMemorySegment(input, 0, ByteOrder.nativeOrder()) 220 .reinterpretShape(osp, 0) 221 .intoMemorySegment(output, 0, ByteOrder.nativeOrder()); 222 } 223 224 public static void runExpandShrinkHelper(VectorSpecies<Byte> isp, VectorSpecies<Byte> osp) throws Throwable { 225 var random = Utils.getRandomInstance(); 226 String testMethodName = VectorSpeciesPair.makePair(isp, osp).format(); 227 var caller = StackWalker.getInstance(StackWalker.Option.RETAIN_CLASS_REFERENCE).getCallerClass(); 228 var testMethod = MethodHandles.lookup().findStatic(caller, 229 testMethodName, 230 MethodType.methodType(void.class, MemorySegment.class, MemorySegment.class)); 231 byte[] input = new byte[isp.vectorByteSize()]; 232 byte[] output = new byte[osp.vectorByteSize()]; 233 MemorySegment msInput = MemorySegment.ofArray(input); 234 MemorySegment msOutput = MemorySegment.ofArray(output); 235 for (int iter = 0; iter < INVOCATIONS; iter++) { 236 random.nextBytes(input); 237 238 testMethod.invokeExact(msInput, msOutput); 239 240 for (int i = 0; i < osp.vectorByteSize(); i++) { 241 int expected = i < isp.vectorByteSize() ? input[i] : 0; 242 int actual = output[i]; 243 Asserts.assertEquals(expected, actual); 244 } 245 } 246 } 247 248 @ForceInline 249 public static void vectorDoubleExpandShrink(VectorSpecies<Byte> isp, VectorSpecies<Byte> osp, 250 MemorySegment input, MemorySegment output) { 251 isp.fromMemorySegment(input, 0, ByteOrder.nativeOrder()) 252 .reinterpretShape(osp, 0) 253 .reinterpretShape(isp, 0) 254 .intoMemorySegment(output, 0, ByteOrder.nativeOrder()); 255 } 256 257 public static void runDoubleExpandShrinkHelper(VectorSpecies<Byte> isp, VectorSpecies<Byte> osp) throws Throwable { 258 var random = Utils.getRandomInstance(); 259 String testMethodName = VectorSpeciesPair.makePair(isp, osp).format(); 260 var caller = StackWalker.getInstance(StackWalker.Option.RETAIN_CLASS_REFERENCE).getCallerClass(); 261 var testMethod = MethodHandles.lookup().findStatic(caller, 262 testMethodName, 263 MethodType.methodType(void.class, MemorySegment.class, MemorySegment.class)); 264 byte[] input = new byte[isp.vectorByteSize()]; 265 byte[] output = new byte[isp.vectorByteSize()]; 266 MemorySegment msInput = MemorySegment.ofArray(input); 267 MemorySegment msOutput = MemorySegment.ofArray(output); 268 for (int iter = 0; iter < INVOCATIONS; iter++) { 269 random.nextBytes(input); 270 271 testMethod.invokeExact(msInput, msOutput); 272 273 for (int i = 0; i < isp.vectorByteSize(); i++) { 274 int expected = i < osp.vectorByteSize() ? input[i] : 0; 275 int actual = output[i]; 276 Asserts.assertEquals(expected, actual); 277 } 278 } 279 } 280 281 @ForceInline 282 public static <T, U> void vectorRebracket(VectorSpecies<T> isp, VectorSpecies<U> osp, Object input, Object output) { 283 var outputVector = readVector(isp, input) 284 .reinterpretShape(osp, 0); 285 writeVector(osp, outputVector, output); 286 } 287 288 public static <T, U> void runRebracketHelper(VectorSpecies<T> isp, VectorSpecies<U> osp) throws Throwable { 289 var random = Utils.getRandomInstance(); 290 String testMethodName = VectorSpeciesPair.makePair(isp, osp).format(); 291 var caller = StackWalker.getInstance(StackWalker.Option.RETAIN_CLASS_REFERENCE).getCallerClass(); 292 var testMethod = MethodHandles.lookup().findStatic(caller, 293 testMethodName, 294 MethodType.methodType(void.class, isp.elementType().arrayType(), osp.elementType().arrayType())) 295 .asType(MethodType.methodType(void.class, Object.class, Object.class)); 296 Object input = Array.newInstance(isp.elementType(), isp.length()); 297 Object output = Array.newInstance(osp.elementType(), osp.length()); 298 long ibase = UnsafeUtils.arrayBase(isp.elementType()); 299 long obase = UnsafeUtils.arrayBase(osp.elementType()); 300 for (int iter = 0; iter < INVOCATIONS; iter++) { 301 for (int i = 0; i < isp.vectorByteSize(); i++) { 302 UnsafeUtils.putByte(input, ibase, i, (byte)random.nextInt()); 303 } 304 305 testMethod.invokeExact(input, output); 306 307 for (int i = 0; i < osp.vectorByteSize(); i++) { 308 int expected = i < isp.vectorByteSize() ? UnsafeUtils.getByte(input, ibase, i) : 0; 309 int actual = UnsafeUtils.getByte(output, obase, i); 310 Asserts.assertEquals(expected, actual); 311 } 312 } 313 } 314 315 @ForceInline 316 private static <T> Vector<T> readVector(VectorSpecies<T> isp, Object input) { 317 return isp.fromArray(input, 0); 318 } 319 320 @ForceInline 321 private static <U> void writeVector(VectorSpecies<U> osp, Vector<U> vector, Object output) { 322 var otype = osp.elementType(); 323 if (otype == byte.class) { 324 ((ByteVector)vector).intoArray((byte[])output, 0); 325 } else if (otype == short.class) { 326 ((ShortVector)vector).intoArray((short[])output, 0); 327 } else if (otype == int.class) { 328 ((IntVector)vector).intoArray((int[])output, 0); 329 } else if (otype == long.class) { 330 ((LongVector)vector).intoArray((long[])output, 0); 331 } else if (otype == float.class) { 332 ((FloatVector)vector).intoArray((float[])output, 0); 333 } else if (otype == double.class) { 334 ((DoubleVector)vector).intoArray((double[])output, 0); 335 } else { 336 throw new AssertionError(); 337 } 338 } 339 }