1 /* 2 * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved. 3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. 4 * 5 * This code is free software; you can redistribute it and/or modify it 6 * under the terms of the GNU General Public License version 2 only, as 7 * published by the Free Software Foundation. 8 * 9 * This code is distributed in the hope that it will be useful, but WITHOUT 10 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or 11 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License 12 * version 2 for more details (a copy is included in the LICENSE file that 13 * accompanied this code). 14 * 15 * You should have received a copy of the GNU General Public License version 16 * 2 along with this work; if not, write to the Free Software Foundation, 17 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. 18 * 19 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA 20 * or visit www.oracle.com if you need additional information or have any 21 * questions. 22 */ 23 24 import java.io.IOException; 25 import java.lang.classfile.ClassFile; 26 import java.lang.classfile.ClassModel; 27 import java.lang.classfile.components.ClassPrinter; 28 import java.lang.constant.MethodTypeDesc; 29 import java.lang.invoke.MethodHandle; 30 import java.lang.invoke.MethodHandles; 31 import java.lang.reflect.AccessFlag; 32 import org.testng.Assert; 33 import org.testng.SkipException; 34 import org.testng.annotations.BeforeClass; 35 import org.testng.annotations.DataProvider; 36 import org.testng.annotations.Test; 37 38 import jdk.incubator.code.*; 39 import jdk.incubator.code.op.CoreOp; 40 import jdk.incubator.code.bytecode.BytecodeLift; 41 import jdk.incubator.code.interpreter.Interpreter; 42 import java.lang.reflect.Method; 43 import jdk.incubator.code.bytecode.BytecodeGenerator; 44 import jdk.incubator.code.type.JavaType; 45 import jdk.incubator.code.CodeReflection; 46 import java.nio.file.Files; 47 import java.nio.file.Path; 48 import java.util.*; 49 import java.util.function.Consumer; 50 import java.util.function.Function; 51 import java.util.stream.Collectors; 52 import java.util.stream.Stream; 53 54 /* 55 * @test 56 * @modules jdk.incubator.code 57 * @enablePreview 58 * @run testng/othervm -Djdk.invoke.MethodHandle.dumpClassFiles=true TestBytecode 59 */ 60 61 public class TestBytecode { 62 63 @CodeReflection 64 static int intNumOps(int i, int j, int k) { 65 k++; 66 i = (i + j) / k - i % j; 67 i--; 68 return i; 69 } 70 71 @CodeReflection 72 static byte byteNumOps(byte i, byte j, byte k) { 73 k++; 74 i = (byte) ((i + j) / k - i % j); 75 i--; 76 return i; 77 } 78 79 @CodeReflection 80 static short shortNumOps(short i, short j, short k) { 81 k++; 82 i = (short) ((i + j) / k - i % j); 83 i--; 84 return i; 85 } 86 87 @CodeReflection 88 static char charNumOps(char i, char j, char k) { 89 k++; 90 i = (char) ((i + j) / k - i % j); 91 i--; 92 return i; 93 } 94 95 @CodeReflection 96 static long longNumOps(long i, long j, long k) { 97 k++; 98 i = (i + j) / k - i % j; 99 i--; 100 return i; 101 } 102 103 @CodeReflection 104 static float floatNumOps(float i, float j, float k) { 105 k++; 106 i = (i + j) / k - i % j; 107 i--; 108 return i; 109 } 110 111 @CodeReflection 112 static double doubleNumOps(double i, double j, double k) { 113 k++; 114 i = (i + j) / k - i % j; 115 i--; 116 return i; 117 } 118 119 @CodeReflection 120 static int intBitOps(int i, int j, int k) { 121 return ~(i & j | k ^ j); 122 } 123 124 @CodeReflection 125 static byte byteBitOps(byte i, byte j, byte k) { 126 return (byte) ~(i & j | k ^ j); 127 } 128 129 @CodeReflection 130 static short shortBitOps(short i, short j, short k) { 131 return (short) ~(i & j | k ^ j); 132 } 133 134 @CodeReflection 135 static char charBitOps(char i, char j, char k) { 136 return (char) ~(i & j | k ^ j); 137 } 138 139 @CodeReflection 140 static long longBitOps(long i, long j, long k) { 141 return ~(i & j | k ^ j); 142 } 143 144 @CodeReflection 145 static boolean boolBitOps(boolean i, boolean j, boolean k) { 146 return i & j | k ^ j; 147 } 148 149 @CodeReflection 150 static int intShiftOps(int i, int j, int k) { 151 return ((-1 >> i) << (j << k)) >>> (k - j); 152 } 153 154 @CodeReflection 155 static byte byteShiftOps(byte i, byte j, byte k) { 156 return (byte) (((-1 >> i) << (j << k)) >>> (k - j)); 157 } 158 159 @CodeReflection 160 static short shortShiftOps(short i, short j, short k) { 161 return (short) (((-1 >> i) << (j << k)) >>> (k - j)); 162 } 163 164 @CodeReflection 165 static char charShiftOps(char i, char j, char k) { 166 return (char) (((-1 >> i) << (j << k)) >>> (k - j)); 167 } 168 169 @CodeReflection 170 static long longShiftOps(long i, long j, long k) { 171 return ((-1 >> i) << (j << k)) >>> (k - j); 172 } 173 174 @CodeReflection 175 static Object[] boxingAndUnboxing(int i, byte b, short s, char c, Integer ii, Byte bb, Short ss, Character cc) { 176 ii += i; ii += b; ii += s; ii += c; 177 i += ii; i += bb; i += ss; i += cc; 178 b += ii; b += bb; b += ss; b += cc; 179 s += ii; s += bb; s += ss; s += cc; 180 c += ii; c += bb; c += ss; c += cc; 181 return new Object[]{i, b, s, c}; 182 } 183 184 @CodeReflection 185 static String constructor(String s, int i, int j) { 186 return new String(s.getBytes(), i, j); 187 } 188 189 @CodeReflection 190 static Class<?> classArray(int i, int j) { 191 Class<?>[] ifaces = new Class[1 + i + j]; 192 ifaces[0] = Function.class; 193 return ifaces[0]; 194 } 195 196 @CodeReflection 197 static String[] stringArray(int i, int j) { 198 return new String[i]; 199 } 200 201 @CodeReflection 202 static String[][] stringArray2(int i, int j) { 203 return new String[i][]; 204 } 205 206 @CodeReflection 207 static String[][] stringArrayMulti(int i, int j) { 208 return new String[i][j]; 209 } 210 211 @CodeReflection 212 static int[][] initializedIntArray(int i, int j) { 213 return new int[][]{{i, j}, {i + j}}; 214 } 215 216 @CodeReflection 217 static int ifElseCompare(int i, int j) { 218 if (i < 3) { 219 i += 1; 220 } else { 221 j += 2; 222 } 223 return i + j; 224 } 225 226 @CodeReflection 227 static int ifElseEquality(int i, int j) { 228 if (j != 0) { 229 if (i != 0) { 230 i += 1; 231 } else { 232 i += 2; 233 } 234 } else { 235 if (j != 0) { 236 i += 3; 237 } else { 238 i += 4; 239 } 240 } 241 return i; 242 } 243 244 @CodeReflection 245 static int objectsCompare(Boolean b1, Boolean b2, Boolean b3) { 246 Object a = b1; 247 Object b = b2; 248 Object c = b3; 249 return a == b ? (a != c ? 1 : 2) : (b != c ? 3 : 4); 250 } 251 252 @CodeReflection 253 static int conditionalExpr(int i, int j) { 254 return ((i - 1 >= 0) ? i - 1 : j - 1); 255 } 256 257 @CodeReflection 258 static int nestedConditionalExpr(int i, int j) { 259 return (i < 2) ? (j < 3) ? i : j : i + j; 260 } 261 262 static final int[] MAP = {0, 1, 2, 3, 4}; 263 264 @CodeReflection 265 static int deepStackBranches(boolean a, boolean b) { 266 return MAP[a ? MAP[b ? 1 : 2] : MAP[b ? 3 : 4]]; 267 } 268 269 @CodeReflection 270 static int tryFinally(int i, int j) { 271 try { 272 i = i + j; 273 } finally { 274 i = i + j; 275 } 276 return i; 277 } 278 279 public record A(String s) {} 280 281 @CodeReflection 282 static A newWithArgs(int i, int j) { 283 return new A("hello world".substring(i, i + j)); 284 } 285 286 @CodeReflection 287 static int loop(int n, int j) { 288 int sum = 0; 289 for (int i = 0; i < n; i++) { 290 sum = sum + j; 291 } 292 return sum; 293 } 294 295 296 @CodeReflection 297 static int ifElseNested(int a, int b) { 298 int c = a + b; 299 int d = 10 - a + b; 300 if (b < 3) { 301 if (a < 3) { 302 a += 1; 303 } else { 304 b += 2; 305 } 306 c += 3; 307 } else { 308 if (a > 2) { 309 a += 4; 310 } else { 311 b += 5; 312 } 313 d += 6; 314 } 315 return a + b + c + d; 316 } 317 318 @CodeReflection 319 static int nestedLoop(int m, int n) { 320 int sum = 0; 321 for (int i = 0; i < m; i++) { 322 for (int j = 0; j < n; j++) { 323 sum = sum + i + j; 324 } 325 } 326 return sum; 327 } 328 329 @CodeReflection 330 static int methodCall(int a, int b) { 331 int i = Math.max(a, b); 332 return Math.negateExact(i); 333 } 334 335 @CodeReflection 336 static int[] primitiveArray(int i, int j) { 337 int[] ia = new int[i + 1]; 338 ia[0] = j; 339 return ia; 340 } 341 342 @CodeReflection 343 static boolean not(boolean b) { 344 return !b; 345 } 346 347 @CodeReflection 348 static boolean notCompare(int i, int j) { 349 boolean b = i < j; 350 return !b; 351 } 352 353 @CodeReflection 354 static int mod(int i, int j) { 355 return i % (j + 1); 356 } 357 358 @CodeReflection 359 static int xor(int i, int j) { 360 return i ^ j; 361 } 362 363 @CodeReflection 364 static int whileLoop(int i, int n) { int 365 counter = 0; 366 while (i < n && counter < 3) { 367 counter++; 368 if (counter == 4) { 369 break; 370 } 371 i++; 372 } 373 return counter; 374 } 375 376 public interface Func { 377 int apply(int a); 378 } 379 380 public interface QuotableFunc extends Quotable { 381 int apply(int a); 382 } 383 384 static int consume(int i, Func f) { 385 return f.apply(i + 1); 386 } 387 388 static int consumeQuotable(int i, QuotableFunc f) { 389 Assert.assertNotNull(f.quoted()); 390 Assert.assertNotNull(f.quoted().op()); 391 Assert.assertTrue(f.quoted().op() instanceof CoreOp.LambdaOp); 392 return f.apply(i + 1); 393 } 394 395 @CodeReflection 396 static int lambda(int i) { 397 return consume(i, a -> -a); 398 } 399 400 @CodeReflection 401 static int quotableLambda(int i) { 402 return consumeQuotable(i, a -> -a); 403 } 404 405 @CodeReflection 406 static int lambdaWithCapture(int i, String s) { 407 return consume(i, a -> a + s.length()); 408 } 409 410 @CodeReflection 411 static int quotableLambdaWithCapture(int i, String s) { 412 return consumeQuotable(i, a -> a + s.length()); 413 } 414 415 @CodeReflection 416 static int nestedLambdasWithCaptures(int i, int j, String s) { 417 return consume(i, a -> consume(a, b -> a + b + j - s.length()) + s.length()); 418 } 419 420 @CodeReflection 421 static int nestedQuotableLambdasWithCaptures(int i, int j, String s) { 422 return consumeQuotable(i, a -> consumeQuotable(a, b -> a + b + j - s.length()) + s.length()); 423 } 424 425 @CodeReflection 426 static int methodHandle(int i) { 427 return consume(i, Math::negateExact); 428 } 429 430 int instanceMethod(int i) { 431 return -i + 13; 432 } 433 434 @CodeReflection 435 int instanceMethodHandle(int i) { 436 return consume(i, this::instanceMethod); 437 } 438 439 static void consume(boolean b, Consumer<Object> requireNonNull) { 440 if (b) { 441 requireNonNull.accept(new Object()); 442 } else try { 443 requireNonNull.accept(null); 444 throw new AssertionError("Expectend NPE"); 445 } catch (NullPointerException expected) { 446 } 447 } 448 449 @CodeReflection 450 static void nullReturningMethodHandle(boolean b) { 451 consume(b, Objects::requireNonNull); 452 } 453 454 @CodeReflection 455 static boolean compareLong(long i, long j) { 456 return i > j; 457 } 458 459 @CodeReflection 460 static boolean compareFloat(float i, float j) { 461 return i > j; 462 } 463 464 @CodeReflection 465 static boolean compareDouble(double i, double j) { 466 return i > j; 467 } 468 469 @CodeReflection 470 static int lookupSwitch(int i) { 471 return switch (1000 * i) { 472 case 1000 -> 1; 473 case 2000 -> 2; 474 case 3000 -> 3; 475 default -> 0; 476 }; 477 } 478 479 @CodeReflection 480 static int tableSwitch(int i) { 481 return switch (i) { 482 case 1 -> 1; 483 case 2 -> 2; 484 case 3 -> 3; 485 default -> 0; 486 }; 487 } 488 489 int instanceField = -1; 490 491 @CodeReflection 492 int instanceFieldAccess(int i) { 493 int ret = instanceField; 494 instanceField = i; 495 return ret; 496 } 497 498 @CodeReflection 499 static String stringConcat(String a, String b) { 500 return "a"+ a +"\u0001" + a + "b\u0002c" + b + "\u0001\u0002" + b + "dd"; 501 } 502 503 @CodeReflection 504 static String multiTypeConcat(int i, Boolean b, char c, Short s, float f, Double d) { 505 return "i:"+ i +" b:" + b + " c:" + c + " f:" + f + " d:" + d; 506 } 507 508 @CodeReflection 509 static int ifTrue(int i) { 510 if (true) { 511 return i; 512 } 513 return -i; 514 } 515 516 @CodeReflection 517 static int excHandlerFollowingSplitTable(boolean b) { 518 try { 519 if (b) return 1; 520 else throw new Exception(); 521 } catch (Exception ex) {} 522 return 2; 523 } 524 525 @CodeReflection 526 static int varModifiedInTryBlock(boolean b) { 527 int i = 0; 528 try { 529 i++; 530 if (b) throw new Exception(); 531 i++; 532 throw new Exception(); 533 } catch (Exception ex) { 534 return i; 535 } 536 } 537 538 @CodeReflection 539 static boolean finallyWithLoop(boolean b) { 540 try { 541 while (b) { 542 if (b) 543 return false; 544 b = !b; 545 } 546 return true; 547 } finally { 548 b = false; 549 } 550 } 551 552 @CodeReflection 553 static long doubleUseOfOperand(int x) { 554 long piece = x; 555 return piece * piece; 556 } 557 558 record TestData(Method testMethod) { 559 @Override 560 public String toString() { 561 String s = testMethod.getName() + Arrays.stream(testMethod.getParameterTypes()) 562 .map(Class::getSimpleName).collect(Collectors.joining(",", "(", ")")); 563 if (s.length() > 30) s = s.substring(0, 27) + "..."; 564 return s; 565 } 566 } 567 568 @DataProvider(name = "testMethods") 569 public static TestData[]testMethods() { 570 return Stream.of(TestBytecode.class.getDeclaredMethods()) 571 .filter(m -> m.isAnnotationPresent(CodeReflection.class)) 572 .map(TestData::new).toArray(TestData[]::new); 573 } 574 575 private static byte[] CLASS_DATA; 576 private static ClassModel CLASS_MODEL; 577 578 @BeforeClass 579 public static void setup() throws Exception { 580 CLASS_DATA = TestBytecode.class.getResourceAsStream("TestBytecode.class").readAllBytes(); 581 CLASS_MODEL = ClassFile.of().parse(CLASS_DATA); 582 } 583 584 private static MethodTypeDesc toMethodTypeDesc(Method m) { 585 return MethodTypeDesc.of( 586 m.getReturnType().describeConstable().orElseThrow(), 587 Arrays.stream(m.getParameterTypes()) 588 .map(cls -> cls.describeConstable().orElseThrow()).toList()); 589 } 590 591 592 private static final Map<Class<?>, Object[]> TEST_ARGS = new IdentityHashMap<>(); 593 private static Object[] values(Object... values) { 594 return values; 595 } 596 private static void initTestArgs(Object[] values, Class<?>... argTypes) { 597 for (var argType : argTypes) TEST_ARGS.put(argType, values); 598 } 599 static { 600 initTestArgs(values(1, 2, 4), int.class, Integer.class); 601 initTestArgs(values((byte)1, (byte)3, (byte)4), byte.class, Byte.class); 602 initTestArgs(values((short)1, (short)2, (short)3), short.class, Short.class); 603 initTestArgs(values((char)2, (char)3, (char)4), char.class, Character.class); 604 initTestArgs(values(false, true), boolean.class, Boolean.class); 605 initTestArgs(values("Hello World"), String.class); 606 initTestArgs(values(1l, 2l, 4l), long.class, Long.class); 607 initTestArgs(values(1f, 3f, 4f), float.class, Float.class); 608 initTestArgs(values(1d, 2d, 3d), double.class, Double.class); 609 } 610 611 interface Executor { 612 void execute(Object[] args) throws Throwable; 613 } 614 615 private static void permutateAllArgs(Class<?>[] argTypes, Executor executor) throws Throwable { 616 final int argn = argTypes.length; 617 Object[][] argValues = new Object[argn][]; 618 for (int i = 0; i < argn; i++) { 619 argValues[i] = TEST_ARGS.get(argTypes[i]); 620 } 621 int[] argIndexes = new int[argn]; 622 Object[] args = new Object[argn]; 623 while (true) { 624 for (int i = 0; i < argn; i++) { 625 args[i] = argValues[i][argIndexes[i]]; 626 } 627 executor.execute(args); 628 int i = argn - 1; 629 while (i >= 0 && argIndexes[i] == argValues[i].length - 1) i--; 630 if (i < 0) return; 631 argIndexes[i++]++; 632 while (i < argn) argIndexes[i++] = 0; 633 } 634 } 635 636 @Test(dataProvider = "testMethods") 637 public void testLift(TestData d) throws Throwable { 638 CoreOp.FuncOp flift; 639 try { 640 flift = BytecodeLift.lift(CLASS_DATA, d.testMethod.getName(), toMethodTypeDesc(d.testMethod)); 641 } catch (Throwable e) { 642 ClassPrinter.toYaml(ClassFile.of().parse(TestBytecode.class.getResourceAsStream("TestBytecode.class").readAllBytes()) 643 .methods().stream().filter(m -> m.methodName().equalsString(d.testMethod().getName())).findAny().get(), 644 ClassPrinter.Verbosity.CRITICAL_ATTRIBUTES, System.out::print); 645 System.out.println("Lift failed, compiled model:"); 646 Op.ofMethod(d.testMethod).ifPresent(f -> f.writeTo(System.out)); 647 throw e; 648 } 649 try { 650 Object receiver1, receiver2; 651 if (d.testMethod.accessFlags().contains(AccessFlag.STATIC)) { 652 receiver1 = null; 653 receiver2 = null; 654 } else { 655 receiver1 = new TestBytecode(); 656 receiver2 = new TestBytecode(); 657 } 658 permutateAllArgs(d.testMethod.getParameterTypes(), args -> 659 Assert.assertEquals(invokeAndConvert(flift, receiver1, args), d.testMethod.invoke(receiver2, args))); 660 } catch (Throwable e) { 661 System.out.println("Compiled model:"); 662 Op.ofMethod(d.testMethod).ifPresent(f -> f.writeTo(System.out)); 663 System.out.println("Lifted model:"); 664 flift.writeTo(System.out); 665 throw e; 666 } 667 } 668 669 private static Object invokeAndConvert(CoreOp.FuncOp func, Object receiver, Object... args) { 670 List argl = new ArrayList(args.length + 1); 671 if (receiver != null) argl.add(receiver); 672 argl.addAll(Arrays.asList(args)); 673 Object ret = Interpreter.invoke(MethodHandles.lookup(), func, argl); 674 if (ret instanceof Integer i) { 675 TypeElement rt = func.invokableType().returnType(); 676 if (rt.equals(JavaType.BOOLEAN)) { 677 return i != 0; 678 } else if (rt.equals(JavaType.BYTE)) { 679 return i.byteValue(); 680 } else if (rt.equals(JavaType.CHAR)) { 681 return (short)i.intValue(); 682 } else if (rt.equals(JavaType.SHORT)) { 683 return i.shortValue(); 684 } 685 } 686 return ret; 687 } 688 689 @Test(dataProvider = "testMethods") 690 public void testGenerate(TestData d) throws Throwable { 691 CoreOp.FuncOp func = Op.ofMethod(d.testMethod).get(); 692 693 CoreOp.FuncOp lfunc; 694 try { 695 lfunc = func.transform(CopyContext.create(), OpTransformer.LOWERING_TRANSFORMER); 696 } catch (UnsupportedOperationException uoe) { 697 throw new SkipException("lowering caused:", uoe); 698 } 699 700 try { 701 MethodHandle mh = BytecodeGenerator.generate(MethodHandles.lookup(), lfunc); 702 Object receiver1, receiver2; 703 if (d.testMethod.accessFlags().contains(AccessFlag.STATIC)) { 704 receiver1 = null; 705 receiver2 = null; 706 } else { 707 receiver1 = new TestBytecode(); 708 receiver2 = new TestBytecode(); 709 } 710 permutateAllArgs(d.testMethod.getParameterTypes(), args -> { 711 List argl = new ArrayList(args.length + 1); 712 if (receiver1 != null) argl.add(receiver1); 713 argl.addAll(Arrays.asList(args)); 714 Assert.assertEquals(mh.invokeWithArguments(argl), d.testMethod.invoke(receiver2, args)); 715 }); 716 } catch (Throwable e) { 717 func.writeTo(System.out); 718 lfunc.writeTo(System.out); 719 String methodName = d.testMethod().getName(); 720 for (var mm : CLASS_MODEL.methods()) { 721 if (mm.methodName().equalsString(methodName) 722 || mm.methodName().stringValue().startsWith("lambda$" + methodName + "$")) { 723 ClassPrinter.toYaml(mm, 724 ClassPrinter.Verbosity.CRITICAL_ATTRIBUTES, 725 System.out::print); 726 } 727 } 728 Files.list(Path.of("DUMP_CLASS_FILES")).forEach(p -> { 729 if (p.getFileName().toString().matches(methodName + "\\..+\\.class")) try { 730 ClassPrinter.toYaml(ClassFile.of().parse(p), 731 ClassPrinter.Verbosity.CRITICAL_ATTRIBUTES, 732 System.out::print); 733 } catch (IOException ignore) {} 734 }); 735 throw e; 736 } 737 } 738 }