1 /*
  2  * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.
  8  *
  9  * This code is distributed in the hope that it will be useful, but WITHOUT
 10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 12  * version 2 for more details (a copy is included in the LICENSE file that
 13  * accompanied this code).
 14  *
 15  * You should have received a copy of the GNU General Public License version
 16  * 2 along with this work; if not, write to the Free Software Foundation,
 17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 18  *
 19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 20  * or visit www.oracle.com if you need additional information or have any
 21  * questions.
 22  */
 23 
 24 import java.io.IOException;
 25 import java.lang.classfile.ClassFile;
 26 import java.lang.classfile.ClassModel;
 27 
 28 import jdk.incubator.code.dialect.java.JavaOp;
 29 import jdk.internal.classfile.components.ClassPrinter;
 30 import java.lang.constant.MethodTypeDesc;
 31 import java.lang.invoke.MethodHandle;
 32 import java.lang.invoke.MethodHandles;
 33 import java.lang.reflect.AccessFlag;
 34 import org.testng.Assert;
 35 import org.testng.SkipException;
 36 import org.testng.annotations.BeforeClass;
 37 import org.testng.annotations.DataProvider;
 38 import org.testng.annotations.Test;
 39 
 40 import jdk.incubator.code.*;
 41 import jdk.incubator.code.dialect.core.CoreOp;
 42 import jdk.incubator.code.bytecode.BytecodeLift;
 43 import jdk.incubator.code.interpreter.Interpreter;
 44 import java.lang.reflect.Method;
 45 import jdk.incubator.code.bytecode.BytecodeGenerator;
 46 import jdk.incubator.code.dialect.java.JavaType;
 47 import jdk.incubator.code.CodeReflection;
 48 import java.nio.file.Files;
 49 import java.nio.file.Path;
 50 import java.util.*;
 51 import java.util.function.Consumer;
 52 import java.util.function.Function;
 53 import java.util.stream.Collectors;
 54 import java.util.stream.Stream;
 55 
 56 /*
 57  * @test
 58  * @modules jdk.incubator.code
 59  * @modules java.base/jdk.internal.classfile.components
 60  * @enablePreview
 61  * @run testng/othervm -Djdk.invoke.MethodHandle.dumpClassFiles=true TestBytecode
 62  */
 63 
 64 public class TestBytecode {
 65 
 66     @CodeReflection
 67     static int intNumOps(int i, int j, int k) {
 68         k++;
 69         i = (i + j) / k - i % j;
 70         i--;
 71         return i;
 72     }
 73 
 74     @CodeReflection
 75     static byte byteNumOps(byte i, byte j, byte k) {
 76         k++;
 77         i = (byte) ((i + j) / k - i % j);
 78         i--;
 79         return i;
 80     }
 81 
 82     @CodeReflection
 83     static short shortNumOps(short i, short j, short k) {
 84         k++;
 85         i = (short) ((i + j) / k - i % j);
 86         i--;
 87         return i;
 88     }
 89 
 90     @CodeReflection
 91     static char charNumOps(char i, char j, char k) {
 92         k++;
 93         i = (char) ((i + j) / k - i % j);
 94         i--;
 95         return i;
 96     }
 97 
 98     @CodeReflection
 99     static long longNumOps(long i, long j, long k) {
100         k++;
101         i = (i + j) / k - i % j;
102         i--;
103         return i;
104     }
105 
106     @CodeReflection
107     static float floatNumOps(float i, float j, float k) {
108         k++;
109         i = (i + j) / k - i % j;
110         i--;
111         return i;
112     }
113 
114     @CodeReflection
115     static double doubleNumOps(double i, double j, double k) {
116         k++;
117         i = (i + j) / k - i % j;
118         i--;
119         return i;
120     }
121 
122     @CodeReflection
123     static int intBitOps(int i, int j, int k) {
124         return ~(i & j | k ^ j);
125     }
126 
127     @CodeReflection
128     static byte byteBitOps(byte i, byte j, byte k) {
129         return (byte) ~(i & j | k ^ j);
130     }
131 
132     @CodeReflection
133     static short shortBitOps(short i, short j, short k) {
134         return (short) ~(i & j | k ^ j);
135     }
136 
137     @CodeReflection
138     static char charBitOps(char i, char j, char k) {
139         return (char) ~(i & j | k ^ j);
140     }
141 
142     @CodeReflection
143     static long longBitOps(long i, long j, long k) {
144         return ~(i & j | k ^ j);
145     }
146 
147     @CodeReflection
148     static boolean boolBitOps(boolean i, boolean j, boolean k) {
149         return i & j | k ^ j;
150     }
151 
152     @CodeReflection
153     static int intShiftOps(int i, int j, int k) {
154         return ((-1 >> i) << (j << k)) >>> (k - j);
155     }
156 
157     @CodeReflection
158     static byte byteShiftOps(byte i, byte j, byte k) {
159         return (byte) (((-1 >> i) << (j << k)) >>> (k - j));
160     }
161 
162     @CodeReflection
163     static short shortShiftOps(short i, short j, short k) {
164         return (short) (((-1 >> i) << (j << k)) >>> (k - j));
165     }
166 
167     @CodeReflection
168     static char charShiftOps(char i, char j, char k) {
169         return (char) (((-1 >> i) << (j << k)) >>> (k - j));
170     }
171 
172     @CodeReflection
173     static long longShiftOps(long i, long j, long k) {
174         return ((-1 >> i) << (j << k)) >>> (k - j);
175     }
176 
177     @CodeReflection
178     static Object[] boxingAndUnboxing(int i, byte b, short s, char c, Integer ii, Byte bb, Short ss, Character cc) {
179         ii += i; ii += b; ii += s; ii += c;
180         i += ii; i += bb; i += ss; i += cc;
181         b += ii; b += bb; b += ss; b += cc;
182         s += ii; s += bb; s += ss; s += cc;
183         c += ii; c += bb; c += ss; c += cc;
184         return new Object[]{i, b, s, c};
185     }
186 
187     @CodeReflection
188     static String constructor(String s, int i, int j) {
189         return new String(s.getBytes(), i, j);
190     }
191 
192     @CodeReflection
193     static Class<?> classArray(int i, int j) {
194         Class<?>[] ifaces = new Class[1 + i + j];
195         ifaces[0] = Function.class;
196         return ifaces[0];
197     }
198 
199     @CodeReflection
200     static String[] stringArray(int i, int j) {
201         return new String[i];
202     }
203 
204     @CodeReflection
205     static String[][] stringArray2(int i, int j) {
206         return new String[i][];
207     }
208 
209     @CodeReflection
210     static String[][] stringArrayMulti(int i, int j) {
211         return new String[i][j];
212     }
213 
214     @CodeReflection
215     static int[][] initializedIntArray(int i, int j) {
216         return new int[][]{{i, j}, {i + j}};
217     }
218 
219     @CodeReflection
220     static int ifElseCompare(int i, int j) {
221         if (i < 3) {
222             i += 1;
223         } else {
224             j += 2;
225         }
226         return i + j;
227     }
228 
229     @CodeReflection
230     static int ifElseEquality(int i, int j) {
231         if (j != 0) {
232             if (i != 0) {
233                 i += 1;
234             } else {
235                 i += 2;
236             }
237         } else {
238             if (j != 0) {
239                 i += 3;
240             } else {
241                 i += 4;
242             }
243         }
244         return i;
245     }
246 
247     @CodeReflection
248     static int objectsCompare(Boolean b1, Boolean b2, Boolean b3) {
249         Object a = b1;
250         Object b = b2;
251         Object c = b3;
252         return a == b ? (a != c ? 1 : 2) : (b != c ? 3 : 4);
253     }
254 
255     @CodeReflection
256     static int conditionalExpr(int i, int j) {
257         return ((i - 1 >= 0) ? i - 1 : j - 1);
258     }
259 
260     @CodeReflection
261     static int nestedConditionalExpr(int i, int j) {
262         return (i < 2) ? (j < 3) ? i : j : i + j;
263     }
264 
265     static final int[] MAP = {0, 1, 2, 3, 4};
266 
267     @CodeReflection
268     static int deepStackBranches(boolean a, boolean b) {
269         return MAP[a ? MAP[b ? 1 : 2] : MAP[b ? 3 : 4]];
270     }
271 
272     @CodeReflection
273     static int tryFinally(int i, int j) {
274         try {
275             i = i + j;
276         } finally {
277             i = i + j;
278         }
279         return i;
280     }
281 
282     public record A(String s) {}
283 
284     @CodeReflection
285     static A newWithArgs(int i, int j) {
286         return new A("hello world".substring(i, i + j));
287     }
288 
289     @CodeReflection
290     static int loop(int n, int j) {
291         int sum = 0;
292         for (int i = 0; i < n; i++) {
293             sum = sum + j;
294         }
295         return sum;
296     }
297 
298 
299     @CodeReflection
300     static int ifElseNested(int a, int b) {
301         int c = a + b;
302         int d = 10 - a + b;
303         if (b < 3) {
304             if (a < 3) {
305                 a += 1;
306             } else {
307                 b += 2;
308             }
309             c += 3;
310         } else {
311             if (a > 2) {
312                 a += 4;
313             } else {
314                 b += 5;
315             }
316             d += 6;
317         }
318         return a + b + c + d;
319     }
320 
321     @CodeReflection
322     static int nestedLoop(int m, int n) {
323         int sum = 0;
324         for (int i = 0; i < m; i++) {
325             for (int j = 0; j < n; j++) {
326                 sum = sum + i + j;
327             }
328         }
329         return sum;
330     }
331 
332     @CodeReflection
333     static int methodCall(int a, int b) {
334         int i = Math.max(a, b);
335         return Math.negateExact(i);
336     }
337 
338     @CodeReflection
339     static int[] primitiveArray(int i, int j) {
340         int[] ia = new int[i + 1];
341         ia[0] = j;
342         return ia;
343     }
344 
345     @CodeReflection
346     static boolean not(boolean b) {
347         return !b;
348     }
349 
350     @CodeReflection
351     static boolean notCompare(int i, int j) {
352         boolean b = i < j;
353         return !b;
354     }
355 
356     @CodeReflection
357     static int mod(int i, int j) {
358         return i % (j + 1);
359     }
360 
361     @CodeReflection
362     static int xor(int i, int j) {
363         return i ^ j;
364     }
365 
366     @CodeReflection
367     static int whileLoop(int i, int n) { int
368         counter = 0;
369         while (i < n && counter < 3) {
370             counter++;
371             if (counter == 4) {
372                 break;
373             }
374             i++;
375         }
376         return counter;
377     }
378 
379     public interface Func {
380         int apply(int a);
381     }
382 
383     public interface QuotableFunc extends Quotable {
384         int apply(int a);
385     }
386 
387     static int consume(int i, Func f) {
388         return f.apply(i + 1);
389     }
390 
391     static int consumeQuotable(int i, QuotableFunc f) {
392         Assert.assertNotNull(Op.ofQuotable(f).get());
393         Assert.assertNotNull(Op.ofQuotable(f).get().op());
394         Assert.assertTrue(Op.ofQuotable(f).get().op() instanceof JavaOp.LambdaOp);
395         return f.apply(i + 1);
396     }
397 
398     @CodeReflection
399     static int lambda(int i) {
400         return consume(i, a -> -a);
401     }
402 
403     @CodeReflection
404     static int quotableLambda(int i) {
405         return consumeQuotable(i, a -> -a);
406     }
407 
408     @CodeReflection
409     static int lambdaWithCapture(int i, String s) {
410         return consume(i, a -> a + s.length());
411     }
412 
413     @CodeReflection
414     static int quotableLambdaWithCapture(int i, String s) {
415         return consumeQuotable(i, a -> a + s.length());
416     }
417 
418     @CodeReflection
419     static int nestedLambdasWithCaptures(int i, int j, String s) {
420         return consume(i, a -> consume(a, b -> a + b + j - s.length()) + s.length());
421     }
422 
423     @CodeReflection
424     static int nestedQuotableLambdasWithCaptures(int i, int j, String s) {
425         return consumeQuotable(i, a -> consumeQuotable(a, b -> a + b + j - s.length()) + s.length());
426     }
427 
428     @CodeReflection
429     static int methodHandle(int i) {
430         return consume(i, Math::negateExact);
431     }
432 
433     int instanceMethod(int i) {
434         return -i + 13;
435     }
436 
437     @CodeReflection
438     int instanceMethodHandle(int i) {
439         return consume(i, this::instanceMethod);
440     }
441 
442     static void consume(boolean b, Consumer<Object> requireNonNull) {
443         if (b) {
444             requireNonNull.accept(new Object());
445         } else try {
446             requireNonNull.accept(null);
447             throw new AssertionError("Expectend NPE");
448         } catch (NullPointerException expected) {
449         }
450     }
451 
452     @CodeReflection
453     static void nullReturningMethodHandle(boolean b) {
454         consume(b, Objects::requireNonNull);
455     }
456 
457     @CodeReflection
458     static boolean compareLong(long i, long j) {
459         return i > j;
460     }
461 
462     @CodeReflection
463     static boolean compareFloat(float i, float j) {
464         return i > j;
465     }
466 
467     @CodeReflection
468     static boolean compareDouble(double i, double j) {
469         return i > j;
470     }
471 
472     @CodeReflection
473     static int lookupSwitch(int i) {
474         return switch (1000 * i) {
475             case 1000 -> 1;
476             case 2000 -> 2;
477             case 3000 -> 3;
478             default -> 0;
479         };
480     }
481 
482     @CodeReflection
483     static int tableSwitch(int i) {
484         return switch (i) {
485             case 1 -> 1;
486             case 2 -> 2;
487             case 3 -> 3;
488             default -> 0;
489         };
490     }
491 
492     int instanceField = -1;
493 
494     @CodeReflection
495     int instanceFieldAccess(int i) {
496         int ret = instanceField;
497         instanceField = i;
498         return ret;
499     }
500 
501     @CodeReflection
502     static String stringConcat(String a, String b) {
503         return "a"+ a +"\u0001" + a + "b\u0002c" + b + "\u0001\u0002" + b + "dd";
504     }
505 
506     @CodeReflection
507     static String multiTypeConcat(int i, Boolean b, char c, Short s, float f, Double d) {
508         return "i:"+ i +" b:" + b + " c:" + c + " f:" + f + " d:" + d;
509     }
510 
511     @CodeReflection
512     static int ifTrue(int i) {
513         if (true) {
514             return i;
515         }
516         return -i;
517     }
518 
519     @CodeReflection
520     static int excHandlerFollowingSplitTable(boolean b) {
521         try {
522             if (b) return 1;
523             else throw new Exception();
524         } catch (Exception ex) {}
525         return 2;
526     }
527 
528     @CodeReflection
529     static int varModifiedInTryBlock(boolean b) {
530         int i = 0;
531         try {
532             i++;
533             if (b) throw new Exception();
534             i++;
535             throw new Exception();
536         } catch (Exception ex) {
537             return i;
538         }
539     }
540 
541     @CodeReflection
542     static boolean finallyWithLoop(boolean b) {
543         try {
544             while (b) {
545                 if (b)
546                     return false;
547                 b = !b;
548             }
549             return true;
550         } finally {
551             b = false;
552         }
553     }
554 
555     @CodeReflection
556     static long doubleUseOfOperand(int x) {
557         long piece = x;
558         return piece * piece;
559     }
560 
561     record TestData(Method testMethod) {
562         @Override
563         public String toString() {
564             String s = testMethod.getName() + Arrays.stream(testMethod.getParameterTypes())
565                     .map(Class::getSimpleName).collect(Collectors.joining(",", "(", ")"));
566             if (s.length() > 30) s = s.substring(0, 27) + "...";
567             return s;
568         }
569     }
570 
571     @DataProvider(name = "testMethods")
572     public static TestData[]testMethods() {
573         return Stream.of(TestBytecode.class.getDeclaredMethods())
574                 .filter(m -> m.isAnnotationPresent(CodeReflection.class))
575                 .map(TestData::new).toArray(TestData[]::new);
576     }
577 
578     private static byte[] CLASS_DATA;
579     private static ClassModel CLASS_MODEL;
580 
581     @BeforeClass
582     public static void setup() throws Exception {
583         CLASS_DATA = TestBytecode.class.getResourceAsStream("TestBytecode.class").readAllBytes();
584         CLASS_MODEL = ClassFile.of().parse(CLASS_DATA);
585     }
586 
587     private static MethodTypeDesc toMethodTypeDesc(Method m) {
588         return MethodTypeDesc.of(
589                 m.getReturnType().describeConstable().orElseThrow(),
590                 Arrays.stream(m.getParameterTypes())
591                         .map(cls -> cls.describeConstable().orElseThrow()).toList());
592     }
593 
594 
595     private static final Map<Class<?>, Object[]> TEST_ARGS = new IdentityHashMap<>();
596     private static Object[] values(Object... values) {
597         return values;
598     }
599     private static void initTestArgs(Object[] values, Class<?>... argTypes) {
600         for (var argType : argTypes) TEST_ARGS.put(argType, values);
601     }
602     static {
603         initTestArgs(values(1, 2, 4), int.class, Integer.class);
604         initTestArgs(values((byte)1, (byte)3, (byte)4), byte.class, Byte.class);
605         initTestArgs(values((short)1, (short)2, (short)3), short.class, Short.class);
606         initTestArgs(values((char)2, (char)3, (char)4), char.class, Character.class);
607         initTestArgs(values(false, true), boolean.class, Boolean.class);
608         initTestArgs(values("Hello World"), String.class);
609         initTestArgs(values(1l, 2l, 4l), long.class, Long.class);
610         initTestArgs(values(1f, 3f, 4f), float.class, Float.class);
611         initTestArgs(values(1d, 2d, 3d), double.class, Double.class);
612     }
613 
614     interface Executor {
615         void execute(Object[] args) throws Throwable;
616     }
617 
618     private static void permutateAllArgs(Class<?>[] argTypes, Executor executor) throws Throwable {
619         final int argn = argTypes.length;
620         Object[][] argValues = new Object[argn][];
621         for (int i = 0; i < argn; i++) {
622             argValues[i] = TEST_ARGS.get(argTypes[i]);
623         }
624         int[] argIndexes = new int[argn];
625         Object[] args = new Object[argn];
626         while (true) {
627             for (int i = 0; i < argn; i++) {
628                 args[i] = argValues[i][argIndexes[i]];
629             }
630             executor.execute(args);
631             int i = argn - 1;
632             while (i >= 0 && argIndexes[i] == argValues[i].length - 1) i--;
633             if (i < 0) return;
634             argIndexes[i++]++;
635             while (i < argn) argIndexes[i++] = 0;
636         }
637     }
638 
639     @Test(dataProvider = "testMethods")
640     public void testLift(TestData d) throws Throwable {
641         CoreOp.FuncOp flift;
642         try {
643             flift = BytecodeLift.lift(CLASS_DATA, d.testMethod.getName(), toMethodTypeDesc(d.testMethod));
644         } catch (Throwable e) {
645             ClassPrinter.toYaml(ClassFile.of().parse(TestBytecode.class.getResourceAsStream("TestBytecode.class").readAllBytes())
646                     .methods().stream().filter(m -> m.methodName().equalsString(d.testMethod().getName())).findAny().get(),
647                     ClassPrinter.Verbosity.CRITICAL_ATTRIBUTES, System.out::print);
648             System.out.println("Lift failed, compiled model:");
649             Op.ofMethod(d.testMethod).ifPresent(f -> f.writeTo(System.out));
650             throw e;
651         }
652         try {
653             Object receiver1, receiver2;
654             if (d.testMethod.accessFlags().contains(AccessFlag.STATIC)) {
655                 receiver1 = null;
656                 receiver2 = null;
657             } else {
658                 receiver1 = new TestBytecode();
659                 receiver2 = new TestBytecode();
660             }
661             permutateAllArgs(d.testMethod.getParameterTypes(), args ->
662                 Assert.assertEquals(invokeAndConvert(flift, receiver1, args), d.testMethod.invoke(receiver2, args)));
663         } catch (Throwable e) {
664             System.out.println("Compiled model:");
665             Op.ofMethod(d.testMethod).ifPresent(f -> f.writeTo(System.out));
666             System.out.println("Lifted model:");
667             flift.writeTo(System.out);
668             throw e;
669         }
670     }
671 
672     private static Object invokeAndConvert(CoreOp.FuncOp func, Object receiver, Object... args) {
673         List argl = new ArrayList(args.length + 1);
674         if (receiver != null) argl.add(receiver);
675         argl.addAll(Arrays.asList(args));
676         Object ret = Interpreter.invoke(MethodHandles.lookup(), func, argl);
677         if (ret instanceof Integer i) {
678             TypeElement rt = func.invokableType().returnType();
679             if (rt.equals(JavaType.BOOLEAN)) {
680                 return i != 0;
681             } else if (rt.equals(JavaType.BYTE)) {
682                 return i.byteValue();
683             } else if (rt.equals(JavaType.CHAR)) {
684                 return (short)i.intValue();
685             } else if (rt.equals(JavaType.SHORT)) {
686                 return i.shortValue();
687             }
688         }
689         return ret;
690     }
691 
692     @Test(dataProvider = "testMethods")
693     public void testGenerate(TestData d) throws Throwable {
694         CoreOp.FuncOp func = Op.ofMethod(d.testMethod).get();
695 
696         CoreOp.FuncOp lfunc;
697         try {
698             lfunc = func.transform(CopyContext.create(), OpTransformer.LOWERING_TRANSFORMER);
699         } catch (UnsupportedOperationException uoe) {
700             throw new SkipException("lowering caused:", uoe);
701         }
702 
703         try {
704             MethodHandle mh = BytecodeGenerator.generate(MethodHandles.lookup(), lfunc);
705             Object receiver1, receiver2;
706             if (d.testMethod.accessFlags().contains(AccessFlag.STATIC)) {
707                 receiver1 = null;
708                 receiver2 = null;
709             } else {
710                 receiver1 = new TestBytecode();
711                 receiver2 = new TestBytecode();
712             }
713             permutateAllArgs(d.testMethod.getParameterTypes(), args -> {
714                     List argl = new ArrayList(args.length + 1);
715                     if (receiver1 != null) argl.add(receiver1);
716                     argl.addAll(Arrays.asList(args));
717                     Assert.assertEquals(mh.invokeWithArguments(argl), d.testMethod.invoke(receiver2, args));
718             });
719         } catch (Throwable e) {
720             func.writeTo(System.out);
721             lfunc.writeTo(System.out);
722             String methodName = d.testMethod().getName();
723             for (var mm : CLASS_MODEL.methods()) {
724                 if (mm.methodName().equalsString(methodName)
725                         || mm.methodName().stringValue().startsWith("lambda$" + methodName + "$")) {
726                     ClassPrinter.toYaml(mm,
727                                         ClassPrinter.Verbosity.CRITICAL_ATTRIBUTES,
728                                         System.out::print);
729                 }
730             }
731             Files.list(Path.of("DUMP_CLASS_FILES")).forEach(p -> {
732                 if (p.getFileName().toString().matches(methodName + "\\..+\\.class")) try {
733                     ClassPrinter.toYaml(ClassFile.of().parse(p),
734                                         ClassPrinter.Verbosity.CRITICAL_ATTRIBUTES,
735                                         System.out::print);
736                 } catch (IOException ignore) {}
737             });
738             throw e;
739         }
740     }
741 }