1 /*
  2  * Copyright (c) 2024, 2025, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.
  8  *
  9  * This code is distributed in the hope that it will be useful, but WITHOUT
 10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 12  * version 2 for more details (a copy is included in the LICENSE file that
 13  * accompanied this code).
 14  *
 15  * You should have received a copy of the GNU General Public License version
 16  * 2 along with this work; if not, write to the Free Software Foundation,
 17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 18  *
 19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 20  * or visit www.oracle.com if you need additional information or have any
 21  * questions.
 22  */
 23 
 24 import jdk.incubator.code.*;
 25 import jdk.incubator.code.bytecode.BytecodeGenerator;
 26 import jdk.incubator.code.dialect.core.CoreOp;
 27 import jdk.incubator.code.dialect.java.JavaOp;
 28 import jdk.internal.classfile.components.ClassPrinter;
 29 import org.junit.jupiter.api.Assertions;
 30 import org.junit.jupiter.api.BeforeAll;
 31 import org.junit.jupiter.params.ParameterizedTest;
 32 import org.junit.jupiter.params.provider.MethodSource;
 33 import org.opentest4j.TestSkippedException;
 34 
 35 import java.io.IOException;
 36 import java.lang.classfile.ClassFile;
 37 import java.lang.classfile.ClassModel;
 38 import java.lang.constant.MethodTypeDesc;
 39 import java.lang.invoke.MethodHandle;
 40 import java.lang.invoke.MethodHandles;
 41 import java.lang.reflect.AccessFlag;
 42 import java.lang.reflect.Method;
 43 import java.nio.file.Files;
 44 import java.nio.file.Path;
 45 import java.util.*;
 46 import java.util.function.Consumer;
 47 import java.util.function.Function;
 48 import java.util.stream.Collectors;
 49 import java.util.stream.Stream;
 50 
 51 /*
 52  * @test
 53  * @modules jdk.incubator.code
 54  * @modules java.base/jdk.internal.classfile.components
 55  * @enablePreview
 56  * @run junit/othervm -Djdk.invoke.MethodHandle.dumpClassFiles=true TestBytecode
 57  */
 58 
 59 public class TestBytecode {
 60 
 61     @CodeReflection
 62     static int intNumOps(int i, int j, int k) {
 63         k++;
 64         i = (i + j) / k - i % j;
 65         i--;
 66         return i;
 67     }
 68 
 69     @CodeReflection
 70     static byte byteNumOps(byte i, byte j, byte k) {
 71         k++;
 72         i = (byte) ((i + j) / k - i % j);
 73         i--;
 74         return i;
 75     }
 76 
 77     @CodeReflection
 78     static short shortNumOps(short i, short j, short k) {
 79         k++;
 80         i = (short) ((i + j) / k - i % j);
 81         i--;
 82         return i;
 83     }
 84 
 85     @CodeReflection
 86     static char charNumOps(char i, char j, char k) {
 87         k++;
 88         i = (char) ((i + j) / k - i % j);
 89         i--;
 90         return i;
 91     }
 92 
 93     @CodeReflection
 94     static long longNumOps(long i, long j, long k) {
 95         k++;
 96         i = (i + j) / k - i % j;
 97         i--;
 98         return i;
 99     }
100 
101     @CodeReflection
102     static float floatNumOps(float i, float j, float k) {
103         k++;
104         i = (i + j) / k - i % j;
105         i--;
106         return i;
107     }
108 
109     @CodeReflection
110     static double doubleNumOps(double i, double j, double k) {
111         k++;
112         i = (i + j) / k - i % j;
113         i--;
114         return i;
115     }
116 
117     @CodeReflection
118     static int intBitOps(int i, int j, int k) {
119         return ~(i & j | k ^ j);
120     }
121 
122     @CodeReflection
123     static byte byteBitOps(byte i, byte j, byte k) {
124         return (byte) ~(i & j | k ^ j);
125     }
126 
127     @CodeReflection
128     static short shortBitOps(short i, short j, short k) {
129         return (short) ~(i & j | k ^ j);
130     }
131 
132     @CodeReflection
133     static char charBitOps(char i, char j, char k) {
134         return (char) ~(i & j | k ^ j);
135     }
136 
137     @CodeReflection
138     static long longBitOps(long i, long j, long k) {
139         return ~(i & j | k ^ j);
140     }
141 
142     @CodeReflection
143     static boolean boolBitOps(boolean i, boolean j, boolean k) {
144         return i & j | k ^ j;
145     }
146 
147     @CodeReflection
148     static int intShiftOps(int i, int j, int k) {
149         return ((-1 >> i) << (j << k)) >>> (k - j);
150     }
151 
152     @CodeReflection
153     static byte byteShiftOps(byte i, byte j, byte k) {
154         return (byte) (((-1 >> i) << (j << k)) >>> (k - j));
155     }
156 
157     @CodeReflection
158     static short shortShiftOps(short i, short j, short k) {
159         return (short) (((-1 >> i) << (j << k)) >>> (k - j));
160     }
161 
162     @CodeReflection
163     static char charShiftOps(char i, char j, char k) {
164         return (char) (((-1 >> i) << (j << k)) >>> (k - j));
165     }
166 
167     @CodeReflection
168     static long longShiftOps(long i, long j, long k) {
169         return ((-1 >> i) << (j << k)) >>> (k - j);
170     }
171 
172     @CodeReflection
173     static Object[] boxingAndUnboxing(int i, byte b, short s, char c, Integer ii, Byte bb, Short ss, Character cc) {
174         ii += i; ii += b; ii += s; ii += c;
175         i += ii; i += bb; i += ss; i += cc;
176         b += ii; b += bb; b += ss; b += cc;
177         s += ii; s += bb; s += ss; s += cc;
178         c += ii; c += bb; c += ss; c += cc;
179         return new Object[]{i, b, s, c};
180     }
181 
182     @CodeReflection
183     static String constructor(String s, int i, int j) {
184         return new String(s.getBytes(), i, j);
185     }
186 
187     @CodeReflection
188     static Class<?> classArray(int i, int j) {
189         Class<?>[] ifaces = new Class[1 + i + j];
190         ifaces[0] = Function.class;
191         return ifaces[0];
192     }
193 
194     @CodeReflection
195     static String[] stringArray(int i, int j) {
196         return new String[i];
197     }
198 
199     @CodeReflection
200     static String[][] stringArray2(int i, int j) {
201         return new String[i][];
202     }
203 
204     @CodeReflection
205     static String[][] stringArrayMulti(int i, int j) {
206         return new String[i][j];
207     }
208 
209     @CodeReflection
210     static int[][] initializedIntArray(int i, int j) {
211         return new int[][]{{i, j}, {i + j}};
212     }
213 
214     @CodeReflection
215     static int ifElseCompare(int i, int j) {
216         if (i < 3) {
217             i += 1;
218         } else {
219             j += 2;
220         }
221         return i + j;
222     }
223 
224     @CodeReflection
225     static int ifElseEquality(int i, int j) {
226         if (j != 0) {
227             if (i != 0) {
228                 i += 1;
229             } else {
230                 i += 2;
231             }
232         } else {
233             if (j != 0) {
234                 i += 3;
235             } else {
236                 i += 4;
237             }
238         }
239         return i;
240     }
241 
242     @CodeReflection
243     static int objectsCompare(Boolean b1, Boolean b2, Boolean b3) {
244         Object a = b1;
245         Object b = b2;
246         Object c = b3;
247         return a == b ? (a != c ? 1 : 2) : (b != c ? 3 : 4);
248     }
249 
250     @CodeReflection
251     static int conditionalExpr(int i, int j) {
252         return ((i - 1 >= 0) ? i - 1 : j - 1);
253     }
254 
255     @CodeReflection
256     static int nestedConditionalExpr(int i, int j) {
257         return (i < 2) ? (j < 3) ? i : j : i + j;
258     }
259 
260     static final int[] MAP = {0, 1, 2, 3, 4};
261 
262     @CodeReflection
263     static int deepStackBranches(boolean a, boolean b) {
264         return MAP[a ? MAP[b ? 1 : 2] : MAP[b ? 3 : 4]];
265     }
266 
267     @CodeReflection
268     static int tryFinally(int i, int j) {
269         try {
270             i = i + j;
271         } finally {
272             i = i + j;
273         }
274         return i;
275     }
276 
277     public record A(String s) {}
278 
279     @CodeReflection
280     static A newWithArgs(int i, int j) {
281         return new A("hello world".substring(i, i + j));
282     }
283 
284     @CodeReflection
285     static int loop(int n, int j) {
286         int sum = 0;
287         for (int i = 0; i < n; i++) {
288             sum = sum + j;
289         }
290         return sum;
291     }
292 
293 
294     @CodeReflection
295     static int ifElseNested(int a, int b) {
296         int c = a + b;
297         int d = 10 - a + b;
298         if (b < 3) {
299             if (a < 3) {
300                 a += 1;
301             } else {
302                 b += 2;
303             }
304             c += 3;
305         } else {
306             if (a > 2) {
307                 a += 4;
308             } else {
309                 b += 5;
310             }
311             d += 6;
312         }
313         return a + b + c + d;
314     }
315 
316     @CodeReflection
317     static int nestedLoop(int m, int n) {
318         int sum = 0;
319         for (int i = 0; i < m; i++) {
320             for (int j = 0; j < n; j++) {
321                 sum = sum + i + j;
322             }
323         }
324         return sum;
325     }
326 
327     @CodeReflection
328     static int methodCall(int a, int b) {
329         int i = Math.max(a, b);
330         return Math.negateExact(i);
331     }
332 
333     @CodeReflection
334     static int[] primitiveArray(int i, int j) {
335         int[] ia = new int[i + 1];
336         ia[0] = j;
337         return ia;
338     }
339 
340     @CodeReflection
341     static boolean not(boolean b) {
342         return !b;
343     }
344 
345     @CodeReflection
346     static boolean notCompare(int i, int j) {
347         boolean b = i < j;
348         return !b;
349     }
350 
351     @CodeReflection
352     static int mod(int i, int j) {
353         return i % (j + 1);
354     }
355 
356     @CodeReflection
357     static int xor(int i, int j) {
358         return i ^ j;
359     }
360 
361     @CodeReflection
362     static int whileLoop(int i, int n) { int
363         counter = 0;
364         while (i < n && counter < 3) {
365             counter++;
366             if (counter == 4) {
367                 break;
368             }
369             i++;
370         }
371         return counter;
372     }
373 
374     public interface Func {
375         int apply(int a);
376     }
377 
378     public interface QuotableFunc extends Quotable {
379         int apply(int a);
380     }
381 
382     static int consume(int i, Func f) {
383         return f.apply(i + 1);
384     }
385 
386     static int consumeQuotable(int i, QuotableFunc f) {
387         Assertions.assertNotNull(Op.ofQuotable(f).get());
388         Assertions.assertNotNull(Op.ofQuotable(f).get().op());
389         Assertions.assertTrue(Op.ofQuotable(f).get().op() instanceof JavaOp.LambdaOp);
390         return f.apply(i + 1);
391     }
392 
393     @CodeReflection
394     static int lambda(int i) {
395         return consume(i, a -> -a);
396     }
397 
398     @CodeReflection
399     static int quotableLambda(int i) {
400         return consumeQuotable(i, a -> -a);
401     }
402 
403     @CodeReflection
404     static int lambdaWithCapture(int i, String s) {
405         return consume(i, a -> a + s.length());
406     }
407 
408     @CodeReflection
409     static int quotableLambdaWithCapture(int i, String s) {
410         return consumeQuotable(i, a -> a + s.length());
411     }
412 
413     @CodeReflection
414     static int nestedLambdasWithCaptures(int i, int j, String s) {
415         return consume(i, a -> consume(a, b -> a + b + j - s.length()) + s.length());
416     }
417 
418     @CodeReflection
419     static int nestedQuotableLambdasWithCaptures(int i, int j, String s) {
420         return consumeQuotable(i, a -> consumeQuotable(a, b -> a + b + j - s.length()) + s.length());
421     }
422 
423     @CodeReflection
424     static int methodHandle(int i) {
425         return consume(i, Math::negateExact);
426     }
427 
428     int instanceMethod(int i) {
429         return -i + 13;
430     }
431 
432     @CodeReflection
433     int instanceMethodHandle(int i) {
434         return consume(i, this::instanceMethod);
435     }
436 
437     static void consume(boolean b, Consumer<Object> requireNonNull) {
438         if (b) {
439             requireNonNull.accept(new Object());
440         } else try {
441             requireNonNull.accept(null);
442             throw new AssertionError("Expectend NPE");
443         } catch (NullPointerException expected) {
444         }
445     }
446 
447     @CodeReflection
448     static void nullReturningMethodHandle(boolean b) {
449         consume(b, Objects::requireNonNull);
450     }
451 
452     @CodeReflection
453     static boolean compareLong(long i, long j) {
454         return i > j;
455     }
456 
457     @CodeReflection
458     static boolean compareFloat(float i, float j) {
459         return i > j;
460     }
461 
462     @CodeReflection
463     static boolean compareDouble(double i, double j) {
464         return i > j;
465     }
466 
467     @CodeReflection
468     static int lookupSwitch(int i) {
469         return switch (1000 * i) {
470             case 1000 -> 1;
471             case 2000 -> 2;
472             case 3000 -> 3;
473             default -> 0;
474         };
475     }
476 
477     @CodeReflection
478     static int tableSwitch(int i) {
479         return switch (i) {
480             case 1 -> 1;
481             case 2 -> 2;
482             case 3 -> 3;
483             default -> 0;
484         };
485     }
486 
487     int instanceField = -1;
488 
489     @CodeReflection
490     int instanceFieldAccess(int i) {
491         int ret = instanceField;
492         instanceField = i;
493         return ret;
494     }
495 
496     @CodeReflection
497     static String stringConcat(String a, String b) {
498         return "a"+ a +"\u0001" + a + "b\u0002c" + b + "\u0001\u0002" + b + "dd";
499     }
500 
501     @CodeReflection
502     static String multiTypeConcat(int i, Boolean b, char c, Short s, float f, Double d) {
503         return "i:"+ i +" b:" + b + " c:" + c + " f:" + f + " d:" + d;
504     }
505 
506     @CodeReflection
507     static int ifTrue(int i) {
508         if (true) {
509             return i;
510         }
511         return -i;
512     }
513 
514     @CodeReflection
515     static int excHandlerFollowingSplitTable(boolean b) {
516         try {
517             if (b) return 1;
518             else throw new Exception();
519         } catch (Exception ex) {}
520         return 2;
521     }
522 
523     @CodeReflection
524     static int varModifiedInTryBlock(boolean b) {
525         int i = 0;
526         try {
527             i++;
528             if (b) throw new Exception();
529             i++;
530             throw new Exception();
531         } catch (Exception ex) {
532             return i;
533         }
534     }
535 
536     @CodeReflection
537     static boolean finallyWithLoop(boolean b) {
538         try {
539             while (b) {
540                 if (b)
541                     return false;
542                 b = !b;
543             }
544             return true;
545         } finally {
546             b = false;
547         }
548     }
549 
550     @CodeReflection
551     static long doubleUseOfOperand(int x) {
552         long piece = x;
553         return piece * piece;
554     }
555 
556     record TestData(Method testMethod) {
557         @Override
558         public String toString() {
559             String s = testMethod.getName() + Arrays.stream(testMethod.getParameterTypes())
560                     .map(Class::getSimpleName).collect(Collectors.joining(",", "(", ")"));
561             if (s.length() > 30) s = s.substring(0, 27) + "...";
562             return s;
563         }
564     }
565 
566     public static Stream<TestData> testMethods() {
567         return Stream.of(TestBytecode.class.getDeclaredMethods())
568                 .filter(m -> m.isAnnotationPresent(CodeReflection.class))
569                 .map(TestData::new);
570     }
571 
572     private static byte[] CLASS_DATA;
573     private static ClassModel CLASS_MODEL;
574 
575     @BeforeAll
576     public static void setup() throws Exception {
577         CLASS_DATA = TestBytecode.class.getResourceAsStream("TestBytecode.class").readAllBytes();
578         CLASS_MODEL = ClassFile.of().parse(CLASS_DATA);
579     }
580 
581     private static MethodTypeDesc toMethodTypeDesc(Method m) {
582         return MethodTypeDesc.of(
583                 m.getReturnType().describeConstable().orElseThrow(),
584                 Arrays.stream(m.getParameterTypes())
585                         .map(cls -> cls.describeConstable().orElseThrow()).toList());
586     }
587 
588 
589     private static final Map<Class<?>, Object[]> TEST_ARGS = new IdentityHashMap<>();
590     private static Object[] values(Object... values) {
591         return values;
592     }
593     private static void initTestArgs(Object[] values, Class<?>... argTypes) {
594         for (var argType : argTypes) TEST_ARGS.put(argType, values);
595     }
596     static {
597         initTestArgs(values(1, 2, 4), int.class, Integer.class);
598         initTestArgs(values((byte)1, (byte)3, (byte)4), byte.class, Byte.class);
599         initTestArgs(values((short)1, (short)2, (short)3), short.class, Short.class);
600         initTestArgs(values((char)2, (char)3, (char)4), char.class, Character.class);
601         initTestArgs(values(false, true), boolean.class, Boolean.class);
602         initTestArgs(values("Hello World"), String.class);
603         initTestArgs(values(1l, 2l, 4l), long.class, Long.class);
604         initTestArgs(values(1f, 3f, 4f), float.class, Float.class);
605         initTestArgs(values(1d, 2d, 3d), double.class, Double.class);
606     }
607 
608     interface Executor {
609         void execute(Object[] args) throws Throwable;
610     }
611 
612     private static void permutateAllArgs(Class<?>[] argTypes, Executor executor) throws Throwable {
613         final int argn = argTypes.length;
614         Object[][] argValues = new Object[argn][];
615         for (int i = 0; i < argn; i++) {
616             argValues[i] = TEST_ARGS.get(argTypes[i]);
617         }
618         int[] argIndexes = new int[argn];
619         Object[] args = new Object[argn];
620         while (true) {
621             for (int i = 0; i < argn; i++) {
622                 args[i] = argValues[i][argIndexes[i]];
623             }
624             executor.execute(args);
625             int i = argn - 1;
626             while (i >= 0 && argIndexes[i] == argValues[i].length - 1) i--;
627             if (i < 0) return;
628             argIndexes[i++]++;
629             while (i < argn) argIndexes[i++] = 0;
630         }
631     }
632 
633     @ParameterizedTest
634     @MethodSource("testMethods")
635     public void testGenerate(TestData d) throws Throwable {
636         CoreOp.FuncOp func = Op.ofMethod(d.testMethod).get();
637 
638         CoreOp.FuncOp lfunc;
639         try {
640             lfunc = func.transform(CopyContext.create(), OpTransformer.LOWERING_TRANSFORMER);
641         } catch (UnsupportedOperationException uoe) {
642             throw new TestSkippedException("lowering caused:", uoe);
643         }
644 
645         try {
646             MethodHandle mh = BytecodeGenerator.generate(MethodHandles.lookup(), lfunc);
647             Object receiver1, receiver2;
648             if (d.testMethod.accessFlags().contains(AccessFlag.STATIC)) {
649                 receiver1 = null;
650                 receiver2 = null;
651             } else {
652                 receiver1 = new TestBytecode();
653                 receiver2 = new TestBytecode();
654             }
655             permutateAllArgs(d.testMethod.getParameterTypes(), args -> {
656                 List argl = new ArrayList(args.length + 1);
657                 if (receiver1 != null) argl.add(receiver1);
658                 argl.addAll(Arrays.asList(args));
659                 assertEquals(d.testMethod.invoke(receiver2, args), mh.invokeWithArguments(argl));
660             });
661         } catch (Throwable e) {
662             System.out.println(func.toText());
663             System.out.println(lfunc.toText());
664             String methodName = d.testMethod().getName();
665             for (var mm : CLASS_MODEL.methods()) {
666                 if (mm.methodName().equalsString(methodName)
667                         || mm.methodName().stringValue().startsWith("lambda$" + methodName + "$")) {
668                     ClassPrinter.toYaml(mm,
669                                         ClassPrinter.Verbosity.CRITICAL_ATTRIBUTES,
670                                         System.out::print);
671                 }
672             }
673             Files.list(Path.of("DUMP_CLASS_FILES")).forEach(p -> {
674                 if (p.getFileName().toString().matches(methodName + "\\..+\\.class")) try {
675                     ClassPrinter.toYaml(ClassFile.of().parse(p),
676                                         ClassPrinter.Verbosity.CRITICAL_ATTRIBUTES,
677                                         System.out::print);
678                 } catch (IOException ignore) {}
679             });
680             throw e;
681         }
682     }
683 
684     private static void assertEquals(Object expected, Object actual) {
685         switch (expected) {
686             case int[] expArr when actual instanceof int[] actArr -> Assertions.assertArrayEquals(expArr, actArr);
687             case Object[] expArr when actual instanceof Object[] actArr -> Assertions.assertArrayEquals(expArr, actArr);
688             case null, default -> Assertions.assertEquals(expected, actual);
689         }
690     }
691 }