1 /*
  2  * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.
  8  *
  9  * This code is distributed in the hope that it will be useful, but WITHOUT
 10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 12  * version 2 for more details (a copy is included in the LICENSE file that
 13  * accompanied this code).
 14  *
 15  * You should have received a copy of the GNU General Public License version
 16  * 2 along with this work; if not, write to the Free Software Foundation,
 17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 18  *
 19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 20  * or visit www.oracle.com if you need additional information or have any
 21  * questions.
 22  */
 23 
 24 import java.io.IOException;
 25 import java.lang.annotation.Retention;
 26 import java.lang.annotation.RetentionPolicy;
 27 import java.lang.classfile.ClassFile;
 28 import java.lang.classfile.ClassModel;
 29 import java.lang.classfile.components.ClassPrinter;
 30 import java.lang.constant.MethodTypeDesc;
 31 import java.lang.invoke.MethodHandle;
 32 import java.lang.invoke.MethodHandles;
 33 import org.testng.Assert;
 34 import org.testng.SkipException;
 35 import org.testng.annotations.BeforeClass;
 36 import org.testng.annotations.DataProvider;
 37 import org.testng.annotations.Test;
 38 
 39 import java.lang.reflect.code.*;
 40 import java.lang.reflect.code.op.CoreOp;
 41 import java.lang.reflect.code.bytecode.BytecodeLift;
 42 import java.lang.reflect.code.interpreter.Interpreter;
 43 import java.lang.reflect.Method;
 44 import java.lang.reflect.code.bytecode.BytecodeGenerator;
 45 import java.lang.reflect.code.type.JavaType;
 46 import java.lang.runtime.CodeReflection;
 47 import java.nio.file.Files;
 48 import java.nio.file.Path;
 49 import java.util.Arrays;
 50 import java.util.IdentityHashMap;
 51 import java.util.Map;
 52 import java.util.function.Function;
 53 import java.util.stream.Collectors;
 54 import java.util.stream.Stream;
 55 
 56 /*
 57  * @test
 58  * @enablePreview
 59  * @run testng/othervm -Djdk.invoke.MethodHandle.dumpClassFiles=true TestBytecode
 60  */
 61 
 62 public class TestBytecode {
 63 
 64     @CodeReflection
 65     static int intNumOps(int i, int j, int k) {
 66         k++;
 67         i = (i + j) / k - i % j;
 68         i--;
 69         return i;
 70     }
 71 
 72     @CodeReflection
 73     @SkipLift
 74     static byte byteNumOps(byte i, byte j, byte k) {
 75         k++;
 76         i = (byte) ((i + j) / k - i % j);
 77         i--;
 78         return i;
 79     }
 80 
 81     @CodeReflection
 82     @SkipLift
 83     static short shortNumOps(short i, short j, short k) {
 84         k++;
 85         i = (short) ((i + j) / k - i % j);
 86         i--;
 87         return i;
 88     }
 89 
 90     @CodeReflection
 91     @SkipLift
 92     static char charNumOps(char i, char j, char k) {
 93         k++;
 94         i = (char) ((i + j) / k - i % j);
 95         i--;
 96         return i;
 97     }
 98 
 99     @CodeReflection
100     static long longNumOps(long i, long j, long k) {
101         k++;
102         i = (i + j) / k - i % j;
103         i--;
104         return i;
105     }
106 
107     @CodeReflection
108     static float floatNumOps(float i, float j, float k) {
109         k++;
110         i = (i + j) / k - i % j;
111         i--;
112         return i;
113     }
114 
115     @CodeReflection
116     static double doubleNumOps(double i, double j, double k) {
117         k++;
118         i = (i + j) / k - i % j;
119         i--;
120         return i;
121     }
122 
123     @CodeReflection
124     static int intBitOps(int i, int j, int k) {
125         return i & j | k ^ j;
126     }
127 
128     @CodeReflection
129     @SkipLift
130     static byte byteBitOps(byte i, byte j, byte k) {
131         return (byte) (i & j | k ^ j);
132     }
133 
134     @CodeReflection
135     @SkipLift
136     static short shortBitOps(short i, short j, short k) {
137         return (short) (i & j | k ^ j);
138     }
139 
140     @CodeReflection
141     @SkipLift
142     static char charBitOps(char i, char j, char k) {
143         return (char) (i & j | k ^ j);
144     }
145 
146     @CodeReflection
147     static long longBitOps(long i, long j, long k) {
148         return i & j | k ^ j;
149     }
150 
151     @CodeReflection
152     static boolean boolBitOps(boolean i, boolean j, boolean k) {
153         return i & j | k ^ j;
154     }
155 
156     @CodeReflection
157     @SkipLift
158     static int intShiftOps(int i, int j, int k) {
159         return ((-1 >> i) << (j << k)) >>> (k - j);
160     }
161 
162     @CodeReflection
163     @SkipLift
164     static byte byteShiftOps(byte i, byte j, byte k) {
165         return (byte) (((-1 >> i) << (j << k)) >>> (k - j));
166     }
167 
168     @CodeReflection
169     @SkipLift
170     static short shortShiftOps(short i, short j, short k) {
171         return (short) (((-1 >> i) << (j << k)) >>> (k - j));
172     }
173 
174     @CodeReflection
175     @SkipLift
176     static char charShiftOps(char i, char j, char k) {
177         return (char) (((-1 >> i) << (j << k)) >>> (k - j));
178     }
179 
180     @CodeReflection
181     @SkipLift
182     static long longShiftOps(long i, long j, long k) {
183         return ((-1 >> i) << (j << k)) >>> (k - j);
184     }
185 
186     @CodeReflection
187     @SkipLift
188     static Object[] boxingAndUnboxing(int i, byte b, short s, char c, Integer ii, Byte bb, Short ss, Character cc) {
189         ii += i; ii += b; ii += s; ii += c;
190         i += ii; i += bb; i += ss; i += cc;
191         b += ii; b += bb; b += ss; b += cc;
192         s += ii; s += bb; s += ss; s += cc;
193         c += ii; c += bb; c += ss; c += cc;
194         return new Object[]{i, b, s, c};
195     }
196 
197     @CodeReflection
198     static String constructor(String s, int i, int j) {
199         return new String(s.getBytes(), i, j);
200     }
201 
202     @CodeReflection
203     static Class<?> classArray(int i, int j) {
204         Class<?>[] ifaces = new Class[1 + i + j];
205         ifaces[0] = Function.class;
206         return ifaces[0];
207     }
208 
209     @CodeReflection
210     static String[] stringArray(int i, int j) {
211         return new String[i];
212     }
213 
214     @CodeReflection
215     static String[][] stringArray2(int i, int j) {
216         return new String[i][];
217     }
218 
219     @CodeReflection
220     static String[][] stringArrayMulti(int i, int j) {
221         return new String[i][j];
222     }
223 
224     @CodeReflection
225     static int[][] initializedIntArray(int i, int j) {
226         return new int[][]{{i, j}, {i + j}};
227     }
228 
229     @CodeReflection
230     static int ifElseCompare(int i, int j) {
231         if (i < 3) {
232             i += 1;
233         } else {
234             j += 2;
235         }
236         return i + j;
237     }
238 
239     @CodeReflection
240     static int ifElseEquality(int i, int j) {
241         if (j != 0) {
242             if (i != 0) {
243                 i += 1;
244             } else {
245                 i += 2;
246             }
247         } else {
248             if (j != 0) {
249                 i += 3;
250             } else {
251                 i += 4;
252             }
253         }
254         return i;
255     }
256 
257     @CodeReflection
258     static int conditionalExpr(int i, int j) {
259         return ((i - 1 >= 0) ? i - 1 : j - 1);
260     }
261 
262     @CodeReflection
263     static int nestedConditionalExpr(int i, int j) {
264         return (i < 2) ? (j < 3) ? i : j : i + j;
265     }
266 
267     @CodeReflection
268     static int tryFinally(int i, int j) {
269         try {
270             i = i + j;
271         } finally {
272             i = i + j;
273         }
274         return i;
275     }
276 
277     public record A(String s) {}
278 
279     @CodeReflection
280     static A newWithArgs(int i, int j) {
281         return new A("hello world".substring(i, i + j));
282     }
283 
284     @CodeReflection
285     static int loop(int n, int j) {
286         int sum = 0;
287         for (int i = 0; i < n; i++) {
288             sum = sum + j;
289         }
290         return sum;
291     }
292 
293 
294     @CodeReflection
295     static int ifElseNested(int a, int b) {
296         int c = a + b;
297         int d = 10 - a + b;
298         if (b < 3) {
299             if (a < 3) {
300                 a += 1;
301             } else {
302                 b += 2;
303             }
304             c += 3;
305         } else {
306             if (a > 2) {
307                 a += 4;
308             } else {
309                 b += 5;
310             }
311             d += 6;
312         }
313         return a + b + c + d;
314     }
315 
316     @CodeReflection
317     static int nestedLoop(int m, int n) {
318         int sum = 0;
319         for (int i = 0; i < m; i++) {
320             for (int j = 0; j < n; j++) {
321                 sum = sum + i + j;
322             }
323         }
324         return sum;
325     }
326 
327     @CodeReflection
328     static int methodCall(int a, int b) {
329         int i = Math.max(a, b);
330         return Math.negateExact(i);
331     }
332 
333     @CodeReflection
334     static int[] primitiveArray(int i, int j) {
335         int[] ia = new int[i + 1];
336         ia[0] = j;
337         return ia;
338     }
339 
340     @CodeReflection
341     @SkipLift
342     static boolean not(boolean b) {
343         return !b;
344     }
345 
346     @CodeReflection
347     static boolean notCompare(int i, int j) {
348         boolean b = i < j;
349         return !b;
350     }
351 
352     @CodeReflection
353     static int mod(int i, int j) {
354         return i % (j + 1);
355     }
356 
357     @CodeReflection
358     static int xor(int i, int j) {
359         return i ^ j;
360     }
361 
362     @CodeReflection
363     static int whileLoop(int i, int n) { int
364         counter = 0;
365         while (i < n && counter < 3) {
366             counter++;
367             if (counter == 4) {
368                 break;
369             }
370             i++;
371         }
372         return counter;
373     }
374 
375     public interface Func {
376         int apply(int a);
377     }
378 
379     public interface QuotableFunc extends Quotable {
380         int apply(int a);
381     }
382 
383     static int consume(int i, Func f) {
384         return f.apply(i + 1);
385     }
386 
387     static int consumeQuotable(int i, QuotableFunc f) {
388         Assert.assertNotNull(f.quoted());
389         Assert.assertNotNull(f.quoted().op());
390         Assert.assertTrue(f.quoted().op() instanceof CoreOp.LambdaOp);
391         return f.apply(i + 1);
392     }
393 
394     @CodeReflection
395     @SkipLift
396     static int lambda(int i) {
397         return consume(i, a -> -a);
398     }
399 
400     @CodeReflection
401     @SkipLift
402     static int quotableLambda(int i) {
403         return consumeQuotable(i, a -> -a);
404     }
405 
406     @CodeReflection
407     @SkipLift
408     static int lambdaWithCapture(int i, String s) {
409         return consume(i, a -> a + s.length());
410     }
411 
412     @CodeReflection
413     @SkipLift
414     static int quotableLambdaWithCapture(int i, String s) {
415         return consumeQuotable(i, a -> a + s.length());
416     }
417 
418     @CodeReflection
419     @SkipLift
420     static int nestedLambdasWithCaptures(int i, int j, String s) {
421         return consume(i, a -> consume(a, b -> a + b + j) + s.length());
422     }
423 
424     @CodeReflection
425     @SkipLift
426     static int nestedQuotableLambdasWithCaptures(int i, int j, String s) {
427         return consumeQuotable(i, a -> consumeQuotable(a, b -> a + b + j) + s.length());
428     }
429 
430     @CodeReflection
431     @SkipLift
432     static int methodHandle(int i) {
433         return consume(i, Math::negateExact);
434     }
435 
436     @Retention(RetentionPolicy.RUNTIME)
437     @interface SkipLift {}
438 
439     record TestData(Method testMethod) {
440         @Override
441         public String toString() {
442             String s = testMethod.getName() + Arrays.stream(testMethod.getParameterTypes())
443                     .map(Class::getSimpleName).collect(Collectors.joining(",", "(", ")"));
444             if (s.length() > 30) s = s.substring(0, 27) + "...";
445             return s;
446         }
447     }
448 
449     @DataProvider(name = "testMethods")
450     public static TestData[]testMethods() {
451         return Stream.of(TestBytecode.class.getDeclaredMethods())
452                 .filter(m -> m.isAnnotationPresent(CodeReflection.class))
453                 .map(TestData::new).toArray(TestData[]::new);
454     }
455 
456     private static byte[] CLASS_DATA;
457     private static ClassModel CLASS_MODEL;
458 
459     @BeforeClass
460     public static void setup() throws Exception {
461         CLASS_DATA = TestBytecode.class.getResourceAsStream("TestBytecode.class").readAllBytes();
462         CLASS_MODEL = ClassFile.of().parse(CLASS_DATA);
463     }
464 
465     private static MethodTypeDesc toMethodTypeDesc(Method m) {
466         return MethodTypeDesc.of(
467                 m.getReturnType().describeConstable().orElseThrow(),
468                 Arrays.stream(m.getParameterTypes())
469                         .map(cls -> cls.describeConstable().orElseThrow()).toList());
470     }
471 
472 
473     private static final Map<Class<?>, Object[]> TEST_ARGS = new IdentityHashMap<>();
474     private static Object[] values(Object... values) {
475         return values;
476     }
477     private static void initTestArgs(Object[] values, Class<?>... argTypes) {
478         for (var argType : argTypes) TEST_ARGS.put(argType, values);
479     }
480     static {
481         initTestArgs(values(1, 2, 3, 4), int.class, Integer.class);
482         initTestArgs(values((byte)1, (byte)2, (byte)3, (byte)4), byte.class, Byte.class);
483         initTestArgs(values((short)1, (short)2, (short)3, (short)4), short.class, Short.class);
484         initTestArgs(values((char)1, (char)2, (char)3, (char)4), char.class, Character.class);
485         initTestArgs(values(false, true), boolean.class, Boolean.class);
486         initTestArgs(values("Hello World"), String.class);
487         initTestArgs(values(1l, 2l, 3l, 4l), long.class, Long.class);
488         initTestArgs(values(1f, 2f, 3f, 4f), float.class, Float.class);
489         initTestArgs(values(1d, 2d, 3d, 4d), double.class, Double.class);
490     }
491 
492     interface Executor {
493         void execute(Object[] args) throws Throwable;
494     }
495 
496     private static void permutateAllArgs(Class<?>[] argTypes, Executor executor) throws Throwable {
497         final int argn = argTypes.length;
498         Object[][] argValues = new Object[argn][];
499         for (int i = 0; i < argn; i++) {
500             argValues[i] = TEST_ARGS.get(argTypes[i]);
501         }
502         int[] argIndexes = new int[argn];
503         Object[] args = new Object[argn];
504         while (true) {
505             for (int i = 0; i < argn; i++) {
506                 args[i] = argValues[i][argIndexes[i]];
507             }
508             executor.execute(args);
509             int i = argn - 1;
510             while (i >= 0 && argIndexes[i] == argValues[i].length - 1) i--;
511             if (i < 0) return;
512             argIndexes[i++]++;
513             while (i < argn) argIndexes[i++] = 0;
514         }
515     }
516 
517     @Test(dataProvider = "testMethods")
518     public void testLift(TestData d) throws Throwable {
519         if (d.testMethod.getAnnotation(SkipLift.class) != null) {
520             throw new SkipException("skipped");
521         }
522         CoreOp.FuncOp flift;
523         try {
524             flift = BytecodeLift.lift(CLASS_DATA, d.testMethod.getName(), toMethodTypeDesc(d.testMethod));
525         } catch (Throwable e) {
526             System.out.println("Lift failed, expected:");
527             d.testMethod.getCodeModel().ifPresent(f -> f.writeTo(System.out));
528             throw e;
529         }
530         try {
531             permutateAllArgs(d.testMethod.getParameterTypes(), args ->
532                 Assert.assertEquals(invokeAndConvert(flift, args), d.testMethod.invoke(null, args)));
533         } catch (Throwable e) {
534             flift.writeTo(System.out);
535             throw e;
536         }
537     }
538 
539     private static Object invokeAndConvert(CoreOp.FuncOp func, Object[] args) {
540         Object ret = Interpreter.invoke(func, args);
541         if (ret instanceof Integer i) {
542             TypeElement rt = func.invokableType().returnType();
543             if (rt.equals(JavaType.BOOLEAN)) {
544                 return i != 0;
545             } else if (rt.equals(JavaType.BYTE)) {
546                 return i.byteValue();
547             } else if (rt.equals(JavaType.CHAR)) {
548                 return (short)i.intValue();
549             } else if (rt.equals(JavaType.SHORT)) {
550                 return i.shortValue();
551             }
552         }
553         return ret;
554     }
555 
556     @Test(dataProvider = "testMethods")
557     public void testGenerate(TestData d) throws Throwable {
558         CoreOp.FuncOp func = d.testMethod.getCodeModel().get();
559 
560         CoreOp.FuncOp lfunc = func.transform(CopyContext.create(), OpTransformer.LOWERING_TRANSFORMER);
561 
562         try {
563             MethodHandle mh = BytecodeGenerator.generate(MethodHandles.lookup(), lfunc);
564             permutateAllArgs(d.testMethod.getParameterTypes(), args ->
565                     Assert.assertEquals(mh.invokeWithArguments(args), d.testMethod.invoke(null, args)));
566         } catch (Throwable e) {
567             func.writeTo(System.out);
568             lfunc.writeTo(System.out);
569             String methodName = d.testMethod().getName();
570             for (var mm : CLASS_MODEL.methods()) {
571                 if (mm.methodName().equalsString(methodName)
572                         || mm.methodName().stringValue().startsWith("lambda$" + methodName + "$")) {
573                     ClassPrinter.toYaml(mm,
574                                         ClassPrinter.Verbosity.CRITICAL_ATTRIBUTES,
575                                         System.out::print);
576                 }
577             }
578             Files.list(Path.of("DUMP_CLASS_FILES")).forEach(p -> {
579                 if (p.getFileName().toString().matches(methodName + "\\..+\\.class")) try {
580                     ClassPrinter.toYaml(ClassFile.of().parse(p),
581                                         ClassPrinter.Verbosity.CRITICAL_ATTRIBUTES,
582                                         System.out::print);
583                 } catch (IOException ignore) {}
584             });
585             throw e;
586         }
587     }
588 }