1 /* 2 * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved. 3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. 4 * 5 * This code is free software; you can redistribute it and/or modify it 6 * under the terms of the GNU General Public License version 2 only, as 7 * published by the Free Software Foundation. 8 * 9 * This code is distributed in the hope that it will be useful, but WITHOUT 10 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or 11 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License 12 * version 2 for more details (a copy is included in the LICENSE file that 13 * accompanied this code). 14 * 15 * You should have received a copy of the GNU General Public License version 16 * 2 along with this work; if not, write to the Free Software Foundation, 17 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. 18 * 19 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA 20 * or visit www.oracle.com if you need additional information or have any 21 * questions. 22 */ 23 24 import jdk.incubator.code.dialect.java.JavaOp; 25 import org.testng.Assert; 26 import org.testng.annotations.Test; 27 28 import java.io.PrintStream; 29 import jdk.incubator.code.*; 30 import jdk.incubator.code.dialect.core.CoreOp; 31 import jdk.incubator.code.dialect.java.FieldRef; 32 import jdk.incubator.code.dialect.java.MethodRef; 33 import jdk.incubator.code.interpreter.Interpreter; 34 import java.lang.invoke.MethodHandles; 35 import java.lang.reflect.Method; 36 import jdk.incubator.code.CodeReflection; 37 import java.util.List; 38 import java.util.Optional; 39 import java.util.concurrent.atomic.AtomicBoolean; 40 import java.util.function.Function; 41 import java.util.function.IntBinaryOperator; 42 import java.util.function.IntUnaryOperator; 43 import java.util.stream.Collectors; 44 import java.util.stream.Stream; 45 46 import static jdk.incubator.code.dialect.java.JavaOp.arrayStoreOp; 47 import static jdk.incubator.code.dialect.core.CoreOp.constant; 48 import static jdk.incubator.code.dialect.java.JavaOp.fieldLoad; 49 import static jdk.incubator.code.dialect.java.JavaOp.newArray; 50 import static jdk.incubator.code.dialect.java.MethodRef.method; 51 import static jdk.incubator.code.dialect.java.JavaType.*; 52 53 /* 54 * @test 55 * @modules jdk.incubator.code 56 * @enablePreview 57 * @run testng TestLocalTransformationsAdaption 58 */ 59 60 public class TestLocalTransformationsAdaption { 61 62 @CodeReflection 63 static int f(int i) { 64 IntBinaryOperator add = (a, b) -> { 65 return add(a, b); 66 }; 67 68 try { 69 IntUnaryOperator add42 = (a) -> { 70 return add.applyAsInt(a, 42); 71 }; 72 73 int j = add42.applyAsInt(i); 74 75 IntBinaryOperator f = (a, b) -> { 76 if (i < 0) { 77 throw new RuntimeException(); 78 } 79 80 IntUnaryOperator g = (c) -> { 81 return add(a, c); 82 }; 83 84 return g.applyAsInt(b); 85 }; 86 87 return f.applyAsInt(j, j); 88 } catch (RuntimeException e) { 89 throw new IndexOutOfBoundsException(i); 90 } 91 } 92 93 @Test 94 public void testInvocation() { 95 CoreOp.FuncOp f = getFuncOp("f"); 96 f.writeTo(System.out); 97 98 f = f.transform(OpTransformer.LOWERING_TRANSFORMER); 99 f.writeTo(System.out); 100 101 int x = (int) Interpreter.invoke(MethodHandles.lookup(), f, 2); 102 Assert.assertEquals(x, f(2)); 103 104 try { 105 Interpreter.invoke(MethodHandles.lookup(), f, -10); 106 Assert.fail(); 107 } catch (Throwable e) { 108 Assert.assertEquals(IndexOutOfBoundsException.class, e.getClass()); 109 } 110 } 111 112 @Test 113 public void testFuncEntryExit() { 114 CoreOp.FuncOp f = getFuncOp("f"); 115 f.writeTo(System.out); 116 117 AtomicBoolean first = new AtomicBoolean(true); 118 CoreOp.FuncOp fc = f.transform((block, op) -> { 119 if (first.get()) { 120 printConstantString(block, "ENTRY"); 121 first.set(false); 122 } 123 124 switch (op) { 125 case CoreOp.ReturnOp returnOp when getNearestInvokeableAncestorOp(returnOp) instanceof CoreOp.FuncOp: { 126 printConstantString(block, "EXIT"); 127 break; 128 } 129 case JavaOp.ThrowOp throwOp: { 130 printConstantString(block, "EXIT"); 131 break; 132 } 133 default: 134 } 135 136 block.apply(op); 137 138 return block; 139 }); 140 fc.writeTo(System.out); 141 142 fc = fc.transform(OpTransformer.LOWERING_TRANSFORMER); 143 fc.writeTo(System.out); 144 145 int x = (int) Interpreter.invoke(MethodHandles.lookup(), fc, 2); 146 Assert.assertEquals(x, f(2)); 147 148 try { 149 Interpreter.invoke(MethodHandles.lookup(), fc, -10); 150 Assert.fail(); 151 } catch (Throwable e) { 152 Assert.assertEquals(IndexOutOfBoundsException.class, e.getClass()); 153 } 154 } 155 156 static void printConstantString(Function<Op, Op.Result> opBuilder, String s) { 157 Op.Result c = opBuilder.apply(constant(J_L_STRING, s)); 158 Value System_out = opBuilder.apply(fieldLoad(FieldRef.field(System.class, "out", PrintStream.class))); 159 opBuilder.apply(JavaOp.invoke(method(PrintStream.class, "println", void.class, String.class), System_out, c)); 160 } 161 162 static Op getNearestInvokeableAncestorOp(Op op) { 163 do { 164 op = op.ancestorBody().parentOp(); 165 } while (!(op instanceof Op.Invokable)); 166 return op; 167 } 168 169 170 @Test 171 public void testReplaceCall() { 172 CoreOp.FuncOp f = getFuncOp("f"); 173 f.writeTo(System.out); 174 175 CoreOp.FuncOp fc = f.transform((block, op) -> { 176 switch (op) { 177 case JavaOp.InvokeOp invokeOp when invokeOp.invokeDescriptor().equals(ADD_METHOD): { 178 // Get the adapted operands, and pass those to the new call method 179 List<Value> adaptedOperands = block.context().getValues(op.operands()); 180 Op.Result adaptedResult = block.apply(JavaOp.invoke(ADD_WITH_PRINT_METHOD, adaptedOperands)); 181 // Map the old call result to the new call result, so existing operations can be 182 // adapted to use the new result 183 block.context().mapValue(invokeOp.result(), adaptedResult); 184 break; 185 } 186 default: { 187 block.apply(op); 188 } 189 } 190 return block; 191 }); 192 fc.writeTo(System.out); 193 194 fc = fc.transform(OpTransformer.LOWERING_TRANSFORMER); 195 fc.writeTo(System.out); 196 197 int x = (int) Interpreter.invoke(MethodHandles.lookup(), fc, 2); 198 Assert.assertEquals(x, f(2)); 199 } 200 201 202 @Test 203 public void testCallEntryExit() { 204 CoreOp.FuncOp f = getFuncOp("f"); 205 f.writeTo(System.out); 206 207 CoreOp.FuncOp fc = f.transform((block, op) -> { 208 switch (op) { 209 case JavaOp.InvokeOp invokeOp: { 210 printCall(block.context(), invokeOp, block); 211 break; 212 } 213 default: { 214 block.apply(op); 215 } 216 } 217 return block; 218 }); 219 fc.writeTo(System.out); 220 221 fc = fc.transform(OpTransformer.LOWERING_TRANSFORMER); 222 fc.writeTo(System.out); 223 224 int x = (int) Interpreter.invoke(MethodHandles.lookup(), fc, 2); 225 Assert.assertEquals(x, f(2)); 226 } 227 228 static void printCall(CopyContext cc, JavaOp.InvokeOp invokeOp, Function<Op, Op.Result> opBuilder) { 229 List<Value> adaptedInvokeOperands = cc.getValues(invokeOp.operands()); 230 231 String prefix = "ENTER"; 232 233 Value arrayLength = opBuilder.apply( 234 constant(INT, adaptedInvokeOperands.size())); 235 Value formatArray = opBuilder.apply( 236 newArray(type(Object[].class), arrayLength)); 237 238 Value indexZero = null; 239 for (int i = 0; i < adaptedInvokeOperands.size(); i++) { 240 Value operand = adaptedInvokeOperands.get(i); 241 242 Value index = opBuilder.apply( 243 constant(INT, i)); 244 if (i == 0) { 245 indexZero = index; 246 } 247 248 if (operand.type().equals(INT)) { 249 operand = opBuilder.apply( 250 JavaOp.invoke(method(Integer.class, "valueOf", Integer.class, int.class), operand)); 251 // @@@ Other primitive types 252 } 253 opBuilder.apply( 254 arrayStoreOp(formatArray, index, operand)); 255 } 256 257 Op.Result formatString = opBuilder.apply( 258 constant(J_L_STRING, 259 prefix + ": " + invokeOp.invokeDescriptor() + "(" + formatString(adaptedInvokeOperands) + ")%n")); 260 Value System_out = opBuilder.apply(fieldLoad(FieldRef.field(System.class, "out", PrintStream.class))); 261 opBuilder.apply( 262 JavaOp.invoke(method(PrintStream.class, "printf", PrintStream.class, String.class, Object[].class), 263 System_out, formatString, formatArray)); 264 265 // Method call 266 267 Op.Result adaptedInvokeResult = opBuilder.apply(invokeOp); 268 269 // After method call 270 271 prefix = "EXIT"; 272 273 if (adaptedInvokeResult.type().equals(INT)) { 274 adaptedInvokeResult = opBuilder.apply( 275 JavaOp.invoke(method(Integer.class, "valueOf", Integer.class, int.class), adaptedInvokeResult)); 276 // @@@ Other primitive types 277 } 278 opBuilder.apply( 279 arrayStoreOp(formatArray, indexZero, adaptedInvokeResult)); 280 281 formatString = opBuilder.apply( 282 constant(J_L_STRING, 283 prefix + ": " + invokeOp.invokeDescriptor() + " -> " + formatString(adaptedInvokeResult.type()) + "%n")); 284 opBuilder.apply( 285 JavaOp.invoke(method(PrintStream.class, "printf", PrintStream.class, String.class, Object[].class), 286 System_out, formatString, formatArray)); 287 } 288 289 static String formatString(List<Value> vs) { 290 return vs.stream().map(v -> formatString(v.type())).collect(Collectors.joining(",")); 291 } 292 293 static String formatString(TypeElement t) { 294 if (t.equals(INT)) { 295 return "%d"; 296 } else { 297 return "%s"; 298 } 299 } 300 301 302 static final MethodRef ADD_METHOD = MethodRef.method( 303 TestLocalTransformationsAdaption.class, "add", 304 int.class, int.class, int.class); 305 306 static int add(int a, int b) { 307 return a + b; 308 } 309 310 static final MethodRef ADD_WITH_PRINT_METHOD = MethodRef.method( 311 TestLocalTransformationsAdaption.class, "addWithPrint", 312 int.class, int.class, int.class); 313 314 static int addWithPrint(int a, int b) { 315 System.out.printf("Adding %d + %d%n", a, b); 316 return a + b; 317 } 318 319 static CoreOp.FuncOp getFuncOp(String name) { 320 Optional<Method> om = Stream.of(TestLocalTransformationsAdaption.class.getDeclaredMethods()) 321 .filter(m -> m.getName().equals(name)) 322 .findFirst(); 323 324 Method m = om.get(); 325 return Op.ofMethod(m).get(); 326 } 327 }