1 /*
  2  * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.
  8  *
  9  * This code is distributed in the hope that it will be useful, but WITHOUT
 10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 12  * version 2 for more details (a copy is included in the LICENSE file that
 13  * accompanied this code).
 14  *
 15  * You should have received a copy of the GNU General Public License version
 16  * 2 along with this work; if not, write to the Free Software Foundation,
 17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 18  *
 19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 20  * or visit www.oracle.com if you need additional information or have any
 21  * questions.
 22  */
 23 
 24 import jdk.incubator.code.*;
 25 import jdk.incubator.code.Reflect;
 26 import jdk.incubator.code.dialect.core.CoreOp;
 27 import jdk.incubator.code.dialect.java.FieldRef;
 28 import jdk.incubator.code.dialect.java.JavaOp;
 29 import jdk.incubator.code.dialect.java.MethodRef;
 30 import jdk.incubator.code.interpreter.Interpreter;
 31 import org.junit.jupiter.api.Assertions;
 32 import org.junit.jupiter.api.Test;
 33 
 34 import java.io.PrintStream;
 35 import java.lang.invoke.MethodHandles;
 36 import java.lang.reflect.Method;
 37 import java.util.List;
 38 import java.util.Optional;
 39 import java.util.concurrent.atomic.AtomicBoolean;
 40 import java.util.function.IntBinaryOperator;
 41 import java.util.function.IntUnaryOperator;
 42 import java.util.stream.Collectors;
 43 import java.util.stream.Stream;
 44 
 45 import static jdk.incubator.code.dialect.core.CoreOp.constant;
 46 import static jdk.incubator.code.dialect.java.JavaOp.*;
 47 import static jdk.incubator.code.dialect.java.JavaType.*;
 48 import static jdk.incubator.code.dialect.java.MethodRef.method;
 49 
 50 /*
 51  * @test
 52  * @modules jdk.incubator.code
 53  * @enablePreview
 54  * @run junit TestLocalTransformationsAdaption
 55  * @run main Unreflect TestLocalTransformationsAdaption
 56  * @run junit TestLocalTransformationsAdaption
 57  */
 58 
 59 public class TestLocalTransformationsAdaption {
 60 
 61     @Reflect
 62     static int f(int i) {
 63         IntBinaryOperator add = (a, b) -> {
 64             return add(a, b);
 65         };
 66 
 67         try {
 68             IntUnaryOperator add42 = (a) -> {
 69                 return add.applyAsInt(a, 42);
 70             };
 71 
 72             int j = add42.applyAsInt(i);
 73 
 74             IntBinaryOperator f = (a, b) -> {
 75                 if (i < 0) {
 76                     throw new RuntimeException();
 77                 }
 78 
 79                 IntUnaryOperator g = (c) -> {
 80                     return add(a, c);
 81                 };
 82 
 83                 return g.applyAsInt(b);
 84             };
 85 
 86             return f.applyAsInt(j, j);
 87         } catch (RuntimeException e) {
 88             throw new IndexOutOfBoundsException(i);
 89         }
 90     }
 91 
 92     @Test
 93     public void testInvocation() {
 94         CoreOp.FuncOp f = getFuncOp("f");
 95         System.out.println(f.toText());
 96 
 97         f = f.transform(CodeTransformer.LOWERING_TRANSFORMER);
 98         System.out.println(f.toText());
 99 
100         int x = (int) Interpreter.invoke(MethodHandles.lookup(), f, 2);
101         Assertions.assertEquals(f(2), x);
102 
103         try {
104             Interpreter.invoke(MethodHandles.lookup(), f, -10);
105             Assertions.fail();
106         } catch (Throwable e) {
107             Assertions.assertEquals(e.getClass(), IndexOutOfBoundsException.class);
108         }
109     }
110 
111     @Test
112     public void testFuncEntryExit() {
113         CoreOp.FuncOp f = getFuncOp("f");
114         System.out.println(f.toText());
115 
116         AtomicBoolean first = new AtomicBoolean(true);
117         CoreOp.FuncOp fc = f.transform((block, op) -> {
118             if (first.get()) {
119                 printConstantString(block, "ENTRY");
120                 first.set(false);
121             }
122 
123             switch (op) {
124                 case CoreOp.ReturnOp returnOp when getNearestInvokeableAncestorOp(returnOp) instanceof CoreOp.FuncOp: {
125                     printConstantString(block, "EXIT");
126                     break;
127                 }
128                 case JavaOp.ThrowOp throwOp: {
129                     printConstantString(block, "EXIT");
130                     break;
131                 }
132                 default:
133             }
134 
135             block.op(op);
136 
137             return block;
138         });
139         System.out.println(fc.toText());
140 
141         fc = fc.transform(CodeTransformer.LOWERING_TRANSFORMER);
142         System.out.println(fc.toText());
143 
144         int x = (int) Interpreter.invoke(MethodHandles.lookup(), fc, 2);
145         Assertions.assertEquals(f(2), x);
146 
147         try {
148             Interpreter.invoke(MethodHandles.lookup(), fc, -10);
149             Assertions.fail();
150         } catch (Throwable e) {
151             Assertions.assertEquals(e.getClass(), IndexOutOfBoundsException.class);
152         }
153     }
154 
155     static void printConstantString(Block.Builder opBuilder, String s) {
156         Op.Result c = opBuilder.op(constant(J_L_STRING, s));
157         Value System_out = opBuilder.op(fieldLoad(FieldRef.field(System.class, "out", PrintStream.class)));
158         opBuilder.op(JavaOp.invoke(method(PrintStream.class, "println", void.class, String.class), System_out, c));
159     }
160 
161     static Op getNearestInvokeableAncestorOp(Op op) {
162         do {
163             op = op.ancestorOp();
164         } while (!(op instanceof Op.Invokable));
165         return op;
166     }
167 
168 
169     @Test
170     public void testReplaceCall() {
171         CoreOp.FuncOp f = getFuncOp("f");
172         System.out.println(f.toText());
173 
174         CoreOp.FuncOp fc = f.transform((block, op) -> {
175             switch (op) {
176                 case JavaOp.InvokeOp invokeOp when invokeOp.invokeDescriptor().equals(ADD_METHOD): {
177                     // Get the adapted operands, and pass those to the new call method
178                     List<Value> adaptedOperands = block.context().getValues(op.operands());
179                     Op.Result adaptedResult = block.op(JavaOp.invoke(ADD_WITH_PRINT_METHOD, adaptedOperands));
180                     // Map the old call result to the new call result, so existing operations can be
181                     // adapted to use the new result
182                     block.context().mapValue(invokeOp.result(), adaptedResult);
183                     break;
184                 }
185                 default: {
186                     block.op(op);
187                 }
188             }
189             return block;
190         });
191         System.out.println(fc.toText());
192 
193         fc = fc.transform(CodeTransformer.LOWERING_TRANSFORMER);
194         System.out.println(fc.toText());
195 
196         int x = (int) Interpreter.invoke(MethodHandles.lookup(), fc, 2);
197         Assertions.assertEquals(f(2), x);
198     }
199 
200 
201     @Test
202     public void testCallEntryExit() {
203         CoreOp.FuncOp f = getFuncOp("f");
204         System.out.println(f.toText());
205 
206         CoreOp.FuncOp fc = f.transform((block, op) -> {
207             switch (op) {
208                 case JavaOp.InvokeOp invokeOp: {
209                     printCall(block.context(), invokeOp, block);
210                     break;
211                 }
212                 default: {
213                     block.op(op);
214                 }
215             }
216             return block;
217         });
218         System.out.println(fc.toText());
219 
220         fc = fc.transform(CodeTransformer.LOWERING_TRANSFORMER);
221         System.out.println(fc.toText());
222 
223         int x = (int) Interpreter.invoke(MethodHandles.lookup(), fc, 2);
224         Assertions.assertEquals(f(2), x);
225     }
226 
227     static void printCall(CodeContext cc, JavaOp.InvokeOp invokeOp, Block.Builder opBuilder) {
228         List<Value> adaptedInvokeOperands = cc.getValues(invokeOp.operands());
229 
230         String prefix = "ENTER";
231 
232         Value arrayLength = opBuilder.op(
233                 constant(INT, adaptedInvokeOperands.size()));
234         Value formatArray = opBuilder.op(
235                 newArray(type(Object[].class), arrayLength));
236 
237         Value indexZero = null;
238         for (int i = 0; i < adaptedInvokeOperands.size(); i++) {
239             Value operand = adaptedInvokeOperands.get(i);
240 
241             Value index = opBuilder.op(
242                     constant(INT, i));
243             if (i == 0) {
244                 indexZero = index;
245             }
246 
247             if (operand.type().equals(INT)) {
248                 operand = opBuilder.op(
249                         JavaOp.invoke(method(Integer.class, "valueOf", Integer.class, int.class), operand));
250                 // @@@ Other primitive types
251             }
252             opBuilder.op(
253                     arrayStoreOp(formatArray, index, operand));
254         }
255 
256         Op.Result formatString = opBuilder.op(
257                 constant(J_L_STRING,
258                         prefix + ": " + invokeOp.invokeDescriptor() + "(" + formatString(adaptedInvokeOperands) + ")%n"));
259         Value System_out = opBuilder.op(fieldLoad(FieldRef.field(System.class, "out", PrintStream.class)));
260         opBuilder.op(
261                 JavaOp.invoke(method(PrintStream.class, "printf", PrintStream.class, String.class, Object[].class),
262                         System_out, formatString, formatArray));
263 
264         // Method call
265 
266         Op.Result adaptedInvokeResult = opBuilder.op(invokeOp);
267 
268         // After method call
269 
270         prefix = "EXIT";
271 
272         if (adaptedInvokeResult.type().equals(INT)) {
273             adaptedInvokeResult = opBuilder.op(
274                     JavaOp.invoke(method(Integer.class, "valueOf", Integer.class, int.class), adaptedInvokeResult));
275             // @@@ Other primitive types
276         }
277         opBuilder.op(
278                 arrayStoreOp(formatArray, indexZero, adaptedInvokeResult));
279 
280         formatString = opBuilder.op(
281                 constant(J_L_STRING,
282                         prefix + ": " + invokeOp.invokeDescriptor() + " -> " + formatString(adaptedInvokeResult.type()) + "%n"));
283         opBuilder.op(
284                 JavaOp.invoke(method(PrintStream.class, "printf", PrintStream.class, String.class, Object[].class),
285                         System_out, formatString, formatArray));
286     }
287 
288     static String formatString(List<Value> vs) {
289         return vs.stream().map(v -> formatString(v.type())).collect(Collectors.joining(","));
290     }
291 
292     static String formatString(TypeElement t) {
293         if (t.equals(INT)) {
294             return "%d";
295         } else {
296             return "%s";
297         }
298     }
299 
300 
301     static final MethodRef ADD_METHOD = MethodRef.method(
302             TestLocalTransformationsAdaption.class, "add",
303             int.class, int.class, int.class);
304 
305     static int add(int a, int b) {
306         return a + b;
307     }
308 
309     static final MethodRef ADD_WITH_PRINT_METHOD = MethodRef.method(
310             TestLocalTransformationsAdaption.class, "addWithPrint",
311             int.class, int.class, int.class);
312 
313     static int addWithPrint(int a, int b) {
314         System.out.printf("Adding %d + %d%n", a, b);
315         return a + b;
316     }
317 
318     static CoreOp.FuncOp getFuncOp(String name) {
319         Optional<Method> om = Stream.of(TestLocalTransformationsAdaption.class.getDeclaredMethods())
320                 .filter(m -> m.getName().equals(name))
321                 .findFirst();
322 
323         Method m = om.get();
324         return Op.ofMethod(m).get();
325     }
326 }