1 /*
  2  * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.
  8  *
  9  * This code is distributed in the hope that it will be useful, but WITHOUT
 10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 12  * version 2 for more details (a copy is included in the LICENSE file that
 13  * accompanied this code).
 14  *
 15  * You should have received a copy of the GNU General Public License version
 16  * 2 along with this work; if not, write to the Free Software Foundation,
 17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 18  *
 19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 20  * or visit www.oracle.com if you need additional information or have any
 21  * questions.
 22  */
 23 
 24 import jdk.incubator.code.dialect.java.JavaOp;
 25 import org.testng.Assert;
 26 import org.testng.annotations.Test;
 27 
 28 import java.io.PrintStream;
 29 import jdk.incubator.code.*;
 30 import jdk.incubator.code.dialect.core.CoreOp;
 31 import jdk.incubator.code.dialect.java.FieldRef;
 32 import jdk.incubator.code.dialect.java.MethodRef;
 33 import jdk.incubator.code.interpreter.Interpreter;
 34 import java.lang.invoke.MethodHandles;
 35 import java.lang.reflect.Method;
 36 import jdk.incubator.code.CodeReflection;
 37 import java.util.List;
 38 import java.util.Optional;
 39 import java.util.concurrent.atomic.AtomicBoolean;
 40 import java.util.function.Function;
 41 import java.util.function.IntBinaryOperator;
 42 import java.util.function.IntUnaryOperator;
 43 import java.util.stream.Collectors;
 44 import java.util.stream.Stream;
 45 
 46 import static jdk.incubator.code.dialect.java.JavaOp.arrayStoreOp;
 47 import static jdk.incubator.code.dialect.core.CoreOp.constant;
 48 import static jdk.incubator.code.dialect.java.JavaOp.fieldLoad;
 49 import static jdk.incubator.code.dialect.java.JavaOp.newArray;
 50 import static jdk.incubator.code.dialect.java.MethodRef.method;
 51 import static jdk.incubator.code.dialect.java.JavaType.*;
 52 
 53 /*
 54  * @test
 55  * @modules jdk.incubator.code
 56  * @enablePreview
 57  * @run testng TestLocalTransformationsAdaption
 58  */
 59 
 60 public class TestLocalTransformationsAdaption {
 61 
 62     @CodeReflection
 63     static int f(int i) {
 64         IntBinaryOperator add = (a, b) -> {
 65             return add(a, b);
 66         };
 67 
 68         try {
 69             IntUnaryOperator add42 = (a) -> {
 70                 return add.applyAsInt(a, 42);
 71             };
 72 
 73             int j = add42.applyAsInt(i);
 74 
 75             IntBinaryOperator f = (a, b) -> {
 76                 if (i < 0) {
 77                     throw new RuntimeException();
 78                 }
 79 
 80                 IntUnaryOperator g = (c) -> {
 81                     return add(a, c);
 82                 };
 83 
 84                 return g.applyAsInt(b);
 85             };
 86 
 87             return f.applyAsInt(j, j);
 88         } catch (RuntimeException e) {
 89             throw new IndexOutOfBoundsException(i);
 90         }
 91     }
 92 
 93     @Test
 94     public void testInvocation() {
 95         CoreOp.FuncOp f = getFuncOp("f");
 96         f.writeTo(System.out);
 97 
 98         f = f.transform(OpTransformer.LOWERING_TRANSFORMER);
 99         f.writeTo(System.out);
100 
101         int x = (int) Interpreter.invoke(MethodHandles.lookup(), f, 2);
102         Assert.assertEquals(x, f(2));
103 
104         try {
105             Interpreter.invoke(MethodHandles.lookup(), f, -10);
106             Assert.fail();
107         } catch (Throwable e) {
108             Assert.assertEquals(IndexOutOfBoundsException.class, e.getClass());
109         }
110     }
111 
112     @Test
113     public void testFuncEntryExit() {
114         CoreOp.FuncOp f = getFuncOp("f");
115         f.writeTo(System.out);
116 
117         AtomicBoolean first = new AtomicBoolean(true);
118         CoreOp.FuncOp fc = f.transform((block, op) -> {
119             if (first.get()) {
120                 printConstantString(block, "ENTRY");
121                 first.set(false);
122             }
123 
124             switch (op) {
125                 case CoreOp.ReturnOp returnOp when getNearestInvokeableAncestorOp(returnOp) instanceof CoreOp.FuncOp: {
126                     printConstantString(block, "EXIT");
127                     break;
128                 }
129                 case JavaOp.ThrowOp throwOp: {
130                     printConstantString(block, "EXIT");
131                     break;
132                 }
133                 default:
134             }
135 
136             block.apply(op);
137 
138             return block;
139         });
140         fc.writeTo(System.out);
141 
142         fc = fc.transform(OpTransformer.LOWERING_TRANSFORMER);
143         fc.writeTo(System.out);
144 
145         int x = (int) Interpreter.invoke(MethodHandles.lookup(), fc, 2);
146         Assert.assertEquals(x, f(2));
147 
148         try {
149             Interpreter.invoke(MethodHandles.lookup(), fc, -10);
150             Assert.fail();
151         } catch (Throwable e) {
152             Assert.assertEquals(IndexOutOfBoundsException.class, e.getClass());
153         }
154     }
155 
156     static void printConstantString(Function<Op, Op.Result> opBuilder, String s) {
157         Op.Result c = opBuilder.apply(constant(J_L_STRING, s));
158         Value System_out = opBuilder.apply(fieldLoad(FieldRef.field(System.class, "out", PrintStream.class)));
159         opBuilder.apply(JavaOp.invoke(method(PrintStream.class, "println", void.class, String.class), System_out, c));
160     }
161 
162     static Op getNearestInvokeableAncestorOp(Op op) {
163         do {
164             op = op.ancestorBody().parentOp();
165         } while (!(op instanceof Op.Invokable));
166         return op;
167     }
168 
169 
170     @Test
171     public void testReplaceCall() {
172         CoreOp.FuncOp f = getFuncOp("f");
173         f.writeTo(System.out);
174 
175         CoreOp.FuncOp fc = f.transform((block, op) -> {
176             switch (op) {
177                 case JavaOp.InvokeOp invokeOp when invokeOp.invokeDescriptor().equals(ADD_METHOD): {
178                     // Get the adapted operands, and pass those to the new call method
179                     List<Value> adaptedOperands = block.context().getValues(op.operands());
180                     Op.Result adaptedResult = block.apply(JavaOp.invoke(ADD_WITH_PRINT_METHOD, adaptedOperands));
181                     // Map the old call result to the new call result, so existing operations can be
182                     // adapted to use the new result
183                     block.context().mapValue(invokeOp.result(), adaptedResult);
184                     break;
185                 }
186                 default: {
187                     block.apply(op);
188                 }
189             }
190             return block;
191         });
192         fc.writeTo(System.out);
193 
194         fc = fc.transform(OpTransformer.LOWERING_TRANSFORMER);
195         fc.writeTo(System.out);
196 
197         int x = (int) Interpreter.invoke(MethodHandles.lookup(), fc, 2);
198         Assert.assertEquals(x, f(2));
199     }
200 
201 
202     @Test
203     public void testCallEntryExit() {
204         CoreOp.FuncOp f = getFuncOp("f");
205         f.writeTo(System.out);
206 
207         CoreOp.FuncOp fc = f.transform((block, op) -> {
208             switch (op) {
209                 case JavaOp.InvokeOp invokeOp: {
210                     printCall(block.context(), invokeOp, block);
211                     break;
212                 }
213                 default: {
214                     block.apply(op);
215                 }
216             }
217             return block;
218         });
219         fc.writeTo(System.out);
220 
221         fc = fc.transform(OpTransformer.LOWERING_TRANSFORMER);
222         fc.writeTo(System.out);
223 
224         int x = (int) Interpreter.invoke(MethodHandles.lookup(), fc, 2);
225         Assert.assertEquals(x, f(2));
226     }
227 
228     static void printCall(CopyContext cc, JavaOp.InvokeOp invokeOp, Function<Op, Op.Result> opBuilder) {
229         List<Value> adaptedInvokeOperands = cc.getValues(invokeOp.operands());
230 
231         String prefix = "ENTER";
232 
233         Value arrayLength = opBuilder.apply(
234                 constant(INT, adaptedInvokeOperands.size()));
235         Value formatArray = opBuilder.apply(
236                 newArray(type(Object[].class), arrayLength));
237 
238         Value indexZero = null;
239         for (int i = 0; i < adaptedInvokeOperands.size(); i++) {
240             Value operand = adaptedInvokeOperands.get(i);
241 
242             Value index = opBuilder.apply(
243                     constant(INT, i));
244             if (i == 0) {
245                 indexZero = index;
246             }
247 
248             if (operand.type().equals(INT)) {
249                 operand = opBuilder.apply(
250                         JavaOp.invoke(method(Integer.class, "valueOf", Integer.class, int.class), operand));
251                 // @@@ Other primitive types
252             }
253             opBuilder.apply(
254                     arrayStoreOp(formatArray, index, operand));
255         }
256 
257         Op.Result formatString = opBuilder.apply(
258                 constant(J_L_STRING,
259                         prefix + ": " + invokeOp.invokeDescriptor() + "(" + formatString(adaptedInvokeOperands) + ")%n"));
260         Value System_out = opBuilder.apply(fieldLoad(FieldRef.field(System.class, "out", PrintStream.class)));
261         opBuilder.apply(
262                 JavaOp.invoke(method(PrintStream.class, "printf", PrintStream.class, String.class, Object[].class),
263                         System_out, formatString, formatArray));
264 
265         // Method call
266 
267         Op.Result adaptedInvokeResult = opBuilder.apply(invokeOp);
268 
269         // After method call
270 
271         prefix = "EXIT";
272 
273         if (adaptedInvokeResult.type().equals(INT)) {
274             adaptedInvokeResult = opBuilder.apply(
275                     JavaOp.invoke(method(Integer.class, "valueOf", Integer.class, int.class), adaptedInvokeResult));
276             // @@@ Other primitive types
277         }
278         opBuilder.apply(
279                 arrayStoreOp(formatArray, indexZero, adaptedInvokeResult));
280 
281         formatString = opBuilder.apply(
282                 constant(J_L_STRING,
283                         prefix + ": " + invokeOp.invokeDescriptor() + " -> " + formatString(adaptedInvokeResult.type()) + "%n"));
284         opBuilder.apply(
285                 JavaOp.invoke(method(PrintStream.class, "printf", PrintStream.class, String.class, Object[].class),
286                         System_out, formatString, formatArray));
287     }
288 
289     static String formatString(List<Value> vs) {
290         return vs.stream().map(v -> formatString(v.type())).collect(Collectors.joining(","));
291     }
292 
293     static String formatString(TypeElement t) {
294         if (t.equals(INT)) {
295             return "%d";
296         } else {
297             return "%s";
298         }
299     }
300 
301 
302     static final MethodRef ADD_METHOD = MethodRef.method(
303             TestLocalTransformationsAdaption.class, "add",
304             int.class, int.class, int.class);
305 
306     static int add(int a, int b) {
307         return a + b;
308     }
309 
310     static final MethodRef ADD_WITH_PRINT_METHOD = MethodRef.method(
311             TestLocalTransformationsAdaption.class, "addWithPrint",
312             int.class, int.class, int.class);
313 
314     static int addWithPrint(int a, int b) {
315         System.out.printf("Adding %d + %d%n", a, b);
316         return a + b;
317     }
318 
319     static CoreOp.FuncOp getFuncOp(String name) {
320         Optional<Method> om = Stream.of(TestLocalTransformationsAdaption.class.getDeclaredMethods())
321                 .filter(m -> m.getName().equals(name))
322                 .findFirst();
323 
324         Method m = om.get();
325         return Op.ofMethod(m).get();
326     }
327 }