1 /*
2 * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation. Oracle designates this
8 * particular file as subject to the "Classpath" exception as provided
9 * by Oracle in the LICENSE file that accompanied this code.
10 *
11 * This code is distributed in the hope that it will be useful, but WITHOUT
12 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
13 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
14 * version 2 for more details (a copy is included in the LICENSE file that
15 * accompanied this code).
16 *
17 * You should have received a copy of the GNU General Public License version
18 * 2 along with this work; if not, write to the Free Software Foundation,
19 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
20 *
21 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
22 * or visit www.oracle.com if you need additional information or have any
23 * questions.
24 */
25 package oracle.code.samples;
26
27 import jdk.incubator.code.CodeReflection;
28 import jdk.incubator.code.CopyContext;
29 import jdk.incubator.code.Location;
30 import jdk.incubator.code.Op;
31 import jdk.incubator.code.OpTransformer;
32 import jdk.incubator.code.TypeElement;
33 import jdk.incubator.code.Value;
34 import jdk.incubator.code.analysis.SSA;
35 import jdk.incubator.code.bytecode.BytecodeGenerator;
36 import jdk.incubator.code.dialect.core.CoreOp;
37 import jdk.incubator.code.dialect.java.JavaOp;
38 import jdk.incubator.code.dialect.java.JavaType;
39 import jdk.incubator.code.dialect.java.MethodRef;
40 import jdk.incubator.code.interpreter.Interpreter;
41
42 import java.lang.invoke.MethodHandle;
43 import java.lang.invoke.MethodHandles;
44 import java.lang.reflect.Method;
45 import java.util.ArrayList;
46 import java.util.List;
47 import java.util.Optional;
48 import java.util.stream.Stream;
49
50 /**
51 * Simple example of how to use the code reflection API.
52 *
53 * <p>
54 * This example replaces a math function Math.pow with an optimized function using code transforms
55 * from the code-reflection API. The optimized function can be applied only under certain conditions.
56 * </p>
57 *
58 * <p>
59 * Optimizations:
60 * 1) Replace Pow(x, y) when x == 2 to (1 << y), if only if the parameter y is an integer.
61 * 2) Replace Pow(x, y) when y == 2 to (x * x).
62 * </p>
63 *
64 * <p>
65 * Babylon repository: {@see <a href="https://github.com/openjdk/babylon/tree/code-reflection">link</a>}
66 * </p>
67 *
68 * <p>
69 * How to run?
70 * <code>
71 * java --add-modules jdk.incubator.code -cp target/crsamples-1.0-SNAPSHOT.jar oracle.code.samples.MathOptimizer
72 * </code>
73 * </p>:
74 */
75 public class MathOptimizer {
76
77 @CodeReflection
78 private static double myFunction(int value) {
79 return Math.pow(2, value);
80 }
81
82 // if pow(2, x), then substitute for this function
83 // We could apply this function if, at runtime, user pass int values to the pow function
84 // Thus, we narrow the result type from 8 bytes (double) to 4 bytes (INT).
85 private static int functionShift(int val) {
86 return 1 << val;
87 }
88
89 // if pow(x, 2) then substitute for this function
90 private static double functionMult(double x) {
91 return x * x;
92 }
93
94 private static final MethodRef MY_SHIFT_FUNCTION = MethodRef.method(MathOptimizer.class, "functionShift", int.class, int.class);
95
96 private static final MethodRef MY_MULT_FUNCTION = MethodRef.method(MathOptimizer.class, "functionMult", double.class, double.class);
97
98 // Analyze type methods: taken from example of String Concat Transformer to traverse the tree.
99 static boolean analyseType(JavaOp.ConvOp convOp, JavaType typeToMatch) {
100 return analyseType(convOp.operands().get(0), typeToMatch);
101 }
102
103 static boolean analyseType(Value v, JavaType typeToMatch) {
104 // Maybe there is a utility already to do tree traversal
105 if (v instanceof Op.Result r && r.op() instanceof JavaOp.ConvOp convOp) {
106 // Node of tree, recursively traverse the operands
107 return analyseType(convOp, typeToMatch);
108 } else {
109 // Leaf of tree: analyze type
110 TypeElement type = v.type();
111 return type.equals(typeToMatch);
112 }
113 }
114
115 static void main(String[] args) throws Throwable {
116
117 Optional<Method> myFunction = Stream.of(MathOptimizer.class.getDeclaredMethods())
118 .filter(m -> m.getName().equals("myFunction"))
119 .findFirst();
120
121 Method myMathMethod = myFunction.get();
122
123 // Obtain the code model for the annotated method
124 CoreOp.FuncOp codeModel = Op.ofMethod(myMathMethod).get();
125 System.out.println(codeModel.toText());
126
127 // In addition, we can generate bytecodes from a new code model that
128 // has been transformed.
129 MethodHandle mhNewTransform = BytecodeGenerator.generate(MethodHandles.lookup(), codeModel);
130 // And invoke the method handle result
131 var resultBC = mhNewTransform.invoke( 10);
132 System.out.println("Result after BC generation: " + resultBC);
133
134 System.out.println("\nLet's transform the code");
135 codeModel = codeModel.transform(CopyContext.create(), (blockBuilder, op) -> {
136 switch (op) {
137 case JavaOp.InvokeOp invokeOp when whenIsMathPowFunction(invokeOp) -> {
138 // The idea here is to create a new JavaOp.invoke with the optimization and replace it.
139 List<Value> operands = blockBuilder.context().getValues(op.operands());
140
141 // Analyse second operand of the Math.pow(x, y).
142 // if the x == 2, and both are integers, then we can optimize the function using bitwise operations
143 // pow(2, y) replace with (1 << y)
144 Value operand = operands.getFirst(); // obtain the first parameter
145 // inspect if the base (as in pow(base, exp) is value 2
146 boolean canApplyBitShift = inspectParameterRecursive(operand, 2);
147 if (canApplyBitShift) {
148 // We also need to inspect types. We can apply this optimization
149 // if the exp type is also an integer.
150 boolean isIntType = analyseType(operands.get(1), JavaType.INT);
151 if (!isIntType) {
152 canApplyBitShift = false;
153 }
154 }
155
156 // If the conditions to apply the first optimization failed, we try the second optimization
157 // if types are not int, and base is not 2.
158 // pow(x, 2) => replace with x * x
159 boolean canApplyMultiplication = false;
160 if (!canApplyBitShift) {
161 // inspect if exp (as in pow(base, exp) is value 2
162 canApplyMultiplication = inspectParameterRecursive(operands.get(1), 2);
163 }
164
165 if (canApplyBitShift) {
166 // Narrow type from DOUBLE to INT for the input parameter of the new function.
167 Op.Result op2 = blockBuilder.op(JavaOp.conv(JavaType.INT, operands.get(1)));
168 List<Value> newOperandList = new ArrayList<>();
169 newOperandList.add(op2);
170
171 // Create a new invoke with the optimised method
172 JavaOp.InvokeOp newInvoke = JavaOp.invoke(MY_SHIFT_FUNCTION, newOperandList);
173 // Copy the original location info to the new invoke
174 newInvoke.setLocation(invokeOp.location());
175
176 // Replace the invoke node with the new optimized invoke
177 Op.Result newResult = blockBuilder.op(newInvoke);
178 // Apply type conversion to double
179 newResult = blockBuilder.op(JavaOp.conv(JavaType.DOUBLE, newResult));
180 // Propagate the new result
181 blockBuilder.context().mapValue(invokeOp.result(), newResult);
182
183 } else if (canApplyMultiplication) {
184 // Adapt the parameters to the new function. We only need the first
185 // parameter from the initial parameter list - pow(x, 2) -
186 // Create a new invoke function with the optimised method
187 JavaOp.InvokeOp newInvoke = JavaOp.invoke(MY_MULT_FUNCTION, operands.get(0));
188 // Copy the location info to the new invoke
189 newInvoke.setLocation(invokeOp.location());
190
191 // Replace the invoke node with the new optimized invoke
192 Op.Result newResult = blockBuilder.op(newInvoke);
193 blockBuilder.context().mapValue(invokeOp.result(), newResult);
194
195 } else {
196 // ignore the transformation
197 blockBuilder.op(op);
198 }
199 }
200 default -> blockBuilder.op(op);
201 }
202 return blockBuilder;
203 });
204
205 System.out.println("AFTER TRANSFORM: ");
206 System.out.println(codeModel.toText());
207 codeModel = codeModel.transform(OpTransformer.LOWERING_TRANSFORMER);
208 System.out.println("After Lowering: ");
209 System.out.println(codeModel.toText());
210
211 System.out.println("\nEvaluate");
212 // The Interpreter Invoke should launch new exceptions
213 var result = Interpreter.invoke(MethodHandles.lookup(), codeModel, 10);
214 System.out.println(result);
215
216 // Select invocation calls and display the lines
217 System.out.println("\nPlaying with Traverse");
218 codeModel.traverse(null, (map, op) -> {
219 if (op instanceof JavaOp.InvokeOp invokeOp) {
220 System.out.println("Function Name: " + invokeOp.invokeDescriptor().name());
221
222 // Maybe Location should throw a new exception instead of the NPE,
223 // since it is possible we don't have a location after a transformation has been done.
224 Location location = invokeOp.location();
225 if (location != null) {
226 int line = location.line();
227 System.out.println("Line " + line);
228 System.out.println("Class: " + invokeOp.getClass());
229 // Detect Math::pow
230 boolean contains = invokeOp.invokeDescriptor().equals(JAVA_LANG_MATH_POW);
231 if (contains) {
232 System.out.println("Method: " + invokeOp.invokeDescriptor().name());
233 }
234 } else {
235 System.out.println("[WARNING] Location is null");
236 }
237 }
238 return map;
239 });
240
241 // In addition, we can generate bytecodes from a new code model that
242 // has been transformed.
243 MethodHandle methodHandle = BytecodeGenerator.generate(MethodHandles.lookup(), codeModel);
244 // And invoke the method handle result
245 var resultBC2 = methodHandle.invoke( 10);
246 System.out.println("Result after BC generation: " + resultBC2);
247 }
248
249 // Goal: obtain and check the value of the function parameters.
250 // The function inspectParameterRecursive implements this method
251 // in a much simpler and shorter manner. We can keep this first
252 // implementation as a reference.
253 private static boolean inspectParameter(Value operand, final int value) {
254 final Boolean[] isMultipliedByTwo = new Boolean[] { false };
255 if (operand instanceof Op.Result res) {
256 if (res.op() instanceof JavaOp.ConvOp convOp) {
257 convOp.operands().forEach(v -> {
258 if (v instanceof Op.Result res2) {
259 if (res2.op() instanceof CoreOp.ConstantOp constantOp) {
260 if (constantOp.value() instanceof Integer parameter) {
261 if (parameter.intValue() == value) {
262 // Transformation is valid
263 isMultipliedByTwo[0] = true;
264 }
265 }
266 }
267 }
268 });
269 }
270 }
271 return isMultipliedByTwo[0];
272 }
273
274 // Inspect a value for a parameter
275 static boolean inspectParameterRecursive(JavaOp.ConvOp convOp, int valToMatch) {
276 return inspectParameterRecursive(convOp.operands().get(0), valToMatch);
277 }
278
279 static boolean inspectParameterRecursive(Value v, int valToMatch) {
280 if (v instanceof Op.Result r && r.op() instanceof JavaOp.ConvOp convOp) {
281 return inspectParameterRecursive(convOp, valToMatch);
282 } else {
283 // Leaf of tree - we want to analyse the value
284 if (v instanceof CoreOp.Result r && r.op() instanceof CoreOp.ConstantOp constant) {
285 return constant.value().equals(valToMatch);
286 }
287 return false;
288 }
289 }
290
291 static final MethodRef JAVA_LANG_MATH_POW = MethodRef.method(Math.class, "pow", double.class, double.class, double.class);
292
293 private static boolean whenIsMathPowFunction(JavaOp.InvokeOp invokeOp) {
294 return invokeOp.invokeDescriptor().equals(JAVA_LANG_MATH_POW);
295 }
296 }