1 /*
2 * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation.
8 *
9 * This code is distributed in the hope that it will be useful, but WITHOUT
10 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
11 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
12 * version 2 for more details (a copy is included in the LICENSE file that
13 * accompanied this code).
14 *
15 * You should have received a copy of the GNU General Public License version
16 * 2 along with this work; if not, write to the Free Software Foundation,
17 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
18 *
19 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
20 * or visit www.oracle.com if you need additional information or have any
21 * questions.
22 */
23
24 import jdk.incubator.code.*;
25 import jdk.incubator.code.Reflect;
26 import jdk.incubator.code.dialect.core.CoreOp;
27 import jdk.incubator.code.dialect.java.FieldRef;
28 import jdk.incubator.code.dialect.java.JavaOp;
29 import jdk.incubator.code.dialect.java.MethodRef;
30 import jdk.incubator.code.interpreter.Interpreter;
31 import org.junit.jupiter.api.Assertions;
32 import org.junit.jupiter.api.Test;
33
34 import java.io.PrintStream;
35 import java.lang.invoke.MethodHandles;
36 import java.lang.reflect.Method;
37 import java.util.List;
38 import java.util.Optional;
39 import java.util.concurrent.atomic.AtomicBoolean;
40 import java.util.function.IntBinaryOperator;
41 import java.util.function.IntUnaryOperator;
42 import java.util.stream.Collectors;
43 import java.util.stream.Stream;
44
45 import static jdk.incubator.code.dialect.core.CoreOp.constant;
46 import static jdk.incubator.code.dialect.java.JavaOp.*;
47 import static jdk.incubator.code.dialect.java.JavaType.*;
48 import static jdk.incubator.code.dialect.java.MethodRef.method;
49
50 /*
51 * @test
52 * @modules jdk.incubator.code
53 * @enablePreview
54 * @run junit TestLocalTransformationsAdaption
55 */
56
57 public class TestLocalTransformationsAdaption {
58
59 @Reflect
60 static int f(int i) {
61 IntBinaryOperator add = (a, b) -> {
62 return add(a, b);
63 };
64
65 try {
66 IntUnaryOperator add42 = (a) -> {
67 return add.applyAsInt(a, 42);
68 };
69
70 int j = add42.applyAsInt(i);
71
72 IntBinaryOperator f = (a, b) -> {
73 if (i < 0) {
74 throw new RuntimeException();
75 }
76
77 IntUnaryOperator g = (c) -> {
78 return add(a, c);
79 };
80
81 return g.applyAsInt(b);
82 };
83
84 return f.applyAsInt(j, j);
85 } catch (RuntimeException e) {
86 throw new IndexOutOfBoundsException(i);
87 }
88 }
89
90 @Test
91 public void testInvocation() {
92 CoreOp.FuncOp f = getFuncOp("f");
93 System.out.println(f.toText());
94
95 f = f.transform(CodeTransformer.LOWERING_TRANSFORMER);
96 System.out.println(f.toText());
97
98 int x = (int) Interpreter.invoke(MethodHandles.lookup(), f, 2);
99 Assertions.assertEquals(f(2), x);
100
101 try {
102 Interpreter.invoke(MethodHandles.lookup(), f, -10);
103 Assertions.fail();
104 } catch (Throwable e) {
105 Assertions.assertEquals(e.getClass(), IndexOutOfBoundsException.class);
106 }
107 }
108
109 @Test
110 public void testFuncEntryExit() {
111 CoreOp.FuncOp f = getFuncOp("f");
112 System.out.println(f.toText());
113
114 AtomicBoolean first = new AtomicBoolean(true);
115 CoreOp.FuncOp fc = f.transform((block, op) -> {
116 if (first.get()) {
117 printConstantString(block, "ENTRY");
118 first.set(false);
119 }
120
121 switch (op) {
122 case CoreOp.ReturnOp returnOp when getNearestInvokeableAncestorOp(returnOp) instanceof CoreOp.FuncOp: {
123 printConstantString(block, "EXIT");
124 break;
125 }
126 case JavaOp.ThrowOp throwOp: {
127 printConstantString(block, "EXIT");
128 break;
129 }
130 default:
131 }
132
133 block.op(op);
134
135 return block;
136 });
137 System.out.println(fc.toText());
138
139 fc = fc.transform(CodeTransformer.LOWERING_TRANSFORMER);
140 System.out.println(fc.toText());
141
142 int x = (int) Interpreter.invoke(MethodHandles.lookup(), fc, 2);
143 Assertions.assertEquals(f(2), x);
144
145 try {
146 Interpreter.invoke(MethodHandles.lookup(), fc, -10);
147 Assertions.fail();
148 } catch (Throwable e) {
149 Assertions.assertEquals(e.getClass(), IndexOutOfBoundsException.class);
150 }
151 }
152
153 static void printConstantString(Block.Builder opBuilder, String s) {
154 Op.Result c = opBuilder.op(constant(J_L_STRING, s));
155 Value System_out = opBuilder.op(fieldLoad(FieldRef.field(System.class, "out", PrintStream.class)));
156 opBuilder.op(JavaOp.invoke(method(PrintStream.class, "println", void.class, String.class), System_out, c));
157 }
158
159 static Op getNearestInvokeableAncestorOp(Op op) {
160 do {
161 op = op.ancestorOp();
162 } while (!(op instanceof Op.Invokable));
163 return op;
164 }
165
166
167 @Test
168 public void testReplaceCall() {
169 CoreOp.FuncOp f = getFuncOp("f");
170 System.out.println(f.toText());
171
172 CoreOp.FuncOp fc = f.transform((block, op) -> {
173 switch (op) {
174 case JavaOp.InvokeOp invokeOp when invokeOp.invokeDescriptor().equals(ADD_METHOD): {
175 // Get the adapted operands, and pass those to the new call method
176 List<Value> adaptedOperands = block.context().getValues(op.operands());
177 Op.Result adaptedResult = block.op(JavaOp.invoke(ADD_WITH_PRINT_METHOD, adaptedOperands));
178 // Map the old call result to the new call result, so existing operations can be
179 // adapted to use the new result
180 block.context().mapValue(invokeOp.result(), adaptedResult);
181 break;
182 }
183 default: {
184 block.op(op);
185 }
186 }
187 return block;
188 });
189 System.out.println(fc.toText());
190
191 fc = fc.transform(CodeTransformer.LOWERING_TRANSFORMER);
192 System.out.println(fc.toText());
193
194 int x = (int) Interpreter.invoke(MethodHandles.lookup(), fc, 2);
195 Assertions.assertEquals(f(2), x);
196 }
197
198
199 @Test
200 public void testCallEntryExit() {
201 CoreOp.FuncOp f = getFuncOp("f");
202 System.out.println(f.toText());
203
204 CoreOp.FuncOp fc = f.transform((block, op) -> {
205 switch (op) {
206 case JavaOp.InvokeOp invokeOp: {
207 printCall(block.context(), invokeOp, block);
208 break;
209 }
210 default: {
211 block.op(op);
212 }
213 }
214 return block;
215 });
216 System.out.println(fc.toText());
217
218 fc = fc.transform(CodeTransformer.LOWERING_TRANSFORMER);
219 System.out.println(fc.toText());
220
221 int x = (int) Interpreter.invoke(MethodHandles.lookup(), fc, 2);
222 Assertions.assertEquals(f(2), x);
223 }
224
225 static void printCall(CodeContext cc, JavaOp.InvokeOp invokeOp, Block.Builder opBuilder) {
226 List<Value> adaptedInvokeOperands = cc.getValues(invokeOp.operands());
227
228 String prefix = "ENTER";
229
230 Value arrayLength = opBuilder.op(
231 constant(INT, adaptedInvokeOperands.size()));
232 Value formatArray = opBuilder.op(
233 newArray(type(Object[].class), arrayLength));
234
235 Value indexZero = null;
236 for (int i = 0; i < adaptedInvokeOperands.size(); i++) {
237 Value operand = adaptedInvokeOperands.get(i);
238
239 Value index = opBuilder.op(
240 constant(INT, i));
241 if (i == 0) {
242 indexZero = index;
243 }
244
245 if (operand.type().equals(INT)) {
246 operand = opBuilder.op(
247 JavaOp.invoke(method(Integer.class, "valueOf", Integer.class, int.class), operand));
248 // @@@ Other primitive types
249 }
250 opBuilder.op(
251 arrayStoreOp(formatArray, index, operand));
252 }
253
254 Op.Result formatString = opBuilder.op(
255 constant(J_L_STRING,
256 prefix + ": " + invokeOp.invokeDescriptor() + "(" + formatString(adaptedInvokeOperands) + ")%n"));
257 Value System_out = opBuilder.op(fieldLoad(FieldRef.field(System.class, "out", PrintStream.class)));
258 opBuilder.op(
259 JavaOp.invoke(method(PrintStream.class, "printf", PrintStream.class, String.class, Object[].class),
260 System_out, formatString, formatArray));
261
262 // Method call
263
264 Op.Result adaptedInvokeResult = opBuilder.op(invokeOp);
265
266 // After method call
267
268 prefix = "EXIT";
269
270 if (adaptedInvokeResult.type().equals(INT)) {
271 adaptedInvokeResult = opBuilder.op(
272 JavaOp.invoke(method(Integer.class, "valueOf", Integer.class, int.class), adaptedInvokeResult));
273 // @@@ Other primitive types
274 }
275 opBuilder.op(
276 arrayStoreOp(formatArray, indexZero, adaptedInvokeResult));
277
278 formatString = opBuilder.op(
279 constant(J_L_STRING,
280 prefix + ": " + invokeOp.invokeDescriptor() + " -> " + formatString(adaptedInvokeResult.type()) + "%n"));
281 opBuilder.op(
282 JavaOp.invoke(method(PrintStream.class, "printf", PrintStream.class, String.class, Object[].class),
283 System_out, formatString, formatArray));
284 }
285
286 static String formatString(List<Value> vs) {
287 return vs.stream().map(v -> formatString(v.type())).collect(Collectors.joining(","));
288 }
289
290 static String formatString(TypeElement t) {
291 if (t.equals(INT)) {
292 return "%d";
293 } else {
294 return "%s";
295 }
296 }
297
298
299 static final MethodRef ADD_METHOD = MethodRef.method(
300 TestLocalTransformationsAdaption.class, "add",
301 int.class, int.class, int.class);
302
303 static int add(int a, int b) {
304 return a + b;
305 }
306
307 static final MethodRef ADD_WITH_PRINT_METHOD = MethodRef.method(
308 TestLocalTransformationsAdaption.class, "addWithPrint",
309 int.class, int.class, int.class);
310
311 static int addWithPrint(int a, int b) {
312 System.out.printf("Adding %d + %d%n", a, b);
313 return a + b;
314 }
315
316 static CoreOp.FuncOp getFuncOp(String name) {
317 Optional<Method> om = Stream.of(TestLocalTransformationsAdaption.class.getDeclaredMethods())
318 .filter(m -> m.getName().equals(name))
319 .findFirst();
320
321 Method m = om.get();
322 return Op.ofMethod(m).get();
323 }
324 }