1 /* 2 * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved. 3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. 4 * 5 * This code is free software; you can redistribute it and/or modify it 6 * under the terms of the GNU General Public License version 2 only, as 7 * published by the Free Software Foundation. 8 * 9 * This code is distributed in the hope that it will be useful, but WITHOUT 10 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or 11 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License 12 * version 2 for more details (a copy is included in the LICENSE file that 13 * accompanied this code). 14 * 15 * You should have received a copy of the GNU General Public License version 16 * 2 along with this work; if not, write to the Free Software Foundation, 17 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. 18 * 19 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA 20 * or visit www.oracle.com if you need additional information or have any 21 * questions. 22 */ 23 24 import java.io.IOException; 25 import java.lang.classfile.ClassFile; 26 import java.lang.classfile.ClassModel; 27 28 import jdk.incubator.code.dialect.java.JavaOp; 29 import jdk.internal.classfile.components.ClassPrinter; 30 import java.lang.constant.MethodTypeDesc; 31 import java.lang.invoke.MethodHandle; 32 import java.lang.invoke.MethodHandles; 33 import java.lang.reflect.AccessFlag; 34 import org.testng.Assert; 35 import org.testng.SkipException; 36 import org.testng.annotations.BeforeClass; 37 import org.testng.annotations.DataProvider; 38 import org.testng.annotations.Test; 39 40 import jdk.incubator.code.*; 41 import jdk.incubator.code.dialect.core.CoreOp; 42 import jdk.incubator.code.bytecode.BytecodeLift; 43 import jdk.incubator.code.interpreter.Interpreter; 44 import java.lang.reflect.Method; 45 import jdk.incubator.code.bytecode.BytecodeGenerator; 46 import jdk.incubator.code.dialect.java.JavaType; 47 import jdk.incubator.code.CodeReflection; 48 import java.nio.file.Files; 49 import java.nio.file.Path; 50 import java.util.*; 51 import java.util.function.Consumer; 52 import java.util.function.Function; 53 import java.util.stream.Collectors; 54 import java.util.stream.Stream; 55 56 /* 57 * @test 58 * @modules jdk.incubator.code 59 * @modules java.base/jdk.internal.classfile.components 60 * @enablePreview 61 * @run testng/othervm -Djdk.invoke.MethodHandle.dumpClassFiles=true TestBytecode 62 */ 63 64 public class TestBytecode { 65 66 @CodeReflection 67 static int intNumOps(int i, int j, int k) { 68 k++; 69 i = (i + j) / k - i % j; 70 i--; 71 return i; 72 } 73 74 @CodeReflection 75 static byte byteNumOps(byte i, byte j, byte k) { 76 k++; 77 i = (byte) ((i + j) / k - i % j); 78 i--; 79 return i; 80 } 81 82 @CodeReflection 83 static short shortNumOps(short i, short j, short k) { 84 k++; 85 i = (short) ((i + j) / k - i % j); 86 i--; 87 return i; 88 } 89 90 @CodeReflection 91 static char charNumOps(char i, char j, char k) { 92 k++; 93 i = (char) ((i + j) / k - i % j); 94 i--; 95 return i; 96 } 97 98 @CodeReflection 99 static long longNumOps(long i, long j, long k) { 100 k++; 101 i = (i + j) / k - i % j; 102 i--; 103 return i; 104 } 105 106 @CodeReflection 107 static float floatNumOps(float i, float j, float k) { 108 k++; 109 i = (i + j) / k - i % j; 110 i--; 111 return i; 112 } 113 114 @CodeReflection 115 static double doubleNumOps(double i, double j, double k) { 116 k++; 117 i = (i + j) / k - i % j; 118 i--; 119 return i; 120 } 121 122 @CodeReflection 123 static int intBitOps(int i, int j, int k) { 124 return ~(i & j | k ^ j); 125 } 126 127 @CodeReflection 128 static byte byteBitOps(byte i, byte j, byte k) { 129 return (byte) ~(i & j | k ^ j); 130 } 131 132 @CodeReflection 133 static short shortBitOps(short i, short j, short k) { 134 return (short) ~(i & j | k ^ j); 135 } 136 137 @CodeReflection 138 static char charBitOps(char i, char j, char k) { 139 return (char) ~(i & j | k ^ j); 140 } 141 142 @CodeReflection 143 static long longBitOps(long i, long j, long k) { 144 return ~(i & j | k ^ j); 145 } 146 147 @CodeReflection 148 static boolean boolBitOps(boolean i, boolean j, boolean k) { 149 return i & j | k ^ j; 150 } 151 152 @CodeReflection 153 static int intShiftOps(int i, int j, int k) { 154 return ((-1 >> i) << (j << k)) >>> (k - j); 155 } 156 157 @CodeReflection 158 static byte byteShiftOps(byte i, byte j, byte k) { 159 return (byte) (((-1 >> i) << (j << k)) >>> (k - j)); 160 } 161 162 @CodeReflection 163 static short shortShiftOps(short i, short j, short k) { 164 return (short) (((-1 >> i) << (j << k)) >>> (k - j)); 165 } 166 167 @CodeReflection 168 static char charShiftOps(char i, char j, char k) { 169 return (char) (((-1 >> i) << (j << k)) >>> (k - j)); 170 } 171 172 @CodeReflection 173 static long longShiftOps(long i, long j, long k) { 174 return ((-1 >> i) << (j << k)) >>> (k - j); 175 } 176 177 @CodeReflection 178 static Object[] boxingAndUnboxing(int i, byte b, short s, char c, Integer ii, Byte bb, Short ss, Character cc) { 179 ii += i; ii += b; ii += s; ii += c; 180 i += ii; i += bb; i += ss; i += cc; 181 b += ii; b += bb; b += ss; b += cc; 182 s += ii; s += bb; s += ss; s += cc; 183 c += ii; c += bb; c += ss; c += cc; 184 return new Object[]{i, b, s, c}; 185 } 186 187 @CodeReflection 188 static String constructor(String s, int i, int j) { 189 return new String(s.getBytes(), i, j); 190 } 191 192 @CodeReflection 193 static Class<?> classArray(int i, int j) { 194 Class<?>[] ifaces = new Class[1 + i + j]; 195 ifaces[0] = Function.class; 196 return ifaces[0]; 197 } 198 199 @CodeReflection 200 static String[] stringArray(int i, int j) { 201 return new String[i]; 202 } 203 204 @CodeReflection 205 static String[][] stringArray2(int i, int j) { 206 return new String[i][]; 207 } 208 209 @CodeReflection 210 static String[][] stringArrayMulti(int i, int j) { 211 return new String[i][j]; 212 } 213 214 @CodeReflection 215 static int[][] initializedIntArray(int i, int j) { 216 return new int[][]{{i, j}, {i + j}}; 217 } 218 219 @CodeReflection 220 static int ifElseCompare(int i, int j) { 221 if (i < 3) { 222 i += 1; 223 } else { 224 j += 2; 225 } 226 return i + j; 227 } 228 229 @CodeReflection 230 static int ifElseEquality(int i, int j) { 231 if (j != 0) { 232 if (i != 0) { 233 i += 1; 234 } else { 235 i += 2; 236 } 237 } else { 238 if (j != 0) { 239 i += 3; 240 } else { 241 i += 4; 242 } 243 } 244 return i; 245 } 246 247 @CodeReflection 248 static int objectsCompare(Boolean b1, Boolean b2, Boolean b3) { 249 Object a = b1; 250 Object b = b2; 251 Object c = b3; 252 return a == b ? (a != c ? 1 : 2) : (b != c ? 3 : 4); 253 } 254 255 @CodeReflection 256 static int conditionalExpr(int i, int j) { 257 return ((i - 1 >= 0) ? i - 1 : j - 1); 258 } 259 260 @CodeReflection 261 static int nestedConditionalExpr(int i, int j) { 262 return (i < 2) ? (j < 3) ? i : j : i + j; 263 } 264 265 static final int[] MAP = {0, 1, 2, 3, 4}; 266 267 @CodeReflection 268 static int deepStackBranches(boolean a, boolean b) { 269 return MAP[a ? MAP[b ? 1 : 2] : MAP[b ? 3 : 4]]; 270 } 271 272 @CodeReflection 273 static int tryFinally(int i, int j) { 274 try { 275 i = i + j; 276 } finally { 277 i = i + j; 278 } 279 return i; 280 } 281 282 public record A(String s) {} 283 284 @CodeReflection 285 static A newWithArgs(int i, int j) { 286 return new A("hello world".substring(i, i + j)); 287 } 288 289 @CodeReflection 290 static int loop(int n, int j) { 291 int sum = 0; 292 for (int i = 0; i < n; i++) { 293 sum = sum + j; 294 } 295 return sum; 296 } 297 298 299 @CodeReflection 300 static int ifElseNested(int a, int b) { 301 int c = a + b; 302 int d = 10 - a + b; 303 if (b < 3) { 304 if (a < 3) { 305 a += 1; 306 } else { 307 b += 2; 308 } 309 c += 3; 310 } else { 311 if (a > 2) { 312 a += 4; 313 } else { 314 b += 5; 315 } 316 d += 6; 317 } 318 return a + b + c + d; 319 } 320 321 @CodeReflection 322 static int nestedLoop(int m, int n) { 323 int sum = 0; 324 for (int i = 0; i < m; i++) { 325 for (int j = 0; j < n; j++) { 326 sum = sum + i + j; 327 } 328 } 329 return sum; 330 } 331 332 @CodeReflection 333 static int methodCall(int a, int b) { 334 int i = Math.max(a, b); 335 return Math.negateExact(i); 336 } 337 338 @CodeReflection 339 static int[] primitiveArray(int i, int j) { 340 int[] ia = new int[i + 1]; 341 ia[0] = j; 342 return ia; 343 } 344 345 @CodeReflection 346 static boolean not(boolean b) { 347 return !b; 348 } 349 350 @CodeReflection 351 static boolean notCompare(int i, int j) { 352 boolean b = i < j; 353 return !b; 354 } 355 356 @CodeReflection 357 static int mod(int i, int j) { 358 return i % (j + 1); 359 } 360 361 @CodeReflection 362 static int xor(int i, int j) { 363 return i ^ j; 364 } 365 366 @CodeReflection 367 static int whileLoop(int i, int n) { int 368 counter = 0; 369 while (i < n && counter < 3) { 370 counter++; 371 if (counter == 4) { 372 break; 373 } 374 i++; 375 } 376 return counter; 377 } 378 379 public interface Func { 380 int apply(int a); 381 } 382 383 public interface QuotableFunc extends Quotable { 384 int apply(int a); 385 } 386 387 static int consume(int i, Func f) { 388 return f.apply(i + 1); 389 } 390 391 static int consumeQuotable(int i, QuotableFunc f) { 392 Assert.assertNotNull(Op.ofQuotable(f).get()); 393 Assert.assertNotNull(Op.ofQuotable(f).get().op()); 394 Assert.assertTrue(Op.ofQuotable(f).get().op() instanceof JavaOp.LambdaOp); 395 return f.apply(i + 1); 396 } 397 398 @CodeReflection 399 static int lambda(int i) { 400 return consume(i, a -> -a); 401 } 402 403 @CodeReflection 404 static int quotableLambda(int i) { 405 return consumeQuotable(i, a -> -a); 406 } 407 408 @CodeReflection 409 static int lambdaWithCapture(int i, String s) { 410 return consume(i, a -> a + s.length()); 411 } 412 413 @CodeReflection 414 static int quotableLambdaWithCapture(int i, String s) { 415 return consumeQuotable(i, a -> a + s.length()); 416 } 417 418 @CodeReflection 419 static int nestedLambdasWithCaptures(int i, int j, String s) { 420 return consume(i, a -> consume(a, b -> a + b + j - s.length()) + s.length()); 421 } 422 423 @CodeReflection 424 static int nestedQuotableLambdasWithCaptures(int i, int j, String s) { 425 return consumeQuotable(i, a -> consumeQuotable(a, b -> a + b + j - s.length()) + s.length()); 426 } 427 428 @CodeReflection 429 static int methodHandle(int i) { 430 return consume(i, Math::negateExact); 431 } 432 433 int instanceMethod(int i) { 434 return -i + 13; 435 } 436 437 @CodeReflection 438 int instanceMethodHandle(int i) { 439 return consume(i, this::instanceMethod); 440 } 441 442 static void consume(boolean b, Consumer<Object> requireNonNull) { 443 if (b) { 444 requireNonNull.accept(new Object()); 445 } else try { 446 requireNonNull.accept(null); 447 throw new AssertionError("Expectend NPE"); 448 } catch (NullPointerException expected) { 449 } 450 } 451 452 @CodeReflection 453 static void nullReturningMethodHandle(boolean b) { 454 consume(b, Objects::requireNonNull); 455 } 456 457 @CodeReflection 458 static boolean compareLong(long i, long j) { 459 return i > j; 460 } 461 462 @CodeReflection 463 static boolean compareFloat(float i, float j) { 464 return i > j; 465 } 466 467 @CodeReflection 468 static boolean compareDouble(double i, double j) { 469 return i > j; 470 } 471 472 @CodeReflection 473 static int lookupSwitch(int i) { 474 return switch (1000 * i) { 475 case 1000 -> 1; 476 case 2000 -> 2; 477 case 3000 -> 3; 478 default -> 0; 479 }; 480 } 481 482 @CodeReflection 483 static int tableSwitch(int i) { 484 return switch (i) { 485 case 1 -> 1; 486 case 2 -> 2; 487 case 3 -> 3; 488 default -> 0; 489 }; 490 } 491 492 int instanceField = -1; 493 494 @CodeReflection 495 int instanceFieldAccess(int i) { 496 int ret = instanceField; 497 instanceField = i; 498 return ret; 499 } 500 501 @CodeReflection 502 static String stringConcat(String a, String b) { 503 return "a"+ a +"\u0001" + a + "b\u0002c" + b + "\u0001\u0002" + b + "dd"; 504 } 505 506 @CodeReflection 507 static String multiTypeConcat(int i, Boolean b, char c, Short s, float f, Double d) { 508 return "i:"+ i +" b:" + b + " c:" + c + " f:" + f + " d:" + d; 509 } 510 511 @CodeReflection 512 static int ifTrue(int i) { 513 if (true) { 514 return i; 515 } 516 return -i; 517 } 518 519 @CodeReflection 520 static int excHandlerFollowingSplitTable(boolean b) { 521 try { 522 if (b) return 1; 523 else throw new Exception(); 524 } catch (Exception ex) {} 525 return 2; 526 } 527 528 @CodeReflection 529 static int varModifiedInTryBlock(boolean b) { 530 int i = 0; 531 try { 532 i++; 533 if (b) throw new Exception(); 534 i++; 535 throw new Exception(); 536 } catch (Exception ex) { 537 return i; 538 } 539 } 540 541 @CodeReflection 542 static boolean finallyWithLoop(boolean b) { 543 try { 544 while (b) { 545 if (b) 546 return false; 547 b = !b; 548 } 549 return true; 550 } finally { 551 b = false; 552 } 553 } 554 555 @CodeReflection 556 static long doubleUseOfOperand(int x) { 557 long piece = x; 558 return piece * piece; 559 } 560 561 record TestData(Method testMethod) { 562 @Override 563 public String toString() { 564 String s = testMethod.getName() + Arrays.stream(testMethod.getParameterTypes()) 565 .map(Class::getSimpleName).collect(Collectors.joining(",", "(", ")")); 566 if (s.length() > 30) s = s.substring(0, 27) + "..."; 567 return s; 568 } 569 } 570 571 @DataProvider(name = "testMethods") 572 public static TestData[]testMethods() { 573 return Stream.of(TestBytecode.class.getDeclaredMethods()) 574 .filter(m -> m.isAnnotationPresent(CodeReflection.class)) 575 .map(TestData::new).toArray(TestData[]::new); 576 } 577 578 private static byte[] CLASS_DATA; 579 private static ClassModel CLASS_MODEL; 580 581 @BeforeClass 582 public static void setup() throws Exception { 583 CLASS_DATA = TestBytecode.class.getResourceAsStream("TestBytecode.class").readAllBytes(); 584 CLASS_MODEL = ClassFile.of().parse(CLASS_DATA); 585 } 586 587 private static MethodTypeDesc toMethodTypeDesc(Method m) { 588 return MethodTypeDesc.of( 589 m.getReturnType().describeConstable().orElseThrow(), 590 Arrays.stream(m.getParameterTypes()) 591 .map(cls -> cls.describeConstable().orElseThrow()).toList()); 592 } 593 594 595 private static final Map<Class<?>, Object[]> TEST_ARGS = new IdentityHashMap<>(); 596 private static Object[] values(Object... values) { 597 return values; 598 } 599 private static void initTestArgs(Object[] values, Class<?>... argTypes) { 600 for (var argType : argTypes) TEST_ARGS.put(argType, values); 601 } 602 static { 603 initTestArgs(values(1, 2, 4), int.class, Integer.class); 604 initTestArgs(values((byte)1, (byte)3, (byte)4), byte.class, Byte.class); 605 initTestArgs(values((short)1, (short)2, (short)3), short.class, Short.class); 606 initTestArgs(values((char)2, (char)3, (char)4), char.class, Character.class); 607 initTestArgs(values(false, true), boolean.class, Boolean.class); 608 initTestArgs(values("Hello World"), String.class); 609 initTestArgs(values(1l, 2l, 4l), long.class, Long.class); 610 initTestArgs(values(1f, 3f, 4f), float.class, Float.class); 611 initTestArgs(values(1d, 2d, 3d), double.class, Double.class); 612 } 613 614 interface Executor { 615 void execute(Object[] args) throws Throwable; 616 } 617 618 private static void permutateAllArgs(Class<?>[] argTypes, Executor executor) throws Throwable { 619 final int argn = argTypes.length; 620 Object[][] argValues = new Object[argn][]; 621 for (int i = 0; i < argn; i++) { 622 argValues[i] = TEST_ARGS.get(argTypes[i]); 623 } 624 int[] argIndexes = new int[argn]; 625 Object[] args = new Object[argn]; 626 while (true) { 627 for (int i = 0; i < argn; i++) { 628 args[i] = argValues[i][argIndexes[i]]; 629 } 630 executor.execute(args); 631 int i = argn - 1; 632 while (i >= 0 && argIndexes[i] == argValues[i].length - 1) i--; 633 if (i < 0) return; 634 argIndexes[i++]++; 635 while (i < argn) argIndexes[i++] = 0; 636 } 637 } 638 639 @Test(dataProvider = "testMethods") 640 public void testLift(TestData d) throws Throwable { 641 CoreOp.FuncOp flift; 642 try { 643 flift = BytecodeLift.lift(CLASS_DATA, d.testMethod.getName(), toMethodTypeDesc(d.testMethod)); 644 } catch (Throwable e) { 645 ClassPrinter.toYaml(ClassFile.of().parse(TestBytecode.class.getResourceAsStream("TestBytecode.class").readAllBytes()) 646 .methods().stream().filter(m -> m.methodName().equalsString(d.testMethod().getName())).findAny().get(), 647 ClassPrinter.Verbosity.CRITICAL_ATTRIBUTES, System.out::print); 648 System.out.println("Lift failed, compiled model:"); 649 Op.ofMethod(d.testMethod).ifPresent(f -> f.writeTo(System.out)); 650 throw e; 651 } 652 try { 653 Object receiver1, receiver2; 654 if (d.testMethod.accessFlags().contains(AccessFlag.STATIC)) { 655 receiver1 = null; 656 receiver2 = null; 657 } else { 658 receiver1 = new TestBytecode(); 659 receiver2 = new TestBytecode(); 660 } 661 permutateAllArgs(d.testMethod.getParameterTypes(), args -> 662 Assert.assertEquals(invokeAndConvert(flift, receiver1, args), d.testMethod.invoke(receiver2, args))); 663 } catch (Throwable e) { 664 System.out.println("Compiled model:"); 665 Op.ofMethod(d.testMethod).ifPresent(f -> f.writeTo(System.out)); 666 System.out.println("Lifted model:"); 667 flift.writeTo(System.out); 668 throw e; 669 } 670 } 671 672 private static Object invokeAndConvert(CoreOp.FuncOp func, Object receiver, Object... args) { 673 List argl = new ArrayList(args.length + 1); 674 if (receiver != null) argl.add(receiver); 675 argl.addAll(Arrays.asList(args)); 676 Object ret = Interpreter.invoke(MethodHandles.lookup(), func, argl); 677 if (ret instanceof Integer i) { 678 TypeElement rt = func.invokableType().returnType(); 679 if (rt.equals(JavaType.BOOLEAN)) { 680 return i != 0; 681 } else if (rt.equals(JavaType.BYTE)) { 682 return i.byteValue(); 683 } else if (rt.equals(JavaType.CHAR)) { 684 return (short)i.intValue(); 685 } else if (rt.equals(JavaType.SHORT)) { 686 return i.shortValue(); 687 } 688 } 689 return ret; 690 } 691 692 @Test(dataProvider = "testMethods") 693 public void testGenerate(TestData d) throws Throwable { 694 CoreOp.FuncOp func = Op.ofMethod(d.testMethod).get(); 695 696 CoreOp.FuncOp lfunc; 697 try { 698 lfunc = func.transform(CopyContext.create(), OpTransformer.LOWERING_TRANSFORMER); 699 } catch (UnsupportedOperationException uoe) { 700 throw new SkipException("lowering caused:", uoe); 701 } 702 703 try { 704 MethodHandle mh = BytecodeGenerator.generate(MethodHandles.lookup(), lfunc); 705 Object receiver1, receiver2; 706 if (d.testMethod.accessFlags().contains(AccessFlag.STATIC)) { 707 receiver1 = null; 708 receiver2 = null; 709 } else { 710 receiver1 = new TestBytecode(); 711 receiver2 = new TestBytecode(); 712 } 713 permutateAllArgs(d.testMethod.getParameterTypes(), args -> { 714 List argl = new ArrayList(args.length + 1); 715 if (receiver1 != null) argl.add(receiver1); 716 argl.addAll(Arrays.asList(args)); 717 Assert.assertEquals(mh.invokeWithArguments(argl), d.testMethod.invoke(receiver2, args)); 718 }); 719 } catch (Throwable e) { 720 func.writeTo(System.out); 721 lfunc.writeTo(System.out); 722 String methodName = d.testMethod().getName(); 723 for (var mm : CLASS_MODEL.methods()) { 724 if (mm.methodName().equalsString(methodName) 725 || mm.methodName().stringValue().startsWith("lambda$" + methodName + "$")) { 726 ClassPrinter.toYaml(mm, 727 ClassPrinter.Verbosity.CRITICAL_ATTRIBUTES, 728 System.out::print); 729 } 730 } 731 Files.list(Path.of("DUMP_CLASS_FILES")).forEach(p -> { 732 if (p.getFileName().toString().matches(methodName + "\\..+\\.class")) try { 733 ClassPrinter.toYaml(ClassFile.of().parse(p), 734 ClassPrinter.Verbosity.CRITICAL_ATTRIBUTES, 735 System.out::print); 736 } catch (IOException ignore) {} 737 }); 738 throw e; 739 } 740 } 741 }