1 /*
  2  * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.
  8  *
  9  * This code is distributed in the hope that it will be useful, but WITHOUT
 10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 12  * version 2 for more details (a copy is included in the LICENSE file that
 13  * accompanied this code).
 14  *
 15  * You should have received a copy of the GNU General Public License version
 16  * 2 along with this work; if not, write to the Free Software Foundation,
 17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 18  *
 19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 20  * or visit www.oracle.com if you need additional information or have any
 21  * questions.
 22  */
 23 
 24 /*
 25  * @test
 26  * @run junit CoreBinaryOpsTest
 27  */
 28 
 29 import org.junit.jupiter.api.Named;
 30 import org.junit.jupiter.api.extension.ExtensionContext;
 31 import org.junit.jupiter.api.function.ThrowingSupplier;
 32 import org.junit.jupiter.params.ParameterizedTest;
 33 import org.junit.jupiter.params.provider.Arguments;
 34 import org.junit.jupiter.params.provider.ArgumentsProvider;
 35 import org.junit.jupiter.params.provider.ArgumentsSource;
 36 
 37 import java.lang.annotation.ElementType;
 38 import java.lang.annotation.Retention;
 39 import java.lang.annotation.RetentionPolicy;
 40 import java.lang.annotation.Target;
 41 import java.lang.invoke.MethodHandle;
 42 import java.lang.invoke.MethodHandles;
 43 import java.lang.invoke.MethodType;
 44 import java.lang.reflect.AccessFlag;
 45 import java.lang.reflect.Method;
 46 import java.lang.reflect.Parameter;
 47 import java.lang.reflect.code.CopyContext;
 48 import java.lang.reflect.code.Op;
 49 import java.lang.reflect.code.OpTransformer;
 50 import java.lang.reflect.code.TypeElement;
 51 import java.lang.reflect.code.analysis.SSA;
 52 import java.lang.reflect.code.bytecode.BytecodeGenerator;
 53 import java.lang.reflect.code.interpreter.Interpreter;
 54 import java.lang.reflect.code.op.CoreOp;
 55 import java.lang.reflect.code.type.FunctionType;
 56 import java.lang.reflect.code.type.JavaType;
 57 import java.lang.runtime.CodeReflection;
 58 import java.util.*;
 59 import java.util.stream.Stream;
 60 
 61 import static org.junit.jupiter.api.Assertions.*;
 62 
 63 public class CoreBinaryOpsTest {
 64 
 65     @CodeReflection
 66     @SupportedTypes(TypeList.INTEGRAL_BOOLEAN)
 67     static int and(int left, int right) {
 68         return left & right;
 69     }
 70 
 71     @CodeReflection
 72     @SupportedTypes(TypeList.INTEGRAL_FLOATING_POINT)
 73     static int add(int left, int right) {
 74         return left + right;
 75     }
 76 
 77     @CodeReflection
 78     @SupportedTypes(TypeList.INTEGRAL_FLOATING_POINT)
 79     static int div(int left, int right) {
 80         return left / right;
 81     }
 82 
 83     @CodeReflection
 84     @SupportedTypes(TypeList.INT_LONG)
 85     static int leftShift(int left, int right) {
 86         return left << right;
 87     }
 88 
 89     @CodeReflection
 90     @Direct
 91     static int leftShiftIL(int left, long right) {
 92         return left << right;
 93     }
 94 
 95     @CodeReflection
 96     @Direct
 97     static long leftShiftLI(long left, int right) {
 98         return left << right;
 99     }
100 
101     @CodeReflection
102     @SupportedTypes(TypeList.INTEGRAL_FLOATING_POINT)
103     static int mod(int left, int right) {
104         return left % right;
105     }
106 
107     @CodeReflection
108     @SupportedTypes(TypeList.INTEGRAL_FLOATING_POINT)
109     static int mul(int left, int right) {
110         return left * right;
111     }
112 
113     @CodeReflection
114     @SupportedTypes(TypeList.INTEGRAL_BOOLEAN)
115     static int or(int left, int right) {
116         return left | right;
117     }
118 
119     @CodeReflection
120     @SupportedTypes(TypeList.INT_LONG)
121     static int signedRightShift(int left, int right) {
122         return left >> right;
123     }
124 
125     @CodeReflection
126     @Direct
127     static int signedRightShiftIL(int left, long right) {
128         return left >> right;
129     }
130 
131     @CodeReflection
132     @Direct
133     static long signedRightShiftLI(long left, int right) {
134         return left >> right;
135     }
136 
137     @CodeReflection
138     @SupportedTypes(TypeList.INTEGRAL_FLOATING_POINT)
139     static int sub(int left, int right) {
140         return left - right;
141     }
142 
143     @CodeReflection
144     @SupportedTypes(TypeList.INT_LONG)
145     static int unsignedRightShift(int left, int right) {
146         return left >>> right;
147     }
148 
149     @CodeReflection
150     @Direct
151     static int unsignedRightShiftIL(int left, long right) {
152         return left >>> right;
153     }
154 
155     @CodeReflection
156     @Direct
157     static long unsignedRightShiftLI(long left, int right) {
158         return left >>> right;
159     }
160 
161     @CodeReflection
162     @SupportedTypes(TypeList.INTEGRAL_BOOLEAN)
163     static int xor(int left, int right) {
164         return left ^ right;
165     }
166 
167     @ParameterizedTest
168     @CodeReflectionExecutionSource
169     void test(CoreOp.FuncOp funcOp, Object left, Object right) {
170         Result interpret = runCatching(() -> interpret(left, right, funcOp));
171         Result bytecode = runCatching(() -> bytecode(left, right, funcOp));
172         assertResults(interpret, bytecode);
173     }
174 
175     @Retention(RetentionPolicy.RUNTIME)
176     @Target(ElementType.METHOD)
177     @interface SupportedTypes {
178         TypeList value();
179     }
180 
181     enum TypeList {
182         INT_LONG(int.class, long.class),
183         INTEGRAL_BOOLEAN(int.class, long.class, byte.class, short.class, char.class, boolean.class),
184         INTEGRAL_FLOATING_POINT(int.class, long.class, byte.class, short.class, char.class, float.class, double.class);
185 
186         private final Class<?>[] types;
187 
188         TypeList(Class<?>... types) {
189             this.types = types;
190         }
191 
192         public Class<?>[] types() {
193             return types;
194         }
195     }
196 
197     // mark as "do not transform"
198     @Retention(RetentionPolicy.RUNTIME)
199     @Target(ElementType.METHOD)
200     @interface Direct {
201     }
202 
203     @Retention(RetentionPolicy.RUNTIME)
204     @Target(ElementType.METHOD)
205     @ArgumentsSource(CodeReflectionSourceProvider.class)
206     @interface CodeReflectionExecutionSource {
207     }
208 
209     static class CodeReflectionSourceProvider implements ArgumentsProvider {
210         private static final Map<JavaType, List<?>> INTERESTING_INPUTS = Map.of(
211                 // explicit type parameters to ensure boxing results in the expected type
212                 JavaType.INT, List.<Integer>of(Integer.MIN_VALUE, Integer.MAX_VALUE, 1, 0, -1),
213                 JavaType.LONG, List.<Long>of(Long.MIN_VALUE, Long.MAX_VALUE, 1L, 0L, -1L),
214                 JavaType.BYTE, List.<Byte>of(Byte.MIN_VALUE, Byte.MAX_VALUE, (byte) 1, (byte) 0, (byte) -1),
215                 JavaType.SHORT, List.<Short>of(Short.MIN_VALUE, Short.MAX_VALUE, (short) 1, (short) 0, (short) -1),
216                 JavaType.CHAR, List.<Character>of(Character.MIN_VALUE, Character.MAX_VALUE, (char) 1, (char) 0, (char) -1),
217                 JavaType.DOUBLE, List.<Double>of(Double.MIN_VALUE, Double.MAX_VALUE, Double.NaN, Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY, Double.MIN_NORMAL, 1d, 0d, -1d),
218                 JavaType.FLOAT, List.<Float>of(Float.MIN_VALUE, Float.MAX_VALUE, Float.NaN, Float.NEGATIVE_INFINITY, Float.POSITIVE_INFINITY, Float.MIN_NORMAL, 1f, 0f, -1f),
219                 JavaType.BOOLEAN, List.<Boolean>of(true, false)
220         );
221 
222         @Override
223         public Stream<? extends Arguments> provideArguments(ExtensionContext extensionContext) {
224             Method testMethod = extensionContext.getRequiredTestMethod();
225             return codeReflectionMethods(extensionContext.getRequiredTestClass())
226                     .flatMap(method -> {
227                         CoreOp.FuncOp funcOp = method.getCodeModel().orElseThrow(
228                                 () -> new IllegalStateException("Expected code model to be present for method " + method)
229                         );
230                         SupportedTypes supportedTypes = method.getAnnotation(SupportedTypes.class);
231                         if (method.isAnnotationPresent(Direct.class)) {
232                             if (supportedTypes != null) {
233                                 throw new IllegalArgumentException("Direct should not be combined with SupportedTypes");
234                             }
235                             return Stream.of(funcOp);
236                         }
237                         if (supportedTypes == null || supportedTypes.value().types().length == 0) {
238                             throw new IllegalArgumentException("Missing supported types");
239                         }
240                         return Arrays.stream(supportedTypes.value().types())
241                                 .map(type -> retype(funcOp, type));
242                     })
243                     .flatMap(transformedFunc -> argumentsForMethod(transformedFunc, testMethod));
244         }
245 
246         private static <T> Stream<List<T>> cartesianProduct(List<List<? extends T>> source) {
247             if (source.isEmpty()) {
248                 return Stream.of(new ArrayList<>());
249             }
250             return source.getFirst().stream()
251                     .flatMap(e -> cartesianProduct(source.subList(1, source.size())).map(l -> {
252                         ArrayList<T> newList = new ArrayList<>(l);
253                         newList.add(e);
254                         return newList;
255                     }));
256         }
257 
258         private static CoreOp.FuncOp retype(CoreOp.FuncOp original, Class<?> newType) {
259             JavaType type = JavaType.type(newType);
260             FunctionType functionType = original.invokableType();
261             if (functionType.parameterTypes().stream().allMatch(t -> t.equals(type))) {
262                 return original; // already expected type
263             }
264             if (functionType.parameterTypes().stream().distinct().count() != 1) {
265                 original.writeTo(System.err);
266                 throw new IllegalArgumentException("Only FuncOps with exactly one distinct parameter type are supported");
267             }
268             // if the return type does not match the input types, we keep it
269             TypeElement retType = functionType.returnType().equals(functionType.parameterTypes().getFirst())
270                     ? type
271                     : functionType.returnType();
272             return CoreOp.func(original.funcName(), FunctionType.functionType(retType, type, type))
273                     .body(builder -> builder.transformBody(original.body(), builder.parameters(), (block, op) -> {
274                                 block.context().mapValue(op.result(), block.op(retype(block.context(), op)));
275                                 return block;
276                             })
277                     );
278         }
279 
280         private static Op retype(CopyContext context, Op op) {
281             return switch (op) {
282                 case CoreOp.VarOp varOp ->
283                         CoreOp.var(varOp.varName(), context.getValueOrDefault(varOp.operands().getFirst(), varOp.operands().getFirst()));
284                 default -> op;
285             };
286         }
287 
288         private static Stream<Arguments> argumentsForMethod(CoreOp.FuncOp funcOp, Method testMethod) {
289             Parameter[] testMethodParameters = testMethod.getParameters();
290             List<TypeElement> funcParameters = funcOp.invokableType().parameterTypes();
291             if (testMethodParameters.length - 1 != funcParameters.size()) {
292                 throw new IllegalArgumentException("method " + testMethod + " does not take the correct number of parameters");
293             }
294             if (testMethodParameters[0].getType() != CoreOp.FuncOp.class) {
295                 throw new IllegalArgumentException("method " + testMethod + " does not take a leading FuncOp argument");
296             }
297             Named<CoreOp.FuncOp> opNamed = Named.of(funcOp.funcName() + "{" + funcOp.invokableType() + "}", funcOp);
298             MethodHandles.Lookup lookup = MethodHandles.lookup();
299             for (int i = 1; i < testMethodParameters.length; i++) {
300                 Class<?> resolved = resolveParameter(funcParameters.get(i - 1), lookup);
301                 if (!isCompatible(resolved, testMethodParameters[i].getType())) {
302                     System.out.println(testMethod + " does not accept inputs of type " + resolved + " at index " + i);
303                     return Stream.empty();
304                 }
305             }
306             List<List<?>> allInputs = new ArrayList<>();
307             for (TypeElement parameterType : funcParameters) {
308                 allInputs.add(INTERESTING_INPUTS.get((JavaType) parameterType));
309             }
310             return cartesianProduct(allInputs)
311                     .map(objects -> {
312                         objects.add(opNamed);
313                         return objects.reversed().toArray(); // reverse so FuncOp is at the beginning
314                     })
315                     .map(Arguments::of);
316         }
317 
318         private static Class<?> resolveParameter(TypeElement typeElement, MethodHandles.Lookup lookup) {
319             try {
320                 return (Class<?>)((JavaType) typeElement).erasure().resolve(lookup);
321             } catch (ReflectiveOperationException e) {
322                 throw new RuntimeException(e);
323             }
324         }
325 
326         // check whether elements of type sourceType can be passed to a parameter of parameterType
327         private static boolean isCompatible(Class<?> sourceType, Class<?> parameterType) {
328             return wrapped(parameterType).isAssignableFrom(wrapped(sourceType));
329         }
330 
331         private static Class<?> wrapped(Class<?> target) {
332             return MethodType.methodType(target).wrap().returnType();
333         }
334 
335         private static Stream<Method> codeReflectionMethods(Class<?> testClass) {
336             return Arrays.stream(testClass.getDeclaredMethods())
337                     .filter(method -> method.accessFlags().contains(AccessFlag.STATIC))
338                     .filter(method -> method.isAnnotationPresent(CodeReflection.class));
339         }
340 
341     }
342 
343     private static Object interpret(Object left, Object right, CoreOp.FuncOp op) {
344         return Interpreter.invoke(MethodHandles.lookup(), op, left, right);
345     }
346 
347     private static Object bytecode(Object left, Object right, CoreOp.FuncOp op) throws Throwable {
348         CoreOp.FuncOp func = SSA.transform(op.transform(OpTransformer.LOWERING_TRANSFORMER));
349         MethodHandle handle = BytecodeGenerator.generate(MethodHandles.lookup(), func);
350         return handle.invoke(left, right);
351     }
352 
353     private static void assertResults(Result first, Result second) {
354         System.out.println("first: " + first);
355         System.out.println("second: " + second);
356         // either the same error occurred on both or no error occurred
357         if (first.throwable != null || second.throwable != null) {
358             assertNotNull(first.throwable, () -> "only second threw an exception: " + second.throwable);
359             assertNotNull(second.throwable, () -> "only first threw an exception: " + first.throwable);
360             if (first.throwable.getClass() != second.throwable.getClass()) {
361                 first.throwable.printStackTrace();
362                 second.throwable.printStackTrace();
363                 fail("Different exceptions were thrown");
364             }
365             return;
366         }
367         // otherwise, both results should be non-null and equals
368         assertNotNull(first.onSuccess);
369         assertEquals(first.onSuccess, second.onSuccess);
370     }
371 
372     private static <T> Result runCatching(ThrowingSupplier<T> supplier) {
373         Object value = null;
374         Throwable interpretThrowable = null;
375         try {
376             value = supplier.get();
377         } catch (Throwable t) {
378             interpretThrowable = t;
379         }
380         return new Result(value, interpretThrowable);
381     }
382 
383     record Result(Object onSuccess, Throwable throwable) {
384     }
385 
386 }