1 /* 2 * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved. 3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. 4 * 5 * This code is free software; you can redistribute it and/or modify it 6 * under the terms of the GNU General Public License version 2 only, as 7 * published by the Free Software Foundation. 8 * 9 * This code is distributed in the hope that it will be useful, but WITHOUT 10 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or 11 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License 12 * version 2 for more details (a copy is included in the LICENSE file that 13 * accompanied this code). 14 * 15 * You should have received a copy of the GNU General Public License version 16 * 2 along with this work; if not, write to the Free Software Foundation, 17 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. 18 * 19 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA 20 * or visit www.oracle.com if you need additional information or have any 21 * questions. 22 */ 23 24 /* 25 * @test 26 * @run junit CoreBinaryOpsTest 27 */ 28 29 import org.junit.jupiter.api.Named; 30 import org.junit.jupiter.api.extension.ExtensionContext; 31 import org.junit.jupiter.api.function.ThrowingSupplier; 32 import org.junit.jupiter.params.ParameterizedTest; 33 import org.junit.jupiter.params.provider.Arguments; 34 import org.junit.jupiter.params.provider.ArgumentsProvider; 35 import org.junit.jupiter.params.provider.ArgumentsSource; 36 37 import java.lang.annotation.ElementType; 38 import java.lang.annotation.Retention; 39 import java.lang.annotation.RetentionPolicy; 40 import java.lang.annotation.Target; 41 import java.lang.invoke.MethodHandle; 42 import java.lang.invoke.MethodHandles; 43 import java.lang.invoke.MethodType; 44 import java.lang.reflect.AccessFlag; 45 import java.lang.reflect.Method; 46 import java.lang.reflect.Parameter; 47 import java.lang.reflect.code.CopyContext; 48 import java.lang.reflect.code.Op; 49 import java.lang.reflect.code.OpTransformer; 50 import java.lang.reflect.code.TypeElement; 51 import java.lang.reflect.code.analysis.SSA; 52 import java.lang.reflect.code.bytecode.BytecodeGenerator; 53 import java.lang.reflect.code.interpreter.Interpreter; 54 import java.lang.reflect.code.op.CoreOp; 55 import java.lang.reflect.code.type.FunctionType; 56 import java.lang.reflect.code.type.JavaType; 57 import java.lang.runtime.CodeReflection; 58 import java.util.*; 59 import java.util.stream.Stream; 60 61 import static org.junit.jupiter.api.Assertions.*; 62 63 public class CoreBinaryOpsTest { 64 65 @CodeReflection 66 @SupportedTypes(TypeList.INTEGRAL_BOOLEAN) 67 static int and(int left, int right) { 68 return left & right; 69 } 70 71 @CodeReflection 72 @SupportedTypes(TypeList.INTEGRAL_FLOATING_POINT) 73 static int add(int left, int right) { 74 return left + right; 75 } 76 77 @CodeReflection 78 @SupportedTypes(TypeList.INTEGRAL_FLOATING_POINT) 79 static int div(int left, int right) { 80 return left / right; 81 } 82 83 @CodeReflection 84 @SupportedTypes(TypeList.INT_LONG) 85 static int leftShift(int left, int right) { 86 return left << right; 87 } 88 89 @CodeReflection 90 @Direct 91 static int leftShiftIL(int left, long right) { 92 return left << right; 93 } 94 95 @CodeReflection 96 @Direct 97 static long leftShiftLI(long left, int right) { 98 return left << right; 99 } 100 101 @CodeReflection 102 @SupportedTypes(TypeList.INTEGRAL_FLOATING_POINT) 103 static int mod(int left, int right) { 104 return left % right; 105 } 106 107 @CodeReflection 108 @SupportedTypes(TypeList.INTEGRAL_FLOATING_POINT) 109 static int mul(int left, int right) { 110 return left * right; 111 } 112 113 @CodeReflection 114 @SupportedTypes(TypeList.INTEGRAL_BOOLEAN) 115 static int or(int left, int right) { 116 return left | right; 117 } 118 119 @CodeReflection 120 @SupportedTypes(TypeList.INT_LONG) 121 static int signedRightShift(int left, int right) { 122 return left >> right; 123 } 124 125 @CodeReflection 126 @Direct 127 static int signedRightShiftIL(int left, long right) { 128 return left >> right; 129 } 130 131 @CodeReflection 132 @Direct 133 static long signedRightShiftLI(long left, int right) { 134 return left >> right; 135 } 136 137 @CodeReflection 138 @SupportedTypes(TypeList.INTEGRAL_FLOATING_POINT) 139 static int sub(int left, int right) { 140 return left - right; 141 } 142 143 @CodeReflection 144 @SupportedTypes(TypeList.INT_LONG) 145 static int unsignedRightShift(int left, int right) { 146 return left >>> right; 147 } 148 149 @CodeReflection 150 @Direct 151 static int unsignedRightShiftIL(int left, long right) { 152 return left >>> right; 153 } 154 155 @CodeReflection 156 @Direct 157 static long unsignedRightShiftLI(long left, int right) { 158 return left >>> right; 159 } 160 161 @CodeReflection 162 @SupportedTypes(TypeList.INTEGRAL_BOOLEAN) 163 static int xor(int left, int right) { 164 return left ^ right; 165 } 166 167 @ParameterizedTest 168 @CodeReflectionExecutionSource 169 void test(CoreOp.FuncOp funcOp, Object left, Object right) { 170 Result interpret = runCatching(() -> interpret(left, right, funcOp)); 171 Result bytecode = runCatching(() -> bytecode(left, right, funcOp)); 172 assertResults(interpret, bytecode); 173 } 174 175 @Retention(RetentionPolicy.RUNTIME) 176 @Target(ElementType.METHOD) 177 @interface SupportedTypes { 178 TypeList value(); 179 } 180 181 enum TypeList { 182 INT_LONG(int.class, long.class), 183 INTEGRAL_BOOLEAN(int.class, long.class, byte.class, short.class, char.class, boolean.class), 184 INTEGRAL_FLOATING_POINT(int.class, long.class, byte.class, short.class, char.class, float.class, double.class); 185 186 private final Class<?>[] types; 187 188 TypeList(Class<?>... types) { 189 this.types = types; 190 } 191 192 public Class<?>[] types() { 193 return types; 194 } 195 } 196 197 // mark as "do not transform" 198 @Retention(RetentionPolicy.RUNTIME) 199 @Target(ElementType.METHOD) 200 @interface Direct { 201 } 202 203 @Retention(RetentionPolicy.RUNTIME) 204 @Target(ElementType.METHOD) 205 @ArgumentsSource(CodeReflectionSourceProvider.class) 206 @interface CodeReflectionExecutionSource { 207 } 208 209 static class CodeReflectionSourceProvider implements ArgumentsProvider { 210 private static final Map<JavaType, List<?>> INTERESTING_INPUTS = Map.of( 211 // explicit type parameters to ensure boxing results in the expected type 212 JavaType.INT, List.<Integer>of(Integer.MIN_VALUE, Integer.MAX_VALUE, 1, 0, -1), 213 JavaType.LONG, List.<Long>of(Long.MIN_VALUE, Long.MAX_VALUE, 1L, 0L, -1L), 214 JavaType.BYTE, List.<Byte>of(Byte.MIN_VALUE, Byte.MAX_VALUE, (byte) 1, (byte) 0, (byte) -1), 215 JavaType.SHORT, List.<Short>of(Short.MIN_VALUE, Short.MAX_VALUE, (short) 1, (short) 0, (short) -1), 216 JavaType.CHAR, List.<Character>of(Character.MIN_VALUE, Character.MAX_VALUE, (char) 1, (char) 0, (char) -1), 217 JavaType.DOUBLE, List.<Double>of(Double.MIN_VALUE, Double.MAX_VALUE, Double.NaN, Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY, Double.MIN_NORMAL, 1d, 0d, -1d), 218 JavaType.FLOAT, List.<Float>of(Float.MIN_VALUE, Float.MAX_VALUE, Float.NaN, Float.NEGATIVE_INFINITY, Float.POSITIVE_INFINITY, Float.MIN_NORMAL, 1f, 0f, -1f), 219 JavaType.BOOLEAN, List.<Boolean>of(true, false) 220 ); 221 222 @Override 223 public Stream<? extends Arguments> provideArguments(ExtensionContext extensionContext) { 224 Method testMethod = extensionContext.getRequiredTestMethod(); 225 return codeReflectionMethods(extensionContext.getRequiredTestClass()) 226 .flatMap(method -> { 227 CoreOp.FuncOp funcOp = method.getCodeModel().orElseThrow( 228 () -> new IllegalStateException("Expected code model to be present for method " + method) 229 ); 230 SupportedTypes supportedTypes = method.getAnnotation(SupportedTypes.class); 231 if (method.isAnnotationPresent(Direct.class)) { 232 if (supportedTypes != null) { 233 throw new IllegalArgumentException("Direct should not be combined with SupportedTypes"); 234 } 235 return Stream.of(funcOp); 236 } 237 if (supportedTypes == null || supportedTypes.value().types().length == 0) { 238 throw new IllegalArgumentException("Missing supported types"); 239 } 240 return Arrays.stream(supportedTypes.value().types()) 241 .map(type -> retype(funcOp, type)); 242 }) 243 .flatMap(transformedFunc -> argumentsForMethod(transformedFunc, testMethod)); 244 } 245 246 private static <T> Stream<List<T>> cartesianProduct(List<List<? extends T>> source) { 247 if (source.isEmpty()) { 248 return Stream.of(new ArrayList<>()); 249 } 250 return source.getFirst().stream() 251 .flatMap(e -> cartesianProduct(source.subList(1, source.size())).map(l -> { 252 ArrayList<T> newList = new ArrayList<>(l); 253 newList.add(e); 254 return newList; 255 })); 256 } 257 258 private static CoreOp.FuncOp retype(CoreOp.FuncOp original, Class<?> newType) { 259 JavaType type = JavaType.type(newType); 260 FunctionType functionType = original.invokableType(); 261 if (functionType.parameterTypes().stream().allMatch(t -> t.equals(type))) { 262 return original; // already expected type 263 } 264 if (functionType.parameterTypes().stream().distinct().count() != 1) { 265 original.writeTo(System.err); 266 throw new IllegalArgumentException("Only FuncOps with exactly one distinct parameter type are supported"); 267 } 268 // if the return type does not match the input types, we keep it 269 TypeElement retType = functionType.returnType().equals(functionType.parameterTypes().getFirst()) 270 ? type 271 : functionType.returnType(); 272 return CoreOp.func(original.funcName(), FunctionType.functionType(retType, type, type)) 273 .body(builder -> builder.transformBody(original.body(), builder.parameters(), (block, op) -> { 274 block.context().mapValue(op.result(), block.op(retype(block.context(), op))); 275 return block; 276 }) 277 ); 278 } 279 280 private static Op retype(CopyContext context, Op op) { 281 return switch (op) { 282 case CoreOp.VarOp varOp -> 283 CoreOp.var(varOp.varName(), context.getValueOrDefault(varOp.operands().getFirst(), varOp.operands().getFirst())); 284 default -> op; 285 }; 286 } 287 288 private static Stream<Arguments> argumentsForMethod(CoreOp.FuncOp funcOp, Method testMethod) { 289 Parameter[] testMethodParameters = testMethod.getParameters(); 290 List<TypeElement> funcParameters = funcOp.invokableType().parameterTypes(); 291 if (testMethodParameters.length - 1 != funcParameters.size()) { 292 throw new IllegalArgumentException("method " + testMethod + " does not take the correct number of parameters"); 293 } 294 if (testMethodParameters[0].getType() != CoreOp.FuncOp.class) { 295 throw new IllegalArgumentException("method " + testMethod + " does not take a leading FuncOp argument"); 296 } 297 Named<CoreOp.FuncOp> opNamed = Named.of(funcOp.funcName() + "{" + funcOp.invokableType() + "}", funcOp); 298 MethodHandles.Lookup lookup = MethodHandles.lookup(); 299 for (int i = 1; i < testMethodParameters.length; i++) { 300 Class<?> resolved = resolveParameter(funcParameters.get(i - 1), lookup); 301 if (!isCompatible(resolved, testMethodParameters[i].getType())) { 302 System.out.println(testMethod + " does not accept inputs of type " + resolved + " at index " + i); 303 return Stream.empty(); 304 } 305 } 306 List<List<?>> allInputs = new ArrayList<>(); 307 for (TypeElement parameterType : funcParameters) { 308 allInputs.add(INTERESTING_INPUTS.get((JavaType) parameterType)); 309 } 310 return cartesianProduct(allInputs) 311 .map(objects -> { 312 objects.add(opNamed); 313 return objects.reversed().toArray(); // reverse so FuncOp is at the beginning 314 }) 315 .map(Arguments::of); 316 } 317 318 private static Class<?> resolveParameter(TypeElement typeElement, MethodHandles.Lookup lookup) { 319 try { 320 return (Class<?>)((JavaType) typeElement).erasure().resolve(lookup); 321 } catch (ReflectiveOperationException e) { 322 throw new RuntimeException(e); 323 } 324 } 325 326 // check whether elements of type sourceType can be passed to a parameter of parameterType 327 private static boolean isCompatible(Class<?> sourceType, Class<?> parameterType) { 328 return wrapped(parameterType).isAssignableFrom(wrapped(sourceType)); 329 } 330 331 private static Class<?> wrapped(Class<?> target) { 332 return MethodType.methodType(target).wrap().returnType(); 333 } 334 335 private static Stream<Method> codeReflectionMethods(Class<?> testClass) { 336 return Arrays.stream(testClass.getDeclaredMethods()) 337 .filter(method -> method.accessFlags().contains(AccessFlag.STATIC)) 338 .filter(method -> method.isAnnotationPresent(CodeReflection.class)); 339 } 340 341 } 342 343 private static Object interpret(Object left, Object right, CoreOp.FuncOp op) { 344 return Interpreter.invoke(MethodHandles.lookup(), op, left, right); 345 } 346 347 private static Object bytecode(Object left, Object right, CoreOp.FuncOp op) throws Throwable { 348 CoreOp.FuncOp func = SSA.transform(op.transform(OpTransformer.LOWERING_TRANSFORMER)); 349 MethodHandle handle = BytecodeGenerator.generate(MethodHandles.lookup(), func); 350 return handle.invoke(left, right); 351 } 352 353 private static void assertResults(Result first, Result second) { 354 System.out.println("first: " + first); 355 System.out.println("second: " + second); 356 // either the same error occurred on both or no error occurred 357 if (first.throwable != null || second.throwable != null) { 358 assertNotNull(first.throwable, () -> "only second threw an exception: " + second.throwable); 359 assertNotNull(second.throwable, () -> "only first threw an exception: " + first.throwable); 360 if (first.throwable.getClass() != second.throwable.getClass()) { 361 first.throwable.printStackTrace(); 362 second.throwable.printStackTrace(); 363 fail("Different exceptions were thrown"); 364 } 365 return; 366 } 367 // otherwise, both results should be non-null and equals 368 assertNotNull(first.onSuccess); 369 assertEquals(first.onSuccess, second.onSuccess); 370 } 371 372 private static <T> Result runCatching(ThrowingSupplier<T> supplier) { 373 Object value = null; 374 Throwable interpretThrowable = null; 375 try { 376 value = supplier.get(); 377 } catch (Throwable t) { 378 interpretThrowable = t; 379 } 380 return new Result(value, interpretThrowable); 381 } 382 383 record Result(Object onSuccess, Throwable throwable) { 384 } 385 386 }