1 /*
  2  * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.
  8  *
  9  * This code is distributed in the hope that it will be useful, but WITHOUT
 10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 12  * version 2 for more details (a copy is included in the LICENSE file that
 13  * accompanied this code).
 14  *
 15  * You should have received a copy of the GNU General Public License version
 16  * 2 along with this work; if not, write to the Free Software Foundation,
 17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 18  *
 19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 20  * or visit www.oracle.com if you need additional information or have any
 21  * questions.
 22  */
 23 
 24 import jdk.incubator.code.*;
 25 import jdk.incubator.code.Reflect;
 26 import jdk.incubator.code.dialect.core.CoreOp;
 27 import jdk.incubator.code.dialect.java.FieldRef;
 28 import jdk.incubator.code.dialect.java.JavaOp;
 29 import jdk.incubator.code.dialect.java.MethodRef;
 30 import jdk.incubator.code.interpreter.Interpreter;
 31 import org.junit.jupiter.api.Assertions;
 32 import org.junit.jupiter.api.Test;
 33 
 34 import java.io.PrintStream;
 35 import java.lang.invoke.MethodHandles;
 36 import java.lang.reflect.Method;
 37 import java.util.List;
 38 import java.util.Optional;
 39 import java.util.concurrent.atomic.AtomicBoolean;
 40 import java.util.function.IntBinaryOperator;
 41 import java.util.function.IntUnaryOperator;
 42 import java.util.stream.Collectors;
 43 import java.util.stream.Stream;
 44 
 45 import static jdk.incubator.code.dialect.core.CoreOp.constant;
 46 import static jdk.incubator.code.dialect.java.JavaOp.*;
 47 import static jdk.incubator.code.dialect.java.JavaType.*;
 48 import static jdk.incubator.code.dialect.java.MethodRef.method;
 49 
 50 /*
 51  * @test
 52  * @modules jdk.incubator.code
 53  * @enablePreview
 54  * @run junit TestLocalTransformationsAdaption
 55  */
 56 
 57 public class TestLocalTransformationsAdaption {
 58 
 59     @Reflect
 60     static int f(int i) {
 61         IntBinaryOperator add = (a, b) -> {
 62             return add(a, b);
 63         };
 64 
 65         try {
 66             IntUnaryOperator add42 = (a) -> {
 67                 return add.applyAsInt(a, 42);
 68             };
 69 
 70             int j = add42.applyAsInt(i);
 71 
 72             IntBinaryOperator f = (a, b) -> {
 73                 if (i < 0) {
 74                     throw new RuntimeException();
 75                 }
 76 
 77                 IntUnaryOperator g = (c) -> {
 78                     return add(a, c);
 79                 };
 80 
 81                 return g.applyAsInt(b);
 82             };
 83 
 84             return f.applyAsInt(j, j);
 85         } catch (RuntimeException e) {
 86             throw new IndexOutOfBoundsException(i);
 87         }
 88     }
 89 
 90     @Test
 91     public void testInvocation() {
 92         CoreOp.FuncOp f = getFuncOp("f");
 93         System.out.println(f.toText());
 94 
 95         f = f.transform(CodeTransformer.LOWERING_TRANSFORMER);
 96         System.out.println(f.toText());
 97 
 98         int x = (int) Interpreter.invoke(MethodHandles.lookup(), f, 2);
 99         Assertions.assertEquals(f(2), x);
100 
101         try {
102             Interpreter.invoke(MethodHandles.lookup(), f, -10);
103             Assertions.fail();
104         } catch (Throwable e) {
105             Assertions.assertEquals(e.getClass(), IndexOutOfBoundsException.class);
106         }
107     }
108 
109     @Test
110     public void testFuncEntryExit() {
111         CoreOp.FuncOp f = getFuncOp("f");
112         System.out.println(f.toText());
113 
114         AtomicBoolean first = new AtomicBoolean(true);
115         CoreOp.FuncOp fc = f.transform((block, op) -> {
116             if (first.get()) {
117                 printConstantString(block, "ENTRY");
118                 first.set(false);
119             }
120 
121             switch (op) {
122                 case CoreOp.ReturnOp returnOp when getNearestInvokeableAncestorOp(returnOp) instanceof CoreOp.FuncOp: {
123                     printConstantString(block, "EXIT");
124                     break;
125                 }
126                 case JavaOp.ThrowOp throwOp: {
127                     printConstantString(block, "EXIT");
128                     break;
129                 }
130                 default:
131             }
132 
133             block.op(op);
134 
135             return block;
136         });
137         System.out.println(fc.toText());
138 
139         fc = fc.transform(CodeTransformer.LOWERING_TRANSFORMER);
140         System.out.println(fc.toText());
141 
142         int x = (int) Interpreter.invoke(MethodHandles.lookup(), fc, 2);
143         Assertions.assertEquals(f(2), x);
144 
145         try {
146             Interpreter.invoke(MethodHandles.lookup(), fc, -10);
147             Assertions.fail();
148         } catch (Throwable e) {
149             Assertions.assertEquals(e.getClass(), IndexOutOfBoundsException.class);
150         }
151     }
152 
153     static void printConstantString(Block.Builder opBuilder, String s) {
154         Op.Result c = opBuilder.op(constant(J_L_STRING, s));
155         Value System_out = opBuilder.op(fieldLoad(FieldRef.field(System.class, "out", PrintStream.class)));
156         opBuilder.op(JavaOp.invoke(method(PrintStream.class, "println", void.class, String.class), System_out, c));
157     }
158 
159     static Op getNearestInvokeableAncestorOp(Op op) {
160         do {
161             op = op.ancestorOp();
162         } while (!(op instanceof Op.Invokable));
163         return op;
164     }
165 
166 
167     @Test
168     public void testReplaceCall() {
169         CoreOp.FuncOp f = getFuncOp("f");
170         System.out.println(f.toText());
171 
172         CoreOp.FuncOp fc = f.transform((block, op) -> {
173             switch (op) {
174                 case JavaOp.InvokeOp invokeOp when invokeOp.invokeDescriptor().equals(ADD_METHOD): {
175                     // Get the adapted operands, and pass those to the new call method
176                     List<Value> adaptedOperands = block.context().getValues(op.operands());
177                     Op.Result adaptedResult = block.op(JavaOp.invoke(ADD_WITH_PRINT_METHOD, adaptedOperands));
178                     // Map the old call result to the new call result, so existing operations can be
179                     // adapted to use the new result
180                     block.context().mapValue(invokeOp.result(), adaptedResult);
181                     break;
182                 }
183                 default: {
184                     block.op(op);
185                 }
186             }
187             return block;
188         });
189         System.out.println(fc.toText());
190 
191         fc = fc.transform(CodeTransformer.LOWERING_TRANSFORMER);
192         System.out.println(fc.toText());
193 
194         int x = (int) Interpreter.invoke(MethodHandles.lookup(), fc, 2);
195         Assertions.assertEquals(f(2), x);
196     }
197 
198 
199     @Test
200     public void testCallEntryExit() {
201         CoreOp.FuncOp f = getFuncOp("f");
202         System.out.println(f.toText());
203 
204         CoreOp.FuncOp fc = f.transform((block, op) -> {
205             switch (op) {
206                 case JavaOp.InvokeOp invokeOp: {
207                     printCall(block.context(), invokeOp, block);
208                     break;
209                 }
210                 default: {
211                     block.op(op);
212                 }
213             }
214             return block;
215         });
216         System.out.println(fc.toText());
217 
218         fc = fc.transform(CodeTransformer.LOWERING_TRANSFORMER);
219         System.out.println(fc.toText());
220 
221         int x = (int) Interpreter.invoke(MethodHandles.lookup(), fc, 2);
222         Assertions.assertEquals(f(2), x);
223     }
224 
225     static void printCall(CodeContext cc, JavaOp.InvokeOp invokeOp, Block.Builder opBuilder) {
226         List<Value> adaptedInvokeOperands = cc.getValues(invokeOp.operands());
227 
228         String prefix = "ENTER";
229 
230         Value arrayLength = opBuilder.op(
231                 constant(INT, adaptedInvokeOperands.size()));
232         Value formatArray = opBuilder.op(
233                 newArray(type(Object[].class), arrayLength));
234 
235         Value indexZero = null;
236         for (int i = 0; i < adaptedInvokeOperands.size(); i++) {
237             Value operand = adaptedInvokeOperands.get(i);
238 
239             Value index = opBuilder.op(
240                     constant(INT, i));
241             if (i == 0) {
242                 indexZero = index;
243             }
244 
245             if (operand.type().equals(INT)) {
246                 operand = opBuilder.op(
247                         JavaOp.invoke(method(Integer.class, "valueOf", Integer.class, int.class), operand));
248                 // @@@ Other primitive types
249             }
250             opBuilder.op(
251                     arrayStoreOp(formatArray, index, operand));
252         }
253 
254         Op.Result formatString = opBuilder.op(
255                 constant(J_L_STRING,
256                         prefix + ": " + invokeOp.invokeDescriptor() + "(" + formatString(adaptedInvokeOperands) + ")%n"));
257         Value System_out = opBuilder.op(fieldLoad(FieldRef.field(System.class, "out", PrintStream.class)));
258         opBuilder.op(
259                 JavaOp.invoke(method(PrintStream.class, "printf", PrintStream.class, String.class, Object[].class),
260                         System_out, formatString, formatArray));
261 
262         // Method call
263 
264         Op.Result adaptedInvokeResult = opBuilder.op(invokeOp);
265 
266         // After method call
267 
268         prefix = "EXIT";
269 
270         if (adaptedInvokeResult.type().equals(INT)) {
271             adaptedInvokeResult = opBuilder.op(
272                     JavaOp.invoke(method(Integer.class, "valueOf", Integer.class, int.class), adaptedInvokeResult));
273             // @@@ Other primitive types
274         }
275         opBuilder.op(
276                 arrayStoreOp(formatArray, indexZero, adaptedInvokeResult));
277 
278         formatString = opBuilder.op(
279                 constant(J_L_STRING,
280                         prefix + ": " + invokeOp.invokeDescriptor() + " -> " + formatString(adaptedInvokeResult.type()) + "%n"));
281         opBuilder.op(
282                 JavaOp.invoke(method(PrintStream.class, "printf", PrintStream.class, String.class, Object[].class),
283                         System_out, formatString, formatArray));
284     }
285 
286     static String formatString(List<Value> vs) {
287         return vs.stream().map(v -> formatString(v.type())).collect(Collectors.joining(","));
288     }
289 
290     static String formatString(TypeElement t) {
291         if (t.equals(INT)) {
292             return "%d";
293         } else {
294             return "%s";
295         }
296     }
297 
298 
299     static final MethodRef ADD_METHOD = MethodRef.method(
300             TestLocalTransformationsAdaption.class, "add",
301             int.class, int.class, int.class);
302 
303     static int add(int a, int b) {
304         return a + b;
305     }
306 
307     static final MethodRef ADD_WITH_PRINT_METHOD = MethodRef.method(
308             TestLocalTransformationsAdaption.class, "addWithPrint",
309             int.class, int.class, int.class);
310 
311     static int addWithPrint(int a, int b) {
312         System.out.printf("Adding %d + %d%n", a, b);
313         return a + b;
314     }
315 
316     static CoreOp.FuncOp getFuncOp(String name) {
317         Optional<Method> om = Stream.of(TestLocalTransformationsAdaption.class.getDeclaredMethods())
318                 .filter(m -> m.getName().equals(name))
319                 .findFirst();
320 
321         Method m = om.get();
322         return Op.ofMethod(m).get();
323     }
324 }