1 /*
  2  * Copyright (c) 2024, 2025, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.
  8  *
  9  * This code is distributed in the hope that it will be useful, but WITHOUT
 10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 12  * version 2 for more details (a copy is included in the LICENSE file that
 13  * accompanied this code).
 14  *
 15  * You should have received a copy of the GNU General Public License version
 16  * 2 along with this work; if not, write to the Free Software Foundation,
 17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 18  *
 19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 20  * or visit www.oracle.com if you need additional information or have any
 21  * questions.
 22  */
 23 
 24 import jdk.incubator.code.*;
 25 import jdk.incubator.code.Reflect;
 26 import jdk.incubator.code.bytecode.BytecodeGenerator;
 27 import jdk.incubator.code.dialect.core.CoreOp;
 28 import jdk.incubator.code.dialect.java.JavaOp;
 29 import jdk.internal.classfile.components.ClassPrinter;
 30 import org.junit.jupiter.api.Assertions;
 31 import org.junit.jupiter.api.BeforeAll;
 32 import org.junit.jupiter.params.ParameterizedTest;
 33 import org.junit.jupiter.params.provider.MethodSource;
 34 
 35 import java.io.IOException;
 36 import java.lang.classfile.ClassFile;
 37 import java.lang.classfile.ClassModel;
 38 import java.lang.constant.MethodTypeDesc;
 39 import java.lang.invoke.MethodHandle;
 40 import java.lang.invoke.MethodHandles;
 41 import java.lang.reflect.AccessFlag;
 42 import java.lang.reflect.Method;
 43 import java.nio.file.Files;
 44 import java.nio.file.Path;
 45 import java.util.*;
 46 import java.util.function.Consumer;
 47 import java.util.function.Function;
 48 import java.util.function.IntUnaryOperator;
 49 import java.util.stream.Collectors;
 50 import java.util.stream.Stream;
 51 
 52 /*
 53  * @test
 54  * @modules jdk.incubator.code
 55  * @modules java.base/jdk.internal.classfile.components
 56  * @enablePreview
 57  * @library ../
 58  * @run junit/othervm -Djdk.invoke.MethodHandle.dumpClassFiles=true TestBytecode
 59  * @run main Unreflect TestBytecode
 60  * @run junit/othervm -Djdk.invoke.MethodHandle.dumpClassFiles=true TestBytecode
 61  */
 62 
 63 public class TestBytecode {
 64 
 65     @Reflect
 66     static int intNumOps(int i, int j, int k) {
 67         k++;
 68         i = (i + j) / k - i % j;
 69         i--;
 70         return i;
 71     }
 72 
 73     @Reflect
 74     static byte byteNumOps(byte i, byte j, byte k) {
 75         k++;
 76         i = (byte) ((i + j) / k - i % j);
 77         i--;
 78         return i;
 79     }
 80 
 81     @Reflect
 82     static short shortNumOps(short i, short j, short k) {
 83         k++;
 84         i = (short) ((i + j) / k - i % j);
 85         i--;
 86         return i;
 87     }
 88 
 89     @Reflect
 90     static char charNumOps(char i, char j, char k) {
 91         k++;
 92         i = (char) ((i + j) / k - i % j);
 93         i--;
 94         return i;
 95     }
 96 
 97     @Reflect
 98     static long longNumOps(long i, long j, long k) {
 99         k++;
100         i = (i + j) / k - i % j;
101         i--;
102         return i;
103     }
104 
105     @Reflect
106     static float floatNumOps(float i, float j, float k) {
107         k++;
108         i = (i + j) / k - i % j;
109         i--;
110         return i;
111     }
112 
113     @Reflect
114     static double doubleNumOps(double i, double j, double k) {
115         k++;
116         i = (i + j) / k - i % j;
117         i--;
118         return i;
119     }
120 
121     @Reflect
122     static int intBitOps(int i, int j, int k) {
123         return ~(i & j | k ^ j);
124     }
125 
126     @Reflect
127     static byte byteBitOps(byte i, byte j, byte k) {
128         return (byte) ~(i & j | k ^ j);
129     }
130 
131     @Reflect
132     static short shortBitOps(short i, short j, short k) {
133         return (short) ~(i & j | k ^ j);
134     }
135 
136     @Reflect
137     static char charBitOps(char i, char j, char k) {
138         return (char) ~(i & j | k ^ j);
139     }
140 
141     @Reflect
142     static long longBitOps(long i, long j, long k) {
143         return ~(i & j | k ^ j);
144     }
145 
146     @Reflect
147     static boolean boolBitOps(boolean i, boolean j, boolean k) {
148         return i & j | k ^ j;
149     }
150 
151     @Reflect
152     static int intShiftOps(int i, int j, int k) {
153         return ((-1 >> i) << (j << k)) >>> (k - j);
154     }
155 
156     @Reflect
157     static byte byteShiftOps(byte i, byte j, byte k) {
158         return (byte) (((-1 >> i) << (j << k)) >>> (k - j));
159     }
160 
161     @Reflect
162     static short shortShiftOps(short i, short j, short k) {
163         return (short) (((-1 >> i) << (j << k)) >>> (k - j));
164     }
165 
166     @Reflect
167     static char charShiftOps(char i, char j, char k) {
168         return (char) (((-1 >> i) << (j << k)) >>> (k - j));
169     }
170 
171     @Reflect
172     static long longShiftOps(long i, long j, long k) {
173         return ((-1 >> i) << (j << k)) >>> (k - j);
174     }
175 
176     @Reflect
177     static Object[] boxingAndUnboxing(int i, byte b, short s, char c, Integer ii, Byte bb, Short ss, Character cc) {
178         ii += i; ii += b; ii += s; ii += c;
179         i += ii; i += bb; i += ss; i += cc;
180         b += ii; b += bb; b += ss; b += cc;
181         s += ii; s += bb; s += ss; s += cc;
182         c += ii; c += bb; c += ss; c += cc;
183         return new Object[]{i, b, s, c};
184     }
185 
186     @Reflect
187     static String constructor(String s, int i, int j) {
188         return new String(s.getBytes(), i, j);
189     }
190 
191     @Reflect
192     static Class<?> classArray(int i, int j) {
193         Class<?>[] ifaces = new Class[1 + i + j];
194         ifaces[0] = Function.class;
195         return ifaces[0];
196     }
197 
198     @Reflect
199     static String[] stringArray(int i, int j) {
200         return new String[i];
201     }
202 
203     @Reflect
204     static String[][] stringArray2(int i, int j) {
205         return new String[i][];
206     }
207 
208     @Reflect
209     static String[][] stringArrayMulti(int i, int j) {
210         return new String[i][j];
211     }
212 
213     @Reflect
214     static int[][] initializedIntArray(int i, int j) {
215         return new int[][]{{i, j}, {i + j}};
216     }
217 
218     @Reflect
219     static int ifElseCompare(int i, int j) {
220         if (i < 3) {
221             i += 1;
222         } else {
223             j += 2;
224         }
225         return i + j;
226     }
227 
228     @Reflect
229     static int ifElseEquality(int i, int j) {
230         if (j != 0) {
231             if (i != 0) {
232                 i += 1;
233             } else {
234                 i += 2;
235             }
236         } else {
237             if (j != 0) {
238                 i += 3;
239             } else {
240                 i += 4;
241             }
242         }
243         return i;
244     }
245 
246     @Reflect
247     static int objectsCompare(Boolean b1, Boolean b2, Boolean b3) {
248         Object a = b1;
249         Object b = b2;
250         Object c = b3;
251         return a == b ? (a != c ? 1 : 2) : (b != c ? 3 : 4);
252     }
253 
254     @Reflect
255     static int conditionalExpr(int i, int j) {
256         return ((i - 1 >= 0) ? i - 1 : j - 1);
257     }
258 
259     @Reflect
260     static int nestedConditionalExpr(int i, int j) {
261         return (i < 2) ? (j < 3) ? i : j : i + j;
262     }
263 
264     static final int[] MAP = {0, 1, 2, 3, 4};
265 
266     @Reflect
267     static int deepStackBranches(boolean a, boolean b) {
268         return MAP[a ? MAP[b ? 1 : 2] : MAP[b ? 3 : 4]];
269     }
270 
271     @Reflect
272     static int tryFinally(int i, int j) {
273         try {
274             i = i + j;
275         } finally {
276             i = i + j;
277         }
278         return i;
279     }
280 
281     public record A(String s) {}
282 
283     @Reflect
284     static A newWithArgs(int i, int j) {
285         return new A("hello world".substring(i, i + j));
286     }
287 
288     @Reflect
289     static int loop(int n, int j) {
290         int sum = 0;
291         for (int i = 0; i < n; i++) {
292             sum = sum + j;
293         }
294         return sum;
295     }
296 
297 
298     @Reflect
299     static int ifElseNested(int a, int b) {
300         int c = a + b;
301         int d = 10 - a + b;
302         if (b < 3) {
303             if (a < 3) {
304                 a += 1;
305             } else {
306                 b += 2;
307             }
308             c += 3;
309         } else {
310             if (a > 2) {
311                 a += 4;
312             } else {
313                 b += 5;
314             }
315             d += 6;
316         }
317         return a + b + c + d;
318     }
319 
320     @Reflect
321     static int nestedLoop(int m, int n) {
322         int sum = 0;
323         for (int i = 0; i < m; i++) {
324             for (int j = 0; j < n; j++) {
325                 sum = sum + i + j;
326             }
327         }
328         return sum;
329     }
330 
331     @Reflect
332     static int methodCall(int a, int b) {
333         int i = Math.max(a, b);
334         return Math.negateExact(i);
335     }
336 
337     @Reflect
338     static int[] primitiveArray(int i, int j) {
339         int[] ia = new int[i + 1];
340         ia[0] = j;
341         return ia;
342     }
343 
344     @Reflect
345     static boolean not(boolean b) {
346         return !b;
347     }
348 
349     @Reflect
350     static boolean notCompare(int i, int j) {
351         boolean b = i < j;
352         return !b;
353     }
354 
355     @Reflect
356     static int mod(int i, int j) {
357         return i % (j + 1);
358     }
359 
360     @Reflect
361     static int xor(int i, int j) {
362         return i ^ j;
363     }
364 
365     @Reflect
366     static int whileLoop(int i, int n) { int
367         counter = 0;
368         while (i < n && counter < 3) {
369             counter++;
370             if (counter == 4) {
371                 break;
372             }
373             i++;
374         }
375         return counter;
376     }
377 
378     static int consumeQuotable(int i, IntUnaryOperator f) {
379         Assertions.assertNotNull(Op.ofQuotable(f).get());
380         Assertions.assertNotNull(Op.ofQuotable(f).get().op());
381         Assertions.assertTrue(Op.ofQuotable(f).get().op() instanceof JavaOp.LambdaOp);
382         return f.applyAsInt(i + 1);
383     }
384 
385     @Reflect
386     static int quotableLambda(int i) {
387         return consumeQuotable(i, a -> -a);
388     }
389 
390     @Reflect
391     static int quotableLambdaWithCapture(int i, String s) {
392         return consumeQuotable(i, a -> a + s.length());
393     }
394 
395     @Reflect
396     static int nestedQuotableLambdasWithCaptures(int i, int j, String s) {
397         return consumeQuotable(i, a -> consumeQuotable(a, b -> a + b + j - s.length()) + s.length());
398     }
399 
400     @Reflect
401     static int methodHandle(int i) {
402         return consumeQuotable(i, Math::negateExact);
403     }
404 
405     int instanceMethod(int i) {
406         return -i + 13;
407     }
408 
409     @Reflect
410     int instanceMethodHandle(int i) {
411         return consumeQuotable(i, this::instanceMethod);
412     }
413 
414     static void consume(boolean b, Consumer<Object> requireNonNull) {
415         if (b) {
416             requireNonNull.accept(new Object());
417         } else try {
418             requireNonNull.accept(null);
419             throw new AssertionError("Expectend NPE");
420         } catch (NullPointerException expected) {
421         }
422     }
423 
424     @Reflect
425     static void nullReturningMethodHandle(boolean b) {
426         consume(b, Objects::requireNonNull);
427     }
428 
429     @Reflect
430     static boolean compareLong(long i, long j) {
431         return i > j;
432     }
433 
434     @Reflect
435     static boolean compareFloat(float i, float j) {
436         return i > j;
437     }
438 
439     @Reflect
440     static boolean compareDouble(double i, double j) {
441         return i > j;
442     }
443 
444     @Reflect
445     static int lookupSwitch(int i) {
446         return switch (1000 * i) {
447             case 1000 -> 1;
448             case 2000 -> 2;
449             case 3000 -> 3;
450             default -> 0;
451         };
452     }
453 
454     @Reflect
455     static int tableSwitch(int i) {
456         return switch (i) {
457             case 1 -> 1;
458             case 2 -> 2;
459             case 3 -> 3;
460             default -> 0;
461         };
462     }
463 
464     int instanceField = -1;
465 
466     @Reflect
467     int instanceFieldAccess(int i) {
468         int ret = instanceField;
469         instanceField = i;
470         return ret;
471     }
472 
473     @Reflect
474     static String stringConcat(String a, String b) {
475         return "a"+ a +"\u0001" + a + "b\u0002c" + b + "\u0001\u0002" + b + "dd";
476     }
477 
478     @Reflect
479     static String multiTypeConcat(int i, Boolean b, char c, Short s, float f, Double d) {
480         return "i:"+ i +" b:" + b + " c:" + c + " f:" + f + " d:" + d;
481     }
482 
483     @Reflect
484     static int ifTrue(int i) {
485         if (true) {
486             return i;
487         }
488         return -i;
489     }
490 
491     @Reflect
492     static int excHandlerFollowingSplitTable(boolean b) {
493         try {
494             if (b) return 1;
495             else throw new Exception();
496         } catch (Exception ex) {}
497         return 2;
498     }
499 
500     @Reflect
501     static int varModifiedInTryBlock(boolean b) {
502         int i = 0;
503         try {
504             i++;
505             if (b) throw new Exception();
506             i++;
507             throw new Exception();
508         } catch (Exception ex) {
509             return i;
510         }
511     }
512 
513     @Reflect
514     static boolean finallyWithLoop(boolean b) {
515         try {
516             while (b) {
517                 if (b)
518                     return false;
519                 b = !b;
520             }
521             return true;
522         } finally {
523             b = false;
524         }
525     }
526 
527     @Reflect
528     static long doubleUseOfOperand(int x) {
529         long piece = x;
530         return piece * piece;
531     }
532 
533     @Reflect
534     static String functionLambda(String s) {
535         return ((Function<String, String>)e -> e.substring(1)).apply(s);
536     }
537 
538     @Reflect
539     static String staticVarargInvokeWithNoRegularArgs(String s) {
540         String a = "prefix";
541         return Arrays.asList(a, s).toString();
542     }
543 
544     record TestData(Method testMethod) {
545         @Override
546         public String toString() {
547             String s = testMethod.getName() + Arrays.stream(testMethod.getParameterTypes())
548                     .map(Class::getSimpleName).collect(Collectors.joining(",", "(", ")"));
549             if (s.length() > 30) s = s.substring(0, 27) + "...";
550             return s;
551         }
552     }
553 
554     public static Stream<TestData> testMethods() {
555         return Stream.of(TestBytecode.class.getDeclaredMethods())
556                 .filter(m -> m.isAnnotationPresent(Reflect.class))
557                 .map(TestData::new);
558     }
559 
560     private static byte[] CLASS_DATA;
561     private static ClassModel CLASS_MODEL;
562 
563     @BeforeAll
564     public static void setup() throws Exception {
565         CLASS_DATA = TestBytecode.class.getResourceAsStream("TestBytecode.class").readAllBytes();
566         CLASS_MODEL = ClassFile.of().parse(CLASS_DATA);
567     }
568 
569     private static MethodTypeDesc toMethodTypeDesc(Method m) {
570         return MethodTypeDesc.of(
571                 m.getReturnType().describeConstable().orElseThrow(),
572                 Arrays.stream(m.getParameterTypes())
573                         .map(cls -> cls.describeConstable().orElseThrow()).toList());
574     }
575 
576 
577     private static final Map<Class<?>, Object[]> TEST_ARGS = new IdentityHashMap<>();
578     private static Object[] values(Object... values) {
579         return values;
580     }
581     private static void initTestArgs(Object[] values, Class<?>... argTypes) {
582         for (var argType : argTypes) TEST_ARGS.put(argType, values);
583     }
584     static {
585         initTestArgs(values(1, 2, 4), int.class, Integer.class);
586         initTestArgs(values((byte)1, (byte)3, (byte)4), byte.class, Byte.class);
587         initTestArgs(values((short)1, (short)2, (short)3), short.class, Short.class);
588         initTestArgs(values((char)2, (char)3, (char)4), char.class, Character.class);
589         initTestArgs(values(false, true), boolean.class, Boolean.class);
590         initTestArgs(values("Hello World"), String.class);
591         initTestArgs(values(1l, 2l, 4l), long.class, Long.class);
592         initTestArgs(values(1f, 3f, 4f), float.class, Float.class);
593         initTestArgs(values(1d, 2d, 3d), double.class, Double.class);
594     }
595 
596     interface Executor {
597         void execute(Object[] args) throws Throwable;
598     }
599 
600     private static void permutateAllArgs(Class<?>[] argTypes, Executor executor) throws Throwable {
601         final int argn = argTypes.length;
602         Object[][] argValues = new Object[argn][];
603         for (int i = 0; i < argn; i++) {
604             argValues[i] = TEST_ARGS.get(argTypes[i]);
605         }
606         int[] argIndexes = new int[argn];
607         Object[] args = new Object[argn];
608         while (true) {
609             for (int i = 0; i < argn; i++) {
610                 args[i] = argValues[i][argIndexes[i]];
611             }
612             executor.execute(args);
613             int i = argn - 1;
614             while (i >= 0 && argIndexes[i] == argValues[i].length - 1) i--;
615             if (i < 0) return;
616             argIndexes[i++]++;
617             while (i < argn) argIndexes[i++] = 0;
618         }
619     }
620 
621     @ParameterizedTest
622     @MethodSource("testMethods")
623     public void testGenerate(TestData d) throws Throwable {
624         CoreOp.FuncOp func = Op.ofMethod(d.testMethod).get();
625 
626         try {
627             MethodHandle mh = BytecodeGenerator.generate(MethodHandles.lookup(), func);
628             Object receiver1, receiver2;
629             if (d.testMethod.accessFlags().contains(AccessFlag.STATIC)) {
630                 receiver1 = null;
631                 receiver2 = null;
632             } else {
633                 receiver1 = new TestBytecode();
634                 receiver2 = new TestBytecode();
635             }
636             permutateAllArgs(d.testMethod.getParameterTypes(), args -> {
637                 List argl = new ArrayList(args.length + 1);
638                 if (receiver1 != null) argl.add(receiver1);
639                 argl.addAll(Arrays.asList(args));
640                 assertEquals(d.testMethod.invoke(receiver2, args), mh.invokeWithArguments(argl));
641             });
642         } catch (Throwable e) {
643             System.out.println(func.toText());
644             String methodName = d.testMethod().getName();
645             for (var mm : CLASS_MODEL.methods()) {
646                 if (mm.methodName().equalsString(methodName)
647                         || mm.methodName().stringValue().startsWith("lambda$" + methodName + "$")) {
648                     ClassPrinter.toYaml(mm,
649                                         ClassPrinter.Verbosity.CRITICAL_ATTRIBUTES,
650                                         System.out::print);
651                 }
652             }
653             Files.list(Path.of("DUMP_CLASS_FILES")).forEach(p -> {
654                 if (p.getFileName().toString().matches(methodName + "\\..+\\.class")) try {
655                     ClassPrinter.toYaml(ClassFile.of().parse(p),
656                                         ClassPrinter.Verbosity.CRITICAL_ATTRIBUTES,
657                                         System.out::print);
658                 } catch (IOException ignore) {}
659             });
660             throw e;
661         }
662     }
663 
664     private static void assertEquals(Object expected, Object actual) {
665         switch (expected) {
666             case int[] expArr when actual instanceof int[] actArr -> Assertions.assertArrayEquals(expArr, actArr);
667             case Object[] expArr when actual instanceof Object[] actArr -> Assertions.assertArrayEquals(expArr, actArr);
668             case null, default -> Assertions.assertEquals(expected, actual);
669         }
670     }
671 }