1 /*
  2  * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.
  8  *
  9  * This code is distributed in the hope that it will be useful, but WITHOUT
 10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 12  * version 2 for more details (a copy is included in the LICENSE file that
 13  * accompanied this code).
 14  *
 15  * You should have received a copy of the GNU General Public License version
 16  * 2 along with this work; if not, write to the Free Software Foundation,
 17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 18  *
 19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 20  * or visit www.oracle.com if you need additional information or have any
 21  * questions.
 22  */
 23 
 24 /*
 25  * @test
 26  * @modules jdk.incubator.code
 27  * @run junit CoreBinaryOpsTest
 28  * @run junit/othervm -Dbabylon.ssa=cytron CoreBinaryOpsTest
 29  */
 30 
 31 import jdk.incubator.code.CodeReflection;
 32 import jdk.incubator.code.Op;
 33 import jdk.incubator.code.OpTransformer;
 34 import jdk.incubator.code.TypeElement;
 35 import jdk.incubator.code.analysis.SSA;
 36 import jdk.incubator.code.bytecode.BytecodeGenerator;
 37 import jdk.incubator.code.dialect.core.CoreOp;
 38 import jdk.incubator.code.dialect.core.CoreType;
 39 import jdk.incubator.code.dialect.core.FunctionType;
 40 import jdk.incubator.code.dialect.java.JavaType;
 41 import jdk.incubator.code.interpreter.Interpreter;
 42 import org.junit.jupiter.api.Named;
 43 import org.junit.jupiter.api.extension.ExtensionContext;
 44 import org.junit.jupiter.api.function.ThrowingSupplier;
 45 import org.junit.jupiter.params.ParameterizedTest;
 46 import org.junit.jupiter.params.provider.Arguments;
 47 import org.junit.jupiter.params.provider.ArgumentsProvider;
 48 import org.junit.jupiter.params.provider.ArgumentsSource;
 49 
 50 import java.lang.annotation.ElementType;
 51 import java.lang.annotation.Retention;
 52 import java.lang.annotation.RetentionPolicy;
 53 import java.lang.annotation.Target;
 54 import java.lang.invoke.MethodHandle;
 55 import java.lang.invoke.MethodHandles;
 56 import java.lang.invoke.MethodType;
 57 import java.lang.reflect.AccessFlag;
 58 import java.lang.reflect.Method;
 59 import java.lang.reflect.Parameter;
 60 import java.util.ArrayList;
 61 import java.util.Arrays;
 62 import java.util.List;
 63 import java.util.Map;
 64 import java.util.stream.Stream;
 65 
 66 import static org.junit.jupiter.api.Assertions.*;
 67 
 68 public class CoreBinaryOpsTest {
 69 
 70     @CodeReflection
 71     @SupportedTypes(TypeList.INTEGRAL_BOOLEAN)
 72     static int and(int left, int right) {
 73         return left & right;
 74     }
 75 
 76     @CodeReflection
 77     @SupportedTypes(TypeList.INTEGRAL_FLOATING_POINT)
 78     static int add(int left, int right) {
 79         return left + right;
 80     }
 81 
 82     @CodeReflection
 83     @SupportedTypes(TypeList.INTEGRAL_FLOATING_POINT)
 84     static int div(int left, int right) {
 85         return left / right;
 86     }
 87 
 88     @CodeReflection
 89     @SupportedTypes(TypeList.INT_LONG)
 90     static int leftShift(int left, int right) {
 91         return left << right;
 92     }
 93 
 94     @CodeReflection
 95     @Direct
 96     static int leftShiftIL(int left, long right) {
 97         return left << right;
 98     }
 99 
100     @CodeReflection
101     @Direct
102     static long leftShiftLI(long left, int right) {
103         return left << right;
104     }
105 
106     @CodeReflection
107     @SupportedTypes(TypeList.INTEGRAL_FLOATING_POINT)
108     static int mod(int left, int right) {
109         return left % right;
110     }
111 
112     @CodeReflection
113     @SupportedTypes(TypeList.INTEGRAL_FLOATING_POINT)
114     static int mul(int left, int right) {
115         return left * right;
116     }
117 
118     @CodeReflection
119     @SupportedTypes(TypeList.INTEGRAL_BOOLEAN)
120     static int or(int left, int right) {
121         return left | right;
122     }
123 
124     @CodeReflection
125     @SupportedTypes(TypeList.INT_LONG)
126     static int signedRightShift(int left, int right) {
127         return left >> right;
128     }
129 
130     @CodeReflection
131     @Direct
132     static int signedRightShiftIL(int left, long right) {
133         return left >> right;
134     }
135 
136     @CodeReflection
137     @Direct
138     static long signedRightShiftLI(long left, int right) {
139         return left >> right;
140     }
141 
142     @CodeReflection
143     @SupportedTypes(TypeList.INTEGRAL_FLOATING_POINT)
144     static int sub(int left, int right) {
145         return left - right;
146     }
147 
148     @CodeReflection
149     @SupportedTypes(TypeList.INT_LONG)
150     static int unsignedRightShift(int left, int right) {
151         return left >>> right;
152     }
153 
154     @CodeReflection
155     @Direct
156     static int unsignedRightShiftIL(int left, long right) {
157         return left >>> right;
158     }
159 
160     @CodeReflection
161     @Direct
162     static long unsignedRightShiftLI(long left, int right) {
163         return left >>> right;
164     }
165 
166     @CodeReflection
167     @SupportedTypes(TypeList.INTEGRAL_BOOLEAN)
168     static int xor(int left, int right) {
169         return left ^ right;
170     }
171 
172     @ParameterizedTest
173     @CodeReflectionExecutionSource
174     void test(CoreOp.FuncOp funcOp, Object left, Object right) {
175         Result interpret = runCatching(() -> interpret(left, right, funcOp));
176         Result bytecode = runCatching(() -> bytecode(left, right, funcOp));
177         assertResults(interpret, bytecode);
178     }
179 
180     @Retention(RetentionPolicy.RUNTIME)
181     @Target(ElementType.METHOD)
182     @interface SupportedTypes {
183         TypeList value();
184     }
185 
186     enum TypeList {
187         INT_LONG(int.class, long.class),
188         INTEGRAL_BOOLEAN(int.class, long.class, byte.class, short.class, char.class, boolean.class),
189         INTEGRAL_FLOATING_POINT(int.class, long.class, byte.class, short.class, char.class, float.class, double.class);
190 
191         private final Class<?>[] types;
192 
193         TypeList(Class<?>... types) {
194             this.types = types;
195         }
196 
197         public Class<?>[] types() {
198             return types;
199         }
200     }
201 
202     // mark as "do not transform"
203     @Retention(RetentionPolicy.RUNTIME)
204     @Target(ElementType.METHOD)
205     @interface Direct {
206     }
207 
208     @Retention(RetentionPolicy.RUNTIME)
209     @Target(ElementType.METHOD)
210     @ArgumentsSource(CodeReflectionSourceProvider.class)
211     @interface CodeReflectionExecutionSource {
212     }
213 
214     static class CodeReflectionSourceProvider implements ArgumentsProvider {
215         private static final Map<JavaType, List<?>> INTERESTING_INPUTS = Map.of(
216                 // explicit type parameters to ensure boxing results in the expected type
217                 JavaType.INT, List.<Integer>of(Integer.MIN_VALUE, Integer.MAX_VALUE, 1, 0, -1),
218                 JavaType.LONG, List.<Long>of(Long.MIN_VALUE, Long.MAX_VALUE, 1L, 0L, -1L),
219                 JavaType.BYTE, List.<Byte>of(Byte.MIN_VALUE, Byte.MAX_VALUE, (byte) 1, (byte) 0, (byte) -1),
220                 JavaType.SHORT, List.<Short>of(Short.MIN_VALUE, Short.MAX_VALUE, (short) 1, (short) 0, (short) -1),
221                 JavaType.CHAR, List.<Character>of(Character.MIN_VALUE, Character.MAX_VALUE, (char) 1, (char) 0, (char) -1),
222                 JavaType.DOUBLE, List.<Double>of(Double.MIN_VALUE, Double.MAX_VALUE, Double.NaN, Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY, Double.MIN_NORMAL, 1d, 0d, -1d),
223                 JavaType.FLOAT, List.<Float>of(Float.MIN_VALUE, Float.MAX_VALUE, Float.NaN, Float.NEGATIVE_INFINITY, Float.POSITIVE_INFINITY, Float.MIN_NORMAL, 1f, 0f, -1f),
224                 JavaType.BOOLEAN, List.<Boolean>of(true, false)
225         );
226 
227         @Override
228         public Stream<? extends Arguments> provideArguments(ExtensionContext extensionContext) {
229             Method testMethod = extensionContext.getRequiredTestMethod();
230             return codeReflectionMethods(extensionContext.getRequiredTestClass())
231                     .flatMap(method -> {
232                         CoreOp.FuncOp funcOp = Op.ofMethod(method).orElseThrow(
233                                 () -> new IllegalStateException("Expected code model to be present for method " + method)
234                         );
235                         SupportedTypes supportedTypes = method.getAnnotation(SupportedTypes.class);
236                         if (method.isAnnotationPresent(Direct.class)) {
237                             if (supportedTypes != null) {
238                                 throw new IllegalArgumentException("Direct should not be combined with SupportedTypes");
239                             }
240                             return Stream.of(funcOp);
241                         }
242                         if (supportedTypes == null || supportedTypes.value().types().length == 0) {
243                             throw new IllegalArgumentException("Missing supported types");
244                         }
245                         return Arrays.stream(supportedTypes.value().types())
246                                 .map(type -> retype(funcOp, type));
247                     })
248                     .flatMap(transformedFunc -> argumentsForMethod(transformedFunc, testMethod));
249         }
250 
251         private static <T> Stream<List<T>> cartesianProduct(List<List<? extends T>> source) {
252             if (source.isEmpty()) {
253                 return Stream.of(new ArrayList<>());
254             }
255             return source.getFirst().stream()
256                     .flatMap(e -> cartesianProduct(source.subList(1, source.size())).map(l -> {
257                         ArrayList<T> newList = new ArrayList<>(l);
258                         newList.add(e);
259                         return newList;
260                     }));
261         }
262 
263         private static CoreOp.FuncOp retype(CoreOp.FuncOp original, Class<?> newType) {
264             JavaType type = JavaType.type(newType);
265             FunctionType functionType = original.invokableType();
266             if (functionType.parameterTypes().stream().allMatch(t -> t.equals(type))) {
267                 return original; // already expected type
268             }
269             if (functionType.parameterTypes().stream().distinct().count() != 1) {
270                 System.err.println(original.toText());
271                 throw new IllegalArgumentException("Only FuncOps with exactly one distinct parameter type are supported");
272             }
273             // if the return type does not match the input types, we keep it
274             TypeElement retType = functionType.returnType().equals(functionType.parameterTypes().getFirst())
275                     ? type
276                     : functionType.returnType();
277             return CoreOp.func(original.funcName(), CoreType.functionType(retType, type, type))
278                     .body(builder -> builder.body(original.body(), builder.parameters(), OpTransformer.COPYING_TRANSFORMER)
279                     );
280         }
281 
282         private static Stream<Arguments> argumentsForMethod(CoreOp.FuncOp funcOp, Method testMethod) {
283             Parameter[] testMethodParameters = testMethod.getParameters();
284             List<TypeElement> funcParameters = funcOp.invokableType().parameterTypes();
285             if (testMethodParameters.length - 1 != funcParameters.size()) {
286                 throw new IllegalArgumentException("method " + testMethod + " does not take the correct number of parameters");
287             }
288             if (testMethodParameters[0].getType() != CoreOp.FuncOp.class) {
289                 throw new IllegalArgumentException("method " + testMethod + " does not take a leading FuncOp argument");
290             }
291             Named<CoreOp.FuncOp> opNamed = Named.of(funcOp.funcName() + "{" + funcOp.invokableType() + "}", funcOp);
292             MethodHandles.Lookup lookup = MethodHandles.lookup();
293             for (int i = 1; i < testMethodParameters.length; i++) {
294                 Class<?> resolved = resolveParameter(funcParameters.get(i - 1), lookup);
295                 if (!isCompatible(resolved, testMethodParameters[i].getType())) {
296                     System.out.println(testMethod + " does not accept inputs of type " + resolved + " at index " + i);
297                     return Stream.empty();
298                 }
299             }
300             List<List<?>> allInputs = new ArrayList<>();
301             for (TypeElement parameterType : funcParameters) {
302                 allInputs.add(INTERESTING_INPUTS.get((JavaType) parameterType));
303             }
304             return cartesianProduct(allInputs)
305                     .map(objects -> {
306                         objects.add(opNamed);
307                         return objects.reversed().toArray(); // reverse so FuncOp is at the beginning
308                     })
309                     .map(Arguments::of);
310         }
311 
312         private static Class<?> resolveParameter(TypeElement typeElement, MethodHandles.Lookup lookup) {
313             try {
314                 return (Class<?>)((JavaType) typeElement).erasure().resolve(lookup);
315             } catch (ReflectiveOperationException e) {
316                 throw new RuntimeException(e);
317             }
318         }
319 
320         // check whether elements of type sourceType can be passed to a parameter of parameterType
321         private static boolean isCompatible(Class<?> sourceType, Class<?> parameterType) {
322             return wrapped(parameterType).isAssignableFrom(wrapped(sourceType));
323         }
324 
325         private static Class<?> wrapped(Class<?> target) {
326             return MethodType.methodType(target).wrap().returnType();
327         }
328 
329         private static Stream<Method> codeReflectionMethods(Class<?> testClass) {
330             return Arrays.stream(testClass.getDeclaredMethods())
331                     .filter(method -> method.accessFlags().contains(AccessFlag.STATIC))
332                     .filter(method -> method.isAnnotationPresent(CodeReflection.class));
333         }
334 
335     }
336 
337     private static Object interpret(Object left, Object right, CoreOp.FuncOp op) {
338         return Interpreter.invoke(MethodHandles.lookup(), op, left, right);
339     }
340 
341     private static Object bytecode(Object left, Object right, CoreOp.FuncOp op) throws Throwable {
342         CoreOp.FuncOp func = SSA.transform(op.transform(OpTransformer.LOWERING_TRANSFORMER));
343         MethodHandle handle = BytecodeGenerator.generate(MethodHandles.lookup(), func);
344         return handle.invoke(left, right);
345     }
346 
347     private static void assertResults(Result first, Result second) {
348         System.out.println("first: " + first);
349         System.out.println("second: " + second);
350         // either the same error occurred on both or no error occurred
351         if (first.throwable != null || second.throwable != null) {
352             assertNotNull(first.throwable, () -> "only second threw an exception: " + second.throwable);
353             assertNotNull(second.throwable, () -> "only first threw an exception: " + first.throwable);
354             if (first.throwable.getClass() != second.throwable.getClass()) {
355                 first.throwable.printStackTrace();
356                 second.throwable.printStackTrace();
357                 fail("Different exceptions were thrown");
358             }
359             return;
360         }
361         // otherwise, both results should be non-null and equals
362         assertNotNull(first.onSuccess);
363         assertEquals(first.onSuccess, second.onSuccess);
364     }
365 
366     private static <T> Result runCatching(ThrowingSupplier<T> supplier) {
367         Object value = null;
368         Throwable interpretThrowable = null;
369         try {
370             value = supplier.get();
371         } catch (Throwable t) {
372             interpretThrowable = t;
373         }
374         return new Result(value, interpretThrowable);
375     }
376 
377     record Result(Object onSuccess, Throwable throwable) {
378     }
379 
380 }