1 /*
2 * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation.
8 *
9 * This code is distributed in the hope that it will be useful, but WITHOUT
10 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
11 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
12 * version 2 for more details (a copy is included in the LICENSE file that
13 * accompanied this code).
14 *
15 * You should have received a copy of the GNU General Public License version
16 * 2 along with this work; if not, write to the Free Software Foundation,
17 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
18 *
19 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
20 * or visit www.oracle.com if you need additional information or have any
21 * questions.
22 */
23
24 import jdk.incubator.code.*;
25 import jdk.incubator.code.dialect.core.CoreOp;
26 import jdk.incubator.code.dialect.java.FieldRef;
27 import jdk.incubator.code.dialect.java.JavaOp;
28 import jdk.incubator.code.dialect.java.MethodRef;
29 import jdk.incubator.code.interpreter.Interpreter;
30 import org.junit.jupiter.api.Assertions;
31 import org.junit.jupiter.api.Test;
32
33 import java.io.PrintStream;
34 import java.lang.invoke.MethodHandles;
35 import java.lang.reflect.Method;
36 import java.util.List;
37 import java.util.Optional;
38 import java.util.concurrent.atomic.AtomicBoolean;
39 import java.util.function.IntBinaryOperator;
40 import java.util.function.IntUnaryOperator;
41 import java.util.stream.Collectors;
42 import java.util.stream.Stream;
43
44 import static jdk.incubator.code.dialect.core.CoreOp.constant;
45 import static jdk.incubator.code.dialect.java.JavaOp.*;
46 import static jdk.incubator.code.dialect.java.JavaType.*;
47 import static jdk.incubator.code.dialect.java.MethodRef.method;
48
49 /*
50 * @test
51 * @modules jdk.incubator.code
52 * @enablePreview
53 * @run junit TestLocalTransformationsAdaption
54 */
55
56 public class TestLocalTransformationsAdaption {
57
58 @CodeReflection
59 static int f(int i) {
60 IntBinaryOperator add = (a, b) -> {
61 return add(a, b);
62 };
63
64 try {
65 IntUnaryOperator add42 = (a) -> {
66 return add.applyAsInt(a, 42);
67 };
68
69 int j = add42.applyAsInt(i);
70
71 IntBinaryOperator f = (a, b) -> {
72 if (i < 0) {
73 throw new RuntimeException();
74 }
75
76 IntUnaryOperator g = (c) -> {
77 return add(a, c);
78 };
79
80 return g.applyAsInt(b);
81 };
82
83 return f.applyAsInt(j, j);
84 } catch (RuntimeException e) {
85 throw new IndexOutOfBoundsException(i);
86 }
87 }
88
89 @Test
90 public void testInvocation() {
91 CoreOp.FuncOp f = getFuncOp("f");
92 System.out.println(f.toText());
93
94 f = f.transform(OpTransformer.LOWERING_TRANSFORMER);
95 System.out.println(f.toText());
96
97 int x = (int) Interpreter.invoke(MethodHandles.lookup(), f, 2);
98 Assertions.assertEquals(f(2), x);
99
100 try {
101 Interpreter.invoke(MethodHandles.lookup(), f, -10);
102 Assertions.fail();
103 } catch (Throwable e) {
104 Assertions.assertEquals(e.getClass(), IndexOutOfBoundsException.class);
105 }
106 }
107
108 @Test
109 public void testFuncEntryExit() {
110 CoreOp.FuncOp f = getFuncOp("f");
111 System.out.println(f.toText());
112
113 AtomicBoolean first = new AtomicBoolean(true);
114 CoreOp.FuncOp fc = f.transform((block, op) -> {
115 if (first.get()) {
116 printConstantString(block, "ENTRY");
117 first.set(false);
118 }
119
120 switch (op) {
121 case CoreOp.ReturnOp returnOp when getNearestInvokeableAncestorOp(returnOp) instanceof CoreOp.FuncOp: {
122 printConstantString(block, "EXIT");
123 break;
124 }
125 case JavaOp.ThrowOp throwOp: {
126 printConstantString(block, "EXIT");
127 break;
128 }
129 default:
130 }
131
132 block.op(op);
133
134 return block;
135 });
136 System.out.println(fc.toText());
137
138 fc = fc.transform(OpTransformer.LOWERING_TRANSFORMER);
139 System.out.println(fc.toText());
140
141 int x = (int) Interpreter.invoke(MethodHandles.lookup(), fc, 2);
142 Assertions.assertEquals(f(2), x);
143
144 try {
145 Interpreter.invoke(MethodHandles.lookup(), fc, -10);
146 Assertions.fail();
147 } catch (Throwable e) {
148 Assertions.assertEquals(e.getClass(), IndexOutOfBoundsException.class);
149 }
150 }
151
152 static void printConstantString(Block.Builder opBuilder, String s) {
153 Op.Result c = opBuilder.op(constant(J_L_STRING, s));
154 Value System_out = opBuilder.op(fieldLoad(FieldRef.field(System.class, "out", PrintStream.class)));
155 opBuilder.op(JavaOp.invoke(method(PrintStream.class, "println", void.class, String.class), System_out, c));
156 }
157
158 static Op getNearestInvokeableAncestorOp(Op op) {
159 do {
160 op = op.ancestorOp();
161 } while (!(op instanceof Op.Invokable));
162 return op;
163 }
164
165
166 @Test
167 public void testReplaceCall() {
168 CoreOp.FuncOp f = getFuncOp("f");
169 System.out.println(f.toText());
170
171 CoreOp.FuncOp fc = f.transform((block, op) -> {
172 switch (op) {
173 case JavaOp.InvokeOp invokeOp when invokeOp.invokeDescriptor().equals(ADD_METHOD): {
174 // Get the adapted operands, and pass those to the new call method
175 List<Value> adaptedOperands = block.context().getValues(op.operands());
176 Op.Result adaptedResult = block.op(JavaOp.invoke(ADD_WITH_PRINT_METHOD, adaptedOperands));
177 // Map the old call result to the new call result, so existing operations can be
178 // adapted to use the new result
179 block.context().mapValue(invokeOp.result(), adaptedResult);
180 break;
181 }
182 default: {
183 block.op(op);
184 }
185 }
186 return block;
187 });
188 System.out.println(fc.toText());
189
190 fc = fc.transform(OpTransformer.LOWERING_TRANSFORMER);
191 System.out.println(fc.toText());
192
193 int x = (int) Interpreter.invoke(MethodHandles.lookup(), fc, 2);
194 Assertions.assertEquals(f(2), x);
195 }
196
197
198 @Test
199 public void testCallEntryExit() {
200 CoreOp.FuncOp f = getFuncOp("f");
201 System.out.println(f.toText());
202
203 CoreOp.FuncOp fc = f.transform((block, op) -> {
204 switch (op) {
205 case JavaOp.InvokeOp invokeOp: {
206 printCall(block.context(), invokeOp, block);
207 break;
208 }
209 default: {
210 block.op(op);
211 }
212 }
213 return block;
214 });
215 System.out.println(fc.toText());
216
217 fc = fc.transform(OpTransformer.LOWERING_TRANSFORMER);
218 System.out.println(fc.toText());
219
220 int x = (int) Interpreter.invoke(MethodHandles.lookup(), fc, 2);
221 Assertions.assertEquals(f(2), x);
222 }
223
224 static void printCall(CopyContext cc, JavaOp.InvokeOp invokeOp, Block.Builder opBuilder) {
225 List<Value> adaptedInvokeOperands = cc.getValues(invokeOp.operands());
226
227 String prefix = "ENTER";
228
229 Value arrayLength = opBuilder.op(
230 constant(INT, adaptedInvokeOperands.size()));
231 Value formatArray = opBuilder.op(
232 newArray(type(Object[].class), arrayLength));
233
234 Value indexZero = null;
235 for (int i = 0; i < adaptedInvokeOperands.size(); i++) {
236 Value operand = adaptedInvokeOperands.get(i);
237
238 Value index = opBuilder.op(
239 constant(INT, i));
240 if (i == 0) {
241 indexZero = index;
242 }
243
244 if (operand.type().equals(INT)) {
245 operand = opBuilder.op(
246 JavaOp.invoke(method(Integer.class, "valueOf", Integer.class, int.class), operand));
247 // @@@ Other primitive types
248 }
249 opBuilder.op(
250 arrayStoreOp(formatArray, index, operand));
251 }
252
253 Op.Result formatString = opBuilder.op(
254 constant(J_L_STRING,
255 prefix + ": " + invokeOp.invokeDescriptor() + "(" + formatString(adaptedInvokeOperands) + ")%n"));
256 Value System_out = opBuilder.op(fieldLoad(FieldRef.field(System.class, "out", PrintStream.class)));
257 opBuilder.op(
258 JavaOp.invoke(method(PrintStream.class, "printf", PrintStream.class, String.class, Object[].class),
259 System_out, formatString, formatArray));
260
261 // Method call
262
263 Op.Result adaptedInvokeResult = opBuilder.op(invokeOp);
264
265 // After method call
266
267 prefix = "EXIT";
268
269 if (adaptedInvokeResult.type().equals(INT)) {
270 adaptedInvokeResult = opBuilder.op(
271 JavaOp.invoke(method(Integer.class, "valueOf", Integer.class, int.class), adaptedInvokeResult));
272 // @@@ Other primitive types
273 }
274 opBuilder.op(
275 arrayStoreOp(formatArray, indexZero, adaptedInvokeResult));
276
277 formatString = opBuilder.op(
278 constant(J_L_STRING,
279 prefix + ": " + invokeOp.invokeDescriptor() + " -> " + formatString(adaptedInvokeResult.type()) + "%n"));
280 opBuilder.op(
281 JavaOp.invoke(method(PrintStream.class, "printf", PrintStream.class, String.class, Object[].class),
282 System_out, formatString, formatArray));
283 }
284
285 static String formatString(List<Value> vs) {
286 return vs.stream().map(v -> formatString(v.type())).collect(Collectors.joining(","));
287 }
288
289 static String formatString(TypeElement t) {
290 if (t.equals(INT)) {
291 return "%d";
292 } else {
293 return "%s";
294 }
295 }
296
297
298 static final MethodRef ADD_METHOD = MethodRef.method(
299 TestLocalTransformationsAdaption.class, "add",
300 int.class, int.class, int.class);
301
302 static int add(int a, int b) {
303 return a + b;
304 }
305
306 static final MethodRef ADD_WITH_PRINT_METHOD = MethodRef.method(
307 TestLocalTransformationsAdaption.class, "addWithPrint",
308 int.class, int.class, int.class);
309
310 static int addWithPrint(int a, int b) {
311 System.out.printf("Adding %d + %d%n", a, b);
312 return a + b;
313 }
314
315 static CoreOp.FuncOp getFuncOp(String name) {
316 Optional<Method> om = Stream.of(TestLocalTransformationsAdaption.class.getDeclaredMethods())
317 .filter(m -> m.getName().equals(name))
318 .findFirst();
319
320 Method m = om.get();
321 return Op.ofMethod(m).get();
322 }
323 }