1 /*
  2  * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 package oracle.code.samples;
 26 
 27 import jdk.incubator.code.Reflect;
 28 import jdk.incubator.code.Op;
 29 import jdk.incubator.code.CodeType;
 30 import jdk.incubator.code.Value;
 31 import jdk.incubator.code.bytecode.BytecodeGenerator;
 32 import jdk.incubator.code.dialect.core.CoreOp;
 33 import jdk.incubator.code.dialect.java.JavaOp;
 34 import jdk.incubator.code.dialect.java.JavaType;
 35 import jdk.incubator.code.dialect.java.MethodRef;
 36 
 37 import java.lang.invoke.MethodHandle;
 38 import java.lang.invoke.MethodHandles;
 39 import java.lang.reflect.Method;
 40 import java.util.ArrayList;
 41 import java.util.List;
 42 import java.util.Optional;
 43 import java.util.stream.Stream;
 44 
 45 import static jdk.incubator.code.CodeTransformer.LOWERING_TRANSFORMER;
 46 
 47 /**
 48  * Simple example of how to use the code reflection API.
 49  *
 50  * <p>
 51  * This example replaces a math function Math.pow with an optimized function using code transforms
 52  * from the code-reflection API. The optimized function can be applied only under certain conditions.
 53  * </p>
 54  *
 55  * <p>
 56  * Optimizations:
 57  * 1) Replace Pow(x, y) when x == 2 to (1 << y), if only if the parameter y is an integer.
 58  * 2) Replace Pow(x, y) when y == 2 to (x * x).
 59  * </p>
 60  *
 61  * <p>
 62  *     Babylon repository: {@see <a href="https://github.com/openjdk/babylon/tree/code-reflection">link</a>}
 63  * </p>
 64  *
 65  * <p>
 66  *     How to run?
 67  *     <code>
 68  *         java --add-modules jdk.incubator.code -cp target/crsamples-1.0-SNAPSHOT.jar oracle.code.samples.MathOptimizer
 69  *     </code>
 70  * </p>:
 71  */
 72 public class MathOptimizer {
 73 
 74     @Reflect
 75     private static double myFunction(int value) {
 76         return Math.pow(2, value);
 77     }
 78 
 79     // if pow(2, x), then substitute for this function
 80     // We could apply this function if, at runtime, user pass int values to the pow function
 81     // Thus, we narrow the result type from 8 bytes (double) to 4 bytes (INT).
 82     private static int functionShift(int val) {
 83         return 1 << val;
 84     }
 85 
 86     // if pow(x, 2) then substitute for this function
 87     private static double functionMult(double x) {
 88         return x * x;
 89     }
 90 
 91     private static final MethodRef MY_SHIFT_FUNCTION = MethodRef.method(MathOptimizer.class, "functionShift", int.class, int.class);
 92 
 93     private static final MethodRef MY_MULT_FUNCTION = MethodRef.method(MathOptimizer.class, "functionMult", double.class, double.class);
 94 
 95     // Analyze type methods: taken from example of String Concat Transformer to traverse the tree.
 96     static boolean analyseType(JavaOp.ConvOp convOp, JavaType typeToMatch) {
 97         return analyseType(convOp.operands().get(0), typeToMatch);
 98     }
 99 
100     static boolean analyseType(Value v, JavaType typeToMatch) {
101         // Maybe there is a utility already to do tree traversal
102         if (v instanceof Op.Result r && r.op() instanceof JavaOp.ConvOp convOp) {
103             // Node of tree, recursively traverse the operands
104             return analyseType(convOp, typeToMatch);
105         } else {
106             // Leaf of tree: analyze type
107             CodeType type = v.type();
108             return type.equals(typeToMatch);
109         }
110     }
111 
112     static void main() {
113 
114         Optional<Method> myFunction = Stream.of(MathOptimizer.class.getDeclaredMethods())
115                 .filter(m -> m.getName().equals("myFunction"))
116                 .findFirst();
117 
118         Method myMathMethod = myFunction.get();
119 
120         // Obtain the code model for the annotated method
121         CoreOp.FuncOp codeModel = Op.ofMethod(myMathMethod).get();
122         IO.println(codeModel.toText());
123 
124         IO.println("\nLet's transform the code");
125         codeModel = codeModel.transform((blockBuilder, op) -> {
126             switch (op) {
127                 case JavaOp.InvokeOp invokeOp when whenIsMathPowFunction(invokeOp) -> {
128                     // The idea here is to create a new JavaOp.invoke with the optimization and replace it.
129                     List<Value> operands = blockBuilder.context().getValues(op.operands());
130 
131                     // Analyse second operand of the Math.pow(x, y).
132                     // if the x == 2, and both are integers, then we can optimize the function using bitwise operations
133                     // pow(2, y) replace with (1 << y)
134                     Value operand = operands.getFirst();  // obtain the first parameter
135                     // inspect if the base (as in pow(base, exp) is value 2
136                     boolean canApplyBitShift = inspectParameterRecursive(operand, 2);
137                     if (canApplyBitShift) {
138                         // We also need to inspect types. We can apply this optimization
139                         // if the exp type is also an integer.
140                         boolean isIntType = analyseType(operands.get(1), JavaType.INT);
141                         if (!isIntType) {
142                             canApplyBitShift = false;
143                         }
144                     }
145 
146                     // If the conditions to apply the first optimization failed, we try the second optimization
147                     // if types are not int, and base is not 2.
148                     // pow(x, 2) => replace with x * x
149                     boolean canApplyMultiplication = false;
150                     if (!canApplyBitShift) {
151                         // inspect if exp (as in pow(base, exp) is value 2
152                         canApplyMultiplication = inspectParameterRecursive(operands.get(1), 2);
153                     }
154 
155                     if (canApplyBitShift) {
156                         // Narrow type from DOUBLE to INT for the input parameter of the new function.
157                         Op.Result op2 = blockBuilder.op(JavaOp.conv(JavaType.INT, operands.get(1)));
158                         List<Value> newOperandList = new ArrayList<>();
159                         newOperandList.add(op2);
160 
161                         // Create a new invoke with the optimised method
162                         JavaOp.InvokeOp newInvoke = JavaOp.invoke(MY_SHIFT_FUNCTION, newOperandList);
163                         // Copy the original location info to the new invoke
164                         newInvoke.setLocation(invokeOp.location());
165 
166                         // Replace the invoke node with the new optimized invoke
167                         Op.Result newResult = blockBuilder.op(newInvoke);
168                         // Apply type conversion to double
169                         newResult = blockBuilder.op(JavaOp.conv(JavaType.DOUBLE, newResult));
170                         // Propagate the new result
171                         blockBuilder.context().mapValue(invokeOp.result(), newResult);
172 
173                     } else if (canApplyMultiplication) {
174                         // Adapt the parameters to the new function. We only need the first
175                         // parameter from the initial parameter list  - pow(x, 2) -
176                         // Create a new invoke function with the optimised method
177                         JavaOp.InvokeOp newInvoke = JavaOp.invoke(MY_MULT_FUNCTION, operands.get(0));
178                         // Copy the location info to the new invoke
179                         newInvoke.setLocation(invokeOp.location());
180 
181                         // Replace the invoke node with the new optimized invoke
182                         Op.Result newResult = blockBuilder.op(newInvoke);
183                         blockBuilder.context().mapValue(invokeOp.result(), newResult);
184 
185                     } else {
186                         // ignore the transformation
187                         blockBuilder.op(op);
188                     }
189                 }
190                 default -> blockBuilder.op(op);
191             }
192             return blockBuilder;
193         });
194 
195         IO.println("AFTER TRANSFORM: ");
196         IO.println(codeModel.toText());
197         codeModel = codeModel.transform(LOWERING_TRANSFORMER);
198         IO.println("After Lowering: ");
199         IO.println(codeModel.toText());
200 
201         // Select invocation calls and display the lines
202         IO.println("\nPlaying with Traverse");
203         codeModel.elements().forEach(e -> {
204             if (e instanceof JavaOp.InvokeOp invokeOp) {
205                 IO.println("Function Name: " + invokeOp.invokeReference().name());
206 
207                 // Maybe Location should throw a new exception instead of the NPE,
208                 // since it is possible we don't have a location after a transformation has been done.
209                 Op.Location location = invokeOp.location();
210                 if (location != null) {
211                     int line = location.line();
212                     IO.println("Line " + line);
213                     IO.println("Class: " + invokeOp.getClass());
214                     // Detect Math::pow
215                     boolean contains = invokeOp.invokeReference().equals(JAVA_LANG_MATH_POW);
216                     if (contains) {
217                         System.out.println("Method: " + invokeOp.invokeReference().name());
218                     }
219                 } else {
220                     IO.println("[WARNING] Location is null");
221                 }
222             }
223         });
224 
225         // In addition, we can generate bytecodes from a new code model that
226         // has been transformed.
227         MethodHandle methodHandle = BytecodeGenerator.generate(MethodHandles.lookup(), codeModel);
228         // And invoke the method handle result
229         try {
230             double result = (double) methodHandle.invoke(10);
231             double checkResult = myFunction(10);
232             IO.println("Result after BC generation: " + result);
233             IO.println("Is correct? " + (checkResult == result));
234         } catch (Throwable e) {
235             throw new RuntimeException(e);
236         }
237     }
238 
239     // Inspect a value for a parameter
240     static boolean inspectParameterRecursive(JavaOp.ConvOp convOp, int valToMatch) {
241         return inspectParameterRecursive(convOp.operands().get(0), valToMatch);
242     }
243 
244     static boolean inspectParameterRecursive(Value v, int valToMatch) {
245         if (v instanceof Op.Result r && r.op() instanceof JavaOp.ConvOp convOp) {
246             return inspectParameterRecursive(convOp, valToMatch);
247         } else {
248             // Leaf of tree - we want to analyse the value
249             if (v instanceof CoreOp.Result r && r.op() instanceof CoreOp.ConstantOp constant) {
250                 return constant.value().equals(valToMatch);
251             }
252             return false;
253         }
254     }
255 
256     static final MethodRef JAVA_LANG_MATH_POW = MethodRef.method(Math.class, "pow", double.class, double.class, double.class);
257 
258     private static boolean whenIsMathPowFunction(JavaOp.InvokeOp invokeOp) {
259         return invokeOp.invokeReference().equals(JAVA_LANG_MATH_POW);
260     }
261 }