1 /* 2 * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved. 3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. 4 * 5 * This code is free software; you can redistribute it and/or modify it 6 * under the terms of the GNU General Public License version 2 only, as 7 * published by the Free Software Foundation. 8 * 9 * This code is distributed in the hope that it will be useful, but WITHOUT 10 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or 11 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License 12 * version 2 for more details (a copy is included in the LICENSE file that 13 * accompanied this code). 14 * 15 * You should have received a copy of the GNU General Public License version 16 * 2 along with this work; if not, write to the Free Software Foundation, 17 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. 18 * 19 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA 20 * or visit www.oracle.com if you need additional information or have any 21 * questions. 22 */ 23 24 import org.testng.Assert; 25 import org.testng.annotations.Test; 26 27 import java.io.PrintStream; 28 import jdk.incubator.code.*; 29 import jdk.incubator.code.op.CoreOp; 30 import jdk.incubator.code.type.FieldRef; 31 import jdk.incubator.code.type.MethodRef; 32 import jdk.incubator.code.interpreter.Interpreter; 33 import java.lang.invoke.MethodHandles; 34 import java.lang.reflect.Method; 35 import jdk.incubator.code.CodeReflection; 36 import java.util.List; 37 import java.util.Optional; 38 import java.util.concurrent.atomic.AtomicBoolean; 39 import java.util.function.Function; 40 import java.util.function.IntBinaryOperator; 41 import java.util.function.IntUnaryOperator; 42 import java.util.stream.Collectors; 43 import java.util.stream.Stream; 44 45 import static jdk.incubator.code.op.CoreOp.arrayStoreOp; 46 import static jdk.incubator.code.op.CoreOp.constant; 47 import static jdk.incubator.code.op.CoreOp.fieldLoad; 48 import static jdk.incubator.code.op.CoreOp.newArray; 49 import static jdk.incubator.code.type.MethodRef.method; 50 import static jdk.incubator.code.type.JavaType.*; 51 52 /* 53 * @test 54 * @modules jdk.incubator.code 55 * @enablePreview 56 * @run testng TestLocalTransformationsAdaption 57 */ 58 59 public class TestLocalTransformationsAdaption { 60 61 @CodeReflection 62 static int f(int i) { 63 IntBinaryOperator add = (a, b) -> { 64 return add(a, b); 65 }; 66 67 try { 68 IntUnaryOperator add42 = (a) -> { 69 return add.applyAsInt(a, 42); 70 }; 71 72 int j = add42.applyAsInt(i); 73 74 IntBinaryOperator f = (a, b) -> { 75 if (i < 0) { 76 throw new RuntimeException(); 77 } 78 79 IntUnaryOperator g = (c) -> { 80 return add(a, c); 81 }; 82 83 return g.applyAsInt(b); 84 }; 85 86 return f.applyAsInt(j, j); 87 } catch (RuntimeException e) { 88 throw new IndexOutOfBoundsException(i); 89 } 90 } 91 92 @Test 93 public void testInvocation() { 94 CoreOp.FuncOp f = getFuncOp("f"); 95 f.writeTo(System.out); 96 97 f = f.transform(OpTransformer.LOWERING_TRANSFORMER); 98 f.writeTo(System.out); 99 100 int x = (int) Interpreter.invoke(MethodHandles.lookup(), f, 2); 101 Assert.assertEquals(x, f(2)); 102 103 try { 104 Interpreter.invoke(MethodHandles.lookup(), f, -10); 105 Assert.fail(); 106 } catch (Throwable e) { 107 Assert.assertEquals(IndexOutOfBoundsException.class, e.getClass()); 108 } 109 } 110 111 @Test 112 public void testFuncEntryExit() { 113 CoreOp.FuncOp f = getFuncOp("f"); 114 f.writeTo(System.out); 115 116 AtomicBoolean first = new AtomicBoolean(true); 117 CoreOp.FuncOp fc = f.transform((block, op) -> { 118 if (first.get()) { 119 printConstantString(block, "ENTRY"); 120 first.set(false); 121 } 122 123 switch (op) { 124 case CoreOp.ReturnOp returnOp when getNearestInvokeableAncestorOp(returnOp) instanceof CoreOp.FuncOp: { 125 printConstantString(block, "EXIT"); 126 break; 127 } 128 case CoreOp.ThrowOp throwOp: { 129 printConstantString(block, "EXIT"); 130 break; 131 } 132 default: 133 } 134 135 block.apply(op); 136 137 return block; 138 }); 139 fc.writeTo(System.out); 140 141 fc = fc.transform(OpTransformer.LOWERING_TRANSFORMER); 142 fc.writeTo(System.out); 143 144 int x = (int) Interpreter.invoke(MethodHandles.lookup(), fc, 2); 145 Assert.assertEquals(x, f(2)); 146 147 try { 148 Interpreter.invoke(MethodHandles.lookup(), fc, -10); 149 Assert.fail(); 150 } catch (Throwable e) { 151 Assert.assertEquals(IndexOutOfBoundsException.class, e.getClass()); 152 } 153 } 154 155 static void printConstantString(Function<Op, Op.Result> opBuilder, String s) { 156 Op.Result c = opBuilder.apply(constant(J_L_STRING, s)); 157 Value System_out = opBuilder.apply(fieldLoad(FieldRef.field(System.class, "out", PrintStream.class))); 158 opBuilder.apply(CoreOp.invoke(method(PrintStream.class, "println", void.class, String.class), System_out, c)); 159 } 160 161 static Op getNearestInvokeableAncestorOp(Op op) { 162 do { 163 op = op.ancestorBody().parentOp(); 164 } while (!(op instanceof Op.Invokable)); 165 return op; 166 } 167 168 169 @Test 170 public void testReplaceCall() { 171 CoreOp.FuncOp f = getFuncOp("f"); 172 f.writeTo(System.out); 173 174 CoreOp.FuncOp fc = f.transform((block, op) -> { 175 switch (op) { 176 case CoreOp.InvokeOp invokeOp when invokeOp.invokeDescriptor().equals(ADD_METHOD): { 177 // Get the adapted operands, and pass those to the new call method 178 List<Value> adaptedOperands = block.context().getValues(op.operands()); 179 Op.Result adaptedResult = block.apply(CoreOp.invoke(ADD_WITH_PRINT_METHOD, adaptedOperands)); 180 // Map the old call result to the new call result, so existing operations can be 181 // adapted to use the new result 182 block.context().mapValue(invokeOp.result(), adaptedResult); 183 break; 184 } 185 default: { 186 block.apply(op); 187 } 188 } 189 return block; 190 }); 191 fc.writeTo(System.out); 192 193 fc = fc.transform(OpTransformer.LOWERING_TRANSFORMER); 194 fc.writeTo(System.out); 195 196 int x = (int) Interpreter.invoke(MethodHandles.lookup(), fc, 2); 197 Assert.assertEquals(x, f(2)); 198 } 199 200 201 @Test 202 public void testCallEntryExit() { 203 CoreOp.FuncOp f = getFuncOp("f"); 204 f.writeTo(System.out); 205 206 CoreOp.FuncOp fc = f.transform((block, op) -> { 207 switch (op) { 208 case CoreOp.InvokeOp invokeOp: { 209 printCall(block.context(), invokeOp, block); 210 break; 211 } 212 default: { 213 block.apply(op); 214 } 215 } 216 return block; 217 }); 218 fc.writeTo(System.out); 219 220 fc = fc.transform(OpTransformer.LOWERING_TRANSFORMER); 221 fc.writeTo(System.out); 222 223 int x = (int) Interpreter.invoke(MethodHandles.lookup(), fc, 2); 224 Assert.assertEquals(x, f(2)); 225 } 226 227 static void printCall(CopyContext cc, CoreOp.InvokeOp invokeOp, Function<Op, Op.Result> opBuilder) { 228 List<Value> adaptedInvokeOperands = cc.getValues(invokeOp.operands()); 229 230 String prefix = "ENTER"; 231 232 Value arrayLength = opBuilder.apply( 233 constant(INT, adaptedInvokeOperands.size())); 234 Value formatArray = opBuilder.apply( 235 newArray(type(Object[].class), arrayLength)); 236 237 Value indexZero = null; 238 for (int i = 0; i < adaptedInvokeOperands.size(); i++) { 239 Value operand = adaptedInvokeOperands.get(i); 240 241 Value index = opBuilder.apply( 242 constant(INT, i)); 243 if (i == 0) { 244 indexZero = index; 245 } 246 247 if (operand.type().equals(INT)) { 248 operand = opBuilder.apply( 249 CoreOp.invoke(method(Integer.class, "valueOf", Integer.class, int.class), operand)); 250 // @@@ Other primitive types 251 } 252 opBuilder.apply( 253 arrayStoreOp(formatArray, index, operand)); 254 } 255 256 Op.Result formatString = opBuilder.apply( 257 constant(J_L_STRING, 258 prefix + ": " + invokeOp.invokeDescriptor() + "(" + formatString(adaptedInvokeOperands) + ")%n")); 259 Value System_out = opBuilder.apply(fieldLoad(FieldRef.field(System.class, "out", PrintStream.class))); 260 opBuilder.apply( 261 CoreOp.invoke(method(PrintStream.class, "printf", PrintStream.class, String.class, Object[].class), 262 System_out, formatString, formatArray)); 263 264 // Method call 265 266 Op.Result adaptedInvokeResult = opBuilder.apply(invokeOp); 267 268 // After method call 269 270 prefix = "EXIT"; 271 272 if (adaptedInvokeResult.type().equals(INT)) { 273 adaptedInvokeResult = opBuilder.apply( 274 CoreOp.invoke(method(Integer.class, "valueOf", Integer.class, int.class), adaptedInvokeResult)); 275 // @@@ Other primitive types 276 } 277 opBuilder.apply( 278 arrayStoreOp(formatArray, indexZero, adaptedInvokeResult)); 279 280 formatString = opBuilder.apply( 281 constant(J_L_STRING, 282 prefix + ": " + invokeOp.invokeDescriptor() + " -> " + formatString(adaptedInvokeResult.type()) + "%n")); 283 opBuilder.apply( 284 CoreOp.invoke(method(PrintStream.class, "printf", PrintStream.class, String.class, Object[].class), 285 System_out, formatString, formatArray)); 286 } 287 288 static String formatString(List<Value> vs) { 289 return vs.stream().map(v -> formatString(v.type())).collect(Collectors.joining(",")); 290 } 291 292 static String formatString(TypeElement t) { 293 if (t.equals(INT)) { 294 return "%d"; 295 } else { 296 return "%s"; 297 } 298 } 299 300 301 static final MethodRef ADD_METHOD = MethodRef.method( 302 TestLocalTransformationsAdaption.class, "add", 303 int.class, int.class, int.class); 304 305 static int add(int a, int b) { 306 return a + b; 307 } 308 309 static final MethodRef ADD_WITH_PRINT_METHOD = MethodRef.method( 310 TestLocalTransformationsAdaption.class, "addWithPrint", 311 int.class, int.class, int.class); 312 313 static int addWithPrint(int a, int b) { 314 System.out.printf("Adding %d + %d%n", a, b); 315 return a + b; 316 } 317 318 static CoreOp.FuncOp getFuncOp(String name) { 319 Optional<Method> om = Stream.of(TestLocalTransformationsAdaption.class.getDeclaredMethods()) 320 .filter(m -> m.getName().equals(name)) 321 .findFirst(); 322 323 Method m = om.get(); 324 return Op.ofMethod(m).get(); 325 } 326 }