1 /*
  2  * Copyright (c) 2024, 2025, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.
  8  *
  9  * This code is distributed in the hope that it will be useful, but WITHOUT
 10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 12  * version 2 for more details (a copy is included in the LICENSE file that
 13  * accompanied this code).
 14  *
 15  * You should have received a copy of the GNU General Public License version
 16  * 2 along with this work; if not, write to the Free Software Foundation,
 17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 18  *
 19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 20  * or visit www.oracle.com if you need additional information or have any
 21  * questions.
 22  */
 23 
 24 import jdk.incubator.code.*;
 25 import jdk.incubator.code.Reflect;
 26 import jdk.incubator.code.bytecode.BytecodeGenerator;
 27 import jdk.incubator.code.dialect.core.CoreOp;
 28 import jdk.incubator.code.dialect.java.JavaOp;
 29 import jdk.internal.classfile.components.ClassPrinter;
 30 import org.junit.jupiter.api.Assertions;
 31 import org.junit.jupiter.api.BeforeAll;
 32 import org.junit.jupiter.params.ParameterizedTest;
 33 import org.junit.jupiter.params.provider.MethodSource;
 34 import org.opentest4j.TestSkippedException;
 35 
 36 import java.io.IOException;
 37 import java.lang.classfile.ClassFile;
 38 import java.lang.classfile.ClassModel;
 39 import java.lang.constant.MethodTypeDesc;
 40 import java.lang.invoke.MethodHandle;
 41 import java.lang.invoke.MethodHandles;
 42 import java.lang.reflect.AccessFlag;
 43 import java.lang.reflect.Method;
 44 import java.nio.file.Files;
 45 import java.nio.file.Path;
 46 import java.util.*;
 47 import java.util.function.Consumer;
 48 import java.util.function.Function;
 49 import java.util.function.IntUnaryOperator;
 50 import java.util.stream.Collectors;
 51 import java.util.stream.Stream;
 52 
 53 /*
 54  * @test
 55  * @modules jdk.incubator.code
 56  * @modules java.base/jdk.internal.classfile.components
 57  * @enablePreview
 58  * @run junit/othervm -Djdk.invoke.MethodHandle.dumpClassFiles=true TestBytecode
 59  */
 60 
 61 public class TestBytecode {
 62 
 63     @Reflect
 64     static int intNumOps(int i, int j, int k) {
 65         k++;
 66         i = (i + j) / k - i % j;
 67         i--;
 68         return i;
 69     }
 70 
 71     @Reflect
 72     static byte byteNumOps(byte i, byte j, byte k) {
 73         k++;
 74         i = (byte) ((i + j) / k - i % j);
 75         i--;
 76         return i;
 77     }
 78 
 79     @Reflect
 80     static short shortNumOps(short i, short j, short k) {
 81         k++;
 82         i = (short) ((i + j) / k - i % j);
 83         i--;
 84         return i;
 85     }
 86 
 87     @Reflect
 88     static char charNumOps(char i, char j, char k) {
 89         k++;
 90         i = (char) ((i + j) / k - i % j);
 91         i--;
 92         return i;
 93     }
 94 
 95     @Reflect
 96     static long longNumOps(long i, long j, long k) {
 97         k++;
 98         i = (i + j) / k - i % j;
 99         i--;
100         return i;
101     }
102 
103     @Reflect
104     static float floatNumOps(float i, float j, float k) {
105         k++;
106         i = (i + j) / k - i % j;
107         i--;
108         return i;
109     }
110 
111     @Reflect
112     static double doubleNumOps(double i, double j, double k) {
113         k++;
114         i = (i + j) / k - i % j;
115         i--;
116         return i;
117     }
118 
119     @Reflect
120     static int intBitOps(int i, int j, int k) {
121         return ~(i & j | k ^ j);
122     }
123 
124     @Reflect
125     static byte byteBitOps(byte i, byte j, byte k) {
126         return (byte) ~(i & j | k ^ j);
127     }
128 
129     @Reflect
130     static short shortBitOps(short i, short j, short k) {
131         return (short) ~(i & j | k ^ j);
132     }
133 
134     @Reflect
135     static char charBitOps(char i, char j, char k) {
136         return (char) ~(i & j | k ^ j);
137     }
138 
139     @Reflect
140     static long longBitOps(long i, long j, long k) {
141         return ~(i & j | k ^ j);
142     }
143 
144     @Reflect
145     static boolean boolBitOps(boolean i, boolean j, boolean k) {
146         return i & j | k ^ j;
147     }
148 
149     @Reflect
150     static int intShiftOps(int i, int j, int k) {
151         return ((-1 >> i) << (j << k)) >>> (k - j);
152     }
153 
154     @Reflect
155     static byte byteShiftOps(byte i, byte j, byte k) {
156         return (byte) (((-1 >> i) << (j << k)) >>> (k - j));
157     }
158 
159     @Reflect
160     static short shortShiftOps(short i, short j, short k) {
161         return (short) (((-1 >> i) << (j << k)) >>> (k - j));
162     }
163 
164     @Reflect
165     static char charShiftOps(char i, char j, char k) {
166         return (char) (((-1 >> i) << (j << k)) >>> (k - j));
167     }
168 
169     @Reflect
170     static long longShiftOps(long i, long j, long k) {
171         return ((-1 >> i) << (j << k)) >>> (k - j);
172     }
173 
174     @Reflect
175     static Object[] boxingAndUnboxing(int i, byte b, short s, char c, Integer ii, Byte bb, Short ss, Character cc) {
176         ii += i; ii += b; ii += s; ii += c;
177         i += ii; i += bb; i += ss; i += cc;
178         b += ii; b += bb; b += ss; b += cc;
179         s += ii; s += bb; s += ss; s += cc;
180         c += ii; c += bb; c += ss; c += cc;
181         return new Object[]{i, b, s, c};
182     }
183 
184     @Reflect
185     static String constructor(String s, int i, int j) {
186         return new String(s.getBytes(), i, j);
187     }
188 
189     @Reflect
190     static Class<?> classArray(int i, int j) {
191         Class<?>[] ifaces = new Class[1 + i + j];
192         ifaces[0] = Function.class;
193         return ifaces[0];
194     }
195 
196     @Reflect
197     static String[] stringArray(int i, int j) {
198         return new String[i];
199     }
200 
201     @Reflect
202     static String[][] stringArray2(int i, int j) {
203         return new String[i][];
204     }
205 
206     @Reflect
207     static String[][] stringArrayMulti(int i, int j) {
208         return new String[i][j];
209     }
210 
211     @Reflect
212     static int[][] initializedIntArray(int i, int j) {
213         return new int[][]{{i, j}, {i + j}};
214     }
215 
216     @Reflect
217     static int ifElseCompare(int i, int j) {
218         if (i < 3) {
219             i += 1;
220         } else {
221             j += 2;
222         }
223         return i + j;
224     }
225 
226     @Reflect
227     static int ifElseEquality(int i, int j) {
228         if (j != 0) {
229             if (i != 0) {
230                 i += 1;
231             } else {
232                 i += 2;
233             }
234         } else {
235             if (j != 0) {
236                 i += 3;
237             } else {
238                 i += 4;
239             }
240         }
241         return i;
242     }
243 
244     @Reflect
245     static int objectsCompare(Boolean b1, Boolean b2, Boolean b3) {
246         Object a = b1;
247         Object b = b2;
248         Object c = b3;
249         return a == b ? (a != c ? 1 : 2) : (b != c ? 3 : 4);
250     }
251 
252     @Reflect
253     static int conditionalExpr(int i, int j) {
254         return ((i - 1 >= 0) ? i - 1 : j - 1);
255     }
256 
257     @Reflect
258     static int nestedConditionalExpr(int i, int j) {
259         return (i < 2) ? (j < 3) ? i : j : i + j;
260     }
261 
262     static final int[] MAP = {0, 1, 2, 3, 4};
263 
264     @Reflect
265     static int deepStackBranches(boolean a, boolean b) {
266         return MAP[a ? MAP[b ? 1 : 2] : MAP[b ? 3 : 4]];
267     }
268 
269     @Reflect
270     static int tryFinally(int i, int j) {
271         try {
272             i = i + j;
273         } finally {
274             i = i + j;
275         }
276         return i;
277     }
278 
279     public record A(String s) {}
280 
281     @Reflect
282     static A newWithArgs(int i, int j) {
283         return new A("hello world".substring(i, i + j));
284     }
285 
286     @Reflect
287     static int loop(int n, int j) {
288         int sum = 0;
289         for (int i = 0; i < n; i++) {
290             sum = sum + j;
291         }
292         return sum;
293     }
294 
295 
296     @Reflect
297     static int ifElseNested(int a, int b) {
298         int c = a + b;
299         int d = 10 - a + b;
300         if (b < 3) {
301             if (a < 3) {
302                 a += 1;
303             } else {
304                 b += 2;
305             }
306             c += 3;
307         } else {
308             if (a > 2) {
309                 a += 4;
310             } else {
311                 b += 5;
312             }
313             d += 6;
314         }
315         return a + b + c + d;
316     }
317 
318     @Reflect
319     static int nestedLoop(int m, int n) {
320         int sum = 0;
321         for (int i = 0; i < m; i++) {
322             for (int j = 0; j < n; j++) {
323                 sum = sum + i + j;
324             }
325         }
326         return sum;
327     }
328 
329     @Reflect
330     static int methodCall(int a, int b) {
331         int i = Math.max(a, b);
332         return Math.negateExact(i);
333     }
334 
335     @Reflect
336     static int[] primitiveArray(int i, int j) {
337         int[] ia = new int[i + 1];
338         ia[0] = j;
339         return ia;
340     }
341 
342     @Reflect
343     static boolean not(boolean b) {
344         return !b;
345     }
346 
347     @Reflect
348     static boolean notCompare(int i, int j) {
349         boolean b = i < j;
350         return !b;
351     }
352 
353     @Reflect
354     static int mod(int i, int j) {
355         return i % (j + 1);
356     }
357 
358     @Reflect
359     static int xor(int i, int j) {
360         return i ^ j;
361     }
362 
363     @Reflect
364     static int whileLoop(int i, int n) { int
365         counter = 0;
366         while (i < n && counter < 3) {
367             counter++;
368             if (counter == 4) {
369                 break;
370             }
371             i++;
372         }
373         return counter;
374     }
375 
376     static int consumeQuotable(int i, IntUnaryOperator f) {
377         Assertions.assertNotNull(Op.ofQuotable(f).get());
378         Assertions.assertNotNull(Op.ofQuotable(f).get().op());
379         Assertions.assertTrue(Op.ofQuotable(f).get().op() instanceof JavaOp.LambdaOp);
380         return f.applyAsInt(i + 1);
381     }
382 
383     @Reflect
384     static int quotableLambda(int i) {
385         return consumeQuotable(i, a -> -a);
386     }
387 
388     @Reflect
389     static int quotableLambdaWithCapture(int i, String s) {
390         return consumeQuotable(i, a -> a + s.length());
391     }
392 
393     @Reflect
394     static int nestedQuotableLambdasWithCaptures(int i, int j, String s) {
395         return consumeQuotable(i, a -> consumeQuotable(a, b -> a + b + j - s.length()) + s.length());
396     }
397 
398     @Reflect
399     static int methodHandle(int i) {
400         return consumeQuotable(i, Math::negateExact);
401     }
402 
403     int instanceMethod(int i) {
404         return -i + 13;
405     }
406 
407     @Reflect
408     int instanceMethodHandle(int i) {
409         return consumeQuotable(i, this::instanceMethod);
410     }
411 
412     static void consume(boolean b, Consumer<Object> requireNonNull) {
413         if (b) {
414             requireNonNull.accept(new Object());
415         } else try {
416             requireNonNull.accept(null);
417             throw new AssertionError("Expectend NPE");
418         } catch (NullPointerException expected) {
419         }
420     }
421 
422     @Reflect
423     static void nullReturningMethodHandle(boolean b) {
424         consume(b, Objects::requireNonNull);
425     }
426 
427     @Reflect
428     static boolean compareLong(long i, long j) {
429         return i > j;
430     }
431 
432     @Reflect
433     static boolean compareFloat(float i, float j) {
434         return i > j;
435     }
436 
437     @Reflect
438     static boolean compareDouble(double i, double j) {
439         return i > j;
440     }
441 
442     @Reflect
443     static int lookupSwitch(int i) {
444         return switch (1000 * i) {
445             case 1000 -> 1;
446             case 2000 -> 2;
447             case 3000 -> 3;
448             default -> 0;
449         };
450     }
451 
452     @Reflect
453     static int tableSwitch(int i) {
454         return switch (i) {
455             case 1 -> 1;
456             case 2 -> 2;
457             case 3 -> 3;
458             default -> 0;
459         };
460     }
461 
462     int instanceField = -1;
463 
464     @Reflect
465     int instanceFieldAccess(int i) {
466         int ret = instanceField;
467         instanceField = i;
468         return ret;
469     }
470 
471     @Reflect
472     static String stringConcat(String a, String b) {
473         return "a"+ a +"\u0001" + a + "b\u0002c" + b + "\u0001\u0002" + b + "dd";
474     }
475 
476     @Reflect
477     static String multiTypeConcat(int i, Boolean b, char c, Short s, float f, Double d) {
478         return "i:"+ i +" b:" + b + " c:" + c + " f:" + f + " d:" + d;
479     }
480 
481     @Reflect
482     static int ifTrue(int i) {
483         if (true) {
484             return i;
485         }
486         return -i;
487     }
488 
489     @Reflect
490     static int excHandlerFollowingSplitTable(boolean b) {
491         try {
492             if (b) return 1;
493             else throw new Exception();
494         } catch (Exception ex) {}
495         return 2;
496     }
497 
498     @Reflect
499     static int varModifiedInTryBlock(boolean b) {
500         int i = 0;
501         try {
502             i++;
503             if (b) throw new Exception();
504             i++;
505             throw new Exception();
506         } catch (Exception ex) {
507             return i;
508         }
509     }
510 
511     @Reflect
512     static boolean finallyWithLoop(boolean b) {
513         try {
514             while (b) {
515                 if (b)
516                     return false;
517                 b = !b;
518             }
519             return true;
520         } finally {
521             b = false;
522         }
523     }
524 
525     @Reflect
526     static long doubleUseOfOperand(int x) {
527         long piece = x;
528         return piece * piece;
529     }
530 
531     @Reflect
532     static String functionLambda(String s) {
533         return ((Function<String, String>)e -> e.substring(1)).apply(s);
534     }
535 
536     record TestData(Method testMethod) {
537         @Override
538         public String toString() {
539             String s = testMethod.getName() + Arrays.stream(testMethod.getParameterTypes())
540                     .map(Class::getSimpleName).collect(Collectors.joining(",", "(", ")"));
541             if (s.length() > 30) s = s.substring(0, 27) + "...";
542             return s;
543         }
544     }
545 
546     public static Stream<TestData> testMethods() {
547         return Stream.of(TestBytecode.class.getDeclaredMethods())
548                 .filter(m -> m.isAnnotationPresent(Reflect.class))
549                 .map(TestData::new);
550     }
551 
552     private static byte[] CLASS_DATA;
553     private static ClassModel CLASS_MODEL;
554 
555     @BeforeAll
556     public static void setup() throws Exception {
557         CLASS_DATA = TestBytecode.class.getResourceAsStream("TestBytecode.class").readAllBytes();
558         CLASS_MODEL = ClassFile.of().parse(CLASS_DATA);
559     }
560 
561     private static MethodTypeDesc toMethodTypeDesc(Method m) {
562         return MethodTypeDesc.of(
563                 m.getReturnType().describeConstable().orElseThrow(),
564                 Arrays.stream(m.getParameterTypes())
565                         .map(cls -> cls.describeConstable().orElseThrow()).toList());
566     }
567 
568 
569     private static final Map<Class<?>, Object[]> TEST_ARGS = new IdentityHashMap<>();
570     private static Object[] values(Object... values) {
571         return values;
572     }
573     private static void initTestArgs(Object[] values, Class<?>... argTypes) {
574         for (var argType : argTypes) TEST_ARGS.put(argType, values);
575     }
576     static {
577         initTestArgs(values(1, 2, 4), int.class, Integer.class);
578         initTestArgs(values((byte)1, (byte)3, (byte)4), byte.class, Byte.class);
579         initTestArgs(values((short)1, (short)2, (short)3), short.class, Short.class);
580         initTestArgs(values((char)2, (char)3, (char)4), char.class, Character.class);
581         initTestArgs(values(false, true), boolean.class, Boolean.class);
582         initTestArgs(values("Hello World"), String.class);
583         initTestArgs(values(1l, 2l, 4l), long.class, Long.class);
584         initTestArgs(values(1f, 3f, 4f), float.class, Float.class);
585         initTestArgs(values(1d, 2d, 3d), double.class, Double.class);
586     }
587 
588     interface Executor {
589         void execute(Object[] args) throws Throwable;
590     }
591 
592     private static void permutateAllArgs(Class<?>[] argTypes, Executor executor) throws Throwable {
593         final int argn = argTypes.length;
594         Object[][] argValues = new Object[argn][];
595         for (int i = 0; i < argn; i++) {
596             argValues[i] = TEST_ARGS.get(argTypes[i]);
597         }
598         int[] argIndexes = new int[argn];
599         Object[] args = new Object[argn];
600         while (true) {
601             for (int i = 0; i < argn; i++) {
602                 args[i] = argValues[i][argIndexes[i]];
603             }
604             executor.execute(args);
605             int i = argn - 1;
606             while (i >= 0 && argIndexes[i] == argValues[i].length - 1) i--;
607             if (i < 0) return;
608             argIndexes[i++]++;
609             while (i < argn) argIndexes[i++] = 0;
610         }
611     }
612 
613     @ParameterizedTest
614     @MethodSource("testMethods")
615     public void testGenerate(TestData d) throws Throwable {
616         CoreOp.FuncOp func = Op.ofMethod(d.testMethod).get();
617 
618         try {
619             MethodHandle mh = BytecodeGenerator.generate(MethodHandles.lookup(), func);
620             Object receiver1, receiver2;
621             if (d.testMethod.accessFlags().contains(AccessFlag.STATIC)) {
622                 receiver1 = null;
623                 receiver2 = null;
624             } else {
625                 receiver1 = new TestBytecode();
626                 receiver2 = new TestBytecode();
627             }
628             permutateAllArgs(d.testMethod.getParameterTypes(), args -> {
629                 List argl = new ArrayList(args.length + 1);
630                 if (receiver1 != null) argl.add(receiver1);
631                 argl.addAll(Arrays.asList(args));
632                 assertEquals(d.testMethod.invoke(receiver2, args), mh.invokeWithArguments(argl));
633             });
634         } catch (Throwable e) {
635             System.out.println(func.toText());
636             String methodName = d.testMethod().getName();
637             for (var mm : CLASS_MODEL.methods()) {
638                 if (mm.methodName().equalsString(methodName)
639                         || mm.methodName().stringValue().startsWith("lambda$" + methodName + "$")) {
640                     ClassPrinter.toYaml(mm,
641                                         ClassPrinter.Verbosity.CRITICAL_ATTRIBUTES,
642                                         System.out::print);
643                 }
644             }
645             Files.list(Path.of("DUMP_CLASS_FILES")).forEach(p -> {
646                 if (p.getFileName().toString().matches(methodName + "\\..+\\.class")) try {
647                     ClassPrinter.toYaml(ClassFile.of().parse(p),
648                                         ClassPrinter.Verbosity.CRITICAL_ATTRIBUTES,
649                                         System.out::print);
650                 } catch (IOException ignore) {}
651             });
652             throw e;
653         }
654     }
655 
656     private static void assertEquals(Object expected, Object actual) {
657         switch (expected) {
658             case int[] expArr when actual instanceof int[] actArr -> Assertions.assertArrayEquals(expArr, actArr);
659             case Object[] expArr when actual instanceof Object[] actArr -> Assertions.assertArrayEquals(expArr, actArr);
660             case null, default -> Assertions.assertEquals(expected, actual);
661         }
662     }
663 }