1 /*
  2  * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.
  8  *
  9  * This code is distributed in the hope that it will be useful, but WITHOUT
 10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 12  * version 2 for more details (a copy is included in the LICENSE file that
 13  * accompanied this code).
 14  *
 15  * You should have received a copy of the GNU General Public License version
 16  * 2 along with this work; if not, write to the Free Software Foundation,
 17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 18  *
 19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 20  * or visit www.oracle.com if you need additional information or have any
 21  * questions.
 22  */
 23 
 24 import org.testng.Assert;
 25 import org.testng.annotations.Test;
 26 
 27 import java.io.PrintStream;
 28 import jdk.incubator.code.*;
 29 import jdk.incubator.code.op.CoreOp;
 30 import jdk.incubator.code.type.FieldRef;
 31 import jdk.incubator.code.type.MethodRef;
 32 import jdk.incubator.code.interpreter.Interpreter;
 33 import java.lang.invoke.MethodHandles;
 34 import java.lang.reflect.Method;
 35 import jdk.incubator.code.CodeReflection;
 36 import java.util.List;
 37 import java.util.Optional;
 38 import java.util.concurrent.atomic.AtomicBoolean;
 39 import java.util.function.Function;
 40 import java.util.function.IntBinaryOperator;
 41 import java.util.function.IntUnaryOperator;
 42 import java.util.stream.Collectors;
 43 import java.util.stream.Stream;
 44 
 45 import static jdk.incubator.code.op.CoreOp.arrayStoreOp;
 46 import static jdk.incubator.code.op.CoreOp.constant;
 47 import static jdk.incubator.code.op.CoreOp.fieldLoad;
 48 import static jdk.incubator.code.op.CoreOp.newArray;
 49 import static jdk.incubator.code.type.MethodRef.method;
 50 import static jdk.incubator.code.type.JavaType.*;
 51 
 52 /*
 53  * @test
 54  * @modules jdk.incubator.code
 55  * @enablePreview
 56  * @run testng TestLocalTransformationsAdaption
 57  */
 58 
 59 public class TestLocalTransformationsAdaption {
 60 
 61     @CodeReflection
 62     static int f(int i) {
 63         IntBinaryOperator add = (a, b) -> {
 64             return add(a, b);
 65         };
 66 
 67         try {
 68             IntUnaryOperator add42 = (a) -> {
 69                 return add.applyAsInt(a, 42);
 70             };
 71 
 72             int j = add42.applyAsInt(i);
 73 
 74             IntBinaryOperator f = (a, b) -> {
 75                 if (i < 0) {
 76                     throw new RuntimeException();
 77                 }
 78 
 79                 IntUnaryOperator g = (c) -> {
 80                     return add(a, c);
 81                 };
 82 
 83                 return g.applyAsInt(b);
 84             };
 85 
 86             return f.applyAsInt(j, j);
 87         } catch (RuntimeException e) {
 88             throw new IndexOutOfBoundsException(i);
 89         }
 90     }
 91 
 92     @Test
 93     public void testInvocation() {
 94         CoreOp.FuncOp f = getFuncOp("f");
 95         f.writeTo(System.out);
 96 
 97         f = f.transform(OpTransformer.LOWERING_TRANSFORMER);
 98         f.writeTo(System.out);
 99 
100         int x = (int) Interpreter.invoke(MethodHandles.lookup(), f, 2);
101         Assert.assertEquals(x, f(2));
102 
103         try {
104             Interpreter.invoke(MethodHandles.lookup(), f, -10);
105             Assert.fail();
106         } catch (Throwable e) {
107             Assert.assertEquals(IndexOutOfBoundsException.class, e.getClass());
108         }
109     }
110 
111     @Test
112     public void testFuncEntryExit() {
113         CoreOp.FuncOp f = getFuncOp("f");
114         f.writeTo(System.out);
115 
116         AtomicBoolean first = new AtomicBoolean(true);
117         CoreOp.FuncOp fc = f.transform((block, op) -> {
118             if (first.get()) {
119                 printConstantString(block, "ENTRY");
120                 first.set(false);
121             }
122 
123             switch (op) {
124                 case CoreOp.ReturnOp returnOp when getNearestInvokeableAncestorOp(returnOp) instanceof CoreOp.FuncOp: {
125                     printConstantString(block, "EXIT");
126                     break;
127                 }
128                 case CoreOp.ThrowOp throwOp: {
129                     printConstantString(block, "EXIT");
130                     break;
131                 }
132                 default:
133             }
134 
135             block.apply(op);
136 
137             return block;
138         });
139         fc.writeTo(System.out);
140 
141         fc = fc.transform(OpTransformer.LOWERING_TRANSFORMER);
142         fc.writeTo(System.out);
143 
144         int x = (int) Interpreter.invoke(MethodHandles.lookup(), fc, 2);
145         Assert.assertEquals(x, f(2));
146 
147         try {
148             Interpreter.invoke(MethodHandles.lookup(), fc, -10);
149             Assert.fail();
150         } catch (Throwable e) {
151             Assert.assertEquals(IndexOutOfBoundsException.class, e.getClass());
152         }
153     }
154 
155     static void printConstantString(Function<Op, Op.Result> opBuilder, String s) {
156         Op.Result c = opBuilder.apply(constant(J_L_STRING, s));
157         Value System_out = opBuilder.apply(fieldLoad(FieldRef.field(System.class, "out", PrintStream.class)));
158         opBuilder.apply(CoreOp.invoke(method(PrintStream.class, "println", void.class, String.class), System_out, c));
159     }
160 
161     static Op getNearestInvokeableAncestorOp(Op op) {
162         do {
163             op = op.ancestorBody().parentOp();
164         } while (!(op instanceof Op.Invokable));
165         return op;
166     }
167 
168 
169     @Test
170     public void testReplaceCall() {
171         CoreOp.FuncOp f = getFuncOp("f");
172         f.writeTo(System.out);
173 
174         CoreOp.FuncOp fc = f.transform((block, op) -> {
175             switch (op) {
176                 case CoreOp.InvokeOp invokeOp when invokeOp.invokeDescriptor().equals(ADD_METHOD): {
177                     // Get the adapted operands, and pass those to the new call method
178                     List<Value> adaptedOperands = block.context().getValues(op.operands());
179                     Op.Result adaptedResult = block.apply(CoreOp.invoke(ADD_WITH_PRINT_METHOD, adaptedOperands));
180                     // Map the old call result to the new call result, so existing operations can be
181                     // adapted to use the new result
182                     block.context().mapValue(invokeOp.result(), adaptedResult);
183                     break;
184                 }
185                 default: {
186                     block.apply(op);
187                 }
188             }
189             return block;
190         });
191         fc.writeTo(System.out);
192 
193         fc = fc.transform(OpTransformer.LOWERING_TRANSFORMER);
194         fc.writeTo(System.out);
195 
196         int x = (int) Interpreter.invoke(MethodHandles.lookup(), fc, 2);
197         Assert.assertEquals(x, f(2));
198     }
199 
200 
201     @Test
202     public void testCallEntryExit() {
203         CoreOp.FuncOp f = getFuncOp("f");
204         f.writeTo(System.out);
205 
206         CoreOp.FuncOp fc = f.transform((block, op) -> {
207             switch (op) {
208                 case CoreOp.InvokeOp invokeOp: {
209                     printCall(block.context(), invokeOp, block);
210                     break;
211                 }
212                 default: {
213                     block.apply(op);
214                 }
215             }
216             return block;
217         });
218         fc.writeTo(System.out);
219 
220         fc = fc.transform(OpTransformer.LOWERING_TRANSFORMER);
221         fc.writeTo(System.out);
222 
223         int x = (int) Interpreter.invoke(MethodHandles.lookup(), fc, 2);
224         Assert.assertEquals(x, f(2));
225     }
226 
227     static void printCall(CopyContext cc, CoreOp.InvokeOp invokeOp, Function<Op, Op.Result> opBuilder) {
228         List<Value> adaptedInvokeOperands = cc.getValues(invokeOp.operands());
229 
230         String prefix = "ENTER";
231 
232         Value arrayLength = opBuilder.apply(
233                 constant(INT, adaptedInvokeOperands.size()));
234         Value formatArray = opBuilder.apply(
235                 newArray(type(Object[].class), arrayLength));
236 
237         Value indexZero = null;
238         for (int i = 0; i < adaptedInvokeOperands.size(); i++) {
239             Value operand = adaptedInvokeOperands.get(i);
240 
241             Value index = opBuilder.apply(
242                     constant(INT, i));
243             if (i == 0) {
244                 indexZero = index;
245             }
246 
247             if (operand.type().equals(INT)) {
248                 operand = opBuilder.apply(
249                         CoreOp.invoke(method(Integer.class, "valueOf", Integer.class, int.class), operand));
250                 // @@@ Other primitive types
251             }
252             opBuilder.apply(
253                     arrayStoreOp(formatArray, index, operand));
254         }
255 
256         Op.Result formatString = opBuilder.apply(
257                 constant(J_L_STRING,
258                         prefix + ": " + invokeOp.invokeDescriptor() + "(" + formatString(adaptedInvokeOperands) + ")%n"));
259         Value System_out = opBuilder.apply(fieldLoad(FieldRef.field(System.class, "out", PrintStream.class)));
260         opBuilder.apply(
261                 CoreOp.invoke(method(PrintStream.class, "printf", PrintStream.class, String.class, Object[].class),
262                         System_out, formatString, formatArray));
263 
264         // Method call
265 
266         Op.Result adaptedInvokeResult = opBuilder.apply(invokeOp);
267 
268         // After method call
269 
270         prefix = "EXIT";
271 
272         if (adaptedInvokeResult.type().equals(INT)) {
273             adaptedInvokeResult = opBuilder.apply(
274                     CoreOp.invoke(method(Integer.class, "valueOf", Integer.class, int.class), adaptedInvokeResult));
275             // @@@ Other primitive types
276         }
277         opBuilder.apply(
278                 arrayStoreOp(formatArray, indexZero, adaptedInvokeResult));
279 
280         formatString = opBuilder.apply(
281                 constant(J_L_STRING,
282                         prefix + ": " + invokeOp.invokeDescriptor() + " -> " + formatString(adaptedInvokeResult.type()) + "%n"));
283         opBuilder.apply(
284                 CoreOp.invoke(method(PrintStream.class, "printf", PrintStream.class, String.class, Object[].class),
285                         System_out, formatString, formatArray));
286     }
287 
288     static String formatString(List<Value> vs) {
289         return vs.stream().map(v -> formatString(v.type())).collect(Collectors.joining(","));
290     }
291 
292     static String formatString(TypeElement t) {
293         if (t.equals(INT)) {
294             return "%d";
295         } else {
296             return "%s";
297         }
298     }
299 
300 
301     static final MethodRef ADD_METHOD = MethodRef.method(
302             TestLocalTransformationsAdaption.class, "add",
303             int.class, int.class, int.class);
304 
305     static int add(int a, int b) {
306         return a + b;
307     }
308 
309     static final MethodRef ADD_WITH_PRINT_METHOD = MethodRef.method(
310             TestLocalTransformationsAdaption.class, "addWithPrint",
311             int.class, int.class, int.class);
312 
313     static int addWithPrint(int a, int b) {
314         System.out.printf("Adding %d + %d%n", a, b);
315         return a + b;
316     }
317 
318     static CoreOp.FuncOp getFuncOp(String name) {
319         Optional<Method> om = Stream.of(TestLocalTransformationsAdaption.class.getDeclaredMethods())
320                 .filter(m -> m.getName().equals(name))
321                 .findFirst();
322 
323         Method m = om.get();
324         return Op.ofMethod(m).get();
325     }
326 }