1 /*
2 * Copyright (c) 2024, 2025, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation.
8 *
9 * This code is distributed in the hope that it will be useful, but WITHOUT
10 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
11 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
12 * version 2 for more details (a copy is included in the LICENSE file that
13 * accompanied this code).
14 *
15 * You should have received a copy of the GNU General Public License version
16 * 2 along with this work; if not, write to the Free Software Foundation,
17 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
18 *
19 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
20 * or visit www.oracle.com if you need additional information or have any
21 * questions.
22 */
23
24 import jdk.incubator.code.*;
25 import jdk.incubator.code.Reflect;
26 import jdk.incubator.code.dialect.core.CoreOp;
27 import jdk.incubator.code.dialect.java.JavaOp;
28 import jdk.incubator.code.dialect.java.JavaType;
29 import jdk.incubator.code.interpreter.Interpreter;
30 import jdk.internal.classfile.components.ClassPrinter;
31 import org.junit.jupiter.api.Assertions;
32 import org.junit.jupiter.api.BeforeAll;
33 import org.junit.jupiter.params.ParameterizedTest;
34 import org.junit.jupiter.params.provider.MethodSource;
35
36 import java.lang.classfile.ClassFile;
37 import java.lang.classfile.ClassModel;
38 import java.lang.constant.MethodTypeDesc;
39 import java.lang.invoke.MethodHandles;
40 import java.lang.reflect.AccessFlag;
41 import java.lang.reflect.Method;
42 import java.util.*;
43 import java.util.function.Consumer;
44 import java.util.function.Function;
45 import java.util.function.IntUnaryOperator;
46 import java.util.stream.Collectors;
47 import java.util.stream.Stream;
48
49 /*
50 * @test
51 * @modules jdk.incubator.code/jdk.incubator.code.internal
52 * @modules java.base/jdk.internal.classfile.components
53 * @enablePreview
54 * @run junit/othervm -Djdk.invoke.MethodHandle.dumpClassFiles=true TestBytecodeLift
55 */
56
57 public class TestBytecodeLift {
58
59 @Reflect
60 static int intNumOps(int i, int j, int k) {
61 k++;
62 i = (i + j) / k - i % j;
63 i--;
64 return i;
65 }
66
67 @Reflect
68 static byte byteNumOps(byte i, byte j, byte k) {
69 k++;
70 i = (byte) ((i + j) / k - i % j);
71 i--;
72 return i;
73 }
74
75 @Reflect
76 static short shortNumOps(short i, short j, short k) {
77 k++;
78 i = (short) ((i + j) / k - i % j);
79 i--;
80 return i;
81 }
82
83 @Reflect
84 static char charNumOps(char i, char j, char k) {
85 k++;
86 i = (char) ((i + j) / k - i % j);
87 i--;
88 return i;
89 }
90
91 @Reflect
92 static long longNumOps(long i, long j, long k) {
93 k++;
94 i = (i + j) / k - i % j;
95 i--;
96 return i;
97 }
98
99 @Reflect
100 static float floatNumOps(float i, float j, float k) {
101 k++;
102 i = (i + j) / k - i % j;
103 i--;
104 return i;
105 }
106
107 @Reflect
108 static double doubleNumOps(double i, double j, double k) {
109 k++;
110 i = (i + j) / k - i % j;
111 i--;
112 return i;
113 }
114
115 @Reflect
116 static int intBitOps(int i, int j, int k) {
117 return ~(i & j | k ^ j);
118 }
119
120 @Reflect
121 static byte byteBitOps(byte i, byte j, byte k) {
122 return (byte) ~(i & j | k ^ j);
123 }
124
125 @Reflect
126 static short shortBitOps(short i, short j, short k) {
127 return (short) ~(i & j | k ^ j);
128 }
129
130 @Reflect
131 static char charBitOps(char i, char j, char k) {
132 return (char) ~(i & j | k ^ j);
133 }
134
135 @Reflect
136 static long longBitOps(long i, long j, long k) {
137 return ~(i & j | k ^ j);
138 }
139
140 @Reflect
141 static boolean boolBitOps(boolean i, boolean j, boolean k) {
142 return i & j | k ^ j;
143 }
144
145 @Reflect
146 static int intShiftOps(int i, int j, int k) {
147 return ((-1 >> i) << (j << k)) >>> (k - j);
148 }
149
150 @Reflect
151 static byte byteShiftOps(byte i, byte j, byte k) {
152 return (byte) (((-1 >> i) << (j << k)) >>> (k - j));
153 }
154
155 @Reflect
156 static short shortShiftOps(short i, short j, short k) {
157 return (short) (((-1 >> i) << (j << k)) >>> (k - j));
158 }
159
160 @Reflect
161 static char charShiftOps(char i, char j, char k) {
162 return (char) (((-1 >> i) << (j << k)) >>> (k - j));
163 }
164
165 @Reflect
166 static long longShiftOps(long i, long j, long k) {
167 return ((-1 >> i) << (j << k)) >>> (k - j);
168 }
169
170 @Reflect
171 static Object[] boxingAndUnboxing(int i, byte b, short s, char c, Integer ii, Byte bb, Short ss, Character cc) {
172 ii += i; ii += b; ii += s; ii += c;
173 i += ii; i += bb; i += ss; i += cc;
174 b += ii; b += bb; b += ss; b += cc;
175 s += ii; s += bb; s += ss; s += cc;
176 c += ii; c += bb; c += ss; c += cc;
177 return new Object[]{i, b, s, c};
178 }
179
180 @Reflect
181 static String constructor(String s, int i, int j) {
182 return new String(s.getBytes(), i, j);
183 }
184
185 @Reflect
186 static Class<?> classArray(int i, int j) {
187 Class<?>[] ifaces = new Class[1 + i + j];
188 ifaces[0] = Function.class;
189 return ifaces[0];
190 }
191
192 @Reflect
193 static String[] stringArray(int i, int j) {
194 return new String[i];
195 }
196
197 @Reflect
198 static String[][] stringArray2(int i, int j) {
199 return new String[i][];
200 }
201
202 @Reflect
203 static String[][] stringArrayMulti(int i, int j) {
204 return new String[i][j];
205 }
206
207 @Reflect
208 static int[][] initializedIntArray(int i, int j) {
209 return new int[][]{{i, j}, {i + j}};
210 }
211
212 @Reflect
213 static int ifElseCompare(int i, int j) {
214 if (i < 3) {
215 i += 1;
216 } else {
217 j += 2;
218 }
219 return i + j;
220 }
221
222 @Reflect
223 static int ifElseEquality(int i, int j) {
224 if (j != 0) {
225 if (i != 0) {
226 i += 1;
227 } else {
228 i += 2;
229 }
230 } else {
231 if (j != 0) {
232 i += 3;
233 } else {
234 i += 4;
235 }
236 }
237 return i;
238 }
239
240 @Reflect
241 static int objectsCompare(Boolean b1, Boolean b2, Boolean b3) {
242 Object a = b1;
243 Object b = b2;
244 Object c = b3;
245 return a == b ? (a != c ? 1 : 2) : (b != c ? 3 : 4);
246 }
247
248 @Reflect
249 static int conditionalExpr(int i, int j) {
250 return ((i - 1 >= 0) ? i - 1 : j - 1);
251 }
252
253 @Reflect
254 static int nestedConditionalExpr(int i, int j) {
255 return (i < 2) ? (j < 3) ? i : j : i + j;
256 }
257
258 static final int[] MAP = {0, 1, 2, 3, 4};
259
260 @Reflect
261 static int deepStackBranches(boolean a, boolean b) {
262 return MAP[a ? MAP[b ? 1 : 2] : MAP[b ? 3 : 4]];
263 }
264
265 @Reflect
266 static int tryFinally(int i, int j) {
267 try {
268 i = i + j;
269 } finally {
270 i = i + j;
271 }
272 return i;
273 }
274
275 public record A(String s) {}
276
277 @Reflect
278 static A newWithArgs(int i, int j) {
279 return new A("hello world".substring(i, i + j));
280 }
281
282 @Reflect
283 static int loop(int n, int j) {
284 int sum = 0;
285 for (int i = 0; i < n; i++) {
286 sum = sum + j;
287 }
288 return sum;
289 }
290
291
292 @Reflect
293 static int ifElseNested(int a, int b) {
294 int c = a + b;
295 int d = 10 - a + b;
296 if (b < 3) {
297 if (a < 3) {
298 a += 1;
299 } else {
300 b += 2;
301 }
302 c += 3;
303 } else {
304 if (a > 2) {
305 a += 4;
306 } else {
307 b += 5;
308 }
309 d += 6;
310 }
311 return a + b + c + d;
312 }
313
314 @Reflect
315 static int nestedLoop(int m, int n) {
316 int sum = 0;
317 for (int i = 0; i < m; i++) {
318 for (int j = 0; j < n; j++) {
319 sum = sum + i + j;
320 }
321 }
322 return sum;
323 }
324
325 @Reflect
326 static int methodCall(int a, int b) {
327 int i = Math.max(a, b);
328 return Math.negateExact(i);
329 }
330
331 @Reflect
332 static int[] primitiveArray(int i, int j) {
333 int[] ia = new int[i + 1];
334 ia[0] = j;
335 return ia;
336 }
337
338 @Reflect
339 static boolean not(boolean b) {
340 return !b;
341 }
342
343 @Reflect
344 static boolean notCompare(int i, int j) {
345 boolean b = i < j;
346 return !b;
347 }
348
349 @Reflect
350 static int mod(int i, int j) {
351 return i % (j + 1);
352 }
353
354 @Reflect
355 static int xor(int i, int j) {
356 return i ^ j;
357 }
358
359 @Reflect
360 static int whileLoop(int i, int n) { int
361 counter = 0;
362 while (i < n && counter < 3) {
363 counter++;
364 if (counter == 4) {
365 break;
366 }
367 i++;
368 }
369 return counter;
370 }
371
372 static int consumeQuotable(int i, IntUnaryOperator f) {
373 Assertions.assertNotNull(Op.ofQuotable(f).get());
374 Assertions.assertNotNull(Op.ofQuotable(f).get().op());
375 Assertions.assertTrue(Op.ofQuotable(f).get().op() instanceof JavaOp.LambdaOp);
376 return f.applyAsInt(i + 1);
377 }
378
379 @Reflect
380 static int quotableLambda(int i) {
381 return consumeQuotable(i, a -> -a);
382 }
383
384 @Reflect
385 static int quotableLambdaWithCapture(int i, String s) {
386 return consumeQuotable(i, a -> a + s.length());
387 }
388
389 @Reflect
390 static int nestedQuotableLambdasWithCaptures(int i, int j, String s) {
391 return consumeQuotable(i, a -> consumeQuotable(a, b -> a + b + j - s.length()) + s.length());
392 }
393
394 @Reflect
395 static int methodHandle(int i) {
396 return consumeQuotable(i, Math::negateExact);
397 }
398
399 int instanceMethod(int i) {
400 return -i + 13;
401 }
402
403 @Reflect
404 int instanceMethodHandle(int i) {
405 return consumeQuotable(i, this::instanceMethod);
406 }
407
408 static void consume(boolean b, Consumer<Object> requireNonNull) {
409 if (b) {
410 requireNonNull.accept(new Object());
411 } else try {
412 requireNonNull.accept(null);
413 throw new AssertionError("Expectend NPE");
414 } catch (NullPointerException expected) {
415 }
416 }
417
418 @Reflect
419 static void nullReturningMethodHandle(boolean b) {
420 consume(b, Objects::requireNonNull);
421 }
422
423 @Reflect
424 static boolean compareLong(long i, long j) {
425 return i > j;
426 }
427
428 @Reflect
429 static boolean compareFloat(float i, float j) {
430 return i > j;
431 }
432
433 @Reflect
434 static boolean compareDouble(double i, double j) {
435 return i > j;
436 }
437
438 @Reflect
439 static int lookupSwitch(int i) {
440 return switch (1000 * i) {
441 case 1000 -> 1;
442 case 2000 -> 2;
443 case 3000 -> 3;
444 default -> 0;
445 };
446 }
447
448 @Reflect
449 static int tableSwitch(int i) {
450 return switch (i) {
451 case 1 -> 1;
452 case 2 -> 2;
453 case 3 -> 3;
454 default -> 0;
455 };
456 }
457
458 int instanceField = -1;
459
460 @Reflect
461 int instanceFieldAccess(int i) {
462 int ret = instanceField;
463 instanceField = i;
464 return ret;
465 }
466
467 @Reflect
468 static String stringConcat(String a, String b) {
469 return "a"+ a +"\u0001" + a + "b\u0002c" + b + "\u0001\u0002" + b + "dd";
470 }
471
472 @Reflect
473 static String multiTypeConcat(int i, Boolean b, char c, Short s, float f, Double d) {
474 return "i:"+ i +" b:" + b + " c:" + c + " f:" + f + " d:" + d;
475 }
476
477 @Reflect
478 static int ifTrue(int i) {
479 if (true) {
480 return i;
481 }
482 return -i;
483 }
484
485 @Reflect
486 static int excHandlerFollowingSplitTable(boolean b) {
487 try {
488 if (b) return 1;
489 else throw new Exception();
490 } catch (Exception ex) {}
491 return 2;
492 }
493
494 @Reflect
495 static int varModifiedInTryBlock(boolean b) {
496 int i = 0;
497 try {
498 i++;
499 if (b) throw new Exception();
500 i++;
501 throw new Exception();
502 } catch (Exception ex) {
503 return i;
504 }
505 }
506
507 @Reflect
508 static boolean finallyWithLoop(boolean b) {
509 try {
510 while (b) {
511 if (b)
512 return false;
513 b = !b;
514 }
515 return true;
516 } finally {
517 b = false;
518 }
519 }
520
521 @Reflect
522 static long doubleUseOfOperand(int x) {
523 long piece = x;
524 return piece * piece;
525 }
526
527 record TestData(Method testMethod) {
528 @Override
529 public String toString() {
530 String s = testMethod.getName() + Arrays.stream(testMethod.getParameterTypes())
531 .map(Class::getSimpleName).collect(Collectors.joining(",", "(", ")"));
532 if (s.length() > 30) s = s.substring(0, 27) + "...";
533 return s;
534 }
535 }
536
537 public static Stream<TestData> testMethods() {
538 return Stream.of(TestBytecodeLift.class.getDeclaredMethods())
539 .filter(m -> m.isAnnotationPresent(Reflect.class))
540 .map(TestData::new);
541 }
542
543 private static byte[] CLASS_DATA;
544 private static ClassModel CLASS_MODEL;
545
546 @BeforeAll
547 public static void setup() throws Exception {
548 CLASS_DATA = TestBytecodeLift.class.getResourceAsStream("TestBytecodeLift.class").readAllBytes();
549 CLASS_MODEL = ClassFile.of().parse(CLASS_DATA);
550 }
551
552 private static MethodTypeDesc toMethodTypeDesc(Method m) {
553 return MethodTypeDesc.of(
554 m.getReturnType().describeConstable().orElseThrow(),
555 Arrays.stream(m.getParameterTypes())
556 .map(cls -> cls.describeConstable().orElseThrow()).toList());
557 }
558
559
560 private static final Map<Class<?>, Object[]> TEST_ARGS = new IdentityHashMap<>();
561 private static Object[] values(Object... values) {
562 return values;
563 }
564 private static void initTestArgs(Object[] values, Class<?>... argTypes) {
565 for (var argType : argTypes) TEST_ARGS.put(argType, values);
566 }
567 static {
568 initTestArgs(values(1, 2, 4), int.class, Integer.class);
569 initTestArgs(values((byte)1, (byte)3, (byte)4), byte.class, Byte.class);
570 initTestArgs(values((short)1, (short)2, (short)3), short.class, Short.class);
571 initTestArgs(values((char)2, (char)3, (char)4), char.class, Character.class);
572 initTestArgs(values(false, true), boolean.class, Boolean.class);
573 initTestArgs(values("Hello World"), String.class);
574 initTestArgs(values(1l, 2l, 4l), long.class, Long.class);
575 initTestArgs(values(1f, 3f, 4f), float.class, Float.class);
576 initTestArgs(values(1d, 2d, 3d), double.class, Double.class);
577 }
578
579 interface Executor {
580 void execute(Object[] args) throws Throwable;
581 }
582
583 private static void permutateAllArgs(Class<?>[] argTypes, Executor executor) throws Throwable {
584 final int argn = argTypes.length;
585 Object[][] argValues = new Object[argn][];
586 for (int i = 0; i < argn; i++) {
587 argValues[i] = TEST_ARGS.get(argTypes[i]);
588 }
589 int[] argIndexes = new int[argn];
590 Object[] args = new Object[argn];
591 while (true) {
592 for (int i = 0; i < argn; i++) {
593 args[i] = argValues[i][argIndexes[i]];
594 }
595 executor.execute(args);
596 int i = argn - 1;
597 while (i >= 0 && argIndexes[i] == argValues[i].length - 1) i--;
598 if (i < 0) return;
599 argIndexes[i++]++;
600 while (i < argn) argIndexes[i++] = 0;
601 }
602 }
603
604 @ParameterizedTest
605 @MethodSource("testMethods")
606 public void testLift(TestData d) throws Throwable {
607 CoreOp.FuncOp flift;
608 try {
609 flift = BytecodeLift.lift(CLASS_DATA, d.testMethod.getName(), toMethodTypeDesc(d.testMethod));
610 } catch (Throwable e) {
611 ClassPrinter.toYaml(ClassFile.of().parse(TestBytecodeLift.class.getResourceAsStream("TestBytecodeLift.class").readAllBytes())
612 .methods().stream().filter(m -> m.methodName().equalsString(d.testMethod().getName())).findAny().get(),
613 ClassPrinter.Verbosity.CRITICAL_ATTRIBUTES, System.out::print);
614 System.out.println("Lift failed, compiled model:");
615 Op.ofMethod(d.testMethod).ifPresent(f -> System.out.println(f.toText()));
616 throw e;
617 }
618 try {
619 Object receiver1, receiver2;
620 if (d.testMethod.accessFlags().contains(AccessFlag.STATIC)) {
621 receiver1 = null;
622 receiver2 = null;
623 } else {
624 receiver1 = new TestBytecodeLift();
625 receiver2 = new TestBytecodeLift();
626 }
627 permutateAllArgs(d.testMethod.getParameterTypes(), args ->
628 assertEquals(d.testMethod.invoke(receiver2, args), invokeAndConvert(flift, receiver1, args)));
629 } catch (Throwable e) {
630 System.out.println("Compiled model:");
631 Op.ofMethod(d.testMethod).ifPresent(f -> System.out.println(f.toText()));
632 System.out.println("Lifted model:");
633 System.out.println(flift.toText());
634 throw e;
635 }
636 }
637
638 private static Object invokeAndConvert(CoreOp.FuncOp func, Object receiver, Object... args) {
639 List argl = new ArrayList(args.length + 1);
640 if (receiver != null) argl.add(receiver);
641 argl.addAll(Arrays.asList(args));
642 Object ret = Interpreter.invoke(MethodHandles.lookup(), func, argl);
643 if (ret instanceof Integer i) {
644 TypeElement rt = func.invokableType().returnType();
645 if (rt.equals(JavaType.BOOLEAN)) {
646 return i != 0;
647 } else if (rt.equals(JavaType.BYTE)) {
648 return i.byteValue();
649 } else if (rt.equals(JavaType.CHAR)) {
650 return (short)i.intValue();
651 } else if (rt.equals(JavaType.SHORT)) {
652 return i.shortValue();
653 }
654 }
655 return ret;
656 }
657
658 private static void assertEquals(Object expected, Object actual) {
659 switch (expected) {
660 case int[] expArr when actual instanceof int[] actArr -> Assertions.assertArrayEquals(expArr, actArr);
661 case Object[] expArr when actual instanceof Object[] actArr -> Assertions.assertArrayEquals(expArr, actArr);
662 case null, default -> Assertions.assertEquals(expected, actual);
663 }
664 }
665 }