1 /* 2 * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved. 3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. 4 * 5 * This code is free software; you can redistribute it and/or modify it 6 * under the terms of the GNU General Public License version 2 only, as 7 * published by the Free Software Foundation. 8 * 9 * This code is distributed in the hope that it will be useful, but WITHOUT 10 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or 11 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License 12 * version 2 for more details (a copy is included in the LICENSE file that 13 * accompanied this code). 14 * 15 * You should have received a copy of the GNU General Public License version 16 * 2 along with this work; if not, write to the Free Software Foundation, 17 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. 18 * 19 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA 20 * or visit www.oracle.com if you need additional information or have any 21 * questions. 22 */ 23 24 /* 25 * @test 26 * @modules jdk.incubator.code 27 * @run junit CoreBinaryOpsTest 28 * @run junit/othervm -Dbabylon.ssa=cytron CoreBinaryOpsTest 29 */ 30 31 import jdk.incubator.code.CodeReflection; 32 import jdk.incubator.code.Op; 33 import jdk.incubator.code.OpTransformer; 34 import jdk.incubator.code.TypeElement; 35 import jdk.incubator.code.analysis.SSA; 36 import jdk.incubator.code.bytecode.BytecodeGenerator; 37 import jdk.incubator.code.dialect.core.CoreOp; 38 import jdk.incubator.code.dialect.core.CoreType; 39 import jdk.incubator.code.dialect.core.FunctionType; 40 import jdk.incubator.code.dialect.java.JavaType; 41 import jdk.incubator.code.interpreter.Interpreter; 42 import org.junit.jupiter.api.Named; 43 import org.junit.jupiter.api.extension.ExtensionContext; 44 import org.junit.jupiter.api.function.ThrowingSupplier; 45 import org.junit.jupiter.params.ParameterizedTest; 46 import org.junit.jupiter.params.provider.Arguments; 47 import org.junit.jupiter.params.provider.ArgumentsProvider; 48 import org.junit.jupiter.params.provider.ArgumentsSource; 49 50 import java.lang.annotation.ElementType; 51 import java.lang.annotation.Retention; 52 import java.lang.annotation.RetentionPolicy; 53 import java.lang.annotation.Target; 54 import java.lang.invoke.MethodHandle; 55 import java.lang.invoke.MethodHandles; 56 import java.lang.invoke.MethodType; 57 import java.lang.reflect.AccessFlag; 58 import java.lang.reflect.Method; 59 import java.lang.reflect.Parameter; 60 import java.util.ArrayList; 61 import java.util.Arrays; 62 import java.util.List; 63 import java.util.Map; 64 import java.util.stream.Stream; 65 66 import static org.junit.jupiter.api.Assertions.*; 67 68 public class CoreBinaryOpsTest { 69 70 @CodeReflection 71 @SupportedTypes(TypeList.INTEGRAL_BOOLEAN) 72 static int and(int left, int right) { 73 return left & right; 74 } 75 76 @CodeReflection 77 @SupportedTypes(TypeList.INTEGRAL_FLOATING_POINT) 78 static int add(int left, int right) { 79 return left + right; 80 } 81 82 @CodeReflection 83 @SupportedTypes(TypeList.INTEGRAL_FLOATING_POINT) 84 static int div(int left, int right) { 85 return left / right; 86 } 87 88 @CodeReflection 89 @SupportedTypes(TypeList.INT_LONG) 90 static int leftShift(int left, int right) { 91 return left << right; 92 } 93 94 @CodeReflection 95 @Direct 96 static int leftShiftIL(int left, long right) { 97 return left << right; 98 } 99 100 @CodeReflection 101 @Direct 102 static long leftShiftLI(long left, int right) { 103 return left << right; 104 } 105 106 @CodeReflection 107 @SupportedTypes(TypeList.INTEGRAL_FLOATING_POINT) 108 static int mod(int left, int right) { 109 return left % right; 110 } 111 112 @CodeReflection 113 @SupportedTypes(TypeList.INTEGRAL_FLOATING_POINT) 114 static int mul(int left, int right) { 115 return left * right; 116 } 117 118 @CodeReflection 119 @SupportedTypes(TypeList.INTEGRAL_BOOLEAN) 120 static int or(int left, int right) { 121 return left | right; 122 } 123 124 @CodeReflection 125 @SupportedTypes(TypeList.INT_LONG) 126 static int signedRightShift(int left, int right) { 127 return left >> right; 128 } 129 130 @CodeReflection 131 @Direct 132 static int signedRightShiftIL(int left, long right) { 133 return left >> right; 134 } 135 136 @CodeReflection 137 @Direct 138 static long signedRightShiftLI(long left, int right) { 139 return left >> right; 140 } 141 142 @CodeReflection 143 @SupportedTypes(TypeList.INTEGRAL_FLOATING_POINT) 144 static int sub(int left, int right) { 145 return left - right; 146 } 147 148 @CodeReflection 149 @SupportedTypes(TypeList.INT_LONG) 150 static int unsignedRightShift(int left, int right) { 151 return left >>> right; 152 } 153 154 @CodeReflection 155 @Direct 156 static int unsignedRightShiftIL(int left, long right) { 157 return left >>> right; 158 } 159 160 @CodeReflection 161 @Direct 162 static long unsignedRightShiftLI(long left, int right) { 163 return left >>> right; 164 } 165 166 @CodeReflection 167 @SupportedTypes(TypeList.INTEGRAL_BOOLEAN) 168 static int xor(int left, int right) { 169 return left ^ right; 170 } 171 172 @ParameterizedTest 173 @CodeReflectionExecutionSource 174 void test(CoreOp.FuncOp funcOp, Object left, Object right) { 175 Result interpret = runCatching(() -> interpret(left, right, funcOp)); 176 Result bytecode = runCatching(() -> bytecode(left, right, funcOp)); 177 assertResults(interpret, bytecode); 178 } 179 180 @Retention(RetentionPolicy.RUNTIME) 181 @Target(ElementType.METHOD) 182 @interface SupportedTypes { 183 TypeList value(); 184 } 185 186 enum TypeList { 187 INT_LONG(int.class, long.class), 188 INTEGRAL_BOOLEAN(int.class, long.class, byte.class, short.class, char.class, boolean.class), 189 INTEGRAL_FLOATING_POINT(int.class, long.class, byte.class, short.class, char.class, float.class, double.class); 190 191 private final Class<?>[] types; 192 193 TypeList(Class<?>... types) { 194 this.types = types; 195 } 196 197 public Class<?>[] types() { 198 return types; 199 } 200 } 201 202 // mark as "do not transform" 203 @Retention(RetentionPolicy.RUNTIME) 204 @Target(ElementType.METHOD) 205 @interface Direct { 206 } 207 208 @Retention(RetentionPolicy.RUNTIME) 209 @Target(ElementType.METHOD) 210 @ArgumentsSource(CodeReflectionSourceProvider.class) 211 @interface CodeReflectionExecutionSource { 212 } 213 214 static class CodeReflectionSourceProvider implements ArgumentsProvider { 215 private static final Map<JavaType, List<?>> INTERESTING_INPUTS = Map.of( 216 // explicit type parameters to ensure boxing results in the expected type 217 JavaType.INT, List.<Integer>of(Integer.MIN_VALUE, Integer.MAX_VALUE, 1, 0, -1), 218 JavaType.LONG, List.<Long>of(Long.MIN_VALUE, Long.MAX_VALUE, 1L, 0L, -1L), 219 JavaType.BYTE, List.<Byte>of(Byte.MIN_VALUE, Byte.MAX_VALUE, (byte) 1, (byte) 0, (byte) -1), 220 JavaType.SHORT, List.<Short>of(Short.MIN_VALUE, Short.MAX_VALUE, (short) 1, (short) 0, (short) -1), 221 JavaType.CHAR, List.<Character>of(Character.MIN_VALUE, Character.MAX_VALUE, (char) 1, (char) 0, (char) -1), 222 JavaType.DOUBLE, List.<Double>of(Double.MIN_VALUE, Double.MAX_VALUE, Double.NaN, Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY, Double.MIN_NORMAL, 1d, 0d, -1d), 223 JavaType.FLOAT, List.<Float>of(Float.MIN_VALUE, Float.MAX_VALUE, Float.NaN, Float.NEGATIVE_INFINITY, Float.POSITIVE_INFINITY, Float.MIN_NORMAL, 1f, 0f, -1f), 224 JavaType.BOOLEAN, List.<Boolean>of(true, false) 225 ); 226 227 @Override 228 public Stream<? extends Arguments> provideArguments(ExtensionContext extensionContext) { 229 Method testMethod = extensionContext.getRequiredTestMethod(); 230 return codeReflectionMethods(extensionContext.getRequiredTestClass()) 231 .flatMap(method -> { 232 CoreOp.FuncOp funcOp = Op.ofMethod(method).orElseThrow( 233 () -> new IllegalStateException("Expected code model to be present for method " + method) 234 ); 235 SupportedTypes supportedTypes = method.getAnnotation(SupportedTypes.class); 236 if (method.isAnnotationPresent(Direct.class)) { 237 if (supportedTypes != null) { 238 throw new IllegalArgumentException("Direct should not be combined with SupportedTypes"); 239 } 240 return Stream.of(funcOp); 241 } 242 if (supportedTypes == null || supportedTypes.value().types().length == 0) { 243 throw new IllegalArgumentException("Missing supported types"); 244 } 245 return Arrays.stream(supportedTypes.value().types()) 246 .map(type -> retype(funcOp, type)); 247 }) 248 .flatMap(transformedFunc -> argumentsForMethod(transformedFunc, testMethod)); 249 } 250 251 private static <T> Stream<List<T>> cartesianProduct(List<List<? extends T>> source) { 252 if (source.isEmpty()) { 253 return Stream.of(new ArrayList<>()); 254 } 255 return source.getFirst().stream() 256 .flatMap(e -> cartesianProduct(source.subList(1, source.size())).map(l -> { 257 ArrayList<T> newList = new ArrayList<>(l); 258 newList.add(e); 259 return newList; 260 })); 261 } 262 263 private static CoreOp.FuncOp retype(CoreOp.FuncOp original, Class<?> newType) { 264 JavaType type = JavaType.type(newType); 265 FunctionType functionType = original.invokableType(); 266 if (functionType.parameterTypes().stream().allMatch(t -> t.equals(type))) { 267 return original; // already expected type 268 } 269 if (functionType.parameterTypes().stream().distinct().count() != 1) { 270 System.err.println(original.toText()); 271 throw new IllegalArgumentException("Only FuncOps with exactly one distinct parameter type are supported"); 272 } 273 // if the return type does not match the input types, we keep it 274 TypeElement retType = functionType.returnType().equals(functionType.parameterTypes().getFirst()) 275 ? type 276 : functionType.returnType(); 277 return CoreOp.func(original.funcName(), CoreType.functionType(retType, type, type)) 278 .body(builder -> builder.body(original.body(), builder.parameters(), OpTransformer.COPYING_TRANSFORMER) 279 ); 280 } 281 282 private static Stream<Arguments> argumentsForMethod(CoreOp.FuncOp funcOp, Method testMethod) { 283 Parameter[] testMethodParameters = testMethod.getParameters(); 284 List<TypeElement> funcParameters = funcOp.invokableType().parameterTypes(); 285 if (testMethodParameters.length - 1 != funcParameters.size()) { 286 throw new IllegalArgumentException("method " + testMethod + " does not take the correct number of parameters"); 287 } 288 if (testMethodParameters[0].getType() != CoreOp.FuncOp.class) { 289 throw new IllegalArgumentException("method " + testMethod + " does not take a leading FuncOp argument"); 290 } 291 Named<CoreOp.FuncOp> opNamed = Named.of(funcOp.funcName() + "{" + funcOp.invokableType() + "}", funcOp); 292 MethodHandles.Lookup lookup = MethodHandles.lookup(); 293 for (int i = 1; i < testMethodParameters.length; i++) { 294 Class<?> resolved = resolveParameter(funcParameters.get(i - 1), lookup); 295 if (!isCompatible(resolved, testMethodParameters[i].getType())) { 296 System.out.println(testMethod + " does not accept inputs of type " + resolved + " at index " + i); 297 return Stream.empty(); 298 } 299 } 300 List<List<?>> allInputs = new ArrayList<>(); 301 for (TypeElement parameterType : funcParameters) { 302 allInputs.add(INTERESTING_INPUTS.get((JavaType) parameterType)); 303 } 304 return cartesianProduct(allInputs) 305 .map(objects -> { 306 objects.add(opNamed); 307 return objects.reversed().toArray(); // reverse so FuncOp is at the beginning 308 }) 309 .map(Arguments::of); 310 } 311 312 private static Class<?> resolveParameter(TypeElement typeElement, MethodHandles.Lookup lookup) { 313 try { 314 return (Class<?>)((JavaType) typeElement).erasure().resolve(lookup); 315 } catch (ReflectiveOperationException e) { 316 throw new RuntimeException(e); 317 } 318 } 319 320 // check whether elements of type sourceType can be passed to a parameter of parameterType 321 private static boolean isCompatible(Class<?> sourceType, Class<?> parameterType) { 322 return wrapped(parameterType).isAssignableFrom(wrapped(sourceType)); 323 } 324 325 private static Class<?> wrapped(Class<?> target) { 326 return MethodType.methodType(target).wrap().returnType(); 327 } 328 329 private static Stream<Method> codeReflectionMethods(Class<?> testClass) { 330 return Arrays.stream(testClass.getDeclaredMethods()) 331 .filter(method -> method.accessFlags().contains(AccessFlag.STATIC)) 332 .filter(method -> method.isAnnotationPresent(CodeReflection.class)); 333 } 334 335 } 336 337 private static Object interpret(Object left, Object right, CoreOp.FuncOp op) { 338 return Interpreter.invoke(MethodHandles.lookup(), op, left, right); 339 } 340 341 private static Object bytecode(Object left, Object right, CoreOp.FuncOp op) throws Throwable { 342 CoreOp.FuncOp func = SSA.transform(op.transform(OpTransformer.LOWERING_TRANSFORMER)); 343 MethodHandle handle = BytecodeGenerator.generate(MethodHandles.lookup(), func); 344 return handle.invoke(left, right); 345 } 346 347 private static void assertResults(Result first, Result second) { 348 System.out.println("first: " + first); 349 System.out.println("second: " + second); 350 // either the same error occurred on both or no error occurred 351 if (first.throwable != null || second.throwable != null) { 352 assertNotNull(first.throwable, () -> "only second threw an exception: " + second.throwable); 353 assertNotNull(second.throwable, () -> "only first threw an exception: " + first.throwable); 354 if (first.throwable.getClass() != second.throwable.getClass()) { 355 first.throwable.printStackTrace(); 356 second.throwable.printStackTrace(); 357 fail("Different exceptions were thrown"); 358 } 359 return; 360 } 361 // otherwise, both results should be non-null and equals 362 assertNotNull(first.onSuccess); 363 assertEquals(first.onSuccess, second.onSuccess); 364 } 365 366 private static <T> Result runCatching(ThrowingSupplier<T> supplier) { 367 Object value = null; 368 Throwable interpretThrowable = null; 369 try { 370 value = supplier.get(); 371 } catch (Throwable t) { 372 interpretThrowable = t; 373 } 374 return new Result(value, interpretThrowable); 375 } 376 377 record Result(Object onSuccess, Throwable throwable) { 378 } 379 380 }