1 /*
2 * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation.
8 *
9 * This code is distributed in the hope that it will be useful, but WITHOUT
10 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
11 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
12 * version 2 for more details (a copy is included in the LICENSE file that
13 * accompanied this code).
14 *
15 * You should have received a copy of the GNU General Public License version
16 * 2 along with this work; if not, write to the Free Software Foundation,
17 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
18 *
19 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
20 * or visit www.oracle.com if you need additional information or have any
21 * questions.
22 */
23
24 /*
25 * @test
26 * @modules jdk.incubator.code
27 * @run junit CoreBinaryOpsTest
28 * @run junit/othervm -Dbabylon.ssa=cytron CoreBinaryOpsTest
29 */
30
31 import jdk.incubator.code.CodeReflection;
32 import jdk.incubator.code.Op;
33 import jdk.incubator.code.OpTransformer;
34 import jdk.incubator.code.TypeElement;
35 import jdk.incubator.code.analysis.SSA;
36 import jdk.incubator.code.bytecode.BytecodeGenerator;
37 import jdk.incubator.code.dialect.core.CoreOp;
38 import jdk.incubator.code.dialect.core.CoreType;
39 import jdk.incubator.code.dialect.core.FunctionType;
40 import jdk.incubator.code.dialect.java.JavaType;
41 import jdk.incubator.code.interpreter.Interpreter;
42 import org.junit.jupiter.api.Named;
43 import org.junit.jupiter.api.extension.ExtensionContext;
44 import org.junit.jupiter.api.function.ThrowingSupplier;
45 import org.junit.jupiter.params.ParameterizedTest;
46 import org.junit.jupiter.params.provider.Arguments;
47 import org.junit.jupiter.params.provider.ArgumentsProvider;
48 import org.junit.jupiter.params.provider.ArgumentsSource;
49
50 import java.lang.annotation.ElementType;
51 import java.lang.annotation.Retention;
52 import java.lang.annotation.RetentionPolicy;
53 import java.lang.annotation.Target;
54 import java.lang.invoke.MethodHandle;
55 import java.lang.invoke.MethodHandles;
56 import java.lang.invoke.MethodType;
57 import java.lang.reflect.AccessFlag;
58 import java.lang.reflect.Method;
59 import java.lang.reflect.Parameter;
60 import java.util.ArrayList;
61 import java.util.Arrays;
62 import java.util.List;
63 import java.util.Map;
64 import java.util.stream.Stream;
65
66 import static org.junit.jupiter.api.Assertions.*;
67
68 public class CoreBinaryOpsTest {
69
70 @CodeReflection
71 @SupportedTypes(TypeList.INTEGRAL_BOOLEAN)
72 static int and(int left, int right) {
73 return left & right;
74 }
75
76 @CodeReflection
77 @SupportedTypes(TypeList.INTEGRAL_FLOATING_POINT)
78 static int add(int left, int right) {
79 return left + right;
80 }
81
82 @CodeReflection
83 @SupportedTypes(TypeList.INTEGRAL_FLOATING_POINT)
84 static int div(int left, int right) {
85 return left / right;
86 }
87
88 @CodeReflection
89 @SupportedTypes(TypeList.INT_LONG)
90 static int leftShift(int left, int right) {
91 return left << right;
92 }
93
94 @CodeReflection
95 @Direct
96 static int leftShiftIL(int left, long right) {
97 return left << right;
98 }
99
100 @CodeReflection
101 @Direct
102 static long leftShiftLI(long left, int right) {
103 return left << right;
104 }
105
106 @CodeReflection
107 @SupportedTypes(TypeList.INTEGRAL_FLOATING_POINT)
108 static int mod(int left, int right) {
109 return left % right;
110 }
111
112 @CodeReflection
113 @SupportedTypes(TypeList.INTEGRAL_FLOATING_POINT)
114 static int mul(int left, int right) {
115 return left * right;
116 }
117
118 @CodeReflection
119 @SupportedTypes(TypeList.INTEGRAL_BOOLEAN)
120 static int or(int left, int right) {
121 return left | right;
122 }
123
124 @CodeReflection
125 @SupportedTypes(TypeList.INT_LONG)
126 static int signedRightShift(int left, int right) {
127 return left >> right;
128 }
129
130 @CodeReflection
131 @Direct
132 static int signedRightShiftIL(int left, long right) {
133 return left >> right;
134 }
135
136 @CodeReflection
137 @Direct
138 static long signedRightShiftLI(long left, int right) {
139 return left >> right;
140 }
141
142 @CodeReflection
143 @SupportedTypes(TypeList.INTEGRAL_FLOATING_POINT)
144 static int sub(int left, int right) {
145 return left - right;
146 }
147
148 @CodeReflection
149 @SupportedTypes(TypeList.INT_LONG)
150 static int unsignedRightShift(int left, int right) {
151 return left >>> right;
152 }
153
154 @CodeReflection
155 @Direct
156 static int unsignedRightShiftIL(int left, long right) {
157 return left >>> right;
158 }
159
160 @CodeReflection
161 @Direct
162 static long unsignedRightShiftLI(long left, int right) {
163 return left >>> right;
164 }
165
166 @CodeReflection
167 @SupportedTypes(TypeList.INTEGRAL_BOOLEAN)
168 static int xor(int left, int right) {
169 return left ^ right;
170 }
171
172 @ParameterizedTest
173 @CodeReflectionExecutionSource
174 void test(CoreOp.FuncOp funcOp, Object left, Object right) {
175 Result interpret = runCatching(() -> interpret(left, right, funcOp));
176 Result bytecode = runCatching(() -> bytecode(left, right, funcOp));
177 assertResults(interpret, bytecode);
178 }
179
180 @Retention(RetentionPolicy.RUNTIME)
181 @Target(ElementType.METHOD)
182 @interface SupportedTypes {
183 TypeList value();
184 }
185
186 enum TypeList {
187 INT_LONG(int.class, long.class),
188 INTEGRAL_BOOLEAN(int.class, long.class, byte.class, short.class, char.class, boolean.class),
189 INTEGRAL_FLOATING_POINT(int.class, long.class, byte.class, short.class, char.class, float.class, double.class);
190
191 private final Class<?>[] types;
192
193 TypeList(Class<?>... types) {
194 this.types = types;
195 }
196
197 public Class<?>[] types() {
198 return types;
199 }
200 }
201
202 // mark as "do not transform"
203 @Retention(RetentionPolicy.RUNTIME)
204 @Target(ElementType.METHOD)
205 @interface Direct {
206 }
207
208 @Retention(RetentionPolicy.RUNTIME)
209 @Target(ElementType.METHOD)
210 @ArgumentsSource(CodeReflectionSourceProvider.class)
211 @interface CodeReflectionExecutionSource {
212 }
213
214 static class CodeReflectionSourceProvider implements ArgumentsProvider {
215 private static final Map<JavaType, List<?>> INTERESTING_INPUTS = Map.of(
216 // explicit type parameters to ensure boxing results in the expected type
217 JavaType.INT, List.<Integer>of(Integer.MIN_VALUE, Integer.MAX_VALUE, 1, 0, -1),
218 JavaType.LONG, List.<Long>of(Long.MIN_VALUE, Long.MAX_VALUE, 1L, 0L, -1L),
219 JavaType.BYTE, List.<Byte>of(Byte.MIN_VALUE, Byte.MAX_VALUE, (byte) 1, (byte) 0, (byte) -1),
220 JavaType.SHORT, List.<Short>of(Short.MIN_VALUE, Short.MAX_VALUE, (short) 1, (short) 0, (short) -1),
221 JavaType.CHAR, List.<Character>of(Character.MIN_VALUE, Character.MAX_VALUE, (char) 1, (char) 0, (char) -1),
222 JavaType.DOUBLE, List.<Double>of(Double.MIN_VALUE, Double.MAX_VALUE, Double.NaN, Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY, Double.MIN_NORMAL, 1d, 0d, -1d),
223 JavaType.FLOAT, List.<Float>of(Float.MIN_VALUE, Float.MAX_VALUE, Float.NaN, Float.NEGATIVE_INFINITY, Float.POSITIVE_INFINITY, Float.MIN_NORMAL, 1f, 0f, -1f),
224 JavaType.BOOLEAN, List.<Boolean>of(true, false)
225 );
226
227 @Override
228 public Stream<? extends Arguments> provideArguments(ExtensionContext extensionContext) {
229 Method testMethod = extensionContext.getRequiredTestMethod();
230 return codeReflectionMethods(extensionContext.getRequiredTestClass())
231 .flatMap(method -> {
232 CoreOp.FuncOp funcOp = Op.ofMethod(method).orElseThrow(
233 () -> new IllegalStateException("Expected code model to be present for method " + method)
234 );
235 SupportedTypes supportedTypes = method.getAnnotation(SupportedTypes.class);
236 if (method.isAnnotationPresent(Direct.class)) {
237 if (supportedTypes != null) {
238 throw new IllegalArgumentException("Direct should not be combined with SupportedTypes");
239 }
240 return Stream.of(funcOp);
241 }
242 if (supportedTypes == null || supportedTypes.value().types().length == 0) {
243 throw new IllegalArgumentException("Missing supported types");
244 }
245 return Arrays.stream(supportedTypes.value().types())
246 .map(type -> retype(funcOp, type));
247 })
248 .flatMap(transformedFunc -> argumentsForMethod(transformedFunc, testMethod));
249 }
250
251 private static <T> Stream<List<T>> cartesianProduct(List<List<? extends T>> source) {
252 if (source.isEmpty()) {
253 return Stream.of(new ArrayList<>());
254 }
255 return source.getFirst().stream()
256 .flatMap(e -> cartesianProduct(source.subList(1, source.size())).map(l -> {
257 ArrayList<T> newList = new ArrayList<>(l);
258 newList.add(e);
259 return newList;
260 }));
261 }
262
263 private static CoreOp.FuncOp retype(CoreOp.FuncOp original, Class<?> newType) {
264 JavaType type = JavaType.type(newType);
265 FunctionType functionType = original.invokableType();
266 if (functionType.parameterTypes().stream().allMatch(t -> t.equals(type))) {
267 return original; // already expected type
268 }
269 if (functionType.parameterTypes().stream().distinct().count() != 1) {
270 System.err.println(original.toText());
271 throw new IllegalArgumentException("Only FuncOps with exactly one distinct parameter type are supported");
272 }
273 // if the return type does not match the input types, we keep it
274 TypeElement retType = functionType.returnType().equals(functionType.parameterTypes().getFirst())
275 ? type
276 : functionType.returnType();
277 return CoreOp.func(original.funcName(), CoreType.functionType(retType, type, type))
278 .body(builder -> builder.body(original.body(), builder.parameters(), OpTransformer.COPYING_TRANSFORMER)
279 );
280 }
281
282 private static Stream<Arguments> argumentsForMethod(CoreOp.FuncOp funcOp, Method testMethod) {
283 Parameter[] testMethodParameters = testMethod.getParameters();
284 List<TypeElement> funcParameters = funcOp.invokableType().parameterTypes();
285 if (testMethodParameters.length - 1 != funcParameters.size()) {
286 throw new IllegalArgumentException("method " + testMethod + " does not take the correct number of parameters");
287 }
288 if (testMethodParameters[0].getType() != CoreOp.FuncOp.class) {
289 throw new IllegalArgumentException("method " + testMethod + " does not take a leading FuncOp argument");
290 }
291 Named<CoreOp.FuncOp> opNamed = Named.of(funcOp.funcName() + "{" + funcOp.invokableType() + "}", funcOp);
292 MethodHandles.Lookup lookup = MethodHandles.lookup();
293 for (int i = 1; i < testMethodParameters.length; i++) {
294 Class<?> resolved = resolveParameter(funcParameters.get(i - 1), lookup);
295 if (!isCompatible(resolved, testMethodParameters[i].getType())) {
296 System.out.println(testMethod + " does not accept inputs of type " + resolved + " at index " + i);
297 return Stream.empty();
298 }
299 }
300 List<List<?>> allInputs = new ArrayList<>();
301 for (TypeElement parameterType : funcParameters) {
302 allInputs.add(INTERESTING_INPUTS.get((JavaType) parameterType));
303 }
304 return cartesianProduct(allInputs)
305 .map(objects -> {
306 objects.add(opNamed);
307 return objects.reversed().toArray(); // reverse so FuncOp is at the beginning
308 })
309 .map(Arguments::of);
310 }
311
312 private static Class<?> resolveParameter(TypeElement typeElement, MethodHandles.Lookup lookup) {
313 try {
314 return (Class<?>)((JavaType) typeElement).erasure().resolve(lookup);
315 } catch (ReflectiveOperationException e) {
316 throw new RuntimeException(e);
317 }
318 }
319
320 // check whether elements of type sourceType can be passed to a parameter of parameterType
321 private static boolean isCompatible(Class<?> sourceType, Class<?> parameterType) {
322 return wrapped(parameterType).isAssignableFrom(wrapped(sourceType));
323 }
324
325 private static Class<?> wrapped(Class<?> target) {
326 return MethodType.methodType(target).wrap().returnType();
327 }
328
329 private static Stream<Method> codeReflectionMethods(Class<?> testClass) {
330 return Arrays.stream(testClass.getDeclaredMethods())
331 .filter(method -> method.accessFlags().contains(AccessFlag.STATIC))
332 .filter(method -> method.isAnnotationPresent(CodeReflection.class));
333 }
334
335 }
336
337 private static Object interpret(Object left, Object right, CoreOp.FuncOp op) {
338 return Interpreter.invoke(MethodHandles.lookup(), op, left, right);
339 }
340
341 private static Object bytecode(Object left, Object right, CoreOp.FuncOp op) throws Throwable {
342 CoreOp.FuncOp func = SSA.transform(op.transform(OpTransformer.LOWERING_TRANSFORMER));
343 MethodHandle handle = BytecodeGenerator.generate(MethodHandles.lookup(), func);
344 return handle.invoke(left, right);
345 }
346
347 private static void assertResults(Result first, Result second) {
348 System.out.println("first: " + first);
349 System.out.println("second: " + second);
350 // either the same error occurred on both or no error occurred
351 if (first.throwable != null || second.throwable != null) {
352 assertNotNull(first.throwable, () -> "only second threw an exception: " + second.throwable);
353 assertNotNull(second.throwable, () -> "only first threw an exception: " + first.throwable);
354 if (first.throwable.getClass() != second.throwable.getClass()) {
355 first.throwable.printStackTrace();
356 second.throwable.printStackTrace();
357 fail("Different exceptions were thrown");
358 }
359 return;
360 }
361 // otherwise, both results should be non-null and equals
362 assertNotNull(first.onSuccess);
363 assertEquals(first.onSuccess, second.onSuccess);
364 }
365
366 private static <T> Result runCatching(ThrowingSupplier<T> supplier) {
367 Object value = null;
368 Throwable interpretThrowable = null;
369 try {
370 value = supplier.get();
371 } catch (Throwable t) {
372 interpretThrowable = t;
373 }
374 return new Result(value, interpretThrowable);
375 }
376
377 record Result(Object onSuccess, Throwable throwable) {
378 }
379
380 }