1 /*
  2  * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 package oracle.code.samples;
 26 
 27 import jdk.incubator.code.CodeReflection;
 28 import jdk.incubator.code.CopyContext;
 29 import jdk.incubator.code.Op;
 30 import jdk.incubator.code.OpTransformer;
 31 import jdk.incubator.code.TypeElement;
 32 import jdk.incubator.code.Value;
 33 import jdk.incubator.code.analysis.Inliner;
 34 import jdk.incubator.code.dialect.core.CoreOp;
 35 import jdk.incubator.code.dialect.java.JavaOp;
 36 import jdk.incubator.code.dialect.java.JavaType;
 37 import jdk.incubator.code.dialect.java.MethodRef;
 38 import jdk.incubator.code.extern.OpWriter;
 39 import jdk.incubator.code.interpreter.Interpreter;
 40 
 41 import java.lang.invoke.MethodHandles;
 42 import java.lang.reflect.Method;
 43 import java.util.ArrayList;
 44 import java.util.List;
 45 import java.util.Objects;
 46 import java.util.Optional;
 47 import java.util.concurrent.atomic.AtomicReference;
 48 import java.util.stream.Stream;
 49 
 50 /**
 51  * Simple example of how to use the code reflection API.
 52  *
 53  * <p>
 54  * This example is based on the {@link MathOptimizer} example to add inlining for the new replaced
 55  * methods. In this example, we focus on explaining in detail the inlining component.
 56  * </p>
 57  *
 58  * <p>
 59  * Optimizations:
 60  *     <ol>Replace Pow(x, y) when x == 2 to (1 << y), if only if the parameter y is an integer.</ol>
 61  *     <ol>Replace Pow(x, y) when y == 2 to (x * x).</ol>
 62  *     <ol>After a replacement has been done, we inline the new invoke into the main code model.</ol>
 63  * </p>
 64  *
 65  * <p>
 66  *     In a nutshell, we apply a second transform to perform the inlining. Note that the inlining could be done
 67  *     also in within the first transform.
 68  *     To be able to inline, we need to also annotate the new invoke nodes that were replaced during the first
 69  *     transform with the <code>@CodeReflection</code> annotation. In this way, we can build the code models for
 70  *     each of the methods and apply the inlining directly using the <code>Inliner.inline</code> from the code
 71  *     reflection API.
 72  * </p>
 73  *
 74  * <p>
 75  *     Babylon repository: {@see <a href="https://github.com/openjdk/babylon/tree/code-reflection">link</a>}
 76  * </p>
 77  *
 78  * <p>
 79  *     How to run?
 80  *     <code>
 81  *         java --add-modules jdk.incubator.code -cp target/crsamples-1.0-SNAPSHOT.jar oracle.code.samples.MathOptimizerWithInlining
 82  *     </code>
 83  * </p>:
 84  */
 85 public class MathOptimizerWithInlining {
 86 
 87     // New functions are also annotated with code reflection
 88     @CodeReflection
 89     private static double myFunction(int value) {
 90         return Math.pow(2, value);
 91     }
 92 
 93     @CodeReflection
 94     private static int functionShift(int val) {
 95         return 1 << val;
 96     }
 97 
 98     @CodeReflection
 99     private static double functionMult(double x) {
100         return x * x;
101     }
102 
103     private static final MethodRef MY_SHIFT_FUNCTION = MethodRef.method(MathOptimizerWithInlining.class, "functionShift", int.class, int.class);
104 
105     private static final MethodRef MY_MULT_FUNCTION = MethodRef.method(MathOptimizerWithInlining.class, "functionMult", double.class, double.class);
106 
107     private static CoreOp.FuncOp buildCodeModelForMethod(Class<?> klass, String methodName) {
108         Optional<Method> function = Stream.of(klass.getDeclaredMethods())
109                 .filter(m -> m.getName().equals(methodName))
110                 .findFirst();
111         Method method = function.get();
112         CoreOp.FuncOp funcOp = Op.ofMethod(method).get();
113         return funcOp;
114     }
115 
116     static void main(String[] args) {
117 
118         // Obtain the code model for the annotated method
119         CoreOp.FuncOp codeModel = buildCodeModelForMethod(MathOptimizerWithInlining.class, "myFunction");
120         System.out.println(codeModel.toText());
121 
122         enum FunctionToUse {
123             SHIFT,
124             MULT,
125             GENERIC;
126         }
127 
128         AtomicReference<FunctionToUse> replace = new AtomicReference<>(FunctionToUse.GENERIC);
129 
130         codeModel = codeModel.transform(CopyContext.create(), (blockBuilder, op) -> {
131             // The idea here is to create a new JavaOp.invoke with the optimization and replace it.
132             if (Objects.requireNonNull(op) instanceof JavaOp.InvokeOp invokeOp && whenIsMathPowFunction(invokeOp)) {
133                 List<Value> operands = blockBuilder.context().getValues(op.operands());
134 
135                 // Analyse second operand of the Math.pow(x, y).
136                 // if the x == 2, and both are integers, then we can optimize the function using bitwise operations
137                 // pow(2, y) replace with (1 << y)
138                 Value operand = operands.getFirst();  // obtain the first parameter
139                 // inspect if the base (as in pow(base, exp) is value 2
140                 boolean canApplyBitShift = inspectParameterRecursive(operand, 2);
141                 if (canApplyBitShift) {
142                     // We also need to inspect types. We can apply this optimization
143                     // if the exp type is also an integer.
144                     boolean isIntType = analyseType(operands.get(1), JavaType.INT);
145                     if (!isIntType) {
146                         canApplyBitShift = false;
147                     }
148                 }
149 
150                 // If the conditions to apply the first optimization failed, we try the second optimization
151                 // if types are not int, and base is not 2.
152                 // pow(x, 2) => replace with x * x
153                 boolean canApplyMultiplication = false;
154                 if (!canApplyBitShift) {
155                     // inspect if exp (as in pow(base, exp) is value 2
156                     canApplyMultiplication = inspectParameterRecursive(operands.get(1), 2);
157                 }
158 
159                 if (canApplyBitShift) {
160                     // Narrow type from DOUBLE to INT for the input parameter of the new function.
161                     Op.Result op2 = blockBuilder.op(JavaOp.conv(JavaType.INT, operands.get(1)));
162                     List<Value> newOperandList = new ArrayList<>();
163                     newOperandList.add(op2);
164 
165                     // Create a new invoke with the optimised method
166                     JavaOp.InvokeOp newInvoke = JavaOp.invoke(MY_SHIFT_FUNCTION, newOperandList);
167                     // Copy the original location info to the new invoke
168                     newInvoke.setLocation(invokeOp.location());
169 
170                     // Replace the invoke node with the new optimized invoke
171                     Op.Result newResult = blockBuilder.op(newInvoke);
172                     // Type conversion to double
173                     newResult = blockBuilder.op(JavaOp.conv(JavaType.DOUBLE, newResult));
174                     blockBuilder.context().mapValue(invokeOp.result(), newResult);
175 
176                     replace.set(FunctionToUse.SHIFT);
177 
178                 } else if (canApplyMultiplication) {
179                     // Adapt the parameters to the new function. We only need the first
180                     // parameter from the initial parameter list  - pow(x, 2) -
181                     // Create a new invoke function with the optimised method
182                     JavaOp.InvokeOp newInvoke = JavaOp.invoke(MY_MULT_FUNCTION, operands.get(0));
183                     // Copy the location info to the new invoke
184                     newInvoke.setLocation(invokeOp.location());
185 
186                     // Replace the invoke node with the new optimized invoke
187                     Op.Result newResult = blockBuilder.op(newInvoke);
188                     blockBuilder.context().mapValue(invokeOp.result(), newResult);
189                     replace.set(FunctionToUse.MULT);
190 
191                 } else {
192                     // ignore the transformation
193                     blockBuilder.op(op);
194                 }
195             } else {
196                 blockBuilder.op(op);
197             }
198             return blockBuilder;
199         });
200 
201         System.out.println("Code Model after the first transform (replace with a new method): ");
202         System.out.println(codeModel.toText());
203 
204         // Let's now apply a second transformation
205         // We want to inline the functions. Note that we can apply this transformation if any of the new functions
206         // have been replaced.
207         System.out.println("Second transform: apply inlining for the new methods into the main code model");
208         if (replace.get() != FunctionToUse.GENERIC) {
209 
210             // Build code model for the functions we want to inline.
211             // Since we apply two replacements (depending on the values of the input code), we can apply two different
212             // inline functions.
213             CoreOp.FuncOp shiftCodeModel = buildCodeModelForMethod(MathOptimizerWithInlining.class, "functionShift");
214             CoreOp.FuncOp multCodeModel = buildCodeModelForMethod(MathOptimizerWithInlining.class, "functionMult");
215 
216             // Apply inlining
217             codeModel = codeModel.transform(codeModel.funcName(),
218                                             (blockBuilder, op) -> {
219 
220                 if (op instanceof JavaOp.InvokeOp invokeOp && isMethodWeWantToInline(invokeOp)) {
221                     // Let's inline the function
222 
223                     // 1. Select the function we want to inline.
224                     // Since we have two possible replacements, depending on the input code, we need to
225                     // apply the corresponding replacement function
226                     CoreOp.FuncOp codeModelToInline = isShiftFunction(invokeOp) ? shiftCodeModel : multCodeModel;
227 
228                     // 2. Apply the inlining
229                     Inliner.inline(
230                             blockBuilder,   // the current block builder
231                             codeModelToInline,  // the method to inline which we obtained using code reflection too
232                             blockBuilder.context().getValues(invokeOp.operands()),  // operands to this call. Since we already replace the function,
233                             // we can use the same operands as the invoke call
234                             (builder, val) -> blockBuilder.context().mapValue(invokeOp.result(), val)); // Propagate the new result
235                 } else {
236                     // copy the op into the builder if it is not the invoke node we are looking for
237                     blockBuilder.op(op);
238                 }
239 
240                 // return new transformed block builder
241                 return blockBuilder;
242             });
243 
244             System.out.println("After inlining: " + codeModel.toText());
245         }
246 
247         codeModel = codeModel.transform(OpTransformer.LOWERING_TRANSFORMER);
248         System.out.println("After Lowering: ");
249         System.out.println(codeModel.toText());
250 
251         System.out.println("\nEvaluate with Interpreter.invoke");
252         // The Interpreter Invoke should launch new exceptions
253         var result = Interpreter.invoke(MethodHandles.lookup(), codeModel, 10);
254         System.out.println(result);
255     }
256 
257     // Utility methods
258 
259     // Analyze type methods: taken from example of String Concat Transformer to traverse the tree.
260     static boolean analyseType(JavaOp.ConvOp convOp, JavaType typeToMatch) {
261         return analyseType(convOp.operands().get(0), typeToMatch);
262     }
263 
264     static boolean analyseType(Value v, JavaType typeToMatch) {
265         // Maybe there is a utility already to do tree traversal
266         if (v instanceof Op.Result r && r.op() instanceof JavaOp.ConvOp convOp) {
267             // Node of tree, recursively traverse the operands
268             return analyseType(convOp, typeToMatch);
269         } else {
270             // Leaf of tree: analyze type
271             TypeElement type = v.type();
272             return type.equals(typeToMatch);
273         }
274     }
275 
276     // Inspect a value for a parameter
277     static boolean inspectParameterRecursive(JavaOp.ConvOp convOp, int valToMatch) {
278         return inspectParameterRecursive(convOp.operands().get(0), valToMatch);
279     }
280 
281     static boolean inspectParameterRecursive(Value v, int valToMatch) {
282         if (v instanceof Op.Result r && r.op() instanceof JavaOp.ConvOp convOp) {
283             return inspectParameterRecursive(convOp, valToMatch);
284         } else {
285             // Leaf of tree - we want to obtain the actual value of the parameter and check
286             if (v instanceof CoreOp.Result r && r.op() instanceof CoreOp.ConstantOp constant) {
287                 return constant.value().equals(valToMatch);
288             }
289             return false;
290         }
291     }
292 
293     static final MethodRef JAVA_LANG_MATH_POW = MethodRef.method(Math.class, "pow", double.class, double.class, double.class);
294 
295     private static boolean whenIsMathPowFunction(JavaOp.InvokeOp invokeOp) {
296         return invokeOp.invokeDescriptor().equals(JAVA_LANG_MATH_POW);
297     }
298 
299     private static boolean isMethodWeWantToInline(JavaOp.InvokeOp invokeOp) {
300         return (invokeOp.invokeDescriptor().toString().startsWith("oracle.code.samples.MathOptimizerWithInlining::functionShift")
301                 || invokeOp.invokeDescriptor().toString().startsWith("oracle.code.samples.MathOptimizerWithInlining::functionMult"));
302     }
303 
304     private static boolean isShiftFunction(JavaOp.InvokeOp invokeOp) {
305         return invokeOp.invokeDescriptor().toString().contains("functionShift");
306     }
307 }