1 /*
  2  * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.
  8  *
  9  * This code is distributed in the hope that it will be useful, but WITHOUT
 10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 12  * version 2 for more details (a copy is included in the LICENSE file that
 13  * accompanied this code).
 14  *
 15  * You should have received a copy of the GNU General Public License version
 16  * 2 along with this work; if not, write to the Free Software Foundation,
 17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 18  *
 19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 20  * or visit www.oracle.com if you need additional information or have any
 21  * questions.
 22  */
 23 
 24 import java.io.IOException;
 25 import java.lang.classfile.ClassFile;
 26 import java.lang.classfile.ClassModel;
 27 import jdk.internal.classfile.components.ClassPrinter;
 28 import java.lang.constant.MethodTypeDesc;
 29 import java.lang.invoke.MethodHandle;
 30 import java.lang.invoke.MethodHandles;
 31 import java.lang.reflect.AccessFlag;
 32 import org.testng.Assert;
 33 import org.testng.SkipException;
 34 import org.testng.annotations.BeforeClass;
 35 import org.testng.annotations.DataProvider;
 36 import org.testng.annotations.Test;
 37 
 38 import jdk.incubator.code.*;
 39 import jdk.incubator.code.op.CoreOp;
 40 import jdk.incubator.code.bytecode.BytecodeLift;
 41 import jdk.incubator.code.interpreter.Interpreter;
 42 import java.lang.reflect.Method;
 43 import jdk.incubator.code.bytecode.BytecodeGenerator;
 44 import jdk.incubator.code.type.JavaType;
 45 import jdk.incubator.code.CodeReflection;
 46 import java.nio.file.Files;
 47 import java.nio.file.Path;
 48 import java.util.*;
 49 import java.util.function.Consumer;
 50 import java.util.function.Function;
 51 import java.util.stream.Collectors;
 52 import java.util.stream.Stream;
 53 
 54 /*
 55  * @test
 56  * @modules jdk.incubator.code
 57  * @modules java.base/jdk.internal.classfile.components
 58  * @enablePreview
 59  * @run testng/othervm -Djdk.invoke.MethodHandle.dumpClassFiles=true TestBytecode
 60  */
 61 
 62 public class TestBytecode {
 63 
 64     @CodeReflection
 65     static int intNumOps(int i, int j, int k) {
 66         k++;
 67         i = (i + j) / k - i % j;
 68         i--;
 69         return i;
 70     }
 71 
 72     @CodeReflection
 73     static byte byteNumOps(byte i, byte j, byte k) {
 74         k++;
 75         i = (byte) ((i + j) / k - i % j);
 76         i--;
 77         return i;
 78     }
 79 
 80     @CodeReflection
 81     static short shortNumOps(short i, short j, short k) {
 82         k++;
 83         i = (short) ((i + j) / k - i % j);
 84         i--;
 85         return i;
 86     }
 87 
 88     @CodeReflection
 89     static char charNumOps(char i, char j, char k) {
 90         k++;
 91         i = (char) ((i + j) / k - i % j);
 92         i--;
 93         return i;
 94     }
 95 
 96     @CodeReflection
 97     static long longNumOps(long i, long j, long k) {
 98         k++;
 99         i = (i + j) / k - i % j;
100         i--;
101         return i;
102     }
103 
104     @CodeReflection
105     static float floatNumOps(float i, float j, float k) {
106         k++;
107         i = (i + j) / k - i % j;
108         i--;
109         return i;
110     }
111 
112     @CodeReflection
113     static double doubleNumOps(double i, double j, double k) {
114         k++;
115         i = (i + j) / k - i % j;
116         i--;
117         return i;
118     }
119 
120     @CodeReflection
121     static int intBitOps(int i, int j, int k) {
122         return ~(i & j | k ^ j);
123     }
124 
125     @CodeReflection
126     static byte byteBitOps(byte i, byte j, byte k) {
127         return (byte) ~(i & j | k ^ j);
128     }
129 
130     @CodeReflection
131     static short shortBitOps(short i, short j, short k) {
132         return (short) ~(i & j | k ^ j);
133     }
134 
135     @CodeReflection
136     static char charBitOps(char i, char j, char k) {
137         return (char) ~(i & j | k ^ j);
138     }
139 
140     @CodeReflection
141     static long longBitOps(long i, long j, long k) {
142         return ~(i & j | k ^ j);
143     }
144 
145     @CodeReflection
146     static boolean boolBitOps(boolean i, boolean j, boolean k) {
147         return i & j | k ^ j;
148     }
149 
150     @CodeReflection
151     static int intShiftOps(int i, int j, int k) {
152         return ((-1 >> i) << (j << k)) >>> (k - j);
153     }
154 
155     @CodeReflection
156     static byte byteShiftOps(byte i, byte j, byte k) {
157         return (byte) (((-1 >> i) << (j << k)) >>> (k - j));
158     }
159 
160     @CodeReflection
161     static short shortShiftOps(short i, short j, short k) {
162         return (short) (((-1 >> i) << (j << k)) >>> (k - j));
163     }
164 
165     @CodeReflection
166     static char charShiftOps(char i, char j, char k) {
167         return (char) (((-1 >> i) << (j << k)) >>> (k - j));
168     }
169 
170     @CodeReflection
171     static long longShiftOps(long i, long j, long k) {
172         return ((-1 >> i) << (j << k)) >>> (k - j);
173     }
174 
175     @CodeReflection
176     static Object[] boxingAndUnboxing(int i, byte b, short s, char c, Integer ii, Byte bb, Short ss, Character cc) {
177         ii += i; ii += b; ii += s; ii += c;
178         i += ii; i += bb; i += ss; i += cc;
179         b += ii; b += bb; b += ss; b += cc;
180         s += ii; s += bb; s += ss; s += cc;
181         c += ii; c += bb; c += ss; c += cc;
182         return new Object[]{i, b, s, c};
183     }
184 
185     @CodeReflection
186     static String constructor(String s, int i, int j) {
187         return new String(s.getBytes(), i, j);
188     }
189 
190     @CodeReflection
191     static Class<?> classArray(int i, int j) {
192         Class<?>[] ifaces = new Class[1 + i + j];
193         ifaces[0] = Function.class;
194         return ifaces[0];
195     }
196 
197     @CodeReflection
198     static String[] stringArray(int i, int j) {
199         return new String[i];
200     }
201 
202     @CodeReflection
203     static String[][] stringArray2(int i, int j) {
204         return new String[i][];
205     }
206 
207     @CodeReflection
208     static String[][] stringArrayMulti(int i, int j) {
209         return new String[i][j];
210     }
211 
212     @CodeReflection
213     static int[][] initializedIntArray(int i, int j) {
214         return new int[][]{{i, j}, {i + j}};
215     }
216 
217     @CodeReflection
218     static int ifElseCompare(int i, int j) {
219         if (i < 3) {
220             i += 1;
221         } else {
222             j += 2;
223         }
224         return i + j;
225     }
226 
227     @CodeReflection
228     static int ifElseEquality(int i, int j) {
229         if (j != 0) {
230             if (i != 0) {
231                 i += 1;
232             } else {
233                 i += 2;
234             }
235         } else {
236             if (j != 0) {
237                 i += 3;
238             } else {
239                 i += 4;
240             }
241         }
242         return i;
243     }
244 
245     @CodeReflection
246     static int objectsCompare(Boolean b1, Boolean b2, Boolean b3) {
247         Object a = b1;
248         Object b = b2;
249         Object c = b3;
250         return a == b ? (a != c ? 1 : 2) : (b != c ? 3 : 4);
251     }
252 
253     @CodeReflection
254     static int conditionalExpr(int i, int j) {
255         return ((i - 1 >= 0) ? i - 1 : j - 1);
256     }
257 
258     @CodeReflection
259     static int nestedConditionalExpr(int i, int j) {
260         return (i < 2) ? (j < 3) ? i : j : i + j;
261     }
262 
263     static final int[] MAP = {0, 1, 2, 3, 4};
264 
265     @CodeReflection
266     static int deepStackBranches(boolean a, boolean b) {
267         return MAP[a ? MAP[b ? 1 : 2] : MAP[b ? 3 : 4]];
268     }
269 
270     @CodeReflection
271     static int tryFinally(int i, int j) {
272         try {
273             i = i + j;
274         } finally {
275             i = i + j;
276         }
277         return i;
278     }
279 
280     public record A(String s) {}
281 
282     @CodeReflection
283     static A newWithArgs(int i, int j) {
284         return new A("hello world".substring(i, i + j));
285     }
286 
287     @CodeReflection
288     static int loop(int n, int j) {
289         int sum = 0;
290         for (int i = 0; i < n; i++) {
291             sum = sum + j;
292         }
293         return sum;
294     }
295 
296 
297     @CodeReflection
298     static int ifElseNested(int a, int b) {
299         int c = a + b;
300         int d = 10 - a + b;
301         if (b < 3) {
302             if (a < 3) {
303                 a += 1;
304             } else {
305                 b += 2;
306             }
307             c += 3;
308         } else {
309             if (a > 2) {
310                 a += 4;
311             } else {
312                 b += 5;
313             }
314             d += 6;
315         }
316         return a + b + c + d;
317     }
318 
319     @CodeReflection
320     static int nestedLoop(int m, int n) {
321         int sum = 0;
322         for (int i = 0; i < m; i++) {
323             for (int j = 0; j < n; j++) {
324                 sum = sum + i + j;
325             }
326         }
327         return sum;
328     }
329 
330     @CodeReflection
331     static int methodCall(int a, int b) {
332         int i = Math.max(a, b);
333         return Math.negateExact(i);
334     }
335 
336     @CodeReflection
337     static int[] primitiveArray(int i, int j) {
338         int[] ia = new int[i + 1];
339         ia[0] = j;
340         return ia;
341     }
342 
343     @CodeReflection
344     static boolean not(boolean b) {
345         return !b;
346     }
347 
348     @CodeReflection
349     static boolean notCompare(int i, int j) {
350         boolean b = i < j;
351         return !b;
352     }
353 
354     @CodeReflection
355     static int mod(int i, int j) {
356         return i % (j + 1);
357     }
358 
359     @CodeReflection
360     static int xor(int i, int j) {
361         return i ^ j;
362     }
363 
364     @CodeReflection
365     static int whileLoop(int i, int n) { int
366         counter = 0;
367         while (i < n && counter < 3) {
368             counter++;
369             if (counter == 4) {
370                 break;
371             }
372             i++;
373         }
374         return counter;
375     }
376 
377     public interface Func {
378         int apply(int a);
379     }
380 
381     public interface QuotableFunc extends Quotable {
382         int apply(int a);
383     }
384 
385     static int consume(int i, Func f) {
386         return f.apply(i + 1);
387     }
388 
389     static int consumeQuotable(int i, QuotableFunc f) {
390         Assert.assertNotNull(Op.ofQuotable(f).get());
391         Assert.assertNotNull(Op.ofQuotable(f).get().op());
392         Assert.assertTrue(Op.ofQuotable(f).get().op() instanceof CoreOp.LambdaOp);
393         return f.apply(i + 1);
394     }
395 
396     @CodeReflection
397     static int lambda(int i) {
398         return consume(i, a -> -a);
399     }
400 
401     @CodeReflection
402     static int quotableLambda(int i) {
403         return consumeQuotable(i, a -> -a);
404     }
405 
406     @CodeReflection
407     static int lambdaWithCapture(int i, String s) {
408         return consume(i, a -> a + s.length());
409     }
410 
411     @CodeReflection
412     static int quotableLambdaWithCapture(int i, String s) {
413         return consumeQuotable(i, a -> a + s.length());
414     }
415 
416     @CodeReflection
417     static int nestedLambdasWithCaptures(int i, int j, String s) {
418         return consume(i, a -> consume(a, b -> a + b + j - s.length()) + s.length());
419     }
420 
421     @CodeReflection
422     static int nestedQuotableLambdasWithCaptures(int i, int j, String s) {
423         return consumeQuotable(i, a -> consumeQuotable(a, b -> a + b + j - s.length()) + s.length());
424     }
425 
426     @CodeReflection
427     static int methodHandle(int i) {
428         return consume(i, Math::negateExact);
429     }
430 
431     int instanceMethod(int i) {
432         return -i + 13;
433     }
434 
435     @CodeReflection
436     int instanceMethodHandle(int i) {
437         return consume(i, this::instanceMethod);
438     }
439 
440     static void consume(boolean b, Consumer<Object> requireNonNull) {
441         if (b) {
442             requireNonNull.accept(new Object());
443         } else try {
444             requireNonNull.accept(null);
445             throw new AssertionError("Expectend NPE");
446         } catch (NullPointerException expected) {
447         }
448     }
449 
450     @CodeReflection
451     static void nullReturningMethodHandle(boolean b) {
452         consume(b, Objects::requireNonNull);
453     }
454 
455     @CodeReflection
456     static boolean compareLong(long i, long j) {
457         return i > j;
458     }
459 
460     @CodeReflection
461     static boolean compareFloat(float i, float j) {
462         return i > j;
463     }
464 
465     @CodeReflection
466     static boolean compareDouble(double i, double j) {
467         return i > j;
468     }
469 
470     @CodeReflection
471     static int lookupSwitch(int i) {
472         return switch (1000 * i) {
473             case 1000 -> 1;
474             case 2000 -> 2;
475             case 3000 -> 3;
476             default -> 0;
477         };
478     }
479 
480     @CodeReflection
481     static int tableSwitch(int i) {
482         return switch (i) {
483             case 1 -> 1;
484             case 2 -> 2;
485             case 3 -> 3;
486             default -> 0;
487         };
488     }
489 
490     int instanceField = -1;
491 
492     @CodeReflection
493     int instanceFieldAccess(int i) {
494         int ret = instanceField;
495         instanceField = i;
496         return ret;
497     }
498 
499     @CodeReflection
500     static String stringConcat(String a, String b) {
501         return "a"+ a +"\u0001" + a + "b\u0002c" + b + "\u0001\u0002" + b + "dd";
502     }
503 
504     @CodeReflection
505     static String multiTypeConcat(int i, Boolean b, char c, Short s, float f, Double d) {
506         return "i:"+ i +" b:" + b + " c:" + c + " f:" + f + " d:" + d;
507     }
508 
509     @CodeReflection
510     static int ifTrue(int i) {
511         if (true) {
512             return i;
513         }
514         return -i;
515     }
516 
517     @CodeReflection
518     static int excHandlerFollowingSplitTable(boolean b) {
519         try {
520             if (b) return 1;
521             else throw new Exception();
522         } catch (Exception ex) {}
523         return 2;
524     }
525 
526     @CodeReflection
527     static int varModifiedInTryBlock(boolean b) {
528         int i = 0;
529         try {
530             i++;
531             if (b) throw new Exception();
532             i++;
533             throw new Exception();
534         } catch (Exception ex) {
535             return i;
536         }
537     }
538 
539     @CodeReflection
540     static boolean finallyWithLoop(boolean b) {
541         try {
542             while (b) {
543                 if (b)
544                     return false;
545                 b = !b;
546             }
547             return true;
548         } finally {
549             b = false;
550         }
551     }
552 
553     @CodeReflection
554     static long doubleUseOfOperand(int x) {
555         long piece = x;
556         return piece * piece;
557     }
558 
559     record TestData(Method testMethod) {
560         @Override
561         public String toString() {
562             String s = testMethod.getName() + Arrays.stream(testMethod.getParameterTypes())
563                     .map(Class::getSimpleName).collect(Collectors.joining(",", "(", ")"));
564             if (s.length() > 30) s = s.substring(0, 27) + "...";
565             return s;
566         }
567     }
568 
569     @DataProvider(name = "testMethods")
570     public static TestData[]testMethods() {
571         return Stream.of(TestBytecode.class.getDeclaredMethods())
572                 .filter(m -> m.isAnnotationPresent(CodeReflection.class))
573                 .map(TestData::new).toArray(TestData[]::new);
574     }
575 
576     private static byte[] CLASS_DATA;
577     private static ClassModel CLASS_MODEL;
578 
579     @BeforeClass
580     public static void setup() throws Exception {
581         CLASS_DATA = TestBytecode.class.getResourceAsStream("TestBytecode.class").readAllBytes();
582         CLASS_MODEL = ClassFile.of().parse(CLASS_DATA);
583     }
584 
585     private static MethodTypeDesc toMethodTypeDesc(Method m) {
586         return MethodTypeDesc.of(
587                 m.getReturnType().describeConstable().orElseThrow(),
588                 Arrays.stream(m.getParameterTypes())
589                         .map(cls -> cls.describeConstable().orElseThrow()).toList());
590     }
591 
592 
593     private static final Map<Class<?>, Object[]> TEST_ARGS = new IdentityHashMap<>();
594     private static Object[] values(Object... values) {
595         return values;
596     }
597     private static void initTestArgs(Object[] values, Class<?>... argTypes) {
598         for (var argType : argTypes) TEST_ARGS.put(argType, values);
599     }
600     static {
601         initTestArgs(values(1, 2, 4), int.class, Integer.class);
602         initTestArgs(values((byte)1, (byte)3, (byte)4), byte.class, Byte.class);
603         initTestArgs(values((short)1, (short)2, (short)3), short.class, Short.class);
604         initTestArgs(values((char)2, (char)3, (char)4), char.class, Character.class);
605         initTestArgs(values(false, true), boolean.class, Boolean.class);
606         initTestArgs(values("Hello World"), String.class);
607         initTestArgs(values(1l, 2l, 4l), long.class, Long.class);
608         initTestArgs(values(1f, 3f, 4f), float.class, Float.class);
609         initTestArgs(values(1d, 2d, 3d), double.class, Double.class);
610     }
611 
612     interface Executor {
613         void execute(Object[] args) throws Throwable;
614     }
615 
616     private static void permutateAllArgs(Class<?>[] argTypes, Executor executor) throws Throwable {
617         final int argn = argTypes.length;
618         Object[][] argValues = new Object[argn][];
619         for (int i = 0; i < argn; i++) {
620             argValues[i] = TEST_ARGS.get(argTypes[i]);
621         }
622         int[] argIndexes = new int[argn];
623         Object[] args = new Object[argn];
624         while (true) {
625             for (int i = 0; i < argn; i++) {
626                 args[i] = argValues[i][argIndexes[i]];
627             }
628             executor.execute(args);
629             int i = argn - 1;
630             while (i >= 0 && argIndexes[i] == argValues[i].length - 1) i--;
631             if (i < 0) return;
632             argIndexes[i++]++;
633             while (i < argn) argIndexes[i++] = 0;
634         }
635     }
636 
637     @Test(dataProvider = "testMethods")
638     public void testLift(TestData d) throws Throwable {
639         CoreOp.FuncOp flift;
640         try {
641             flift = BytecodeLift.lift(CLASS_DATA, d.testMethod.getName(), toMethodTypeDesc(d.testMethod));
642         } catch (Throwable e) {
643             ClassPrinter.toYaml(ClassFile.of().parse(TestBytecode.class.getResourceAsStream("TestBytecode.class").readAllBytes())
644                     .methods().stream().filter(m -> m.methodName().equalsString(d.testMethod().getName())).findAny().get(),
645                     ClassPrinter.Verbosity.CRITICAL_ATTRIBUTES, System.out::print);
646             System.out.println("Lift failed, compiled model:");
647             Op.ofMethod(d.testMethod).ifPresent(f -> f.writeTo(System.out));
648             throw e;
649         }
650         try {
651             Object receiver1, receiver2;
652             if (d.testMethod.accessFlags().contains(AccessFlag.STATIC)) {
653                 receiver1 = null;
654                 receiver2 = null;
655             } else {
656                 receiver1 = new TestBytecode();
657                 receiver2 = new TestBytecode();
658             }
659             permutateAllArgs(d.testMethod.getParameterTypes(), args ->
660                 Assert.assertEquals(invokeAndConvert(flift, receiver1, args), d.testMethod.invoke(receiver2, args)));
661         } catch (Throwable e) {
662             System.out.println("Compiled model:");
663             Op.ofMethod(d.testMethod).ifPresent(f -> f.writeTo(System.out));
664             System.out.println("Lifted model:");
665             flift.writeTo(System.out);
666             throw e;
667         }
668     }
669 
670     private static Object invokeAndConvert(CoreOp.FuncOp func, Object receiver, Object... args) {
671         List argl = new ArrayList(args.length + 1);
672         if (receiver != null) argl.add(receiver);
673         argl.addAll(Arrays.asList(args));
674         Object ret = Interpreter.invoke(MethodHandles.lookup(), func, argl);
675         if (ret instanceof Integer i) {
676             TypeElement rt = func.invokableType().returnType();
677             if (rt.equals(JavaType.BOOLEAN)) {
678                 return i != 0;
679             } else if (rt.equals(JavaType.BYTE)) {
680                 return i.byteValue();
681             } else if (rt.equals(JavaType.CHAR)) {
682                 return (short)i.intValue();
683             } else if (rt.equals(JavaType.SHORT)) {
684                 return i.shortValue();
685             }
686         }
687         return ret;
688     }
689 
690     @Test(dataProvider = "testMethods")
691     public void testGenerate(TestData d) throws Throwable {
692         CoreOp.FuncOp func = Op.ofMethod(d.testMethod).get();
693 
694         CoreOp.FuncOp lfunc;
695         try {
696             lfunc = func.transform(CopyContext.create(), OpTransformer.LOWERING_TRANSFORMER);
697         } catch (UnsupportedOperationException uoe) {
698             throw new SkipException("lowering caused:", uoe);
699         }
700 
701         try {
702             MethodHandle mh = BytecodeGenerator.generate(MethodHandles.lookup(), lfunc);
703             Object receiver1, receiver2;
704             if (d.testMethod.accessFlags().contains(AccessFlag.STATIC)) {
705                 receiver1 = null;
706                 receiver2 = null;
707             } else {
708                 receiver1 = new TestBytecode();
709                 receiver2 = new TestBytecode();
710             }
711             permutateAllArgs(d.testMethod.getParameterTypes(), args -> {
712                     List argl = new ArrayList(args.length + 1);
713                     if (receiver1 != null) argl.add(receiver1);
714                     argl.addAll(Arrays.asList(args));
715                     Assert.assertEquals(mh.invokeWithArguments(argl), d.testMethod.invoke(receiver2, args));
716             });
717         } catch (Throwable e) {
718             func.writeTo(System.out);
719             lfunc.writeTo(System.out);
720             String methodName = d.testMethod().getName();
721             for (var mm : CLASS_MODEL.methods()) {
722                 if (mm.methodName().equalsString(methodName)
723                         || mm.methodName().stringValue().startsWith("lambda$" + methodName + "$")) {
724                     ClassPrinter.toYaml(mm,
725                                         ClassPrinter.Verbosity.CRITICAL_ATTRIBUTES,
726                                         System.out::print);
727                 }
728             }
729             Files.list(Path.of("DUMP_CLASS_FILES")).forEach(p -> {
730                 if (p.getFileName().toString().matches(methodName + "\\..+\\.class")) try {
731                     ClassPrinter.toYaml(ClassFile.of().parse(p),
732                                         ClassPrinter.Verbosity.CRITICAL_ATTRIBUTES,
733                                         System.out::print);
734                 } catch (IOException ignore) {}
735             });
736             throw e;
737         }
738     }
739 }