1 /* 2 * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved. 3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. 4 * 5 * This code is free software; you can redistribute it and/or modify it 6 * under the terms of the GNU General Public License version 2 only, as 7 * published by the Free Software Foundation. 8 * 9 * This code is distributed in the hope that it will be useful, but WITHOUT 10 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or 11 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License 12 * version 2 for more details (a copy is included in the LICENSE file that 13 * accompanied this code). 14 * 15 * You should have received a copy of the GNU General Public License version 16 * 2 along with this work; if not, write to the Free Software Foundation, 17 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. 18 * 19 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA 20 * or visit www.oracle.com if you need additional information or have any 21 * questions. 22 */ 23 24 import jdk.incubator.code.*; 25 import jdk.incubator.code.dialect.core.CoreOp; 26 import jdk.incubator.code.dialect.java.FieldRef; 27 import jdk.incubator.code.dialect.java.JavaOp; 28 import jdk.incubator.code.dialect.java.MethodRef; 29 import jdk.incubator.code.interpreter.Interpreter; 30 import org.junit.jupiter.api.Assertions; 31 import org.junit.jupiter.api.Test; 32 33 import java.io.PrintStream; 34 import java.lang.invoke.MethodHandles; 35 import java.lang.reflect.Method; 36 import java.util.List; 37 import java.util.Optional; 38 import java.util.concurrent.atomic.AtomicBoolean; 39 import java.util.function.IntBinaryOperator; 40 import java.util.function.IntUnaryOperator; 41 import java.util.stream.Collectors; 42 import java.util.stream.Stream; 43 44 import static jdk.incubator.code.dialect.core.CoreOp.constant; 45 import static jdk.incubator.code.dialect.java.JavaOp.*; 46 import static jdk.incubator.code.dialect.java.JavaType.*; 47 import static jdk.incubator.code.dialect.java.MethodRef.method; 48 49 /* 50 * @test 51 * @modules jdk.incubator.code 52 * @enablePreview 53 * @run junit TestLocalTransformationsAdaption 54 */ 55 56 public class TestLocalTransformationsAdaption { 57 58 @CodeReflection 59 static int f(int i) { 60 IntBinaryOperator add = (a, b) -> { 61 return add(a, b); 62 }; 63 64 try { 65 IntUnaryOperator add42 = (a) -> { 66 return add.applyAsInt(a, 42); 67 }; 68 69 int j = add42.applyAsInt(i); 70 71 IntBinaryOperator f = (a, b) -> { 72 if (i < 0) { 73 throw new RuntimeException(); 74 } 75 76 IntUnaryOperator g = (c) -> { 77 return add(a, c); 78 }; 79 80 return g.applyAsInt(b); 81 }; 82 83 return f.applyAsInt(j, j); 84 } catch (RuntimeException e) { 85 throw new IndexOutOfBoundsException(i); 86 } 87 } 88 89 @Test 90 public void testInvocation() { 91 CoreOp.FuncOp f = getFuncOp("f"); 92 System.out.println(f.toText()); 93 94 f = f.transform(OpTransformer.LOWERING_TRANSFORMER); 95 System.out.println(f.toText()); 96 97 int x = (int) Interpreter.invoke(MethodHandles.lookup(), f, 2); 98 Assertions.assertEquals(f(2), x); 99 100 try { 101 Interpreter.invoke(MethodHandles.lookup(), f, -10); 102 Assertions.fail(); 103 } catch (Throwable e) { 104 Assertions.assertEquals(e.getClass(), IndexOutOfBoundsException.class); 105 } 106 } 107 108 @Test 109 public void testFuncEntryExit() { 110 CoreOp.FuncOp f = getFuncOp("f"); 111 System.out.println(f.toText()); 112 113 AtomicBoolean first = new AtomicBoolean(true); 114 CoreOp.FuncOp fc = f.transform((block, op) -> { 115 if (first.get()) { 116 printConstantString(block, "ENTRY"); 117 first.set(false); 118 } 119 120 switch (op) { 121 case CoreOp.ReturnOp returnOp when getNearestInvokeableAncestorOp(returnOp) instanceof CoreOp.FuncOp: { 122 printConstantString(block, "EXIT"); 123 break; 124 } 125 case JavaOp.ThrowOp throwOp: { 126 printConstantString(block, "EXIT"); 127 break; 128 } 129 default: 130 } 131 132 block.op(op); 133 134 return block; 135 }); 136 System.out.println(fc.toText()); 137 138 fc = fc.transform(OpTransformer.LOWERING_TRANSFORMER); 139 System.out.println(fc.toText()); 140 141 int x = (int) Interpreter.invoke(MethodHandles.lookup(), fc, 2); 142 Assertions.assertEquals(f(2), x); 143 144 try { 145 Interpreter.invoke(MethodHandles.lookup(), fc, -10); 146 Assertions.fail(); 147 } catch (Throwable e) { 148 Assertions.assertEquals(e.getClass(), IndexOutOfBoundsException.class); 149 } 150 } 151 152 static void printConstantString(Block.Builder opBuilder, String s) { 153 Op.Result c = opBuilder.op(constant(J_L_STRING, s)); 154 Value System_out = opBuilder.op(fieldLoad(FieldRef.field(System.class, "out", PrintStream.class))); 155 opBuilder.op(JavaOp.invoke(method(PrintStream.class, "println", void.class, String.class), System_out, c)); 156 } 157 158 static Op getNearestInvokeableAncestorOp(Op op) { 159 do { 160 op = op.ancestorOp(); 161 } while (!(op instanceof Op.Invokable)); 162 return op; 163 } 164 165 166 @Test 167 public void testReplaceCall() { 168 CoreOp.FuncOp f = getFuncOp("f"); 169 System.out.println(f.toText()); 170 171 CoreOp.FuncOp fc = f.transform((block, op) -> { 172 switch (op) { 173 case JavaOp.InvokeOp invokeOp when invokeOp.invokeDescriptor().equals(ADD_METHOD): { 174 // Get the adapted operands, and pass those to the new call method 175 List<Value> adaptedOperands = block.context().getValues(op.operands()); 176 Op.Result adaptedResult = block.op(JavaOp.invoke(ADD_WITH_PRINT_METHOD, adaptedOperands)); 177 // Map the old call result to the new call result, so existing operations can be 178 // adapted to use the new result 179 block.context().mapValue(invokeOp.result(), adaptedResult); 180 break; 181 } 182 default: { 183 block.op(op); 184 } 185 } 186 return block; 187 }); 188 System.out.println(fc.toText()); 189 190 fc = fc.transform(OpTransformer.LOWERING_TRANSFORMER); 191 System.out.println(fc.toText()); 192 193 int x = (int) Interpreter.invoke(MethodHandles.lookup(), fc, 2); 194 Assertions.assertEquals(f(2), x); 195 } 196 197 198 @Test 199 public void testCallEntryExit() { 200 CoreOp.FuncOp f = getFuncOp("f"); 201 System.out.println(f.toText()); 202 203 CoreOp.FuncOp fc = f.transform((block, op) -> { 204 switch (op) { 205 case JavaOp.InvokeOp invokeOp: { 206 printCall(block.context(), invokeOp, block); 207 break; 208 } 209 default: { 210 block.op(op); 211 } 212 } 213 return block; 214 }); 215 System.out.println(fc.toText()); 216 217 fc = fc.transform(OpTransformer.LOWERING_TRANSFORMER); 218 System.out.println(fc.toText()); 219 220 int x = (int) Interpreter.invoke(MethodHandles.lookup(), fc, 2); 221 Assertions.assertEquals(f(2), x); 222 } 223 224 static void printCall(CopyContext cc, JavaOp.InvokeOp invokeOp, Block.Builder opBuilder) { 225 List<Value> adaptedInvokeOperands = cc.getValues(invokeOp.operands()); 226 227 String prefix = "ENTER"; 228 229 Value arrayLength = opBuilder.op( 230 constant(INT, adaptedInvokeOperands.size())); 231 Value formatArray = opBuilder.op( 232 newArray(type(Object[].class), arrayLength)); 233 234 Value indexZero = null; 235 for (int i = 0; i < adaptedInvokeOperands.size(); i++) { 236 Value operand = adaptedInvokeOperands.get(i); 237 238 Value index = opBuilder.op( 239 constant(INT, i)); 240 if (i == 0) { 241 indexZero = index; 242 } 243 244 if (operand.type().equals(INT)) { 245 operand = opBuilder.op( 246 JavaOp.invoke(method(Integer.class, "valueOf", Integer.class, int.class), operand)); 247 // @@@ Other primitive types 248 } 249 opBuilder.op( 250 arrayStoreOp(formatArray, index, operand)); 251 } 252 253 Op.Result formatString = opBuilder.op( 254 constant(J_L_STRING, 255 prefix + ": " + invokeOp.invokeDescriptor() + "(" + formatString(adaptedInvokeOperands) + ")%n")); 256 Value System_out = opBuilder.op(fieldLoad(FieldRef.field(System.class, "out", PrintStream.class))); 257 opBuilder.op( 258 JavaOp.invoke(method(PrintStream.class, "printf", PrintStream.class, String.class, Object[].class), 259 System_out, formatString, formatArray)); 260 261 // Method call 262 263 Op.Result adaptedInvokeResult = opBuilder.op(invokeOp); 264 265 // After method call 266 267 prefix = "EXIT"; 268 269 if (adaptedInvokeResult.type().equals(INT)) { 270 adaptedInvokeResult = opBuilder.op( 271 JavaOp.invoke(method(Integer.class, "valueOf", Integer.class, int.class), adaptedInvokeResult)); 272 // @@@ Other primitive types 273 } 274 opBuilder.op( 275 arrayStoreOp(formatArray, indexZero, adaptedInvokeResult)); 276 277 formatString = opBuilder.op( 278 constant(J_L_STRING, 279 prefix + ": " + invokeOp.invokeDescriptor() + " -> " + formatString(adaptedInvokeResult.type()) + "%n")); 280 opBuilder.op( 281 JavaOp.invoke(method(PrintStream.class, "printf", PrintStream.class, String.class, Object[].class), 282 System_out, formatString, formatArray)); 283 } 284 285 static String formatString(List<Value> vs) { 286 return vs.stream().map(v -> formatString(v.type())).collect(Collectors.joining(",")); 287 } 288 289 static String formatString(TypeElement t) { 290 if (t.equals(INT)) { 291 return "%d"; 292 } else { 293 return "%s"; 294 } 295 } 296 297 298 static final MethodRef ADD_METHOD = MethodRef.method( 299 TestLocalTransformationsAdaption.class, "add", 300 int.class, int.class, int.class); 301 302 static int add(int a, int b) { 303 return a + b; 304 } 305 306 static final MethodRef ADD_WITH_PRINT_METHOD = MethodRef.method( 307 TestLocalTransformationsAdaption.class, "addWithPrint", 308 int.class, int.class, int.class); 309 310 static int addWithPrint(int a, int b) { 311 System.out.printf("Adding %d + %d%n", a, b); 312 return a + b; 313 } 314 315 static CoreOp.FuncOp getFuncOp(String name) { 316 Optional<Method> om = Stream.of(TestLocalTransformationsAdaption.class.getDeclaredMethods()) 317 .filter(m -> m.getName().equals(name)) 318 .findFirst(); 319 320 Method m = om.get(); 321 return Op.ofMethod(m).get(); 322 } 323 }