1 /* 2 * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved. 3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. 4 * 5 * This code is free software; you can redistribute it and/or modify it 6 * under the terms of the GNU General Public License version 2 only, as 7 * published by the Free Software Foundation. 8 * 9 * This code is distributed in the hope that it will be useful, but WITHOUT 10 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or 11 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License 12 * version 2 for more details (a copy is included in the LICENSE file that 13 * accompanied this code). 14 * 15 * You should have received a copy of the GNU General Public License version 16 * 2 along with this work; if not, write to the Free Software Foundation, 17 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. 18 * 19 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA 20 * or visit www.oracle.com if you need additional information or have any 21 * questions. 22 */ 23 24 import org.testng.Assert; 25 import org.testng.annotations.Test; 26 27 import java.io.PrintStream; 28 import java.lang.reflect.code.*; 29 import java.lang.reflect.code.op.CoreOp; 30 import java.lang.reflect.code.type.FieldRef; 31 import java.lang.reflect.code.type.MethodRef; 32 import java.lang.reflect.code.interpreter.Interpreter; 33 import java.lang.invoke.MethodHandles; 34 import java.lang.reflect.Method; 35 import java.lang.runtime.CodeReflection; 36 import java.util.List; 37 import java.util.Optional; 38 import java.util.concurrent.atomic.AtomicBoolean; 39 import java.util.function.Function; 40 import java.util.function.IntBinaryOperator; 41 import java.util.function.IntUnaryOperator; 42 import java.util.stream.Collectors; 43 import java.util.stream.Stream; 44 45 import static java.lang.reflect.code.op.CoreOp.arrayStoreOp; 46 import static java.lang.reflect.code.op.CoreOp.constant; 47 import static java.lang.reflect.code.op.CoreOp.fieldLoad; 48 import static java.lang.reflect.code.op.CoreOp.newArray; 49 import static java.lang.reflect.code.type.MethodRef.method; 50 import static java.lang.reflect.code.type.JavaType.*; 51 52 /* 53 * @test 54 * @enablePreview 55 * @run testng TestLocalTransformationsAdaption 56 */ 57 58 public class TestLocalTransformationsAdaption { 59 60 @CodeReflection 61 static int f(int i) { 62 IntBinaryOperator add = (a, b) -> { 63 return add(a, b); 64 }; 65 66 try { 67 IntUnaryOperator add42 = (a) -> { 68 return add.applyAsInt(a, 42); 69 }; 70 71 int j = add42.applyAsInt(i); 72 73 IntBinaryOperator f = (a, b) -> { 74 if (i < 0) { 75 throw new RuntimeException(); 76 } 77 78 IntUnaryOperator g = (c) -> { 79 return add(a, c); 80 }; 81 82 return g.applyAsInt(b); 83 }; 84 85 return f.applyAsInt(j, j); 86 } catch (RuntimeException e) { 87 throw new IndexOutOfBoundsException(i); 88 } 89 } 90 91 @Test 92 public void testInvocation() { 93 CoreOp.FuncOp f = getFuncOp("f"); 94 f.writeTo(System.out); 95 96 f = f.transform(OpTransformer.LOWERING_TRANSFORMER); 97 f.writeTo(System.out); 98 99 int x = (int) Interpreter.invoke(MethodHandles.lookup(), f, 2); 100 Assert.assertEquals(x, f(2)); 101 102 try { 103 Interpreter.invoke(MethodHandles.lookup(), f, -10); 104 Assert.fail(); 105 } catch (Throwable e) { 106 Assert.assertEquals(IndexOutOfBoundsException.class, e.getClass()); 107 } 108 } 109 110 @Test 111 public void testFuncEntryExit() { 112 CoreOp.FuncOp f = getFuncOp("f"); 113 f.writeTo(System.out); 114 115 AtomicBoolean first = new AtomicBoolean(true); 116 CoreOp.FuncOp fc = f.transform((block, op) -> { 117 if (first.get()) { 118 printConstantString(block, "ENTRY"); 119 first.set(false); 120 } 121 122 switch (op) { 123 case CoreOp.ReturnOp returnOp when getNearestInvokeableAncestorOp(returnOp) instanceof CoreOp.FuncOp: { 124 printConstantString(block, "EXIT"); 125 break; 126 } 127 case CoreOp.ThrowOp throwOp: { 128 printConstantString(block, "EXIT"); 129 break; 130 } 131 default: 132 } 133 134 block.apply(op); 135 136 return block; 137 }); 138 fc.writeTo(System.out); 139 140 fc = fc.transform(OpTransformer.LOWERING_TRANSFORMER); 141 fc.writeTo(System.out); 142 143 int x = (int) Interpreter.invoke(MethodHandles.lookup(), fc, 2); 144 Assert.assertEquals(x, f(2)); 145 146 try { 147 Interpreter.invoke(MethodHandles.lookup(), fc, -10); 148 Assert.fail(); 149 } catch (Throwable e) { 150 Assert.assertEquals(IndexOutOfBoundsException.class, e.getClass()); 151 } 152 } 153 154 static void printConstantString(Function<Op, Op.Result> opBuilder, String s) { 155 Op.Result c = opBuilder.apply(constant(J_L_STRING, s)); 156 Value System_out = opBuilder.apply(fieldLoad(FieldRef.field(System.class, "out", PrintStream.class))); 157 opBuilder.apply(CoreOp.invoke(method(PrintStream.class, "println", void.class, String.class), System_out, c)); 158 } 159 160 static Op getNearestInvokeableAncestorOp(Op op) { 161 do { 162 op = op.ancestorBody().parentOp(); 163 } while (!(op instanceof Op.Invokable)); 164 return op; 165 } 166 167 168 @Test 169 public void testReplaceCall() { 170 CoreOp.FuncOp f = getFuncOp("f"); 171 f.writeTo(System.out); 172 173 CoreOp.FuncOp fc = f.transform((block, op) -> { 174 switch (op) { 175 case CoreOp.InvokeOp invokeOp when invokeOp.invokeDescriptor().equals(ADD_METHOD): { 176 // Get the adapted operands, and pass those to the new call method 177 List<Value> adaptedOperands = block.context().getValues(op.operands()); 178 Op.Result adaptedResult = block.apply(CoreOp.invoke(ADD_WITH_PRINT_METHOD, adaptedOperands)); 179 // Map the old call result to the new call result, so existing operations can be 180 // adapted to use the new result 181 block.context().mapValue(invokeOp.result(), adaptedResult); 182 break; 183 } 184 default: { 185 block.apply(op); 186 } 187 } 188 return block; 189 }); 190 fc.writeTo(System.out); 191 192 fc = fc.transform(OpTransformer.LOWERING_TRANSFORMER); 193 fc.writeTo(System.out); 194 195 int x = (int) Interpreter.invoke(MethodHandles.lookup(), fc, 2); 196 Assert.assertEquals(x, f(2)); 197 } 198 199 200 @Test 201 public void testCallEntryExit() { 202 CoreOp.FuncOp f = getFuncOp("f"); 203 f.writeTo(System.out); 204 205 CoreOp.FuncOp fc = f.transform((block, op) -> { 206 switch (op) { 207 case CoreOp.InvokeOp invokeOp: { 208 printCall(block.context(), invokeOp, block); 209 break; 210 } 211 default: { 212 block.apply(op); 213 } 214 } 215 return block; 216 }); 217 fc.writeTo(System.out); 218 219 fc = fc.transform(OpTransformer.LOWERING_TRANSFORMER); 220 fc.writeTo(System.out); 221 222 int x = (int) Interpreter.invoke(MethodHandles.lookup(), fc, 2); 223 Assert.assertEquals(x, f(2)); 224 } 225 226 static void printCall(CopyContext cc, CoreOp.InvokeOp invokeOp, Function<Op, Op.Result> opBuilder) { 227 List<Value> adaptedInvokeOperands = cc.getValues(invokeOp.operands()); 228 229 String prefix = "ENTER"; 230 231 Value arrayLength = opBuilder.apply( 232 constant(INT, adaptedInvokeOperands.size())); 233 Value formatArray = opBuilder.apply( 234 newArray(type(Object[].class), arrayLength)); 235 236 Value indexZero = null; 237 for (int i = 0; i < adaptedInvokeOperands.size(); i++) { 238 Value operand = adaptedInvokeOperands.get(i); 239 240 Value index = opBuilder.apply( 241 constant(INT, i)); 242 if (i == 0) { 243 indexZero = index; 244 } 245 246 if (operand.type().equals(INT)) { 247 operand = opBuilder.apply( 248 CoreOp.invoke(method(Integer.class, "valueOf", Integer.class, int.class), operand)); 249 // @@@ Other primitive types 250 } 251 opBuilder.apply( 252 arrayStoreOp(formatArray, index, operand)); 253 } 254 255 Op.Result formatString = opBuilder.apply( 256 constant(J_L_STRING, 257 prefix + ": " + invokeOp.invokeDescriptor() + "(" + formatString(adaptedInvokeOperands) + ")%n")); 258 Value System_out = opBuilder.apply(fieldLoad(FieldRef.field(System.class, "out", PrintStream.class))); 259 opBuilder.apply( 260 CoreOp.invoke(method(PrintStream.class, "printf", PrintStream.class, String.class, Object[].class), 261 System_out, formatString, formatArray)); 262 263 // Method call 264 265 Op.Result adaptedInvokeResult = opBuilder.apply(invokeOp); 266 267 // After method call 268 269 prefix = "EXIT"; 270 271 if (adaptedInvokeResult.type().equals(INT)) { 272 adaptedInvokeResult = opBuilder.apply( 273 CoreOp.invoke(method(Integer.class, "valueOf", Integer.class, int.class), adaptedInvokeResult)); 274 // @@@ Other primitive types 275 } 276 opBuilder.apply( 277 arrayStoreOp(formatArray, indexZero, adaptedInvokeResult)); 278 279 formatString = opBuilder.apply( 280 constant(J_L_STRING, 281 prefix + ": " + invokeOp.invokeDescriptor() + " -> " + formatString(adaptedInvokeResult.type()) + "%n")); 282 opBuilder.apply( 283 CoreOp.invoke(method(PrintStream.class, "printf", PrintStream.class, String.class, Object[].class), 284 System_out, formatString, formatArray)); 285 } 286 287 static String formatString(List<Value> vs) { 288 return vs.stream().map(v -> formatString(v.type())).collect(Collectors.joining(",")); 289 } 290 291 static String formatString(TypeElement t) { 292 if (t.equals(INT)) { 293 return "%d"; 294 } else { 295 return "%s"; 296 } 297 } 298 299 300 static final MethodRef ADD_METHOD = MethodRef.method( 301 TestLocalTransformationsAdaption.class, "add", 302 int.class, int.class, int.class); 303 304 static int add(int a, int b) { 305 return a + b; 306 } 307 308 static final MethodRef ADD_WITH_PRINT_METHOD = MethodRef.method( 309 TestLocalTransformationsAdaption.class, "addWithPrint", 310 int.class, int.class, int.class); 311 312 static int addWithPrint(int a, int b) { 313 System.out.printf("Adding %d + %d%n", a, b); 314 return a + b; 315 } 316 317 static CoreOp.FuncOp getFuncOp(String name) { 318 Optional<Method> om = Stream.of(TestLocalTransformationsAdaption.class.getDeclaredMethods()) 319 .filter(m -> m.getName().equals(name)) 320 .findFirst(); 321 322 Method m = om.get(); 323 return m.getCodeModel().get(); 324 } 325 }