1 /* 2 * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved. 3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. 4 * 5 * This code is free software; you can redistribute it and/or modify it 6 * under the terms of the GNU General Public License version 2 only, as 7 * published by the Free Software Foundation. 8 * 9 * This code is distributed in the hope that it will be useful, but WITHOUT 10 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or 11 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License 12 * version 2 for more details (a copy is included in the LICENSE file that 13 * accompanied this code). 14 * 15 * You should have received a copy of the GNU General Public License version 16 * 2 along with this work; if not, write to the Free Software Foundation, 17 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. 18 * 19 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA 20 * or visit www.oracle.com if you need additional information or have any 21 * questions. 22 */ 23 24 import java.io.IOException; 25 import java.lang.annotation.Retention; 26 import java.lang.annotation.RetentionPolicy; 27 import java.lang.classfile.ClassFile; 28 import java.lang.classfile.ClassModel; 29 import java.lang.classfile.components.ClassPrinter; 30 import java.lang.constant.MethodTypeDesc; 31 import java.lang.invoke.MethodHandle; 32 import java.lang.invoke.MethodHandles; 33 import org.testng.Assert; 34 import org.testng.SkipException; 35 import org.testng.annotations.BeforeClass; 36 import org.testng.annotations.DataProvider; 37 import org.testng.annotations.Test; 38 39 import java.lang.reflect.code.*; 40 import java.lang.reflect.code.op.CoreOp; 41 import java.lang.reflect.code.bytecode.BytecodeLift; 42 import java.lang.reflect.code.interpreter.Interpreter; 43 import java.lang.reflect.Method; 44 import java.lang.reflect.code.bytecode.BytecodeGenerator; 45 import java.lang.reflect.code.type.JavaType; 46 import java.lang.runtime.CodeReflection; 47 import java.nio.file.Files; 48 import java.nio.file.Path; 49 import java.util.Arrays; 50 import java.util.IdentityHashMap; 51 import java.util.Map; 52 import java.util.function.Function; 53 import java.util.stream.Collectors; 54 import java.util.stream.Stream; 55 56 /* 57 * @test 58 * @enablePreview 59 * @run testng/othervm -Djdk.invoke.MethodHandle.dumpClassFiles=true TestBytecode 60 */ 61 62 public class TestBytecode { 63 64 @CodeReflection 65 static int intNumOps(int i, int j, int k) { 66 k++; 67 i = (i + j) / k - i % j; 68 i--; 69 return i; 70 } 71 72 @CodeReflection 73 @SkipLift 74 static byte byteNumOps(byte i, byte j, byte k) { 75 k++; 76 i = (byte) ((i + j) / k - i % j); 77 i--; 78 return i; 79 } 80 81 @CodeReflection 82 @SkipLift 83 static short shortNumOps(short i, short j, short k) { 84 k++; 85 i = (short) ((i + j) / k - i % j); 86 i--; 87 return i; 88 } 89 90 @CodeReflection 91 @SkipLift 92 static char charNumOps(char i, char j, char k) { 93 k++; 94 i = (char) ((i + j) / k - i % j); 95 i--; 96 return i; 97 } 98 99 @CodeReflection 100 static long longNumOps(long i, long j, long k) { 101 k++; 102 i = (i + j) / k - i % j; 103 i--; 104 return i; 105 } 106 107 @CodeReflection 108 static float floatNumOps(float i, float j, float k) { 109 k++; 110 i = (i + j) / k - i % j; 111 i--; 112 return i; 113 } 114 115 @CodeReflection 116 static double doubleNumOps(double i, double j, double k) { 117 k++; 118 i = (i + j) / k - i % j; 119 i--; 120 return i; 121 } 122 123 @CodeReflection 124 static int intBitOps(int i, int j, int k) { 125 return i & j | k ^ j; 126 } 127 128 @CodeReflection 129 @SkipLift 130 static byte byteBitOps(byte i, byte j, byte k) { 131 return (byte) (i & j | k ^ j); 132 } 133 134 @CodeReflection 135 @SkipLift 136 static short shortBitOps(short i, short j, short k) { 137 return (short) (i & j | k ^ j); 138 } 139 140 @CodeReflection 141 @SkipLift 142 static char charBitOps(char i, char j, char k) { 143 return (char) (i & j | k ^ j); 144 } 145 146 @CodeReflection 147 static long longBitOps(long i, long j, long k) { 148 return i & j | k ^ j; 149 } 150 151 @CodeReflection 152 static boolean boolBitOps(boolean i, boolean j, boolean k) { 153 return i & j | k ^ j; 154 } 155 156 @CodeReflection 157 @SkipLift 158 static int intShiftOps(int i, int j, int k) { 159 return ((-1 >> i) << (j << k)) >>> (k - j); 160 } 161 162 @CodeReflection 163 @SkipLift 164 static byte byteShiftOps(byte i, byte j, byte k) { 165 return (byte) (((-1 >> i) << (j << k)) >>> (k - j)); 166 } 167 168 @CodeReflection 169 @SkipLift 170 static short shortShiftOps(short i, short j, short k) { 171 return (short) (((-1 >> i) << (j << k)) >>> (k - j)); 172 } 173 174 @CodeReflection 175 @SkipLift 176 static char charShiftOps(char i, char j, char k) { 177 return (char) (((-1 >> i) << (j << k)) >>> (k - j)); 178 } 179 180 @CodeReflection 181 @SkipLift 182 static long longShiftOps(long i, long j, long k) { 183 return ((-1 >> i) << (j << k)) >>> (k - j); 184 } 185 186 @CodeReflection 187 @SkipLift 188 static Object[] boxingAndUnboxing(int i, byte b, short s, char c, Integer ii, Byte bb, Short ss, Character cc) { 189 ii += i; ii += b; ii += s; ii += c; 190 i += ii; i += bb; i += ss; i += cc; 191 b += ii; b += bb; b += ss; b += cc; 192 s += ii; s += bb; s += ss; s += cc; 193 c += ii; c += bb; c += ss; c += cc; 194 return new Object[]{i, b, s, c}; 195 } 196 197 @CodeReflection 198 static String constructor(String s, int i, int j) { 199 return new String(s.getBytes(), i, j); 200 } 201 202 @CodeReflection 203 static Class<?> classArray(int i, int j) { 204 Class<?>[] ifaces = new Class[1 + i + j]; 205 ifaces[0] = Function.class; 206 return ifaces[0]; 207 } 208 209 @CodeReflection 210 static String[] stringArray(int i, int j) { 211 return new String[i]; 212 } 213 214 @CodeReflection 215 static String[][] stringArray2(int i, int j) { 216 return new String[i][]; 217 } 218 219 @CodeReflection 220 static String[][] stringArrayMulti(int i, int j) { 221 return new String[i][j]; 222 } 223 224 @CodeReflection 225 static int[][] initializedIntArray(int i, int j) { 226 return new int[][]{{i, j}, {i + j}}; 227 } 228 229 @CodeReflection 230 static int ifElseCompare(int i, int j) { 231 if (i < 3) { 232 i += 1; 233 } else { 234 j += 2; 235 } 236 return i + j; 237 } 238 239 @CodeReflection 240 static int ifElseEquality(int i, int j) { 241 if (j != 0) { 242 if (i != 0) { 243 i += 1; 244 } else { 245 i += 2; 246 } 247 } else { 248 if (j != 0) { 249 i += 3; 250 } else { 251 i += 4; 252 } 253 } 254 return i; 255 } 256 257 @CodeReflection 258 static int conditionalExpr(int i, int j) { 259 return ((i - 1 >= 0) ? i - 1 : j - 1); 260 } 261 262 @CodeReflection 263 static int nestedConditionalExpr(int i, int j) { 264 return (i < 2) ? (j < 3) ? i : j : i + j; 265 } 266 267 @CodeReflection 268 static int tryFinally(int i, int j) { 269 try { 270 i = i + j; 271 } finally { 272 i = i + j; 273 } 274 return i; 275 } 276 277 public record A(String s) {} 278 279 @CodeReflection 280 static A newWithArgs(int i, int j) { 281 return new A("hello world".substring(i, i + j)); 282 } 283 284 @CodeReflection 285 static int loop(int n, int j) { 286 int sum = 0; 287 for (int i = 0; i < n; i++) { 288 sum = sum + j; 289 } 290 return sum; 291 } 292 293 294 @CodeReflection 295 static int ifElseNested(int a, int b) { 296 int c = a + b; 297 int d = 10 - a + b; 298 if (b < 3) { 299 if (a < 3) { 300 a += 1; 301 } else { 302 b += 2; 303 } 304 c += 3; 305 } else { 306 if (a > 2) { 307 a += 4; 308 } else { 309 b += 5; 310 } 311 d += 6; 312 } 313 return a + b + c + d; 314 } 315 316 @CodeReflection 317 static int nestedLoop(int m, int n) { 318 int sum = 0; 319 for (int i = 0; i < m; i++) { 320 for (int j = 0; j < n; j++) { 321 sum = sum + i + j; 322 } 323 } 324 return sum; 325 } 326 327 @CodeReflection 328 static int methodCall(int a, int b) { 329 int i = Math.max(a, b); 330 return Math.negateExact(i); 331 } 332 333 @CodeReflection 334 static int[] primitiveArray(int i, int j) { 335 int[] ia = new int[i + 1]; 336 ia[0] = j; 337 return ia; 338 } 339 340 @CodeReflection 341 @SkipLift 342 static boolean not(boolean b) { 343 return !b; 344 } 345 346 @CodeReflection 347 static boolean notCompare(int i, int j) { 348 boolean b = i < j; 349 return !b; 350 } 351 352 @CodeReflection 353 static int mod(int i, int j) { 354 return i % (j + 1); 355 } 356 357 @CodeReflection 358 static int xor(int i, int j) { 359 return i ^ j; 360 } 361 362 @CodeReflection 363 static int whileLoop(int i, int n) { int 364 counter = 0; 365 while (i < n && counter < 3) { 366 counter++; 367 if (counter == 4) { 368 break; 369 } 370 i++; 371 } 372 return counter; 373 } 374 375 public interface Func { 376 int apply(int a); 377 } 378 379 public interface QuotableFunc extends Quotable { 380 int apply(int a); 381 } 382 383 static int consume(int i, Func f) { 384 return f.apply(i + 1); 385 } 386 387 static int consumeQuotable(int i, QuotableFunc f) { 388 Assert.assertNotNull(f.quoted()); 389 Assert.assertNotNull(f.quoted().op()); 390 Assert.assertTrue(f.quoted().op() instanceof CoreOp.LambdaOp); 391 return f.apply(i + 1); 392 } 393 394 @CodeReflection 395 @SkipLift 396 static int lambda(int i) { 397 return consume(i, a -> -a); 398 } 399 400 @CodeReflection 401 @SkipLift 402 static int quotableLambda(int i) { 403 return consumeQuotable(i, a -> -a); 404 } 405 406 @CodeReflection 407 @SkipLift 408 static int lambdaWithCapture(int i, String s) { 409 return consume(i, a -> a + s.length()); 410 } 411 412 @CodeReflection 413 @SkipLift 414 static int quotableLambdaWithCapture(int i, String s) { 415 return consumeQuotable(i, a -> a + s.length()); 416 } 417 418 @CodeReflection 419 @SkipLift 420 static int nestedLambdasWithCaptures(int i, int j, String s) { 421 return consume(i, a -> consume(a, b -> a + b + j) + s.length()); 422 } 423 424 @CodeReflection 425 @SkipLift 426 static int nestedQuotableLambdasWithCaptures(int i, int j, String s) { 427 return consumeQuotable(i, a -> consumeQuotable(a, b -> a + b + j) + s.length()); 428 } 429 430 @CodeReflection 431 @SkipLift 432 static int methodHandle(int i) { 433 return consume(i, Math::negateExact); 434 } 435 436 @Retention(RetentionPolicy.RUNTIME) 437 @interface SkipLift {} 438 439 record TestData(Method testMethod) { 440 @Override 441 public String toString() { 442 String s = testMethod.getName() + Arrays.stream(testMethod.getParameterTypes()) 443 .map(Class::getSimpleName).collect(Collectors.joining(",", "(", ")")); 444 if (s.length() > 30) s = s.substring(0, 27) + "..."; 445 return s; 446 } 447 } 448 449 @DataProvider(name = "testMethods") 450 public static TestData[]testMethods() { 451 return Stream.of(TestBytecode.class.getDeclaredMethods()) 452 .filter(m -> m.isAnnotationPresent(CodeReflection.class)) 453 .map(TestData::new).toArray(TestData[]::new); 454 } 455 456 private static byte[] CLASS_DATA; 457 private static ClassModel CLASS_MODEL; 458 459 @BeforeClass 460 public static void setup() throws Exception { 461 CLASS_DATA = TestBytecode.class.getResourceAsStream("TestBytecode.class").readAllBytes(); 462 CLASS_MODEL = ClassFile.of().parse(CLASS_DATA); 463 } 464 465 private static MethodTypeDesc toMethodTypeDesc(Method m) { 466 return MethodTypeDesc.of( 467 m.getReturnType().describeConstable().orElseThrow(), 468 Arrays.stream(m.getParameterTypes()) 469 .map(cls -> cls.describeConstable().orElseThrow()).toList()); 470 } 471 472 473 private static final Map<Class<?>, Object[]> TEST_ARGS = new IdentityHashMap<>(); 474 private static Object[] values(Object... values) { 475 return values; 476 } 477 private static void initTestArgs(Object[] values, Class<?>... argTypes) { 478 for (var argType : argTypes) TEST_ARGS.put(argType, values); 479 } 480 static { 481 initTestArgs(values(1, 2, 3, 4), int.class, Integer.class); 482 initTestArgs(values((byte)1, (byte)2, (byte)3, (byte)4), byte.class, Byte.class); 483 initTestArgs(values((short)1, (short)2, (short)3, (short)4), short.class, Short.class); 484 initTestArgs(values((char)1, (char)2, (char)3, (char)4), char.class, Character.class); 485 initTestArgs(values(false, true), boolean.class, Boolean.class); 486 initTestArgs(values("Hello World"), String.class); 487 initTestArgs(values(1l, 2l, 3l, 4l), long.class, Long.class); 488 initTestArgs(values(1f, 2f, 3f, 4f), float.class, Float.class); 489 initTestArgs(values(1d, 2d, 3d, 4d), double.class, Double.class); 490 } 491 492 interface Executor { 493 void execute(Object[] args) throws Throwable; 494 } 495 496 private static void permutateAllArgs(Class<?>[] argTypes, Executor executor) throws Throwable { 497 final int argn = argTypes.length; 498 Object[][] argValues = new Object[argn][]; 499 for (int i = 0; i < argn; i++) { 500 argValues[i] = TEST_ARGS.get(argTypes[i]); 501 } 502 int[] argIndexes = new int[argn]; 503 Object[] args = new Object[argn]; 504 while (true) { 505 for (int i = 0; i < argn; i++) { 506 args[i] = argValues[i][argIndexes[i]]; 507 } 508 executor.execute(args); 509 int i = argn - 1; 510 while (i >= 0 && argIndexes[i] == argValues[i].length - 1) i--; 511 if (i < 0) return; 512 argIndexes[i++]++; 513 while (i < argn) argIndexes[i++] = 0; 514 } 515 } 516 517 @Test(dataProvider = "testMethods") 518 public void testLift(TestData d) throws Throwable { 519 if (d.testMethod.getAnnotation(SkipLift.class) != null) { 520 throw new SkipException("skipped"); 521 } 522 CoreOp.FuncOp flift; 523 try { 524 flift = BytecodeLift.lift(CLASS_DATA, d.testMethod.getName(), toMethodTypeDesc(d.testMethod)); 525 } catch (Throwable e) { 526 System.out.println("Lift failed, expected:"); 527 d.testMethod.getCodeModel().ifPresent(f -> f.writeTo(System.out)); 528 throw e; 529 } 530 try { 531 permutateAllArgs(d.testMethod.getParameterTypes(), args -> 532 Assert.assertEquals(invokeAndConvert(flift, args), d.testMethod.invoke(null, args))); 533 } catch (Throwable e) { 534 flift.writeTo(System.out); 535 throw e; 536 } 537 } 538 539 private static Object invokeAndConvert(CoreOp.FuncOp func, Object[] args) { 540 Object ret = Interpreter.invoke(func, args); 541 if (ret instanceof Integer i) { 542 TypeElement rt = func.invokableType().returnType(); 543 if (rt.equals(JavaType.BOOLEAN)) { 544 return i != 0; 545 } else if (rt.equals(JavaType.BYTE)) { 546 return i.byteValue(); 547 } else if (rt.equals(JavaType.CHAR)) { 548 return (short)i.intValue(); 549 } else if (rt.equals(JavaType.SHORT)) { 550 return i.shortValue(); 551 } 552 } 553 return ret; 554 } 555 556 @Test(dataProvider = "testMethods") 557 public void testGenerate(TestData d) throws Throwable { 558 CoreOp.FuncOp func = d.testMethod.getCodeModel().get(); 559 560 CoreOp.FuncOp lfunc = func.transform(CopyContext.create(), OpTransformer.LOWERING_TRANSFORMER); 561 562 try { 563 MethodHandle mh = BytecodeGenerator.generate(MethodHandles.lookup(), lfunc); 564 permutateAllArgs(d.testMethod.getParameterTypes(), args -> 565 Assert.assertEquals(mh.invokeWithArguments(args), d.testMethod.invoke(null, args))); 566 } catch (Throwable e) { 567 func.writeTo(System.out); 568 lfunc.writeTo(System.out); 569 String methodName = d.testMethod().getName(); 570 for (var mm : CLASS_MODEL.methods()) { 571 if (mm.methodName().equalsString(methodName) 572 || mm.methodName().stringValue().startsWith("lambda$" + methodName + "$")) { 573 ClassPrinter.toYaml(mm, 574 ClassPrinter.Verbosity.CRITICAL_ATTRIBUTES, 575 System.out::print); 576 } 577 } 578 Files.list(Path.of("DUMP_CLASS_FILES")).forEach(p -> { 579 if (p.getFileName().toString().matches(methodName + "\\..+\\.class")) try { 580 ClassPrinter.toYaml(ClassFile.of().parse(p), 581 ClassPrinter.Verbosity.CRITICAL_ATTRIBUTES, 582 System.out::print); 583 } catch (IOException ignore) {} 584 }); 585 throw e; 586 } 587 } 588 }