1 /*
  2  * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 package oracle.code.samples;
 26 
 27 import jdk.incubator.code.CodeReflection;
 28 import jdk.incubator.code.CopyContext;
 29 import jdk.incubator.code.Location;
 30 import jdk.incubator.code.Op;
 31 import jdk.incubator.code.OpTransformer;
 32 import jdk.incubator.code.TypeElement;
 33 import jdk.incubator.code.Value;
 34 import jdk.incubator.code.analysis.SSA;
 35 import jdk.incubator.code.bytecode.BytecodeGenerator;
 36 import jdk.incubator.code.dialect.core.CoreOp;
 37 import jdk.incubator.code.dialect.java.JavaOp;
 38 import jdk.incubator.code.dialect.java.JavaType;
 39 import jdk.incubator.code.dialect.java.MethodRef;
 40 import jdk.incubator.code.interpreter.Interpreter;
 41 
 42 import java.lang.invoke.MethodHandle;
 43 import java.lang.invoke.MethodHandles;
 44 import java.lang.reflect.Method;
 45 import java.util.ArrayList;
 46 import java.util.List;
 47 import java.util.Optional;
 48 import java.util.stream.Stream;
 49 
 50 /**
 51  * Simple example of how to use the code reflection API.
 52  *
 53  * <p>
 54  * This example replaces a math function Math.pow with an optimized function using code transforms
 55  * from the code-reflection API. The optimized function can be applied only under certain conditions.
 56  * </p>
 57  *
 58  * <p>
 59  * Optimizations:
 60  * 1) Replace Pow(x, y) when x == 2 to (1 << y), if only if the parameter y is an integer.
 61  * 2) Replace Pow(x, y) when y == 2 to (x * x).
 62  * </p>
 63  *
 64  * <p>
 65  *     Babylon repository: {@see <a href="https://github.com/openjdk/babylon/tree/code-reflection">link</a>}
 66  * </p>
 67  *
 68  * <p>
 69  *     How to run?
 70  *     <code>
 71  *         java --add-modules jdk.incubator.code -cp target/crsamples-1.0-SNAPSHOT.jar oracle.code.samples.MathOptimizer
 72  *     </code>
 73  * </p>:
 74  */
 75 public class MathOptimizer {
 76 
 77     @CodeReflection
 78     private static double myFunction(int value) {
 79         return Math.pow(2, value);
 80     }
 81 
 82     // if pow(2, x), then substitute for this function
 83     // We could apply this function if, at runtime, user pass int values to the pow function
 84     // Thus, we narrow the result type from 8 bytes (double) to 4 bytes (INT).
 85     private static int functionShift(int val) {
 86         return 1 << val;
 87     }
 88 
 89     // if pow(x, 2) then substitute for this function
 90     private static double functionMult(double x) {
 91         return x * x;
 92     }
 93 
 94     private static final MethodRef MY_SHIFT_FUNCTION = MethodRef.method(MathOptimizer.class, "functionShift", int.class, int.class);
 95 
 96     private static final MethodRef MY_MULT_FUNCTION = MethodRef.method(MathOptimizer.class, "functionMult", double.class, double.class);
 97 
 98     // Analyze type methods: taken from example of String Concat Transformer to traverse the tree.
 99     static boolean analyseType(JavaOp.ConvOp convOp, JavaType typeToMatch) {
100         return analyseType(convOp.operands().get(0), typeToMatch);
101     }
102 
103     static boolean analyseType(Value v, JavaType typeToMatch) {
104         // Maybe there is a utility already to do tree traversal
105         if (v instanceof Op.Result r && r.op() instanceof JavaOp.ConvOp convOp) {
106             // Node of tree, recursively traverse the operands
107             return analyseType(convOp, typeToMatch);
108         } else {
109             // Leaf of tree: analyze type
110             TypeElement type = v.type();
111             return type.equals(typeToMatch);
112         }
113     }
114 
115     static void main(String[] args) throws Throwable {
116 
117         Optional<Method> myFunction = Stream.of(MathOptimizer.class.getDeclaredMethods())
118                 .filter(m -> m.getName().equals("myFunction"))
119                 .findFirst();
120 
121         Method myMathMethod = myFunction.get();
122 
123         // Obtain the code model for the annotated method
124         CoreOp.FuncOp codeModel = Op.ofMethod(myMathMethod).get();
125         System.out.println(codeModel.toText());
126 
127         // In addition, we can generate bytecodes from a new code model that
128         // has been transformed.
129         MethodHandle mhNewTransform = BytecodeGenerator.generate(MethodHandles.lookup(), codeModel);
130         // And invoke the method handle result
131         var resultBC = mhNewTransform.invoke( 10);
132         System.out.println("Result after BC generation: " + resultBC);
133 
134         System.out.println("\nLet's transform the code");
135         codeModel = codeModel.transform(CopyContext.create(), (blockBuilder, op) -> {
136             switch (op) {
137                 case JavaOp.InvokeOp invokeOp when whenIsMathPowFunction(invokeOp) -> {
138                     // The idea here is to create a new JavaOp.invoke with the optimization and replace it.
139                     List<Value> operands = blockBuilder.context().getValues(op.operands());
140 
141                     // Analyse second operand of the Math.pow(x, y).
142                     // if the x == 2, and both are integers, then we can optimize the function using bitwise operations
143                     // pow(2, y) replace with (1 << y)
144                     Value operand = operands.getFirst();  // obtain the first parameter
145                     // inspect if the base (as in pow(base, exp) is value 2
146                     boolean canApplyBitShift = inspectParameterRecursive(operand, 2);
147                     if (canApplyBitShift) {
148                         // We also need to inspect types. We can apply this optimization
149                         // if the exp type is also an integer.
150                         boolean isIntType = analyseType(operands.get(1), JavaType.INT);
151                         if (!isIntType) {
152                             canApplyBitShift = false;
153                         }
154                     }
155 
156                     // If the conditions to apply the first optimization failed, we try the second optimization
157                     // if types are not int, and base is not 2.
158                     // pow(x, 2) => replace with x * x
159                     boolean canApplyMultiplication = false;
160                     if (!canApplyBitShift) {
161                         // inspect if exp (as in pow(base, exp) is value 2
162                         canApplyMultiplication = inspectParameterRecursive(operands.get(1), 2);
163                     }
164 
165                     if (canApplyBitShift) {
166                         // Narrow type from DOUBLE to INT for the input parameter of the new function.
167                         Op.Result op2 = blockBuilder.op(JavaOp.conv(JavaType.INT, operands.get(1)));
168                         List<Value> newOperandList = new ArrayList<>();
169                         newOperandList.add(op2);
170 
171                         // Create a new invoke with the optimised method
172                         JavaOp.InvokeOp newInvoke = JavaOp.invoke(MY_SHIFT_FUNCTION, newOperandList);
173                         // Copy the original location info to the new invoke
174                         newInvoke.setLocation(invokeOp.location());
175 
176                         // Replace the invoke node with the new optimized invoke
177                         Op.Result newResult = blockBuilder.op(newInvoke);
178                         // Apply type conversion to double
179                         newResult = blockBuilder.op(JavaOp.conv(JavaType.DOUBLE, newResult));
180                         // Propagate the new result
181                         blockBuilder.context().mapValue(invokeOp.result(), newResult);
182 
183                     } else if (canApplyMultiplication) {
184                         // Adapt the parameters to the new function. We only need the first
185                         // parameter from the initial parameter list  - pow(x, 2) -
186                         // Create a new invoke function with the optimised method
187                         JavaOp.InvokeOp newInvoke = JavaOp.invoke(MY_MULT_FUNCTION, operands.get(0));
188                         // Copy the location info to the new invoke
189                         newInvoke.setLocation(invokeOp.location());
190 
191                         // Replace the invoke node with the new optimized invoke
192                         Op.Result newResult = blockBuilder.op(newInvoke);
193                         blockBuilder.context().mapValue(invokeOp.result(), newResult);
194 
195                     } else {
196                         // ignore the transformation
197                         blockBuilder.op(op);
198                     }
199                 }
200                 default -> blockBuilder.op(op);
201             }
202             return blockBuilder;
203         });
204 
205         System.out.println("AFTER TRANSFORM: ");
206         System.out.println(codeModel.toText());
207         codeModel = codeModel.transform(OpTransformer.LOWERING_TRANSFORMER);
208         System.out.println("After Lowering: ");
209         System.out.println(codeModel.toText());
210 
211         System.out.println("\nEvaluate");
212         // The Interpreter Invoke should launch new exceptions
213         var result = Interpreter.invoke(MethodHandles.lookup(), codeModel, 10);
214         System.out.println(result);
215 
216         // Select invocation calls and display the lines
217         System.out.println("\nPlaying with Traverse");
218         codeModel.traverse(null, (map, op) -> {
219             if (op instanceof JavaOp.InvokeOp invokeOp) {
220                 System.out.println("Function Name: " + invokeOp.invokeDescriptor().name());
221 
222                 // Maybe Location should throw a new exception instead of the NPE,
223                 // since it is possible we don't have a location after a transformation has been done.
224                 Location location = invokeOp.location();
225                 if (location != null) {
226                     int line = location.line();
227                     System.out.println("Line " + line);
228                     System.out.println("Class: " + invokeOp.getClass());
229                     // Detect Math::pow
230                     boolean contains = invokeOp.invokeDescriptor().equals(JAVA_LANG_MATH_POW);
231                     if (contains) {
232                         System.out.println("Method: " + invokeOp.invokeDescriptor().name());
233                     }
234                 } else {
235                     System.out.println("[WARNING] Location is null");
236                 }
237             }
238             return map;
239         });
240 
241         // In addition, we can generate bytecodes from a new code model that
242         // has been transformed.
243         MethodHandle methodHandle = BytecodeGenerator.generate(MethodHandles.lookup(), codeModel);
244         // And invoke the method handle result
245         var resultBC2 = methodHandle.invoke( 10);
246         System.out.println("Result after BC generation: " + resultBC2);
247     }
248 
249     // Goal: obtain and check the value of the function parameters.
250     // The function inspectParameterRecursive implements this method
251     // in a much simpler and shorter manner. We can keep this first
252     // implementation as a reference.
253     private static boolean inspectParameter(Value operand, final int value) {
254         final Boolean[] isMultipliedByTwo = new Boolean[] { false };
255         if (operand instanceof Op.Result res) {
256             if (res.op() instanceof JavaOp.ConvOp convOp) {
257                 convOp.operands().forEach(v -> {
258                     if (v instanceof Op.Result res2) {
259                         if (res2.op() instanceof CoreOp.ConstantOp constantOp) {
260                             if (constantOp.value() instanceof Integer parameter) {
261                                 if (parameter.intValue() == value) {
262                                     // Transformation is valid
263                                     isMultipliedByTwo[0] = true;
264                                 }
265                             }
266                         }
267                     }
268                 });
269             }
270         }
271         return isMultipliedByTwo[0];
272     }
273 
274     // Inspect a value for a parameter
275     static boolean inspectParameterRecursive(JavaOp.ConvOp convOp, int valToMatch) {
276         return inspectParameterRecursive(convOp.operands().get(0), valToMatch);
277     }
278 
279     static boolean inspectParameterRecursive(Value v, int valToMatch) {
280         if (v instanceof Op.Result r && r.op() instanceof JavaOp.ConvOp convOp) {
281             return inspectParameterRecursive(convOp, valToMatch);
282         } else {
283             // Leaf of tree - we want to analyse the value
284             if (v instanceof CoreOp.Result r && r.op() instanceof CoreOp.ConstantOp constant) {
285                 return constant.value().equals(valToMatch);
286             }
287             return false;
288         }
289     }
290 
291     static final MethodRef JAVA_LANG_MATH_POW = MethodRef.method(Math.class, "pow", double.class, double.class, double.class);
292 
293     private static boolean whenIsMathPowFunction(JavaOp.InvokeOp invokeOp) {
294         return invokeOp.invokeDescriptor().equals(JAVA_LANG_MATH_POW);
295     }
296 }