1 /*
  2  * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 package oracle.code.samples;
 26 
 27 import jdk.incubator.code.CodeTransformer;
 28 import jdk.incubator.code.Reflect;
 29 import jdk.incubator.code.Op;
 30 import jdk.incubator.code.TypeElement;
 31 import jdk.incubator.code.Value;
 32 import jdk.incubator.code.bytecode.BytecodeGenerator;
 33 import jdk.incubator.code.dialect.core.CoreOp;
 34 import jdk.incubator.code.dialect.core.Inliner;
 35 import jdk.incubator.code.dialect.java.JavaOp;
 36 import jdk.incubator.code.dialect.java.JavaType;
 37 import jdk.incubator.code.dialect.java.MethodRef;
 38 
 39 import java.lang.invoke.MethodHandle;
 40 import java.lang.invoke.MethodHandles;
 41 import java.lang.reflect.Method;
 42 import java.util.ArrayList;
 43 import java.util.List;
 44 import java.util.Objects;
 45 import java.util.Optional;
 46 import java.util.concurrent.atomic.AtomicReference;
 47 import java.util.stream.Stream;
 48 
 49 /**
 50  * Simple example of how to use the code reflection API.
 51  *
 52  * <p>
 53  * This example is based on the {@link MathOptimizer} example to add inlining for the new replaced
 54  * methods. In this example, we focus on explaining in detail the inlining component.
 55  * </p>
 56  *
 57  * <p>
 58  * Optimizations:
 59  *     <ol>Replace Pow(x, y) when x == 2 to (1 << y), if only if the parameter y is an integer.</ol>
 60  *     <ol>Replace Pow(x, y) when y == 2 to (x * x).</ol>
 61  *     <ol>After a replacement has been done, we inline the new invoke into the main code model.</ol>
 62  * </p>
 63  *
 64  * <p>
 65  *     In a nutshell, we apply a second transform to perform the inlining. Note that the inlining could be done
 66  *     also in within the first transform.
 67  *     To be able to inline, we need to also annotate the new invoke nodes that were replaced during the first
 68  *     transform with the <code>@Reflect</code> annotation. In this way, we can build the code models for
 69  *     each of the methods and apply the inlining directly using the <code>Inliner.inline</code> from the code
 70  *     reflection API.
 71  * </p>
 72  *
 73  * <p>
 74  *     Babylon repository: {@see <a href="https://github.com/openjdk/babylon/tree/code-reflection">link</a>}
 75  * </p>
 76  *
 77  * <p>
 78  *     How to run?
 79  *     <code>
 80  *         java --add-modules jdk.incubator.code -cp target/crsamples-1.0-SNAPSHOT.jar oracle.code.samples.MathOptimizerWithInlining
 81  *     </code>
 82  * </p>:
 83  */
 84 public class MathOptimizerWithInlining {
 85 
 86     // New functions are also annotated with code reflection
 87     @Reflect
 88     private static double myFunction(int value) {
 89         return Math.pow(2, value);
 90     }
 91 
 92     @Reflect
 93     private static int functionShift(int val) {
 94         return 1 << val;
 95     }
 96 
 97     @Reflect
 98     private static double functionMult(double x) {
 99         return x * x;
100     }
101 
102     private static final MethodRef MY_SHIFT_FUNCTION = MethodRef.method(MathOptimizerWithInlining.class, "functionShift", int.class, int.class);
103 
104     private static final MethodRef MY_MULT_FUNCTION = MethodRef.method(MathOptimizerWithInlining.class, "functionMult", double.class, double.class);
105 
106     private static CoreOp.FuncOp buildCodeModelForMethod(Class<?> klass, String methodName) {
107         Optional<Method> function = Stream.of(klass.getDeclaredMethods())
108                 .filter(m -> m.getName().equals(methodName))
109                 .findFirst();
110         Method method = function.get();
111         CoreOp.FuncOp funcOp = Op.ofMethod(method).get();
112         return funcOp;
113     }
114 
115     static void main() {
116 
117         // Obtain the code model for the annotated method
118         CoreOp.FuncOp codeModel = buildCodeModelForMethod(MathOptimizerWithInlining.class, "myFunction");
119         System.out.println(codeModel.toText());
120 
121         enum FunctionToUse {
122             SHIFT,
123             MULT,
124             GENERIC;
125         }
126 
127         AtomicReference<FunctionToUse> replace = new AtomicReference<>(FunctionToUse.GENERIC);
128 
129         codeModel = codeModel.transform((blockBuilder, op) -> {
130             // The idea here is to create a new JavaOp.invoke with the optimization and replace it.
131             if (Objects.requireNonNull(op) instanceof JavaOp.InvokeOp invokeOp && whenIsMathPowFunction(invokeOp)) {
132                 List<Value> operands = blockBuilder.context().getValues(op.operands());
133 
134                 // Analyse second operand of the Math.pow(x, y).
135                 // if the x == 2, and both are integers, then we can optimize the function using bitwise operations
136                 // pow(2, y) replace with (1 << y)
137                 Value operand = operands.getFirst();  // obtain the first parameter
138                 // inspect if the base (as in pow(base, exp) is value 2
139                 boolean canApplyBitShift = inspectParameterRecursive(operand, 2);
140                 if (canApplyBitShift) {
141                     // We also need to inspect types. We can apply this optimization
142                     // if the exp type is also an integer.
143                     boolean isIntType = analyseType(operands.get(1), JavaType.INT);
144                     if (!isIntType) {
145                         canApplyBitShift = false;
146                     }
147                 }
148 
149                 // If the conditions to apply the first optimization failed, we try the second optimization
150                 // if types are not int, and base is not 2.
151                 // pow(x, 2) => replace with x * x
152                 boolean canApplyMultiplication = false;
153                 if (!canApplyBitShift) {
154                     // inspect if exp (as in pow(base, exp) is value 2
155                     canApplyMultiplication = inspectParameterRecursive(operands.get(1), 2);
156                 }
157 
158                 if (canApplyBitShift) {
159                     // Narrow type from DOUBLE to INT for the input parameter of the new function.
160                     Op.Result op2 = blockBuilder.op(JavaOp.conv(JavaType.INT, operands.get(1)));
161                     List<Value> newOperandList = new ArrayList<>();
162                     newOperandList.add(op2);
163 
164                     // Create a new invoke with the optimised method
165                     JavaOp.InvokeOp newInvoke = JavaOp.invoke(MY_SHIFT_FUNCTION, newOperandList);
166                     // Copy the original location info to the new invoke
167                     newInvoke.setLocation(invokeOp.location());
168 
169                     // Replace the invoke node with the new optimized invoke
170                     Op.Result newResult = blockBuilder.op(newInvoke);
171                     // Type conversion to double
172                     newResult = blockBuilder.op(JavaOp.conv(JavaType.DOUBLE, newResult));
173                     blockBuilder.context().mapValue(invokeOp.result(), newResult);
174 
175                     replace.set(FunctionToUse.SHIFT);
176 
177                 } else if (canApplyMultiplication) {
178                     // Adapt the parameters to the new function. We only need the first
179                     // parameter from the initial parameter list  - pow(x, 2) -
180                     // Create a new invoke function with the optimised method
181                     JavaOp.InvokeOp newInvoke = JavaOp.invoke(MY_MULT_FUNCTION, operands.get(0));
182                     // Copy the location info to the new invoke
183                     newInvoke.setLocation(invokeOp.location());
184 
185                     // Replace the invoke node with the new optimized invoke
186                     Op.Result newResult = blockBuilder.op(newInvoke);
187                     blockBuilder.context().mapValue(invokeOp.result(), newResult);
188                     replace.set(FunctionToUse.MULT);
189 
190                 } else {
191                     // ignore the transformation
192                     blockBuilder.op(op);
193                 }
194             } else {
195                 blockBuilder.op(op);
196             }
197             return blockBuilder;
198         });
199 
200         System.out.println("Code Model after the first transform (replace with a new method): ");
201         System.out.println(codeModel.toText());
202 
203         // Let's now apply a second transformation
204         // We want to inline the functions. Note that we can apply this transformation if any of the new functions
205         // have been replaced.
206         System.out.println("Second transform: apply inlining for the new methods into the main code model");
207         if (replace.get() != FunctionToUse.GENERIC) {
208 
209             // Build code model for the functions we want to inline.
210             // Since we apply two replacements (depending on the values of the input code), we can apply two different
211             // inline functions.
212             CoreOp.FuncOp shiftCodeModel = buildCodeModelForMethod(MathOptimizerWithInlining.class, "functionShift");
213             CoreOp.FuncOp multCodeModel = buildCodeModelForMethod(MathOptimizerWithInlining.class, "functionMult");
214 
215             // Apply inlining
216             codeModel = codeModel.transform(codeModel.funcName(),
217                                             (blockBuilder, op) -> {
218 
219                 if (op instanceof JavaOp.InvokeOp invokeOp && isMethodWeWantToInline(invokeOp)) {
220                     // Let's inline the function
221 
222                     // 1. Select the function we want to inline.
223                     // Since we have two possible replacements, depending on the input code, we need to
224                     // apply the corresponding replacement function
225                     CoreOp.FuncOp codeModelToInline = isShiftFunction(invokeOp) ? shiftCodeModel : multCodeModel;
226 
227                     // 2. Apply the inlining
228                     Inliner.inline(
229                             blockBuilder,   // the current block builder
230                             codeModelToInline,  // the method to inline which we obtained using code reflection too
231                             blockBuilder.context().getValues(invokeOp.operands()),  // operands to this call. Since we already replace the function,
232                             // we can use the same operands as the invoke call
233                             (builder, val) -> blockBuilder.context().mapValue(invokeOp.result(), val)); // Propagate the new result
234                 } else {
235                     // copy the op into the builder if it is not the invoke node we are looking for
236                     blockBuilder.op(op);
237                 }
238 
239                 // return new transformed block builder
240                 return blockBuilder;
241             });
242 
243             System.out.println("After inlining: " + codeModel.toText());
244         }
245 
246         codeModel = codeModel.transform(CodeTransformer.LOWERING_TRANSFORMER);
247         System.out.println("After Lowering: ");
248         System.out.println(codeModel.toText());
249 
250         System.out.println("\nGenerate bytecode and execute");
251         MethodHandle methodHandle = BytecodeGenerator.generate(MethodHandles.lookup(), codeModel);
252         try {
253             var result = methodHandle.invoke(10);
254             System.out.println(result);
255         } catch (Throwable e) {
256             throw new RuntimeException(e);
257         }
258 
259     }
260 
261     // Utility methods
262 
263     // Analyze type methods: taken from example of String Concat Transformer to traverse the tree.
264     static boolean analyseType(JavaOp.ConvOp convOp, JavaType typeToMatch) {
265         return analyseType(convOp.operands().get(0), typeToMatch);
266     }
267 
268     static boolean analyseType(Value v, JavaType typeToMatch) {
269         // Maybe there is a utility already to do tree traversal
270         if (v instanceof Op.Result r && r.op() instanceof JavaOp.ConvOp convOp) {
271             // Node of tree, recursively traverse the operands
272             return analyseType(convOp, typeToMatch);
273         } else {
274             // Leaf of tree: analyze type
275             TypeElement type = v.type();
276             return type.equals(typeToMatch);
277         }
278     }
279 
280     // Inspect a value for a parameter
281     static boolean inspectParameterRecursive(JavaOp.ConvOp convOp, int valToMatch) {
282         return inspectParameterRecursive(convOp.operands().get(0), valToMatch);
283     }
284 
285     static boolean inspectParameterRecursive(Value v, int valToMatch) {
286         if (v instanceof Op.Result r && r.op() instanceof JavaOp.ConvOp convOp) {
287             return inspectParameterRecursive(convOp, valToMatch);
288         } else {
289             // Leaf of tree - we want to obtain the actual value of the parameter and check
290             if (v instanceof CoreOp.Result r && r.op() instanceof CoreOp.ConstantOp constant) {
291                 return constant.value().equals(valToMatch);
292             }
293             return false;
294         }
295     }
296 
297     static final MethodRef JAVA_LANG_MATH_POW = MethodRef.method(Math.class, "pow", double.class, double.class, double.class);
298 
299     private static boolean whenIsMathPowFunction(JavaOp.InvokeOp invokeOp) {
300         return invokeOp.invokeDescriptor().equals(JAVA_LANG_MATH_POW);
301     }
302 
303     private static boolean isMethodWeWantToInline(JavaOp.InvokeOp invokeOp) {
304         return (invokeOp.invokeDescriptor().toString().startsWith("oracle.code.samples.MathOptimizerWithInlining::functionShift")
305                 || invokeOp.invokeDescriptor().toString().startsWith("oracle.code.samples.MathOptimizerWithInlining::functionMult"));
306     }
307 
308     private static boolean isShiftFunction(JavaOp.InvokeOp invokeOp) {
309         return invokeOp.invokeDescriptor().toString().contains("functionShift");
310     }
311 }