1 /*
2 * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation.
8 *
9 * This code is distributed in the hope that it will be useful, but WITHOUT
10 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
11 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
12 * version 2 for more details (a copy is included in the LICENSE file that
13 * accompanied this code).
14 *
15 * You should have received a copy of the GNU General Public License version
16 * 2 along with this work; if not, write to the Free Software Foundation,
17 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
18 *
19 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
20 * or visit www.oracle.com if you need additional information or have any
21 * questions.
22 */
23
24 import jdk.incubator.code.*;
25 import jdk.incubator.code.Reflect;
26 import jdk.incubator.code.dialect.core.CoreOp;
27 import jdk.incubator.code.dialect.java.FieldRef;
28 import jdk.incubator.code.dialect.java.JavaOp;
29 import jdk.incubator.code.dialect.java.MethodRef;
30 import jdk.incubator.code.interpreter.Interpreter;
31 import org.junit.jupiter.api.Assertions;
32 import org.junit.jupiter.api.Test;
33
34 import java.io.PrintStream;
35 import java.lang.invoke.MethodHandles;
36 import java.lang.reflect.Method;
37 import java.util.List;
38 import java.util.Optional;
39 import java.util.concurrent.atomic.AtomicBoolean;
40 import java.util.function.IntBinaryOperator;
41 import java.util.function.IntUnaryOperator;
42 import java.util.stream.Collectors;
43 import java.util.stream.Stream;
44
45 import static jdk.incubator.code.dialect.core.CoreOp.constant;
46 import static jdk.incubator.code.dialect.java.JavaOp.*;
47 import static jdk.incubator.code.dialect.java.JavaType.*;
48 import static jdk.incubator.code.dialect.java.MethodRef.method;
49
50 /*
51 * @test
52 * @modules jdk.incubator.code
53 * @enablePreview
54 * @run junit TestLocalTransformationsAdaption
55 * @run main Unreflect TestLocalTransformationsAdaption
56 * @run junit TestLocalTransformationsAdaption
57 */
58
59 public class TestLocalTransformationsAdaption {
60
61 @Reflect
62 static int f(int i) {
63 IntBinaryOperator add = (a, b) -> {
64 return add(a, b);
65 };
66
67 try {
68 IntUnaryOperator add42 = (a) -> {
69 return add.applyAsInt(a, 42);
70 };
71
72 int j = add42.applyAsInt(i);
73
74 IntBinaryOperator f = (a, b) -> {
75 if (i < 0) {
76 throw new RuntimeException();
77 }
78
79 IntUnaryOperator g = (c) -> {
80 return add(a, c);
81 };
82
83 return g.applyAsInt(b);
84 };
85
86 return f.applyAsInt(j, j);
87 } catch (RuntimeException e) {
88 throw new IndexOutOfBoundsException(i);
89 }
90 }
91
92 @Test
93 public void testInvocation() {
94 CoreOp.FuncOp f = getFuncOp("f");
95 System.out.println(f.toText());
96
97 f = f.transform(CodeTransformer.LOWERING_TRANSFORMER);
98 System.out.println(f.toText());
99
100 int x = (int) Interpreter.invoke(MethodHandles.lookup(), f, 2);
101 Assertions.assertEquals(f(2), x);
102
103 try {
104 Interpreter.invoke(MethodHandles.lookup(), f, -10);
105 Assertions.fail();
106 } catch (Throwable e) {
107 Assertions.assertEquals(e.getClass(), IndexOutOfBoundsException.class);
108 }
109 }
110
111 @Test
112 public void testFuncEntryExit() {
113 CoreOp.FuncOp f = getFuncOp("f");
114 System.out.println(f.toText());
115
116 AtomicBoolean first = new AtomicBoolean(true);
117 CoreOp.FuncOp fc = f.transform((block, op) -> {
118 if (first.get()) {
119 printConstantString(block, "ENTRY");
120 first.set(false);
121 }
122
123 switch (op) {
124 case CoreOp.ReturnOp returnOp when getNearestInvokeableAncestorOp(returnOp) instanceof CoreOp.FuncOp: {
125 printConstantString(block, "EXIT");
126 break;
127 }
128 case JavaOp.ThrowOp throwOp: {
129 printConstantString(block, "EXIT");
130 break;
131 }
132 default:
133 }
134
135 block.op(op);
136
137 return block;
138 });
139 System.out.println(fc.toText());
140
141 fc = fc.transform(CodeTransformer.LOWERING_TRANSFORMER);
142 System.out.println(fc.toText());
143
144 int x = (int) Interpreter.invoke(MethodHandles.lookup(), fc, 2);
145 Assertions.assertEquals(f(2), x);
146
147 try {
148 Interpreter.invoke(MethodHandles.lookup(), fc, -10);
149 Assertions.fail();
150 } catch (Throwable e) {
151 Assertions.assertEquals(e.getClass(), IndexOutOfBoundsException.class);
152 }
153 }
154
155 static void printConstantString(Block.Builder opBuilder, String s) {
156 Op.Result c = opBuilder.op(constant(J_L_STRING, s));
157 Value System_out = opBuilder.op(fieldLoad(FieldRef.field(System.class, "out", PrintStream.class)));
158 opBuilder.op(JavaOp.invoke(method(PrintStream.class, "println", void.class, String.class), System_out, c));
159 }
160
161 static Op getNearestInvokeableAncestorOp(Op op) {
162 do {
163 op = op.ancestorOp();
164 } while (!(op instanceof Op.Invokable));
165 return op;
166 }
167
168
169 @Test
170 public void testReplaceCall() {
171 CoreOp.FuncOp f = getFuncOp("f");
172 System.out.println(f.toText());
173
174 CoreOp.FuncOp fc = f.transform((block, op) -> {
175 switch (op) {
176 case JavaOp.InvokeOp invokeOp when invokeOp.invokeDescriptor().equals(ADD_METHOD): {
177 // Get the adapted operands, and pass those to the new call method
178 List<Value> adaptedOperands = block.context().getValues(op.operands());
179 Op.Result adaptedResult = block.op(JavaOp.invoke(ADD_WITH_PRINT_METHOD, adaptedOperands));
180 // Map the old call result to the new call result, so existing operations can be
181 // adapted to use the new result
182 block.context().mapValue(invokeOp.result(), adaptedResult);
183 break;
184 }
185 default: {
186 block.op(op);
187 }
188 }
189 return block;
190 });
191 System.out.println(fc.toText());
192
193 fc = fc.transform(CodeTransformer.LOWERING_TRANSFORMER);
194 System.out.println(fc.toText());
195
196 int x = (int) Interpreter.invoke(MethodHandles.lookup(), fc, 2);
197 Assertions.assertEquals(f(2), x);
198 }
199
200
201 @Test
202 public void testCallEntryExit() {
203 CoreOp.FuncOp f = getFuncOp("f");
204 System.out.println(f.toText());
205
206 CoreOp.FuncOp fc = f.transform((block, op) -> {
207 switch (op) {
208 case JavaOp.InvokeOp invokeOp: {
209 printCall(block.context(), invokeOp, block);
210 break;
211 }
212 default: {
213 block.op(op);
214 }
215 }
216 return block;
217 });
218 System.out.println(fc.toText());
219
220 fc = fc.transform(CodeTransformer.LOWERING_TRANSFORMER);
221 System.out.println(fc.toText());
222
223 int x = (int) Interpreter.invoke(MethodHandles.lookup(), fc, 2);
224 Assertions.assertEquals(f(2), x);
225 }
226
227 static void printCall(CodeContext cc, JavaOp.InvokeOp invokeOp, Block.Builder opBuilder) {
228 List<Value> adaptedInvokeOperands = cc.getValues(invokeOp.operands());
229
230 String prefix = "ENTER";
231
232 Value arrayLength = opBuilder.op(
233 constant(INT, adaptedInvokeOperands.size()));
234 Value formatArray = opBuilder.op(
235 newArray(type(Object[].class), arrayLength));
236
237 Value indexZero = null;
238 for (int i = 0; i < adaptedInvokeOperands.size(); i++) {
239 Value operand = adaptedInvokeOperands.get(i);
240
241 Value index = opBuilder.op(
242 constant(INT, i));
243 if (i == 0) {
244 indexZero = index;
245 }
246
247 if (operand.type().equals(INT)) {
248 operand = opBuilder.op(
249 JavaOp.invoke(method(Integer.class, "valueOf", Integer.class, int.class), operand));
250 // @@@ Other primitive types
251 }
252 opBuilder.op(
253 arrayStoreOp(formatArray, index, operand));
254 }
255
256 Op.Result formatString = opBuilder.op(
257 constant(J_L_STRING,
258 prefix + ": " + invokeOp.invokeDescriptor() + "(" + formatString(adaptedInvokeOperands) + ")%n"));
259 Value System_out = opBuilder.op(fieldLoad(FieldRef.field(System.class, "out", PrintStream.class)));
260 opBuilder.op(
261 JavaOp.invoke(method(PrintStream.class, "printf", PrintStream.class, String.class, Object[].class),
262 System_out, formatString, formatArray));
263
264 // Method call
265
266 Op.Result adaptedInvokeResult = opBuilder.op(invokeOp);
267
268 // After method call
269
270 prefix = "EXIT";
271
272 if (adaptedInvokeResult.type().equals(INT)) {
273 adaptedInvokeResult = opBuilder.op(
274 JavaOp.invoke(method(Integer.class, "valueOf", Integer.class, int.class), adaptedInvokeResult));
275 // @@@ Other primitive types
276 }
277 opBuilder.op(
278 arrayStoreOp(formatArray, indexZero, adaptedInvokeResult));
279
280 formatString = opBuilder.op(
281 constant(J_L_STRING,
282 prefix + ": " + invokeOp.invokeDescriptor() + " -> " + formatString(adaptedInvokeResult.type()) + "%n"));
283 opBuilder.op(
284 JavaOp.invoke(method(PrintStream.class, "printf", PrintStream.class, String.class, Object[].class),
285 System_out, formatString, formatArray));
286 }
287
288 static String formatString(List<Value> vs) {
289 return vs.stream().map(v -> formatString(v.type())).collect(Collectors.joining(","));
290 }
291
292 static String formatString(TypeElement t) {
293 if (t.equals(INT)) {
294 return "%d";
295 } else {
296 return "%s";
297 }
298 }
299
300
301 static final MethodRef ADD_METHOD = MethodRef.method(
302 TestLocalTransformationsAdaption.class, "add",
303 int.class, int.class, int.class);
304
305 static int add(int a, int b) {
306 return a + b;
307 }
308
309 static final MethodRef ADD_WITH_PRINT_METHOD = MethodRef.method(
310 TestLocalTransformationsAdaption.class, "addWithPrint",
311 int.class, int.class, int.class);
312
313 static int addWithPrint(int a, int b) {
314 System.out.printf("Adding %d + %d%n", a, b);
315 return a + b;
316 }
317
318 static CoreOp.FuncOp getFuncOp(String name) {
319 Optional<Method> om = Stream.of(TestLocalTransformationsAdaption.class.getDeclaredMethods())
320 .filter(m -> m.getName().equals(name))
321 .findFirst();
322
323 Method m = om.get();
324 return Op.ofMethod(m).get();
325 }
326 }