1 /*
  2  * Copyright (c) 2024, 2025, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.
  8  *
  9  * This code is distributed in the hope that it will be useful, but WITHOUT
 10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 12  * version 2 for more details (a copy is included in the LICENSE file that
 13  * accompanied this code).
 14  *
 15  * You should have received a copy of the GNU General Public License version
 16  * 2 along with this work; if not, write to the Free Software Foundation,
 17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 18  *
 19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 20  * or visit www.oracle.com if you need additional information or have any
 21  * questions.
 22  */
 23 
 24 import jdk.incubator.code.*;
 25 import jdk.incubator.code.Reflect;
 26 import jdk.incubator.code.dialect.core.CoreOp;
 27 import jdk.incubator.code.dialect.java.JavaOp;
 28 import jdk.incubator.code.dialect.java.JavaType;
 29 import jdk.incubator.code.interpreter.Interpreter;
 30 import jdk.internal.classfile.components.ClassPrinter;
 31 import org.junit.jupiter.api.Assertions;
 32 import org.junit.jupiter.api.BeforeAll;
 33 import org.junit.jupiter.params.ParameterizedTest;
 34 import org.junit.jupiter.params.provider.MethodSource;
 35 
 36 import java.lang.classfile.ClassFile;
 37 import java.lang.classfile.ClassModel;
 38 import java.lang.constant.MethodTypeDesc;
 39 import java.lang.invoke.MethodHandles;
 40 import java.lang.reflect.AccessFlag;
 41 import java.lang.reflect.Method;
 42 import java.util.*;
 43 import java.util.function.Consumer;
 44 import java.util.function.Function;
 45 import java.util.function.IntUnaryOperator;
 46 import java.util.stream.Collectors;
 47 import java.util.stream.Stream;
 48 
 49 /*
 50  * @test
 51  * @modules jdk.incubator.code/jdk.incubator.code.internal
 52  * @modules java.base/jdk.internal.classfile.components
 53  * @enablePreview
 54  * @run junit/othervm -Djdk.invoke.MethodHandle.dumpClassFiles=true TestBytecodeLift
 55  */
 56 
 57 public class TestBytecodeLift {
 58 
 59     @Reflect
 60     static int intNumOps(int i, int j, int k) {
 61         k++;
 62         i = (i + j) / k - i % j;
 63         i--;
 64         return i;
 65     }
 66 
 67     @Reflect
 68     static byte byteNumOps(byte i, byte j, byte k) {
 69         k++;
 70         i = (byte) ((i + j) / k - i % j);
 71         i--;
 72         return i;
 73     }
 74 
 75     @Reflect
 76     static short shortNumOps(short i, short j, short k) {
 77         k++;
 78         i = (short) ((i + j) / k - i % j);
 79         i--;
 80         return i;
 81     }
 82 
 83     @Reflect
 84     static char charNumOps(char i, char j, char k) {
 85         k++;
 86         i = (char) ((i + j) / k - i % j);
 87         i--;
 88         return i;
 89     }
 90 
 91     @Reflect
 92     static long longNumOps(long i, long j, long k) {
 93         k++;
 94         i = (i + j) / k - i % j;
 95         i--;
 96         return i;
 97     }
 98 
 99     @Reflect
100     static float floatNumOps(float i, float j, float k) {
101         k++;
102         i = (i + j) / k - i % j;
103         i--;
104         return i;
105     }
106 
107     @Reflect
108     static double doubleNumOps(double i, double j, double k) {
109         k++;
110         i = (i + j) / k - i % j;
111         i--;
112         return i;
113     }
114 
115     @Reflect
116     static int intBitOps(int i, int j, int k) {
117         return ~(i & j | k ^ j);
118     }
119 
120     @Reflect
121     static byte byteBitOps(byte i, byte j, byte k) {
122         return (byte) ~(i & j | k ^ j);
123     }
124 
125     @Reflect
126     static short shortBitOps(short i, short j, short k) {
127         return (short) ~(i & j | k ^ j);
128     }
129 
130     @Reflect
131     static char charBitOps(char i, char j, char k) {
132         return (char) ~(i & j | k ^ j);
133     }
134 
135     @Reflect
136     static long longBitOps(long i, long j, long k) {
137         return ~(i & j | k ^ j);
138     }
139 
140     @Reflect
141     static boolean boolBitOps(boolean i, boolean j, boolean k) {
142         return i & j | k ^ j;
143     }
144 
145     @Reflect
146     static int intShiftOps(int i, int j, int k) {
147         return ((-1 >> i) << (j << k)) >>> (k - j);
148     }
149 
150     @Reflect
151     static byte byteShiftOps(byte i, byte j, byte k) {
152         return (byte) (((-1 >> i) << (j << k)) >>> (k - j));
153     }
154 
155     @Reflect
156     static short shortShiftOps(short i, short j, short k) {
157         return (short) (((-1 >> i) << (j << k)) >>> (k - j));
158     }
159 
160     @Reflect
161     static char charShiftOps(char i, char j, char k) {
162         return (char) (((-1 >> i) << (j << k)) >>> (k - j));
163     }
164 
165     @Reflect
166     static long longShiftOps(long i, long j, long k) {
167         return ((-1 >> i) << (j << k)) >>> (k - j);
168     }
169 
170     @Reflect
171     static Object[] boxingAndUnboxing(int i, byte b, short s, char c, Integer ii, Byte bb, Short ss, Character cc) {
172         ii += i; ii += b; ii += s; ii += c;
173         i += ii; i += bb; i += ss; i += cc;
174         b += ii; b += bb; b += ss; b += cc;
175         s += ii; s += bb; s += ss; s += cc;
176         c += ii; c += bb; c += ss; c += cc;
177         return new Object[]{i, b, s, c};
178     }
179 
180     @Reflect
181     static String constructor(String s, int i, int j) {
182         return new String(s.getBytes(), i, j);
183     }
184 
185     @Reflect
186     static Class<?> classArray(int i, int j) {
187         Class<?>[] ifaces = new Class[1 + i + j];
188         ifaces[0] = Function.class;
189         return ifaces[0];
190     }
191 
192     @Reflect
193     static String[] stringArray(int i, int j) {
194         return new String[i];
195     }
196 
197     @Reflect
198     static String[][] stringArray2(int i, int j) {
199         return new String[i][];
200     }
201 
202     @Reflect
203     static String[][] stringArrayMulti(int i, int j) {
204         return new String[i][j];
205     }
206 
207     @Reflect
208     static int[][] initializedIntArray(int i, int j) {
209         return new int[][]{{i, j}, {i + j}};
210     }
211 
212     @Reflect
213     static int ifElseCompare(int i, int j) {
214         if (i < 3) {
215             i += 1;
216         } else {
217             j += 2;
218         }
219         return i + j;
220     }
221 
222     @Reflect
223     static int ifElseEquality(int i, int j) {
224         if (j != 0) {
225             if (i != 0) {
226                 i += 1;
227             } else {
228                 i += 2;
229             }
230         } else {
231             if (j != 0) {
232                 i += 3;
233             } else {
234                 i += 4;
235             }
236         }
237         return i;
238     }
239 
240     @Reflect
241     static int objectsCompare(Boolean b1, Boolean b2, Boolean b3) {
242         Object a = b1;
243         Object b = b2;
244         Object c = b3;
245         return a == b ? (a != c ? 1 : 2) : (b != c ? 3 : 4);
246     }
247 
248     @Reflect
249     static int conditionalExpr(int i, int j) {
250         return ((i - 1 >= 0) ? i - 1 : j - 1);
251     }
252 
253     @Reflect
254     static int nestedConditionalExpr(int i, int j) {
255         return (i < 2) ? (j < 3) ? i : j : i + j;
256     }
257 
258     static final int[] MAP = {0, 1, 2, 3, 4};
259 
260     @Reflect
261     static int deepStackBranches(boolean a, boolean b) {
262         return MAP[a ? MAP[b ? 1 : 2] : MAP[b ? 3 : 4]];
263     }
264 
265     @Reflect
266     static int tryFinally(int i, int j) {
267         try {
268             i = i + j;
269         } finally {
270             i = i + j;
271         }
272         return i;
273     }
274 
275     public record A(String s) {}
276 
277     @Reflect
278     static A newWithArgs(int i, int j) {
279         return new A("hello world".substring(i, i + j));
280     }
281 
282     @Reflect
283     static int loop(int n, int j) {
284         int sum = 0;
285         for (int i = 0; i < n; i++) {
286             sum = sum + j;
287         }
288         return sum;
289     }
290 
291 
292     @Reflect
293     static int ifElseNested(int a, int b) {
294         int c = a + b;
295         int d = 10 - a + b;
296         if (b < 3) {
297             if (a < 3) {
298                 a += 1;
299             } else {
300                 b += 2;
301             }
302             c += 3;
303         } else {
304             if (a > 2) {
305                 a += 4;
306             } else {
307                 b += 5;
308             }
309             d += 6;
310         }
311         return a + b + c + d;
312     }
313 
314     @Reflect
315     static int nestedLoop(int m, int n) {
316         int sum = 0;
317         for (int i = 0; i < m; i++) {
318             for (int j = 0; j < n; j++) {
319                 sum = sum + i + j;
320             }
321         }
322         return sum;
323     }
324 
325     @Reflect
326     static int methodCall(int a, int b) {
327         int i = Math.max(a, b);
328         return Math.negateExact(i);
329     }
330 
331     @Reflect
332     static int[] primitiveArray(int i, int j) {
333         int[] ia = new int[i + 1];
334         ia[0] = j;
335         return ia;
336     }
337 
338     @Reflect
339     static boolean not(boolean b) {
340         return !b;
341     }
342 
343     @Reflect
344     static boolean notCompare(int i, int j) {
345         boolean b = i < j;
346         return !b;
347     }
348 
349     @Reflect
350     static int mod(int i, int j) {
351         return i % (j + 1);
352     }
353 
354     @Reflect
355     static int xor(int i, int j) {
356         return i ^ j;
357     }
358 
359     @Reflect
360     static int whileLoop(int i, int n) { int
361         counter = 0;
362         while (i < n && counter < 3) {
363             counter++;
364             if (counter == 4) {
365                 break;
366             }
367             i++;
368         }
369         return counter;
370     }
371 
372     static int consumeQuotable(int i, IntUnaryOperator f) {
373         Assertions.assertNotNull(Op.ofQuotable(f).get());
374         Assertions.assertNotNull(Op.ofQuotable(f).get().op());
375         Assertions.assertTrue(Op.ofQuotable(f).get().op() instanceof JavaOp.LambdaOp);
376         return f.applyAsInt(i + 1);
377     }
378 
379     @Reflect
380     static int quotableLambda(int i) {
381         return consumeQuotable(i, a -> -a);
382     }
383 
384     @Reflect
385     static int quotableLambdaWithCapture(int i, String s) {
386         return consumeQuotable(i, a -> a + s.length());
387     }
388 
389     @Reflect
390     static int nestedQuotableLambdasWithCaptures(int i, int j, String s) {
391         return consumeQuotable(i, a -> consumeQuotable(a, b -> a + b + j - s.length()) + s.length());
392     }
393 
394     @Reflect
395     static int methodHandle(int i) {
396         return consumeQuotable(i, Math::negateExact);
397     }
398 
399     int instanceMethod(int i) {
400         return -i + 13;
401     }
402 
403     @Reflect
404     int instanceMethodHandle(int i) {
405         return consumeQuotable(i, this::instanceMethod);
406     }
407 
408     static void consume(boolean b, Consumer<Object> requireNonNull) {
409         if (b) {
410             requireNonNull.accept(new Object());
411         } else try {
412             requireNonNull.accept(null);
413             throw new AssertionError("Expectend NPE");
414         } catch (NullPointerException expected) {
415         }
416     }
417 
418     @Reflect
419     static void nullReturningMethodHandle(boolean b) {
420         consume(b, Objects::requireNonNull);
421     }
422 
423     @Reflect
424     static boolean compareLong(long i, long j) {
425         return i > j;
426     }
427 
428     @Reflect
429     static boolean compareFloat(float i, float j) {
430         return i > j;
431     }
432 
433     @Reflect
434     static boolean compareDouble(double i, double j) {
435         return i > j;
436     }
437 
438     @Reflect
439     static int lookupSwitch(int i) {
440         return switch (1000 * i) {
441             case 1000 -> 1;
442             case 2000 -> 2;
443             case 3000 -> 3;
444             default -> 0;
445         };
446     }
447 
448     @Reflect
449     static int tableSwitch(int i) {
450         return switch (i) {
451             case 1 -> 1;
452             case 2 -> 2;
453             case 3 -> 3;
454             default -> 0;
455         };
456     }
457 
458     int instanceField = -1;
459 
460     @Reflect
461     int instanceFieldAccess(int i) {
462         int ret = instanceField;
463         instanceField = i;
464         return ret;
465     }
466 
467     @Reflect
468     static String stringConcat(String a, String b) {
469         return "a"+ a +"\u0001" + a + "b\u0002c" + b + "\u0001\u0002" + b + "dd";
470     }
471 
472     @Reflect
473     static String multiTypeConcat(int i, Boolean b, char c, Short s, float f, Double d) {
474         return "i:"+ i +" b:" + b + " c:" + c + " f:" + f + " d:" + d;
475     }
476 
477     @Reflect
478     static int ifTrue(int i) {
479         if (true) {
480             return i;
481         }
482         return -i;
483     }
484 
485     @Reflect
486     static int excHandlerFollowingSplitTable(boolean b) {
487         try {
488             if (b) return 1;
489             else throw new Exception();
490         } catch (Exception ex) {}
491         return 2;
492     }
493 
494     @Reflect
495     static int varModifiedInTryBlock(boolean b) {
496         int i = 0;
497         try {
498             i++;
499             if (b) throw new Exception();
500             i++;
501             throw new Exception();
502         } catch (Exception ex) {
503             return i;
504         }
505     }
506 
507     @Reflect
508     static boolean finallyWithLoop(boolean b) {
509         try {
510             while (b) {
511                 if (b)
512                     return false;
513                 b = !b;
514             }
515             return true;
516         } finally {
517             b = false;
518         }
519     }
520 
521     @Reflect
522     static long doubleUseOfOperand(int x) {
523         long piece = x;
524         return piece * piece;
525     }
526 
527     record TestData(Method testMethod) {
528         @Override
529         public String toString() {
530             String s = testMethod.getName() + Arrays.stream(testMethod.getParameterTypes())
531                     .map(Class::getSimpleName).collect(Collectors.joining(",", "(", ")"));
532             if (s.length() > 30) s = s.substring(0, 27) + "...";
533             return s;
534         }
535     }
536 
537     public static Stream<TestData> testMethods() {
538         return Stream.of(TestBytecodeLift.class.getDeclaredMethods())
539                 .filter(m -> m.isAnnotationPresent(Reflect.class))
540                 .map(TestData::new);
541     }
542 
543     private static byte[] CLASS_DATA;
544     private static ClassModel CLASS_MODEL;
545 
546     @BeforeAll
547     public static void setup() throws Exception {
548         CLASS_DATA = TestBytecodeLift.class.getResourceAsStream("TestBytecodeLift.class").readAllBytes();
549         CLASS_MODEL = ClassFile.of().parse(CLASS_DATA);
550     }
551 
552     private static MethodTypeDesc toMethodTypeDesc(Method m) {
553         return MethodTypeDesc.of(
554                 m.getReturnType().describeConstable().orElseThrow(),
555                 Arrays.stream(m.getParameterTypes())
556                         .map(cls -> cls.describeConstable().orElseThrow()).toList());
557     }
558 
559 
560     private static final Map<Class<?>, Object[]> TEST_ARGS = new IdentityHashMap<>();
561     private static Object[] values(Object... values) {
562         return values;
563     }
564     private static void initTestArgs(Object[] values, Class<?>... argTypes) {
565         for (var argType : argTypes) TEST_ARGS.put(argType, values);
566     }
567     static {
568         initTestArgs(values(1, 2, 4), int.class, Integer.class);
569         initTestArgs(values((byte)1, (byte)3, (byte)4), byte.class, Byte.class);
570         initTestArgs(values((short)1, (short)2, (short)3), short.class, Short.class);
571         initTestArgs(values((char)2, (char)3, (char)4), char.class, Character.class);
572         initTestArgs(values(false, true), boolean.class, Boolean.class);
573         initTestArgs(values("Hello World"), String.class);
574         initTestArgs(values(1l, 2l, 4l), long.class, Long.class);
575         initTestArgs(values(1f, 3f, 4f), float.class, Float.class);
576         initTestArgs(values(1d, 2d, 3d), double.class, Double.class);
577     }
578 
579     interface Executor {
580         void execute(Object[] args) throws Throwable;
581     }
582 
583     private static void permutateAllArgs(Class<?>[] argTypes, Executor executor) throws Throwable {
584         final int argn = argTypes.length;
585         Object[][] argValues = new Object[argn][];
586         for (int i = 0; i < argn; i++) {
587             argValues[i] = TEST_ARGS.get(argTypes[i]);
588         }
589         int[] argIndexes = new int[argn];
590         Object[] args = new Object[argn];
591         while (true) {
592             for (int i = 0; i < argn; i++) {
593                 args[i] = argValues[i][argIndexes[i]];
594             }
595             executor.execute(args);
596             int i = argn - 1;
597             while (i >= 0 && argIndexes[i] == argValues[i].length - 1) i--;
598             if (i < 0) return;
599             argIndexes[i++]++;
600             while (i < argn) argIndexes[i++] = 0;
601         }
602     }
603 
604     @ParameterizedTest
605     @MethodSource("testMethods")
606     public void testLift(TestData d) throws Throwable {
607         CoreOp.FuncOp flift;
608         try {
609             flift = BytecodeLift.lift(CLASS_DATA, d.testMethod.getName(), toMethodTypeDesc(d.testMethod));
610         } catch (Throwable e) {
611             ClassPrinter.toYaml(ClassFile.of().parse(TestBytecodeLift.class.getResourceAsStream("TestBytecodeLift.class").readAllBytes())
612                     .methods().stream().filter(m -> m.methodName().equalsString(d.testMethod().getName())).findAny().get(),
613                     ClassPrinter.Verbosity.CRITICAL_ATTRIBUTES, System.out::print);
614             System.out.println("Lift failed, compiled model:");
615             Op.ofMethod(d.testMethod).ifPresent(f -> System.out.println(f.toText()));
616             throw e;
617         }
618         try {
619             Object receiver1, receiver2;
620             if (d.testMethod.accessFlags().contains(AccessFlag.STATIC)) {
621                 receiver1 = null;
622                 receiver2 = null;
623             } else {
624                 receiver1 = new TestBytecodeLift();
625                 receiver2 = new TestBytecodeLift();
626             }
627             permutateAllArgs(d.testMethod.getParameterTypes(), args ->
628                     assertEquals(d.testMethod.invoke(receiver2, args), invokeAndConvert(flift, receiver1, args)));
629         } catch (Throwable e) {
630             System.out.println("Compiled model:");
631             Op.ofMethod(d.testMethod).ifPresent(f -> System.out.println(f.toText()));
632             System.out.println("Lifted model:");
633             System.out.println(flift.toText());
634             throw e;
635         }
636     }
637 
638     private static Object invokeAndConvert(CoreOp.FuncOp func, Object receiver, Object... args) {
639         List argl = new ArrayList(args.length + 1);
640         if (receiver != null) argl.add(receiver);
641         argl.addAll(Arrays.asList(args));
642         Object ret = Interpreter.invoke(MethodHandles.lookup(), func, argl);
643         if (ret instanceof Integer i) {
644             TypeElement rt = func.invokableType().returnType();
645             if (rt.equals(JavaType.BOOLEAN)) {
646                 return i != 0;
647             } else if (rt.equals(JavaType.BYTE)) {
648                 return i.byteValue();
649             } else if (rt.equals(JavaType.CHAR)) {
650                 return (short)i.intValue();
651             } else if (rt.equals(JavaType.SHORT)) {
652                 return i.shortValue();
653             }
654         }
655         return ret;
656     }
657 
658     private static void assertEquals(Object expected, Object actual) {
659         switch (expected) {
660             case int[] expArr when actual instanceof int[] actArr -> Assertions.assertArrayEquals(expArr, actArr);
661             case Object[] expArr when actual instanceof Object[] actArr -> Assertions.assertArrayEquals(expArr, actArr);
662             case null, default -> Assertions.assertEquals(expected, actual);
663         }
664     }
665 }