1 /*
  2  * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.
  8  *
  9  * This code is distributed in the hope that it will be useful, but WITHOUT
 10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 12  * version 2 for more details (a copy is included in the LICENSE file that
 13  * accompanied this code).
 14  *
 15  * You should have received a copy of the GNU General Public License version
 16  * 2 along with this work; if not, write to the Free Software Foundation,
 17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 18  *
 19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 20  * or visit www.oracle.com if you need additional information or have any
 21  * questions.
 22  */
 23 
 24 import jdk.incubator.code.*;
 25 import jdk.incubator.code.dialect.core.CoreOp;
 26 import jdk.incubator.code.dialect.java.FieldRef;
 27 import jdk.incubator.code.dialect.java.JavaOp;
 28 import jdk.incubator.code.dialect.java.MethodRef;
 29 import jdk.incubator.code.interpreter.Interpreter;
 30 import org.junit.jupiter.api.Assertions;
 31 import org.junit.jupiter.api.Test;
 32 
 33 import java.io.PrintStream;
 34 import java.lang.invoke.MethodHandles;
 35 import java.lang.reflect.Method;
 36 import java.util.List;
 37 import java.util.Optional;
 38 import java.util.concurrent.atomic.AtomicBoolean;
 39 import java.util.function.IntBinaryOperator;
 40 import java.util.function.IntUnaryOperator;
 41 import java.util.stream.Collectors;
 42 import java.util.stream.Stream;
 43 
 44 import static jdk.incubator.code.dialect.core.CoreOp.constant;
 45 import static jdk.incubator.code.dialect.java.JavaOp.*;
 46 import static jdk.incubator.code.dialect.java.JavaType.*;
 47 import static jdk.incubator.code.dialect.java.MethodRef.method;
 48 
 49 /*
 50  * @test
 51  * @modules jdk.incubator.code
 52  * @enablePreview
 53  * @run junit TestLocalTransformationsAdaption
 54  */
 55 
 56 public class TestLocalTransformationsAdaption {
 57 
 58     @CodeReflection
 59     static int f(int i) {
 60         IntBinaryOperator add = (a, b) -> {
 61             return add(a, b);
 62         };
 63 
 64         try {
 65             IntUnaryOperator add42 = (a) -> {
 66                 return add.applyAsInt(a, 42);
 67             };
 68 
 69             int j = add42.applyAsInt(i);
 70 
 71             IntBinaryOperator f = (a, b) -> {
 72                 if (i < 0) {
 73                     throw new RuntimeException();
 74                 }
 75 
 76                 IntUnaryOperator g = (c) -> {
 77                     return add(a, c);
 78                 };
 79 
 80                 return g.applyAsInt(b);
 81             };
 82 
 83             return f.applyAsInt(j, j);
 84         } catch (RuntimeException e) {
 85             throw new IndexOutOfBoundsException(i);
 86         }
 87     }
 88 
 89     @Test
 90     public void testInvocation() {
 91         CoreOp.FuncOp f = getFuncOp("f");
 92         System.out.println(f.toText());
 93 
 94         f = f.transform(OpTransformer.LOWERING_TRANSFORMER);
 95         System.out.println(f.toText());
 96 
 97         int x = (int) Interpreter.invoke(MethodHandles.lookup(), f, 2);
 98         Assertions.assertEquals(f(2), x);
 99 
100         try {
101             Interpreter.invoke(MethodHandles.lookup(), f, -10);
102             Assertions.fail();
103         } catch (Throwable e) {
104             Assertions.assertEquals(e.getClass(), IndexOutOfBoundsException.class);
105         }
106     }
107 
108     @Test
109     public void testFuncEntryExit() {
110         CoreOp.FuncOp f = getFuncOp("f");
111         System.out.println(f.toText());
112 
113         AtomicBoolean first = new AtomicBoolean(true);
114         CoreOp.FuncOp fc = f.transform((block, op) -> {
115             if (first.get()) {
116                 printConstantString(block, "ENTRY");
117                 first.set(false);
118             }
119 
120             switch (op) {
121                 case CoreOp.ReturnOp returnOp when getNearestInvokeableAncestorOp(returnOp) instanceof CoreOp.FuncOp: {
122                     printConstantString(block, "EXIT");
123                     break;
124                 }
125                 case JavaOp.ThrowOp throwOp: {
126                     printConstantString(block, "EXIT");
127                     break;
128                 }
129                 default:
130             }
131 
132             block.op(op);
133 
134             return block;
135         });
136         System.out.println(fc.toText());
137 
138         fc = fc.transform(OpTransformer.LOWERING_TRANSFORMER);
139         System.out.println(fc.toText());
140 
141         int x = (int) Interpreter.invoke(MethodHandles.lookup(), fc, 2);
142         Assertions.assertEquals(f(2), x);
143 
144         try {
145             Interpreter.invoke(MethodHandles.lookup(), fc, -10);
146             Assertions.fail();
147         } catch (Throwable e) {
148             Assertions.assertEquals(e.getClass(), IndexOutOfBoundsException.class);
149         }
150     }
151 
152     static void printConstantString(Block.Builder opBuilder, String s) {
153         Op.Result c = opBuilder.op(constant(J_L_STRING, s));
154         Value System_out = opBuilder.op(fieldLoad(FieldRef.field(System.class, "out", PrintStream.class)));
155         opBuilder.op(JavaOp.invoke(method(PrintStream.class, "println", void.class, String.class), System_out, c));
156     }
157 
158     static Op getNearestInvokeableAncestorOp(Op op) {
159         do {
160             op = op.ancestorOp();
161         } while (!(op instanceof Op.Invokable));
162         return op;
163     }
164 
165 
166     @Test
167     public void testReplaceCall() {
168         CoreOp.FuncOp f = getFuncOp("f");
169         System.out.println(f.toText());
170 
171         CoreOp.FuncOp fc = f.transform((block, op) -> {
172             switch (op) {
173                 case JavaOp.InvokeOp invokeOp when invokeOp.invokeDescriptor().equals(ADD_METHOD): {
174                     // Get the adapted operands, and pass those to the new call method
175                     List<Value> adaptedOperands = block.context().getValues(op.operands());
176                     Op.Result adaptedResult = block.op(JavaOp.invoke(ADD_WITH_PRINT_METHOD, adaptedOperands));
177                     // Map the old call result to the new call result, so existing operations can be
178                     // adapted to use the new result
179                     block.context().mapValue(invokeOp.result(), adaptedResult);
180                     break;
181                 }
182                 default: {
183                     block.op(op);
184                 }
185             }
186             return block;
187         });
188         System.out.println(fc.toText());
189 
190         fc = fc.transform(OpTransformer.LOWERING_TRANSFORMER);
191         System.out.println(fc.toText());
192 
193         int x = (int) Interpreter.invoke(MethodHandles.lookup(), fc, 2);
194         Assertions.assertEquals(f(2), x);
195     }
196 
197 
198     @Test
199     public void testCallEntryExit() {
200         CoreOp.FuncOp f = getFuncOp("f");
201         System.out.println(f.toText());
202 
203         CoreOp.FuncOp fc = f.transform((block, op) -> {
204             switch (op) {
205                 case JavaOp.InvokeOp invokeOp: {
206                     printCall(block.context(), invokeOp, block);
207                     break;
208                 }
209                 default: {
210                     block.op(op);
211                 }
212             }
213             return block;
214         });
215         System.out.println(fc.toText());
216 
217         fc = fc.transform(OpTransformer.LOWERING_TRANSFORMER);
218         System.out.println(fc.toText());
219 
220         int x = (int) Interpreter.invoke(MethodHandles.lookup(), fc, 2);
221         Assertions.assertEquals(f(2), x);
222     }
223 
224     static void printCall(CopyContext cc, JavaOp.InvokeOp invokeOp, Block.Builder opBuilder) {
225         List<Value> adaptedInvokeOperands = cc.getValues(invokeOp.operands());
226 
227         String prefix = "ENTER";
228 
229         Value arrayLength = opBuilder.op(
230                 constant(INT, adaptedInvokeOperands.size()));
231         Value formatArray = opBuilder.op(
232                 newArray(type(Object[].class), arrayLength));
233 
234         Value indexZero = null;
235         for (int i = 0; i < adaptedInvokeOperands.size(); i++) {
236             Value operand = adaptedInvokeOperands.get(i);
237 
238             Value index = opBuilder.op(
239                     constant(INT, i));
240             if (i == 0) {
241                 indexZero = index;
242             }
243 
244             if (operand.type().equals(INT)) {
245                 operand = opBuilder.op(
246                         JavaOp.invoke(method(Integer.class, "valueOf", Integer.class, int.class), operand));
247                 // @@@ Other primitive types
248             }
249             opBuilder.op(
250                     arrayStoreOp(formatArray, index, operand));
251         }
252 
253         Op.Result formatString = opBuilder.op(
254                 constant(J_L_STRING,
255                         prefix + ": " + invokeOp.invokeDescriptor() + "(" + formatString(adaptedInvokeOperands) + ")%n"));
256         Value System_out = opBuilder.op(fieldLoad(FieldRef.field(System.class, "out", PrintStream.class)));
257         opBuilder.op(
258                 JavaOp.invoke(method(PrintStream.class, "printf", PrintStream.class, String.class, Object[].class),
259                         System_out, formatString, formatArray));
260 
261         // Method call
262 
263         Op.Result adaptedInvokeResult = opBuilder.op(invokeOp);
264 
265         // After method call
266 
267         prefix = "EXIT";
268 
269         if (adaptedInvokeResult.type().equals(INT)) {
270             adaptedInvokeResult = opBuilder.op(
271                     JavaOp.invoke(method(Integer.class, "valueOf", Integer.class, int.class), adaptedInvokeResult));
272             // @@@ Other primitive types
273         }
274         opBuilder.op(
275                 arrayStoreOp(formatArray, indexZero, adaptedInvokeResult));
276 
277         formatString = opBuilder.op(
278                 constant(J_L_STRING,
279                         prefix + ": " + invokeOp.invokeDescriptor() + " -> " + formatString(adaptedInvokeResult.type()) + "%n"));
280         opBuilder.op(
281                 JavaOp.invoke(method(PrintStream.class, "printf", PrintStream.class, String.class, Object[].class),
282                         System_out, formatString, formatArray));
283     }
284 
285     static String formatString(List<Value> vs) {
286         return vs.stream().map(v -> formatString(v.type())).collect(Collectors.joining(","));
287     }
288 
289     static String formatString(TypeElement t) {
290         if (t.equals(INT)) {
291             return "%d";
292         } else {
293             return "%s";
294         }
295     }
296 
297 
298     static final MethodRef ADD_METHOD = MethodRef.method(
299             TestLocalTransformationsAdaption.class, "add",
300             int.class, int.class, int.class);
301 
302     static int add(int a, int b) {
303         return a + b;
304     }
305 
306     static final MethodRef ADD_WITH_PRINT_METHOD = MethodRef.method(
307             TestLocalTransformationsAdaption.class, "addWithPrint",
308             int.class, int.class, int.class);
309 
310     static int addWithPrint(int a, int b) {
311         System.out.printf("Adding %d + %d%n", a, b);
312         return a + b;
313     }
314 
315     static CoreOp.FuncOp getFuncOp(String name) {
316         Optional<Method> om = Stream.of(TestLocalTransformationsAdaption.class.getDeclaredMethods())
317                 .filter(m -> m.getName().equals(name))
318                 .findFirst();
319 
320         Method m = om.get();
321         return Op.ofMethod(m).get();
322     }
323 }