1 /*
  2  * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.
  8  *
  9  * This code is distributed in the hope that it will be useful, but WITHOUT
 10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 12  * version 2 for more details (a copy is included in the LICENSE file that
 13  * accompanied this code).
 14  *
 15  * You should have received a copy of the GNU General Public License version
 16  * 2 along with this work; if not, write to the Free Software Foundation,
 17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 18  *
 19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 20  * or visit www.oracle.com if you need additional information or have any
 21  * questions.
 22  */
 23 
 24 import java.io.IOException;
 25 import java.lang.classfile.ClassFile;
 26 import java.lang.classfile.ClassModel;
 27 import java.lang.classfile.components.ClassPrinter;
 28 import java.lang.constant.MethodTypeDesc;
 29 import java.lang.invoke.MethodHandle;
 30 import java.lang.invoke.MethodHandles;
 31 import java.lang.reflect.AccessFlag;
 32 import org.testng.Assert;
 33 import org.testng.SkipException;
 34 import org.testng.annotations.BeforeClass;
 35 import org.testng.annotations.DataProvider;
 36 import org.testng.annotations.Test;
 37 
 38 import jdk.incubator.code.*;
 39 import jdk.incubator.code.op.CoreOp;
 40 import jdk.incubator.code.bytecode.BytecodeLift;
 41 import jdk.incubator.code.interpreter.Interpreter;
 42 import java.lang.reflect.Method;
 43 import jdk.incubator.code.bytecode.BytecodeGenerator;
 44 import jdk.incubator.code.type.JavaType;
 45 import jdk.incubator.code.CodeReflection;
 46 import java.nio.file.Files;
 47 import java.nio.file.Path;
 48 import java.util.*;
 49 import java.util.function.Consumer;
 50 import java.util.function.Function;
 51 import java.util.stream.Collectors;
 52 import java.util.stream.Stream;
 53 
 54 /*
 55  * @test
 56  * @modules jdk.incubator.code
 57  * @enablePreview
 58  * @run testng/othervm -Djdk.invoke.MethodHandle.dumpClassFiles=true TestBytecode
 59  */
 60 
 61 public class TestBytecode {
 62 
 63     @CodeReflection
 64     static int intNumOps(int i, int j, int k) {
 65         k++;
 66         i = (i + j) / k - i % j;
 67         i--;
 68         return i;
 69     }
 70 
 71     @CodeReflection
 72     static byte byteNumOps(byte i, byte j, byte k) {
 73         k++;
 74         i = (byte) ((i + j) / k - i % j);
 75         i--;
 76         return i;
 77     }
 78 
 79     @CodeReflection
 80     static short shortNumOps(short i, short j, short k) {
 81         k++;
 82         i = (short) ((i + j) / k - i % j);
 83         i--;
 84         return i;
 85     }
 86 
 87     @CodeReflection
 88     static char charNumOps(char i, char j, char k) {
 89         k++;
 90         i = (char) ((i + j) / k - i % j);
 91         i--;
 92         return i;
 93     }
 94 
 95     @CodeReflection
 96     static long longNumOps(long i, long j, long k) {
 97         k++;
 98         i = (i + j) / k - i % j;
 99         i--;
100         return i;
101     }
102 
103     @CodeReflection
104     static float floatNumOps(float i, float j, float k) {
105         k++;
106         i = (i + j) / k - i % j;
107         i--;
108         return i;
109     }
110 
111     @CodeReflection
112     static double doubleNumOps(double i, double j, double k) {
113         k++;
114         i = (i + j) / k - i % j;
115         i--;
116         return i;
117     }
118 
119     @CodeReflection
120     static int intBitOps(int i, int j, int k) {
121         return ~(i & j | k ^ j);
122     }
123 
124     @CodeReflection
125     static byte byteBitOps(byte i, byte j, byte k) {
126         return (byte) ~(i & j | k ^ j);
127     }
128 
129     @CodeReflection
130     static short shortBitOps(short i, short j, short k) {
131         return (short) ~(i & j | k ^ j);
132     }
133 
134     @CodeReflection
135     static char charBitOps(char i, char j, char k) {
136         return (char) ~(i & j | k ^ j);
137     }
138 
139     @CodeReflection
140     static long longBitOps(long i, long j, long k) {
141         return ~(i & j | k ^ j);
142     }
143 
144     @CodeReflection
145     static boolean boolBitOps(boolean i, boolean j, boolean k) {
146         return i & j | k ^ j;
147     }
148 
149     @CodeReflection
150     static int intShiftOps(int i, int j, int k) {
151         return ((-1 >> i) << (j << k)) >>> (k - j);
152     }
153 
154     @CodeReflection
155     static byte byteShiftOps(byte i, byte j, byte k) {
156         return (byte) (((-1 >> i) << (j << k)) >>> (k - j));
157     }
158 
159     @CodeReflection
160     static short shortShiftOps(short i, short j, short k) {
161         return (short) (((-1 >> i) << (j << k)) >>> (k - j));
162     }
163 
164     @CodeReflection
165     static char charShiftOps(char i, char j, char k) {
166         return (char) (((-1 >> i) << (j << k)) >>> (k - j));
167     }
168 
169     @CodeReflection
170     static long longShiftOps(long i, long j, long k) {
171         return ((-1 >> i) << (j << k)) >>> (k - j);
172     }
173 
174     @CodeReflection
175     static Object[] boxingAndUnboxing(int i, byte b, short s, char c, Integer ii, Byte bb, Short ss, Character cc) {
176         ii += i; ii += b; ii += s; ii += c;
177         i += ii; i += bb; i += ss; i += cc;
178         b += ii; b += bb; b += ss; b += cc;
179         s += ii; s += bb; s += ss; s += cc;
180         c += ii; c += bb; c += ss; c += cc;
181         return new Object[]{i, b, s, c};
182     }
183 
184     @CodeReflection
185     static String constructor(String s, int i, int j) {
186         return new String(s.getBytes(), i, j);
187     }
188 
189     @CodeReflection
190     static Class<?> classArray(int i, int j) {
191         Class<?>[] ifaces = new Class[1 + i + j];
192         ifaces[0] = Function.class;
193         return ifaces[0];
194     }
195 
196     @CodeReflection
197     static String[] stringArray(int i, int j) {
198         return new String[i];
199     }
200 
201     @CodeReflection
202     static String[][] stringArray2(int i, int j) {
203         return new String[i][];
204     }
205 
206     @CodeReflection
207     static String[][] stringArrayMulti(int i, int j) {
208         return new String[i][j];
209     }
210 
211     @CodeReflection
212     static int[][] initializedIntArray(int i, int j) {
213         return new int[][]{{i, j}, {i + j}};
214     }
215 
216     @CodeReflection
217     static int ifElseCompare(int i, int j) {
218         if (i < 3) {
219             i += 1;
220         } else {
221             j += 2;
222         }
223         return i + j;
224     }
225 
226     @CodeReflection
227     static int ifElseEquality(int i, int j) {
228         if (j != 0) {
229             if (i != 0) {
230                 i += 1;
231             } else {
232                 i += 2;
233             }
234         } else {
235             if (j != 0) {
236                 i += 3;
237             } else {
238                 i += 4;
239             }
240         }
241         return i;
242     }
243 
244     @CodeReflection
245     static int objectsCompare(Boolean b1, Boolean b2, Boolean b3) {
246         Object a = b1;
247         Object b = b2;
248         Object c = b3;
249         return a == b ? (a != c ? 1 : 2) : (b != c ? 3 : 4);
250     }
251 
252     @CodeReflection
253     static int conditionalExpr(int i, int j) {
254         return ((i - 1 >= 0) ? i - 1 : j - 1);
255     }
256 
257     @CodeReflection
258     static int nestedConditionalExpr(int i, int j) {
259         return (i < 2) ? (j < 3) ? i : j : i + j;
260     }
261 
262     static final int[] MAP = {0, 1, 2, 3, 4};
263 
264     @CodeReflection
265     static int deepStackBranches(boolean a, boolean b) {
266         return MAP[a ? MAP[b ? 1 : 2] : MAP[b ? 3 : 4]];
267     }
268 
269     @CodeReflection
270     static int tryFinally(int i, int j) {
271         try {
272             i = i + j;
273         } finally {
274             i = i + j;
275         }
276         return i;
277     }
278 
279     public record A(String s) {}
280 
281     @CodeReflection
282     static A newWithArgs(int i, int j) {
283         return new A("hello world".substring(i, i + j));
284     }
285 
286     @CodeReflection
287     static int loop(int n, int j) {
288         int sum = 0;
289         for (int i = 0; i < n; i++) {
290             sum = sum + j;
291         }
292         return sum;
293     }
294 
295 
296     @CodeReflection
297     static int ifElseNested(int a, int b) {
298         int c = a + b;
299         int d = 10 - a + b;
300         if (b < 3) {
301             if (a < 3) {
302                 a += 1;
303             } else {
304                 b += 2;
305             }
306             c += 3;
307         } else {
308             if (a > 2) {
309                 a += 4;
310             } else {
311                 b += 5;
312             }
313             d += 6;
314         }
315         return a + b + c + d;
316     }
317 
318     @CodeReflection
319     static int nestedLoop(int m, int n) {
320         int sum = 0;
321         for (int i = 0; i < m; i++) {
322             for (int j = 0; j < n; j++) {
323                 sum = sum + i + j;
324             }
325         }
326         return sum;
327     }
328 
329     @CodeReflection
330     static int methodCall(int a, int b) {
331         int i = Math.max(a, b);
332         return Math.negateExact(i);
333     }
334 
335     @CodeReflection
336     static int[] primitiveArray(int i, int j) {
337         int[] ia = new int[i + 1];
338         ia[0] = j;
339         return ia;
340     }
341 
342     @CodeReflection
343     static boolean not(boolean b) {
344         return !b;
345     }
346 
347     @CodeReflection
348     static boolean notCompare(int i, int j) {
349         boolean b = i < j;
350         return !b;
351     }
352 
353     @CodeReflection
354     static int mod(int i, int j) {
355         return i % (j + 1);
356     }
357 
358     @CodeReflection
359     static int xor(int i, int j) {
360         return i ^ j;
361     }
362 
363     @CodeReflection
364     static int whileLoop(int i, int n) { int
365         counter = 0;
366         while (i < n && counter < 3) {
367             counter++;
368             if (counter == 4) {
369                 break;
370             }
371             i++;
372         }
373         return counter;
374     }
375 
376     public interface Func {
377         int apply(int a);
378     }
379 
380     public interface QuotableFunc extends Quotable {
381         int apply(int a);
382     }
383 
384     static int consume(int i, Func f) {
385         return f.apply(i + 1);
386     }
387 
388     static int consumeQuotable(int i, QuotableFunc f) {
389         Assert.assertNotNull(f.quoted());
390         Assert.assertNotNull(f.quoted().op());
391         Assert.assertTrue(f.quoted().op() instanceof CoreOp.LambdaOp);
392         return f.apply(i + 1);
393     }
394 
395     @CodeReflection
396     static int lambda(int i) {
397         return consume(i, a -> -a);
398     }
399 
400     @CodeReflection
401     static int quotableLambda(int i) {
402         return consumeQuotable(i, a -> -a);
403     }
404 
405     @CodeReflection
406     static int lambdaWithCapture(int i, String s) {
407         return consume(i, a -> a + s.length());
408     }
409 
410     @CodeReflection
411     static int quotableLambdaWithCapture(int i, String s) {
412         return consumeQuotable(i, a -> a + s.length());
413     }
414 
415     @CodeReflection
416     static int nestedLambdasWithCaptures(int i, int j, String s) {
417         return consume(i, a -> consume(a, b -> a + b + j - s.length()) + s.length());
418     }
419 
420     @CodeReflection
421     static int nestedQuotableLambdasWithCaptures(int i, int j, String s) {
422         return consumeQuotable(i, a -> consumeQuotable(a, b -> a + b + j - s.length()) + s.length());
423     }
424 
425     @CodeReflection
426     static int methodHandle(int i) {
427         return consume(i, Math::negateExact);
428     }
429 
430     int instanceMethod(int i) {
431         return -i + 13;
432     }
433 
434     @CodeReflection
435     int instanceMethodHandle(int i) {
436         return consume(i, this::instanceMethod);
437     }
438 
439     static void consume(boolean b, Consumer<Object> requireNonNull) {
440         if (b) {
441             requireNonNull.accept(new Object());
442         } else try {
443             requireNonNull.accept(null);
444             throw new AssertionError("Expectend NPE");
445         } catch (NullPointerException expected) {
446         }
447     }
448 
449     @CodeReflection
450     static void nullReturningMethodHandle(boolean b) {
451         consume(b, Objects::requireNonNull);
452     }
453 
454     @CodeReflection
455     static boolean compareLong(long i, long j) {
456         return i > j;
457     }
458 
459     @CodeReflection
460     static boolean compareFloat(float i, float j) {
461         return i > j;
462     }
463 
464     @CodeReflection
465     static boolean compareDouble(double i, double j) {
466         return i > j;
467     }
468 
469     @CodeReflection
470     static int lookupSwitch(int i) {
471         return switch (1000 * i) {
472             case 1000 -> 1;
473             case 2000 -> 2;
474             case 3000 -> 3;
475             default -> 0;
476         };
477     }
478 
479     @CodeReflection
480     static int tableSwitch(int i) {
481         return switch (i) {
482             case 1 -> 1;
483             case 2 -> 2;
484             case 3 -> 3;
485             default -> 0;
486         };
487     }
488 
489     int instanceField = -1;
490 
491     @CodeReflection
492     int instanceFieldAccess(int i) {
493         int ret = instanceField;
494         instanceField = i;
495         return ret;
496     }
497 
498     @CodeReflection
499     static String stringConcat(String a, String b) {
500         return "a"+ a +"\u0001" + a + "b\u0002c" + b + "\u0001\u0002" + b + "dd";
501     }
502 
503     @CodeReflection
504     static String multiTypeConcat(int i, Boolean b, char c, Short s, float f, Double d) {
505         return "i:"+ i +" b:" + b + " c:" + c + " f:" + f + " d:" + d;
506     }
507 
508     @CodeReflection
509     static int ifTrue(int i) {
510         if (true) {
511             return i;
512         }
513         return -i;
514     }
515 
516     @CodeReflection
517     static int excHandlerFollowingSplitTable(boolean b) {
518         try {
519             if (b) return 1;
520             else throw new Exception();
521         } catch (Exception ex) {}
522         return 2;
523     }
524 
525     @CodeReflection
526     static int varModifiedInTryBlock(boolean b) {
527         int i = 0;
528         try {
529             i++;
530             if (b) throw new Exception();
531             i++;
532             throw new Exception();
533         } catch (Exception ex) {
534             return i;
535         }
536     }
537 
538     @CodeReflection
539     static boolean finallyWithLoop(boolean b) {
540         try {
541             while (b) {
542                 if (b)
543                     return false;
544                 b = !b;
545             }
546             return true;
547         } finally {
548             b = false;
549         }
550     }
551 
552     @CodeReflection
553     static long doubleUseOfOperand(int x) {
554         long piece = x;
555         return piece * piece;
556     }
557 
558     record TestData(Method testMethod) {
559         @Override
560         public String toString() {
561             String s = testMethod.getName() + Arrays.stream(testMethod.getParameterTypes())
562                     .map(Class::getSimpleName).collect(Collectors.joining(",", "(", ")"));
563             if (s.length() > 30) s = s.substring(0, 27) + "...";
564             return s;
565         }
566     }
567 
568     @DataProvider(name = "testMethods")
569     public static TestData[]testMethods() {
570         return Stream.of(TestBytecode.class.getDeclaredMethods())
571                 .filter(m -> m.isAnnotationPresent(CodeReflection.class))
572                 .map(TestData::new).toArray(TestData[]::new);
573     }
574 
575     private static byte[] CLASS_DATA;
576     private static ClassModel CLASS_MODEL;
577 
578     @BeforeClass
579     public static void setup() throws Exception {
580         CLASS_DATA = TestBytecode.class.getResourceAsStream("TestBytecode.class").readAllBytes();
581         CLASS_MODEL = ClassFile.of().parse(CLASS_DATA);
582     }
583 
584     private static MethodTypeDesc toMethodTypeDesc(Method m) {
585         return MethodTypeDesc.of(
586                 m.getReturnType().describeConstable().orElseThrow(),
587                 Arrays.stream(m.getParameterTypes())
588                         .map(cls -> cls.describeConstable().orElseThrow()).toList());
589     }
590 
591 
592     private static final Map<Class<?>, Object[]> TEST_ARGS = new IdentityHashMap<>();
593     private static Object[] values(Object... values) {
594         return values;
595     }
596     private static void initTestArgs(Object[] values, Class<?>... argTypes) {
597         for (var argType : argTypes) TEST_ARGS.put(argType, values);
598     }
599     static {
600         initTestArgs(values(1, 2, 4), int.class, Integer.class);
601         initTestArgs(values((byte)1, (byte)3, (byte)4), byte.class, Byte.class);
602         initTestArgs(values((short)1, (short)2, (short)3), short.class, Short.class);
603         initTestArgs(values((char)2, (char)3, (char)4), char.class, Character.class);
604         initTestArgs(values(false, true), boolean.class, Boolean.class);
605         initTestArgs(values("Hello World"), String.class);
606         initTestArgs(values(1l, 2l, 4l), long.class, Long.class);
607         initTestArgs(values(1f, 3f, 4f), float.class, Float.class);
608         initTestArgs(values(1d, 2d, 3d), double.class, Double.class);
609     }
610 
611     interface Executor {
612         void execute(Object[] args) throws Throwable;
613     }
614 
615     private static void permutateAllArgs(Class<?>[] argTypes, Executor executor) throws Throwable {
616         final int argn = argTypes.length;
617         Object[][] argValues = new Object[argn][];
618         for (int i = 0; i < argn; i++) {
619             argValues[i] = TEST_ARGS.get(argTypes[i]);
620         }
621         int[] argIndexes = new int[argn];
622         Object[] args = new Object[argn];
623         while (true) {
624             for (int i = 0; i < argn; i++) {
625                 args[i] = argValues[i][argIndexes[i]];
626             }
627             executor.execute(args);
628             int i = argn - 1;
629             while (i >= 0 && argIndexes[i] == argValues[i].length - 1) i--;
630             if (i < 0) return;
631             argIndexes[i++]++;
632             while (i < argn) argIndexes[i++] = 0;
633         }
634     }
635 
636     @Test(dataProvider = "testMethods")
637     public void testLift(TestData d) throws Throwable {
638         CoreOp.FuncOp flift;
639         try {
640             flift = BytecodeLift.lift(CLASS_DATA, d.testMethod.getName(), toMethodTypeDesc(d.testMethod));
641         } catch (Throwable e) {
642             ClassPrinter.toYaml(ClassFile.of().parse(TestBytecode.class.getResourceAsStream("TestBytecode.class").readAllBytes())
643                     .methods().stream().filter(m -> m.methodName().equalsString(d.testMethod().getName())).findAny().get(),
644                     ClassPrinter.Verbosity.CRITICAL_ATTRIBUTES, System.out::print);
645             System.out.println("Lift failed, compiled model:");
646             Op.ofMethod(d.testMethod).ifPresent(f -> f.writeTo(System.out));
647             throw e;
648         }
649         try {
650             Object receiver1, receiver2;
651             if (d.testMethod.accessFlags().contains(AccessFlag.STATIC)) {
652                 receiver1 = null;
653                 receiver2 = null;
654             } else {
655                 receiver1 = new TestBytecode();
656                 receiver2 = new TestBytecode();
657             }
658             permutateAllArgs(d.testMethod.getParameterTypes(), args ->
659                 Assert.assertEquals(invokeAndConvert(flift, receiver1, args), d.testMethod.invoke(receiver2, args)));
660         } catch (Throwable e) {
661             System.out.println("Compiled model:");
662             Op.ofMethod(d.testMethod).ifPresent(f -> f.writeTo(System.out));
663             System.out.println("Lifted model:");
664             flift.writeTo(System.out);
665             throw e;
666         }
667     }
668 
669     private static Object invokeAndConvert(CoreOp.FuncOp func, Object receiver, Object... args) {
670         List argl = new ArrayList(args.length + 1);
671         if (receiver != null) argl.add(receiver);
672         argl.addAll(Arrays.asList(args));
673         Object ret = Interpreter.invoke(MethodHandles.lookup(), func, argl);
674         if (ret instanceof Integer i) {
675             TypeElement rt = func.invokableType().returnType();
676             if (rt.equals(JavaType.BOOLEAN)) {
677                 return i != 0;
678             } else if (rt.equals(JavaType.BYTE)) {
679                 return i.byteValue();
680             } else if (rt.equals(JavaType.CHAR)) {
681                 return (short)i.intValue();
682             } else if (rt.equals(JavaType.SHORT)) {
683                 return i.shortValue();
684             }
685         }
686         return ret;
687     }
688 
689     @Test(dataProvider = "testMethods")
690     public void testGenerate(TestData d) throws Throwable {
691         CoreOp.FuncOp func = Op.ofMethod(d.testMethod).get();
692 
693         CoreOp.FuncOp lfunc;
694         try {
695             lfunc = func.transform(CopyContext.create(), OpTransformer.LOWERING_TRANSFORMER);
696         } catch (UnsupportedOperationException uoe) {
697             throw new SkipException("lowering caused:", uoe);
698         }
699 
700         try {
701             MethodHandle mh = BytecodeGenerator.generate(MethodHandles.lookup(), lfunc);
702             Object receiver1, receiver2;
703             if (d.testMethod.accessFlags().contains(AccessFlag.STATIC)) {
704                 receiver1 = null;
705                 receiver2 = null;
706             } else {
707                 receiver1 = new TestBytecode();
708                 receiver2 = new TestBytecode();
709             }
710             permutateAllArgs(d.testMethod.getParameterTypes(), args -> {
711                     List argl = new ArrayList(args.length + 1);
712                     if (receiver1 != null) argl.add(receiver1);
713                     argl.addAll(Arrays.asList(args));
714                     Assert.assertEquals(mh.invokeWithArguments(argl), d.testMethod.invoke(receiver2, args));
715             });
716         } catch (Throwable e) {
717             func.writeTo(System.out);
718             lfunc.writeTo(System.out);
719             String methodName = d.testMethod().getName();
720             for (var mm : CLASS_MODEL.methods()) {
721                 if (mm.methodName().equalsString(methodName)
722                         || mm.methodName().stringValue().startsWith("lambda$" + methodName + "$")) {
723                     ClassPrinter.toYaml(mm,
724                                         ClassPrinter.Verbosity.CRITICAL_ATTRIBUTES,
725                                         System.out::print);
726                 }
727             }
728             Files.list(Path.of("DUMP_CLASS_FILES")).forEach(p -> {
729                 if (p.getFileName().toString().matches(methodName + "\\..+\\.class")) try {
730                     ClassPrinter.toYaml(ClassFile.of().parse(p),
731                                         ClassPrinter.Verbosity.CRITICAL_ATTRIBUTES,
732                                         System.out::print);
733                 } catch (IOException ignore) {}
734             });
735             throw e;
736         }
737     }
738 }