1 /*
  2  * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.
  8  *
  9  * This code is distributed in the hope that it will be useful, but WITHOUT
 10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 12  * version 2 for more details (a copy is included in the LICENSE file that
 13  * accompanied this code).
 14  *
 15  * You should have received a copy of the GNU General Public License version
 16  * 2 along with this work; if not, write to the Free Software Foundation,
 17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 18  *
 19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 20  * or visit www.oracle.com if you need additional information or have any
 21  * questions.
 22  */
 23 
 24 import org.testng.Assert;
 25 import org.testng.annotations.Test;
 26 
 27 import java.io.PrintStream;
 28 import java.lang.reflect.code.*;
 29 import java.lang.reflect.code.op.CoreOp;
 30 import java.lang.reflect.code.type.FieldRef;
 31 import java.lang.reflect.code.type.MethodRef;
 32 import java.lang.reflect.code.interpreter.Interpreter;
 33 import java.lang.invoke.MethodHandles;
 34 import java.lang.reflect.Method;
 35 import java.lang.runtime.CodeReflection;
 36 import java.util.List;
 37 import java.util.Optional;
 38 import java.util.concurrent.atomic.AtomicBoolean;
 39 import java.util.function.Function;
 40 import java.util.function.IntBinaryOperator;
 41 import java.util.function.IntUnaryOperator;
 42 import java.util.stream.Collectors;
 43 import java.util.stream.Stream;
 44 
 45 import static java.lang.reflect.code.op.CoreOp.arrayStoreOp;
 46 import static java.lang.reflect.code.op.CoreOp.constant;
 47 import static java.lang.reflect.code.op.CoreOp.fieldLoad;
 48 import static java.lang.reflect.code.op.CoreOp.newArray;
 49 import static java.lang.reflect.code.type.MethodRef.method;
 50 import static java.lang.reflect.code.type.JavaType.*;
 51 
 52 /*
 53  * @test
 54  * @enablePreview
 55  * @run testng TestLocalTransformationsAdaption
 56  */
 57 
 58 public class TestLocalTransformationsAdaption {
 59 
 60     @CodeReflection
 61     static int f(int i) {
 62         IntBinaryOperator add = (a, b) -> {
 63             return add(a, b);
 64         };
 65 
 66         try {
 67             IntUnaryOperator add42 = (a) -> {
 68                 return add.applyAsInt(a, 42);
 69             };
 70 
 71             int j = add42.applyAsInt(i);
 72 
 73             IntBinaryOperator f = (a, b) -> {
 74                 if (i < 0) {
 75                     throw new RuntimeException();
 76                 }
 77 
 78                 IntUnaryOperator g = (c) -> {
 79                     return add(a, c);
 80                 };
 81 
 82                 return g.applyAsInt(b);
 83             };
 84 
 85             return f.applyAsInt(j, j);
 86         } catch (RuntimeException e) {
 87             throw new IndexOutOfBoundsException(i);
 88         }
 89     }
 90 
 91     @Test
 92     public void testInvocation() {
 93         CoreOp.FuncOp f = getFuncOp("f");
 94         f.writeTo(System.out);
 95 
 96         f = f.transform(OpTransformer.LOWERING_TRANSFORMER);
 97         f.writeTo(System.out);
 98 
 99         int x = (int) Interpreter.invoke(MethodHandles.lookup(), f, 2);
100         Assert.assertEquals(x, f(2));
101 
102         try {
103             Interpreter.invoke(MethodHandles.lookup(), f, -10);
104             Assert.fail();
105         } catch (Throwable e) {
106             Assert.assertEquals(IndexOutOfBoundsException.class, e.getClass());
107         }
108     }
109 
110     @Test
111     public void testFuncEntryExit() {
112         CoreOp.FuncOp f = getFuncOp("f");
113         f.writeTo(System.out);
114 
115         AtomicBoolean first = new AtomicBoolean(true);
116         CoreOp.FuncOp fc = f.transform((block, op) -> {
117             if (first.get()) {
118                 printConstantString(block, "ENTRY");
119                 first.set(false);
120             }
121 
122             switch (op) {
123                 case CoreOp.ReturnOp returnOp when getNearestInvokeableAncestorOp(returnOp) instanceof CoreOp.FuncOp: {
124                     printConstantString(block, "EXIT");
125                     break;
126                 }
127                 case CoreOp.ThrowOp throwOp: {
128                     printConstantString(block, "EXIT");
129                     break;
130                 }
131                 default:
132             }
133 
134             block.apply(op);
135 
136             return block;
137         });
138         fc.writeTo(System.out);
139 
140         fc = fc.transform(OpTransformer.LOWERING_TRANSFORMER);
141         fc.writeTo(System.out);
142 
143         int x = (int) Interpreter.invoke(MethodHandles.lookup(), fc, 2);
144         Assert.assertEquals(x, f(2));
145 
146         try {
147             Interpreter.invoke(MethodHandles.lookup(), fc, -10);
148             Assert.fail();
149         } catch (Throwable e) {
150             Assert.assertEquals(IndexOutOfBoundsException.class, e.getClass());
151         }
152     }
153 
154     static void printConstantString(Function<Op, Op.Result> opBuilder, String s) {
155         Op.Result c = opBuilder.apply(constant(J_L_STRING, s));
156         Value System_out = opBuilder.apply(fieldLoad(FieldRef.field(System.class, "out", PrintStream.class)));
157         opBuilder.apply(CoreOp.invoke(method(PrintStream.class, "println", void.class, String.class), System_out, c));
158     }
159 
160     static Op getNearestInvokeableAncestorOp(Op op) {
161         do {
162             op = op.ancestorBody().parentOp();
163         } while (!(op instanceof Op.Invokable));
164         return op;
165     }
166 
167 
168     @Test
169     public void testReplaceCall() {
170         CoreOp.FuncOp f = getFuncOp("f");
171         f.writeTo(System.out);
172 
173         CoreOp.FuncOp fc = f.transform((block, op) -> {
174             switch (op) {
175                 case CoreOp.InvokeOp invokeOp when invokeOp.invokeDescriptor().equals(ADD_METHOD): {
176                     // Get the adapted operands, and pass those to the new call method
177                     List<Value> adaptedOperands = block.context().getValues(op.operands());
178                     Op.Result adaptedResult = block.apply(CoreOp.invoke(ADD_WITH_PRINT_METHOD, adaptedOperands));
179                     // Map the old call result to the new call result, so existing operations can be
180                     // adapted to use the new result
181                     block.context().mapValue(invokeOp.result(), adaptedResult);
182                     break;
183                 }
184                 default: {
185                     block.apply(op);
186                 }
187             }
188             return block;
189         });
190         fc.writeTo(System.out);
191 
192         fc = fc.transform(OpTransformer.LOWERING_TRANSFORMER);
193         fc.writeTo(System.out);
194 
195         int x = (int) Interpreter.invoke(MethodHandles.lookup(), fc, 2);
196         Assert.assertEquals(x, f(2));
197     }
198 
199 
200     @Test
201     public void testCallEntryExit() {
202         CoreOp.FuncOp f = getFuncOp("f");
203         f.writeTo(System.out);
204 
205         CoreOp.FuncOp fc = f.transform((block, op) -> {
206             switch (op) {
207                 case CoreOp.InvokeOp invokeOp: {
208                     printCall(block.context(), invokeOp, block);
209                     break;
210                 }
211                 default: {
212                     block.apply(op);
213                 }
214             }
215             return block;
216         });
217         fc.writeTo(System.out);
218 
219         fc = fc.transform(OpTransformer.LOWERING_TRANSFORMER);
220         fc.writeTo(System.out);
221 
222         int x = (int) Interpreter.invoke(MethodHandles.lookup(), fc, 2);
223         Assert.assertEquals(x, f(2));
224     }
225 
226     static void printCall(CopyContext cc, CoreOp.InvokeOp invokeOp, Function<Op, Op.Result> opBuilder) {
227         List<Value> adaptedInvokeOperands = cc.getValues(invokeOp.operands());
228 
229         String prefix = "ENTER";
230 
231         Value arrayLength = opBuilder.apply(
232                 constant(INT, adaptedInvokeOperands.size()));
233         Value formatArray = opBuilder.apply(
234                 newArray(type(Object[].class), arrayLength));
235 
236         Value indexZero = null;
237         for (int i = 0; i < adaptedInvokeOperands.size(); i++) {
238             Value operand = adaptedInvokeOperands.get(i);
239 
240             Value index = opBuilder.apply(
241                     constant(INT, i));
242             if (i == 0) {
243                 indexZero = index;
244             }
245 
246             if (operand.type().equals(INT)) {
247                 operand = opBuilder.apply(
248                         CoreOp.invoke(method(Integer.class, "valueOf", Integer.class, int.class), operand));
249                 // @@@ Other primitive types
250             }
251             opBuilder.apply(
252                     arrayStoreOp(formatArray, index, operand));
253         }
254 
255         Op.Result formatString = opBuilder.apply(
256                 constant(J_L_STRING,
257                         prefix + ": " + invokeOp.invokeDescriptor() + "(" + formatString(adaptedInvokeOperands) + ")%n"));
258         Value System_out = opBuilder.apply(fieldLoad(FieldRef.field(System.class, "out", PrintStream.class)));
259         opBuilder.apply(
260                 CoreOp.invoke(method(PrintStream.class, "printf", PrintStream.class, String.class, Object[].class),
261                         System_out, formatString, formatArray));
262 
263         // Method call
264 
265         Op.Result adaptedInvokeResult = opBuilder.apply(invokeOp);
266 
267         // After method call
268 
269         prefix = "EXIT";
270 
271         if (adaptedInvokeResult.type().equals(INT)) {
272             adaptedInvokeResult = opBuilder.apply(
273                     CoreOp.invoke(method(Integer.class, "valueOf", Integer.class, int.class), adaptedInvokeResult));
274             // @@@ Other primitive types
275         }
276         opBuilder.apply(
277                 arrayStoreOp(formatArray, indexZero, adaptedInvokeResult));
278 
279         formatString = opBuilder.apply(
280                 constant(J_L_STRING,
281                         prefix + ": " + invokeOp.invokeDescriptor() + " -> " + formatString(adaptedInvokeResult.type()) + "%n"));
282         opBuilder.apply(
283                 CoreOp.invoke(method(PrintStream.class, "printf", PrintStream.class, String.class, Object[].class),
284                         System_out, formatString, formatArray));
285     }
286 
287     static String formatString(List<Value> vs) {
288         return vs.stream().map(v -> formatString(v.type())).collect(Collectors.joining(","));
289     }
290 
291     static String formatString(TypeElement t) {
292         if (t.equals(INT)) {
293             return "%d";
294         } else {
295             return "%s";
296         }
297     }
298 
299 
300     static final MethodRef ADD_METHOD = MethodRef.method(
301             TestLocalTransformationsAdaption.class, "add",
302             int.class, int.class, int.class);
303 
304     static int add(int a, int b) {
305         return a + b;
306     }
307 
308     static final MethodRef ADD_WITH_PRINT_METHOD = MethodRef.method(
309             TestLocalTransformationsAdaption.class, "addWithPrint",
310             int.class, int.class, int.class);
311 
312     static int addWithPrint(int a, int b) {
313         System.out.printf("Adding %d + %d%n", a, b);
314         return a + b;
315     }
316 
317     static CoreOp.FuncOp getFuncOp(String name) {
318         Optional<Method> om = Stream.of(TestLocalTransformationsAdaption.class.getDeclaredMethods())
319                 .filter(m -> m.getName().equals(name))
320                 .findFirst();
321 
322         Method m = om.get();
323         return m.getCodeModel().get();
324     }
325 }