1 /* 2 * Copyright (c) 2021, 2022, Oracle and/or its affiliates. All rights reserved. 3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. 4 * 5 * This code is free software; you can redistribute it and/or modify it 6 * under the terms of the GNU General Public License version 2 only, as 7 * published by the Free Software Foundation. 8 * 9 * This code is distributed in the hope that it will be useful, but WITHOUT 10 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or 11 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License 12 * version 2 for more details (a copy is included in the LICENSE file that 13 * accompanied this code). 14 * 15 * You should have received a copy of the GNU General Public License version 16 * 2 along with this work; if not, write to the Free Software Foundation, 17 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. 18 * 19 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA 20 * or visit www.oracle.com if you need additional information or have any 21 * questions. 22 */ 23 24 package compiler.vectorapi.reshape.utils; 25 26 import compiler.lib.ir_framework.ForceInline; 27 import compiler.lib.ir_framework.IRNode; 28 import compiler.lib.ir_framework.TestFramework; 29 import java.lang.invoke.MethodHandles; 30 import java.lang.invoke.MethodType; 31 import java.lang.invoke.VarHandle; 32 import java.lang.reflect.Array; 33 import java.nio.ByteOrder; 34 import java.util.List; 35 import java.util.random.RandomGenerator; 36 import java.util.stream.Collectors; 37 import java.util.stream.Stream; 38 import java.lang.foreign.MemorySegment; 39 import jdk.incubator.vector.*; 40 import jdk.test.lib.Asserts; 41 import jdk.test.lib.Utils; 42 43 public class VectorReshapeHelper { 44 public static final int INVOCATIONS = 10_000; 45 46 public static final VectorSpecies<Byte> BSPEC64 = ByteVector.SPECIES_64; 47 public static final VectorSpecies<Short> SSPEC64 = ShortVector.SPECIES_64; 48 public static final VectorSpecies<Integer> ISPEC64 = IntVector.SPECIES_64; 49 public static final VectorSpecies<Long> LSPEC64 = LongVector.SPECIES_64; 50 public static final VectorSpecies<Float> FSPEC64 = FloatVector.SPECIES_64; 51 public static final VectorSpecies<Double> DSPEC64 = DoubleVector.SPECIES_64; 52 53 public static final VectorSpecies<Byte> BSPEC128 = ByteVector.SPECIES_128; 54 public static final VectorSpecies<Short> SSPEC128 = ShortVector.SPECIES_128; 55 public static final VectorSpecies<Integer> ISPEC128 = IntVector.SPECIES_128; 56 public static final VectorSpecies<Long> LSPEC128 = LongVector.SPECIES_128; 57 public static final VectorSpecies<Float> FSPEC128 = FloatVector.SPECIES_128; 58 public static final VectorSpecies<Double> DSPEC128 = DoubleVector.SPECIES_128; 59 60 public static final VectorSpecies<Byte> BSPEC256 = ByteVector.SPECIES_256; 61 public static final VectorSpecies<Short> SSPEC256 = ShortVector.SPECIES_256; 62 public static final VectorSpecies<Integer> ISPEC256 = IntVector.SPECIES_256; 63 public static final VectorSpecies<Long> LSPEC256 = LongVector.SPECIES_256; 64 public static final VectorSpecies<Float> FSPEC256 = FloatVector.SPECIES_256; 65 public static final VectorSpecies<Double> DSPEC256 = DoubleVector.SPECIES_256; 66 67 public static final VectorSpecies<Byte> BSPEC512 = ByteVector.SPECIES_512; 68 public static final VectorSpecies<Short> SSPEC512 = ShortVector.SPECIES_512; 69 public static final VectorSpecies<Integer> ISPEC512 = IntVector.SPECIES_512; 70 public static final VectorSpecies<Long> LSPEC512 = LongVector.SPECIES_512; 71 public static final VectorSpecies<Float> FSPEC512 = FloatVector.SPECIES_512; 72 public static final VectorSpecies<Double> DSPEC512 = DoubleVector.SPECIES_512; 73 74 public static final String REINTERPRET_NODE = IRNode.VECTOR_REINTERPRET; 75 76 public static void runMainHelper(Class<?> testClass, Stream<VectorSpeciesPair> testMethods, String... flags) { 77 var test = new TestFramework(testClass); 78 test.setDefaultWarmup(1); 79 test.addHelperClasses(VectorReshapeHelper.class); 80 test.addFlags("--add-modules=jdk.incubator.vector", "--add-exports=java.base/jdk.internal.misc=ALL-UNNAMED"); 81 test.addFlags(flags); 82 String testMethodNames = testMethods 83 .filter(p -> p.isp().length() <= VectorSpecies.ofLargestShape(p.isp().elementType()).length()) 84 .filter(p -> p.osp().length() <= VectorSpecies.ofLargestShape(p.osp().elementType()).length()) 85 .map(VectorSpeciesPair::format) 86 .collect(Collectors.joining(",")); 87 test.addFlags("-DTest=" + testMethodNames); 88 test.start(); 89 } 90 91 @ForceInline 92 public static <T, U> void vectorCast(VectorOperators.Conversion<T, U> cop, 93 VectorSpecies<T> isp, VectorSpecies<U> osp, Object input, Object output) { 94 var outputVector = readVector(isp, input) 95 .convertShape(cop, osp, 0); 96 writeVector(osp, outputVector, output); 97 } 98 99 public static <T, U> void runCastHelper(VectorOperators.Conversion<T, U> castOp, 100 VectorSpecies<T> isp, VectorSpecies<U> osp) throws Throwable { 101 var random = Utils.getRandomInstance(); 102 boolean isUnsignedCast = castOp.name().startsWith("ZERO"); 103 String testMethodName = VectorSpeciesPair.makePair(isp, osp, isUnsignedCast).format(); 104 var caller = StackWalker.getInstance(StackWalker.Option.RETAIN_CLASS_REFERENCE).getCallerClass(); 105 var testMethod = MethodHandles.lookup().findStatic(caller, 106 testMethodName, 107 MethodType.methodType(void.class, isp.elementType().arrayType(), osp.elementType().arrayType())) 108 .asType(MethodType.methodType(void.class, Object.class, Object.class)); 109 Object input = Array.newInstance(isp.elementType(), isp.length()); 110 Object output = Array.newInstance(osp.elementType(), osp.length()); 111 long ibase = UnsafeUtils.arrayBase(isp.elementType()); 112 long obase = UnsafeUtils.arrayBase(osp.elementType()); 113 for (int iter = 0; iter < INVOCATIONS; iter++) { 114 // We need to generate arrays with NaN or very large values occasionally 115 boolean normalArray = random.nextBoolean(); 116 var abnormalValue = List.of(Double.NaN, Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY, -1e30, 1e30); 117 for (int i = 0; i < isp.length(); i++) { 118 switch (isp.elementType().getName()) { 119 case "byte" -> UnsafeUtils.putByte(input, ibase, i, (byte)random.nextInt()); 120 case "short" -> UnsafeUtils.putShort(input, ibase, i, (short)random.nextInt()); 121 case "int" -> UnsafeUtils.putInt(input, ibase, i, random.nextInt()); 122 case "long" -> UnsafeUtils.putLong(input, ibase, i, random.nextLong()); 123 case "float" -> { 124 if (normalArray || random.nextBoolean()) { 125 UnsafeUtils.putFloat(input, ibase, i, random.nextFloat(Byte.MIN_VALUE, Byte.MAX_VALUE)); 126 } else { 127 UnsafeUtils.putFloat(input, ibase, i, abnormalValue.get(random.nextInt(abnormalValue.size())).floatValue()); 128 } 129 } 130 case "double" -> { 131 if (normalArray || random.nextBoolean()) { 132 UnsafeUtils.putDouble(input, ibase, i, random.nextDouble(Byte.MIN_VALUE, Byte.MAX_VALUE)); 133 } else { 134 UnsafeUtils.putDouble(input, ibase, i, abnormalValue.get(random.nextInt(abnormalValue.size()))); 135 } 136 } 137 default -> throw new AssertionError(); 138 } 139 } 140 141 testMethod.invokeExact(input, output); 142 143 for (int i = 0; i < osp.length(); i++) { 144 Number expected, actual; 145 if (i < isp.length()) { 146 Number initial = switch (isp.elementType().getName()) { 147 case "byte" -> UnsafeUtils.getByte(input, ibase, i); 148 case "short" -> UnsafeUtils.getShort(input, ibase, i); 149 case "int" -> UnsafeUtils.getInt(input, ibase, i); 150 case "long" -> UnsafeUtils.getLong(input, ibase, i); 151 case "float" -> UnsafeUtils.getFloat(input, ibase, i); 152 case "double" -> UnsafeUtils.getDouble(input, ibase, i); 153 default -> throw new AssertionError(); 154 }; 155 expected = switch (osp.elementType().getName()) { 156 case "byte" -> initial.byteValue(); 157 case "short" -> { 158 if (isUnsignedCast) { 159 yield (short) (initial.longValue() & ((1L << isp.elementSize()) - 1)); 160 } else { 161 yield initial.shortValue(); 162 } 163 } 164 case "int" -> { 165 if (isUnsignedCast) { 166 yield (int) (initial.longValue() & ((1L << isp.elementSize()) - 1)); 167 } else { 168 yield initial.intValue(); 169 } 170 } 171 case "long" -> { 172 if (isUnsignedCast) { 173 yield (long) (initial.longValue() & ((1L << isp.elementSize()) - 1)); 174 } else { 175 yield initial.longValue(); 176 } 177 } 178 case "float" -> initial.floatValue(); 179 case "double" -> initial.doubleValue(); 180 default -> throw new AssertionError(); 181 }; 182 } else { 183 expected = switch (osp.elementType().getName()) { 184 case "byte" -> (byte)0; 185 case "short" -> (short)0; 186 case "int" -> (int)0; 187 case "long" -> (long)0; 188 case "float" -> (float)0; 189 case "double" -> (double)0; 190 default -> throw new AssertionError(); 191 }; 192 } 193 actual = switch (osp.elementType().getName()) { 194 case "byte" -> UnsafeUtils.getByte(output, obase, i); 195 case "short" -> UnsafeUtils.getShort(output, obase, i); 196 case "int" -> UnsafeUtils.getInt(output, obase, i); 197 case "long" -> UnsafeUtils.getLong(output, obase, i); 198 case "float" -> UnsafeUtils.getFloat(output, obase, i); 199 case "double" -> UnsafeUtils.getDouble(output, obase, i); 200 default -> throw new AssertionError(); 201 }; 202 Asserts.assertEquals(expected, actual); 203 } 204 } 205 } 206 207 @ForceInline 208 public static void vectorExpandShrink(VectorSpecies<Byte> isp, VectorSpecies<Byte> osp, 209 MemorySegment input, MemorySegment output) { 210 isp.fromMemorySegment(input, 0, ByteOrder.nativeOrder()) 211 .reinterpretShape(osp, 0) 212 .intoMemorySegment(output, 0, ByteOrder.nativeOrder()); 213 } 214 215 public static void runExpandShrinkHelper(VectorSpecies<Byte> isp, VectorSpecies<Byte> osp) throws Throwable { 216 var random = Utils.getRandomInstance(); 217 String testMethodName = VectorSpeciesPair.makePair(isp, osp).format(); 218 var caller = StackWalker.getInstance(StackWalker.Option.RETAIN_CLASS_REFERENCE).getCallerClass(); 219 var testMethod = MethodHandles.lookup().findStatic(caller, 220 testMethodName, 221 MethodType.methodType(void.class, MemorySegment.class, MemorySegment.class)); 222 byte[] input = new byte[isp.vectorByteSize()]; 223 byte[] output = new byte[osp.vectorByteSize()]; 224 MemorySegment msInput = MemorySegment.ofArray(input); 225 MemorySegment msOutput = MemorySegment.ofArray(output); 226 for (int iter = 0; iter < INVOCATIONS; iter++) { 227 random.nextBytes(input); 228 229 testMethod.invokeExact(msInput, msOutput); 230 231 for (int i = 0; i < osp.vectorByteSize(); i++) { 232 int expected = i < isp.vectorByteSize() ? input[i] : 0; 233 int actual = output[i]; 234 Asserts.assertEquals(expected, actual); 235 } 236 } 237 } 238 239 @ForceInline 240 public static void vectorDoubleExpandShrink(VectorSpecies<Byte> isp, VectorSpecies<Byte> osp, 241 MemorySegment input, MemorySegment output) { 242 isp.fromMemorySegment(input, 0, ByteOrder.nativeOrder()) 243 .reinterpretShape(osp, 0) 244 .reinterpretShape(isp, 0) 245 .intoMemorySegment(output, 0, ByteOrder.nativeOrder()); 246 } 247 248 public static void runDoubleExpandShrinkHelper(VectorSpecies<Byte> isp, VectorSpecies<Byte> osp) throws Throwable { 249 var random = Utils.getRandomInstance(); 250 String testMethodName = VectorSpeciesPair.makePair(isp, osp).format(); 251 var caller = StackWalker.getInstance(StackWalker.Option.RETAIN_CLASS_REFERENCE).getCallerClass(); 252 var testMethod = MethodHandles.lookup().findStatic(caller, 253 testMethodName, 254 MethodType.methodType(void.class, MemorySegment.class, MemorySegment.class)); 255 byte[] input = new byte[isp.vectorByteSize()]; 256 byte[] output = new byte[isp.vectorByteSize()]; 257 MemorySegment msInput = MemorySegment.ofArray(input); 258 MemorySegment msOutput = MemorySegment.ofArray(output); 259 for (int iter = 0; iter < INVOCATIONS; iter++) { 260 random.nextBytes(input); 261 262 testMethod.invokeExact(msInput, msOutput); 263 264 for (int i = 0; i < isp.vectorByteSize(); i++) { 265 int expected = i < osp.vectorByteSize() ? input[i] : 0; 266 int actual = output[i]; 267 Asserts.assertEquals(expected, actual); 268 } 269 } 270 } 271 272 @ForceInline 273 public static <T, U> void vectorRebracket(VectorSpecies<T> isp, VectorSpecies<U> osp, Object input, Object output) { 274 var outputVector = readVector(isp, input) 275 .reinterpretShape(osp, 0); 276 writeVector(osp, outputVector, output); 277 } 278 279 public static <T, U> void runRebracketHelper(VectorSpecies<T> isp, VectorSpecies<U> osp) throws Throwable { 280 var random = Utils.getRandomInstance(); 281 String testMethodName = VectorSpeciesPair.makePair(isp, osp).format(); 282 var caller = StackWalker.getInstance(StackWalker.Option.RETAIN_CLASS_REFERENCE).getCallerClass(); 283 var testMethod = MethodHandles.lookup().findStatic(caller, 284 testMethodName, 285 MethodType.methodType(void.class, isp.elementType().arrayType(), osp.elementType().arrayType())) 286 .asType(MethodType.methodType(void.class, Object.class, Object.class)); 287 Object input = Array.newInstance(isp.elementType(), isp.length()); 288 Object output = Array.newInstance(osp.elementType(), osp.length()); 289 long ibase = UnsafeUtils.arrayBase(isp.elementType()); 290 long obase = UnsafeUtils.arrayBase(osp.elementType()); 291 for (int iter = 0; iter < INVOCATIONS; iter++) { 292 for (int i = 0; i < isp.vectorByteSize(); i++) { 293 UnsafeUtils.putByte(input, ibase, i, (byte)random.nextInt()); 294 } 295 296 testMethod.invokeExact(input, output); 297 298 for (int i = 0; i < osp.vectorByteSize(); i++) { 299 int expected = i < isp.vectorByteSize() ? UnsafeUtils.getByte(input, ibase, i) : 0; 300 int actual = UnsafeUtils.getByte(output, obase, i); 301 Asserts.assertEquals(expected, actual); 302 } 303 } 304 } 305 306 @ForceInline 307 private static <T> Vector<T> readVector(VectorSpecies<T> isp, Object input) { 308 return isp.fromArray(input, 0); 309 } 310 311 @ForceInline 312 private static <U> void writeVector(VectorSpecies<U> osp, Vector<U> vector, Object output) { 313 var otype = osp.elementType(); 314 if (otype == byte.class) { 315 ((ByteVector)vector).intoArray((byte[])output, 0); 316 } else if (otype == short.class) { 317 ((ShortVector)vector).intoArray((short[])output, 0); 318 } else if (otype == int.class) { 319 ((IntVector)vector).intoArray((int[])output, 0); 320 } else if (otype == long.class) { 321 ((LongVector)vector).intoArray((long[])output, 0); 322 } else if (otype == float.class) { 323 ((FloatVector)vector).intoArray((float[])output, 0); 324 } else if (otype == double.class) { 325 ((DoubleVector)vector).intoArray((double[])output, 0); 326 } else { 327 throw new AssertionError(); 328 } 329 } 330 }