1 /*
  2  * Copyright (c) 2024, 2025, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.
  8  *
  9  * This code is distributed in the hope that it will be useful, but WITHOUT
 10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 12  * version 2 for more details (a copy is included in the LICENSE file that
 13  * accompanied this code).
 14  *
 15  * You should have received a copy of the GNU General Public License version
 16  * 2 along with this work; if not, write to the Free Software Foundation,
 17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 18  *
 19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 20  * or visit www.oracle.com if you need additional information or have any
 21  * questions.
 22  */
 23 
 24 import jdk.incubator.code.*;
 25 import jdk.incubator.code.dialect.core.CoreOp;
 26 import jdk.incubator.code.dialect.java.JavaOp;
 27 import jdk.incubator.code.dialect.java.JavaType;
 28 import jdk.incubator.code.interpreter.Interpreter;
 29 import jdk.internal.classfile.components.ClassPrinter;
 30 import org.junit.jupiter.api.Assertions;
 31 import org.junit.jupiter.api.BeforeAll;
 32 import org.junit.jupiter.params.ParameterizedTest;
 33 import org.junit.jupiter.params.provider.MethodSource;
 34 import org.opentest4j.TestSkippedException;
 35 
 36 import java.io.IOException;
 37 import java.lang.classfile.ClassFile;
 38 import java.lang.classfile.ClassModel;
 39 import java.lang.constant.MethodTypeDesc;
 40 import java.lang.invoke.MethodHandle;
 41 import java.lang.invoke.MethodHandles;
 42 import java.lang.reflect.AccessFlag;
 43 import java.lang.reflect.Method;
 44 import java.nio.file.Files;
 45 import java.nio.file.Path;
 46 import java.util.*;
 47 import java.util.function.Consumer;
 48 import java.util.function.Function;
 49 import java.util.stream.Collectors;
 50 import java.util.stream.Stream;
 51 
 52 /*
 53  * @test
 54  * @modules jdk.incubator.code/jdk.incubator.code.internal
 55  * @modules java.base/jdk.internal.classfile.components
 56  * @enablePreview
 57  * @run junit/othervm -Djdk.invoke.MethodHandle.dumpClassFiles=true TestBytecodeLift
 58  */
 59 
 60 public class TestBytecodeLift {
 61 
 62     @CodeReflection
 63     static int intNumOps(int i, int j, int k) {
 64         k++;
 65         i = (i + j) / k - i % j;
 66         i--;
 67         return i;
 68     }
 69 
 70     @CodeReflection
 71     static byte byteNumOps(byte i, byte j, byte k) {
 72         k++;
 73         i = (byte) ((i + j) / k - i % j);
 74         i--;
 75         return i;
 76     }
 77 
 78     @CodeReflection
 79     static short shortNumOps(short i, short j, short k) {
 80         k++;
 81         i = (short) ((i + j) / k - i % j);
 82         i--;
 83         return i;
 84     }
 85 
 86     @CodeReflection
 87     static char charNumOps(char i, char j, char k) {
 88         k++;
 89         i = (char) ((i + j) / k - i % j);
 90         i--;
 91         return i;
 92     }
 93 
 94     @CodeReflection
 95     static long longNumOps(long i, long j, long k) {
 96         k++;
 97         i = (i + j) / k - i % j;
 98         i--;
 99         return i;
100     }
101 
102     @CodeReflection
103     static float floatNumOps(float i, float j, float k) {
104         k++;
105         i = (i + j) / k - i % j;
106         i--;
107         return i;
108     }
109 
110     @CodeReflection
111     static double doubleNumOps(double i, double j, double k) {
112         k++;
113         i = (i + j) / k - i % j;
114         i--;
115         return i;
116     }
117 
118     @CodeReflection
119     static int intBitOps(int i, int j, int k) {
120         return ~(i & j | k ^ j);
121     }
122 
123     @CodeReflection
124     static byte byteBitOps(byte i, byte j, byte k) {
125         return (byte) ~(i & j | k ^ j);
126     }
127 
128     @CodeReflection
129     static short shortBitOps(short i, short j, short k) {
130         return (short) ~(i & j | k ^ j);
131     }
132 
133     @CodeReflection
134     static char charBitOps(char i, char j, char k) {
135         return (char) ~(i & j | k ^ j);
136     }
137 
138     @CodeReflection
139     static long longBitOps(long i, long j, long k) {
140         return ~(i & j | k ^ j);
141     }
142 
143     @CodeReflection
144     static boolean boolBitOps(boolean i, boolean j, boolean k) {
145         return i & j | k ^ j;
146     }
147 
148     @CodeReflection
149     static int intShiftOps(int i, int j, int k) {
150         return ((-1 >> i) << (j << k)) >>> (k - j);
151     }
152 
153     @CodeReflection
154     static byte byteShiftOps(byte i, byte j, byte k) {
155         return (byte) (((-1 >> i) << (j << k)) >>> (k - j));
156     }
157 
158     @CodeReflection
159     static short shortShiftOps(short i, short j, short k) {
160         return (short) (((-1 >> i) << (j << k)) >>> (k - j));
161     }
162 
163     @CodeReflection
164     static char charShiftOps(char i, char j, char k) {
165         return (char) (((-1 >> i) << (j << k)) >>> (k - j));
166     }
167 
168     @CodeReflection
169     static long longShiftOps(long i, long j, long k) {
170         return ((-1 >> i) << (j << k)) >>> (k - j);
171     }
172 
173     @CodeReflection
174     static Object[] boxingAndUnboxing(int i, byte b, short s, char c, Integer ii, Byte bb, Short ss, Character cc) {
175         ii += i; ii += b; ii += s; ii += c;
176         i += ii; i += bb; i += ss; i += cc;
177         b += ii; b += bb; b += ss; b += cc;
178         s += ii; s += bb; s += ss; s += cc;
179         c += ii; c += bb; c += ss; c += cc;
180         return new Object[]{i, b, s, c};
181     }
182 
183     @CodeReflection
184     static String constructor(String s, int i, int j) {
185         return new String(s.getBytes(), i, j);
186     }
187 
188     @CodeReflection
189     static Class<?> classArray(int i, int j) {
190         Class<?>[] ifaces = new Class[1 + i + j];
191         ifaces[0] = Function.class;
192         return ifaces[0];
193     }
194 
195     @CodeReflection
196     static String[] stringArray(int i, int j) {
197         return new String[i];
198     }
199 
200     @CodeReflection
201     static String[][] stringArray2(int i, int j) {
202         return new String[i][];
203     }
204 
205     @CodeReflection
206     static String[][] stringArrayMulti(int i, int j) {
207         return new String[i][j];
208     }
209 
210     @CodeReflection
211     static int[][] initializedIntArray(int i, int j) {
212         return new int[][]{{i, j}, {i + j}};
213     }
214 
215     @CodeReflection
216     static int ifElseCompare(int i, int j) {
217         if (i < 3) {
218             i += 1;
219         } else {
220             j += 2;
221         }
222         return i + j;
223     }
224 
225     @CodeReflection
226     static int ifElseEquality(int i, int j) {
227         if (j != 0) {
228             if (i != 0) {
229                 i += 1;
230             } else {
231                 i += 2;
232             }
233         } else {
234             if (j != 0) {
235                 i += 3;
236             } else {
237                 i += 4;
238             }
239         }
240         return i;
241     }
242 
243     @CodeReflection
244     static int objectsCompare(Boolean b1, Boolean b2, Boolean b3) {
245         Object a = b1;
246         Object b = b2;
247         Object c = b3;
248         return a == b ? (a != c ? 1 : 2) : (b != c ? 3 : 4);
249     }
250 
251     @CodeReflection
252     static int conditionalExpr(int i, int j) {
253         return ((i - 1 >= 0) ? i - 1 : j - 1);
254     }
255 
256     @CodeReflection
257     static int nestedConditionalExpr(int i, int j) {
258         return (i < 2) ? (j < 3) ? i : j : i + j;
259     }
260 
261     static final int[] MAP = {0, 1, 2, 3, 4};
262 
263     @CodeReflection
264     static int deepStackBranches(boolean a, boolean b) {
265         return MAP[a ? MAP[b ? 1 : 2] : MAP[b ? 3 : 4]];
266     }
267 
268     @CodeReflection
269     static int tryFinally(int i, int j) {
270         try {
271             i = i + j;
272         } finally {
273             i = i + j;
274         }
275         return i;
276     }
277 
278     public record A(String s) {}
279 
280     @CodeReflection
281     static A newWithArgs(int i, int j) {
282         return new A("hello world".substring(i, i + j));
283     }
284 
285     @CodeReflection
286     static int loop(int n, int j) {
287         int sum = 0;
288         for (int i = 0; i < n; i++) {
289             sum = sum + j;
290         }
291         return sum;
292     }
293 
294 
295     @CodeReflection
296     static int ifElseNested(int a, int b) {
297         int c = a + b;
298         int d = 10 - a + b;
299         if (b < 3) {
300             if (a < 3) {
301                 a += 1;
302             } else {
303                 b += 2;
304             }
305             c += 3;
306         } else {
307             if (a > 2) {
308                 a += 4;
309             } else {
310                 b += 5;
311             }
312             d += 6;
313         }
314         return a + b + c + d;
315     }
316 
317     @CodeReflection
318     static int nestedLoop(int m, int n) {
319         int sum = 0;
320         for (int i = 0; i < m; i++) {
321             for (int j = 0; j < n; j++) {
322                 sum = sum + i + j;
323             }
324         }
325         return sum;
326     }
327 
328     @CodeReflection
329     static int methodCall(int a, int b) {
330         int i = Math.max(a, b);
331         return Math.negateExact(i);
332     }
333 
334     @CodeReflection
335     static int[] primitiveArray(int i, int j) {
336         int[] ia = new int[i + 1];
337         ia[0] = j;
338         return ia;
339     }
340 
341     @CodeReflection
342     static boolean not(boolean b) {
343         return !b;
344     }
345 
346     @CodeReflection
347     static boolean notCompare(int i, int j) {
348         boolean b = i < j;
349         return !b;
350     }
351 
352     @CodeReflection
353     static int mod(int i, int j) {
354         return i % (j + 1);
355     }
356 
357     @CodeReflection
358     static int xor(int i, int j) {
359         return i ^ j;
360     }
361 
362     @CodeReflection
363     static int whileLoop(int i, int n) { int
364         counter = 0;
365         while (i < n && counter < 3) {
366             counter++;
367             if (counter == 4) {
368                 break;
369             }
370             i++;
371         }
372         return counter;
373     }
374 
375     public interface Func {
376         int apply(int a);
377     }
378 
379     public interface QuotableFunc extends Quotable {
380         int apply(int a);
381     }
382 
383     static int consume(int i, Func f) {
384         return f.apply(i + 1);
385     }
386 
387     static int consumeQuotable(int i, QuotableFunc f) {
388         Assertions.assertNotNull(Op.ofQuotable(f).get());
389         Assertions.assertNotNull(Op.ofQuotable(f).get().op());
390         Assertions.assertTrue(Op.ofQuotable(f).get().op() instanceof JavaOp.LambdaOp);
391         return f.apply(i + 1);
392     }
393 
394     @CodeReflection
395     static int lambda(int i) {
396         return consume(i, a -> -a);
397     }
398 
399     @CodeReflection
400     static int quotableLambda(int i) {
401         return consumeQuotable(i, a -> -a);
402     }
403 
404     @CodeReflection
405     static int lambdaWithCapture(int i, String s) {
406         return consume(i, a -> a + s.length());
407     }
408 
409     @CodeReflection
410     static int quotableLambdaWithCapture(int i, String s) {
411         return consumeQuotable(i, a -> a + s.length());
412     }
413 
414     @CodeReflection
415     static int nestedLambdasWithCaptures(int i, int j, String s) {
416         return consume(i, a -> consume(a, b -> a + b + j - s.length()) + s.length());
417     }
418 
419     @CodeReflection
420     static int nestedQuotableLambdasWithCaptures(int i, int j, String s) {
421         return consumeQuotable(i, a -> consumeQuotable(a, b -> a + b + j - s.length()) + s.length());
422     }
423 
424     @CodeReflection
425     static int methodHandle(int i) {
426         return consume(i, Math::negateExact);
427     }
428 
429     int instanceMethod(int i) {
430         return -i + 13;
431     }
432 
433     @CodeReflection
434     int instanceMethodHandle(int i) {
435         return consume(i, this::instanceMethod);
436     }
437 
438     static void consume(boolean b, Consumer<Object> requireNonNull) {
439         if (b) {
440             requireNonNull.accept(new Object());
441         } else try {
442             requireNonNull.accept(null);
443             throw new AssertionError("Expectend NPE");
444         } catch (NullPointerException expected) {
445         }
446     }
447 
448     @CodeReflection
449     static void nullReturningMethodHandle(boolean b) {
450         consume(b, Objects::requireNonNull);
451     }
452 
453     @CodeReflection
454     static boolean compareLong(long i, long j) {
455         return i > j;
456     }
457 
458     @CodeReflection
459     static boolean compareFloat(float i, float j) {
460         return i > j;
461     }
462 
463     @CodeReflection
464     static boolean compareDouble(double i, double j) {
465         return i > j;
466     }
467 
468     @CodeReflection
469     static int lookupSwitch(int i) {
470         return switch (1000 * i) {
471             case 1000 -> 1;
472             case 2000 -> 2;
473             case 3000 -> 3;
474             default -> 0;
475         };
476     }
477 
478     @CodeReflection
479     static int tableSwitch(int i) {
480         return switch (i) {
481             case 1 -> 1;
482             case 2 -> 2;
483             case 3 -> 3;
484             default -> 0;
485         };
486     }
487 
488     int instanceField = -1;
489 
490     @CodeReflection
491     int instanceFieldAccess(int i) {
492         int ret = instanceField;
493         instanceField = i;
494         return ret;
495     }
496 
497     @CodeReflection
498     static String stringConcat(String a, String b) {
499         return "a"+ a +"\u0001" + a + "b\u0002c" + b + "\u0001\u0002" + b + "dd";
500     }
501 
502     @CodeReflection
503     static String multiTypeConcat(int i, Boolean b, char c, Short s, float f, Double d) {
504         return "i:"+ i +" b:" + b + " c:" + c + " f:" + f + " d:" + d;
505     }
506 
507     @CodeReflection
508     static int ifTrue(int i) {
509         if (true) {
510             return i;
511         }
512         return -i;
513     }
514 
515     @CodeReflection
516     static int excHandlerFollowingSplitTable(boolean b) {
517         try {
518             if (b) return 1;
519             else throw new Exception();
520         } catch (Exception ex) {}
521         return 2;
522     }
523 
524     @CodeReflection
525     static int varModifiedInTryBlock(boolean b) {
526         int i = 0;
527         try {
528             i++;
529             if (b) throw new Exception();
530             i++;
531             throw new Exception();
532         } catch (Exception ex) {
533             return i;
534         }
535     }
536 
537     @CodeReflection
538     static boolean finallyWithLoop(boolean b) {
539         try {
540             while (b) {
541                 if (b)
542                     return false;
543                 b = !b;
544             }
545             return true;
546         } finally {
547             b = false;
548         }
549     }
550 
551     @CodeReflection
552     static long doubleUseOfOperand(int x) {
553         long piece = x;
554         return piece * piece;
555     }
556 
557     record TestData(Method testMethod) {
558         @Override
559         public String toString() {
560             String s = testMethod.getName() + Arrays.stream(testMethod.getParameterTypes())
561                     .map(Class::getSimpleName).collect(Collectors.joining(",", "(", ")"));
562             if (s.length() > 30) s = s.substring(0, 27) + "...";
563             return s;
564         }
565     }
566 
567     public static Stream<TestData> testMethods() {
568         return Stream.of(TestBytecodeLift.class.getDeclaredMethods())
569                 .filter(m -> m.isAnnotationPresent(CodeReflection.class))
570                 .map(TestData::new);
571     }
572 
573     private static byte[] CLASS_DATA;
574     private static ClassModel CLASS_MODEL;
575 
576     @BeforeAll
577     public static void setup() throws Exception {
578         CLASS_DATA = TestBytecodeLift.class.getResourceAsStream("TestBytecodeLift.class").readAllBytes();
579         CLASS_MODEL = ClassFile.of().parse(CLASS_DATA);
580     }
581 
582     private static MethodTypeDesc toMethodTypeDesc(Method m) {
583         return MethodTypeDesc.of(
584                 m.getReturnType().describeConstable().orElseThrow(),
585                 Arrays.stream(m.getParameterTypes())
586                         .map(cls -> cls.describeConstable().orElseThrow()).toList());
587     }
588 
589 
590     private static final Map<Class<?>, Object[]> TEST_ARGS = new IdentityHashMap<>();
591     private static Object[] values(Object... values) {
592         return values;
593     }
594     private static void initTestArgs(Object[] values, Class<?>... argTypes) {
595         for (var argType : argTypes) TEST_ARGS.put(argType, values);
596     }
597     static {
598         initTestArgs(values(1, 2, 4), int.class, Integer.class);
599         initTestArgs(values((byte)1, (byte)3, (byte)4), byte.class, Byte.class);
600         initTestArgs(values((short)1, (short)2, (short)3), short.class, Short.class);
601         initTestArgs(values((char)2, (char)3, (char)4), char.class, Character.class);
602         initTestArgs(values(false, true), boolean.class, Boolean.class);
603         initTestArgs(values("Hello World"), String.class);
604         initTestArgs(values(1l, 2l, 4l), long.class, Long.class);
605         initTestArgs(values(1f, 3f, 4f), float.class, Float.class);
606         initTestArgs(values(1d, 2d, 3d), double.class, Double.class);
607     }
608 
609     interface Executor {
610         void execute(Object[] args) throws Throwable;
611     }
612 
613     private static void permutateAllArgs(Class<?>[] argTypes, Executor executor) throws Throwable {
614         final int argn = argTypes.length;
615         Object[][] argValues = new Object[argn][];
616         for (int i = 0; i < argn; i++) {
617             argValues[i] = TEST_ARGS.get(argTypes[i]);
618         }
619         int[] argIndexes = new int[argn];
620         Object[] args = new Object[argn];
621         while (true) {
622             for (int i = 0; i < argn; i++) {
623                 args[i] = argValues[i][argIndexes[i]];
624             }
625             executor.execute(args);
626             int i = argn - 1;
627             while (i >= 0 && argIndexes[i] == argValues[i].length - 1) i--;
628             if (i < 0) return;
629             argIndexes[i++]++;
630             while (i < argn) argIndexes[i++] = 0;
631         }
632     }
633 
634     @ParameterizedTest
635     @MethodSource("testMethods")
636     public void testLift(TestData d) throws Throwable {
637         CoreOp.FuncOp flift;
638         try {
639             flift = BytecodeLift.lift(CLASS_DATA, d.testMethod.getName(), toMethodTypeDesc(d.testMethod));
640         } catch (Throwable e) {
641             ClassPrinter.toYaml(ClassFile.of().parse(TestBytecodeLift.class.getResourceAsStream("TestBytecodeLift.class").readAllBytes())
642                     .methods().stream().filter(m -> m.methodName().equalsString(d.testMethod().getName())).findAny().get(),
643                     ClassPrinter.Verbosity.CRITICAL_ATTRIBUTES, System.out::print);
644             System.out.println("Lift failed, compiled model:");
645             Op.ofMethod(d.testMethod).ifPresent(f -> System.out.println(f.toText()));
646             throw e;
647         }
648         try {
649             Object receiver1, receiver2;
650             if (d.testMethod.accessFlags().contains(AccessFlag.STATIC)) {
651                 receiver1 = null;
652                 receiver2 = null;
653             } else {
654                 receiver1 = new TestBytecodeLift();
655                 receiver2 = new TestBytecodeLift();
656             }
657             permutateAllArgs(d.testMethod.getParameterTypes(), args ->
658                     assertEquals(d.testMethod.invoke(receiver2, args), invokeAndConvert(flift, receiver1, args)));
659         } catch (Throwable e) {
660             System.out.println("Compiled model:");
661             Op.ofMethod(d.testMethod).ifPresent(f -> System.out.println(f.toText()));
662             System.out.println("Lifted model:");
663             System.out.println(flift.toText());
664             throw e;
665         }
666     }
667 
668     private static Object invokeAndConvert(CoreOp.FuncOp func, Object receiver, Object... args) {
669         List argl = new ArrayList(args.length + 1);
670         if (receiver != null) argl.add(receiver);
671         argl.addAll(Arrays.asList(args));
672         Object ret = Interpreter.invoke(MethodHandles.lookup(), func, argl);
673         if (ret instanceof Integer i) {
674             TypeElement rt = func.invokableType().returnType();
675             if (rt.equals(JavaType.BOOLEAN)) {
676                 return i != 0;
677             } else if (rt.equals(JavaType.BYTE)) {
678                 return i.byteValue();
679             } else if (rt.equals(JavaType.CHAR)) {
680                 return (short)i.intValue();
681             } else if (rt.equals(JavaType.SHORT)) {
682                 return i.shortValue();
683             }
684         }
685         return ret;
686     }
687 
688     private static void assertEquals(Object expected, Object actual) {
689         switch (expected) {
690             case int[] expArr when actual instanceof int[] actArr -> Assertions.assertArrayEquals(expArr, actArr);
691             case Object[] expArr when actual instanceof Object[] actArr -> Assertions.assertArrayEquals(expArr, actArr);
692             case null, default -> Assertions.assertEquals(expected, actual);
693         }
694     }
695 }