1 /*
2 * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation. Oracle designates this
8 * particular file as subject to the "Classpath" exception as provided
9 * by Oracle in the LICENSE file that accompanied this code.
10 *
11 * This code is distributed in the hope that it will be useful, but WITHOUT
12 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
13 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
14 * version 2 for more details (a copy is included in the LICENSE file that
15 * accompanied this code).
16 *
17 * You should have received a copy of the GNU General Public License version
18 * 2 along with this work; if not, write to the Free Software Foundation,
19 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
20 *
21 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
22 * or visit www.oracle.com if you need additional information or have any
23 * questions.
24 */
25 package oracle.code.samples;
26
27 import jdk.incubator.code.Reflect;
28 import jdk.incubator.code.Op;
29 import jdk.incubator.code.CodeType;
30 import jdk.incubator.code.Value;
31 import jdk.incubator.code.bytecode.BytecodeGenerator;
32 import jdk.incubator.code.dialect.core.CoreOp;
33 import jdk.incubator.code.dialect.java.JavaOp;
34 import jdk.incubator.code.dialect.java.JavaType;
35 import jdk.incubator.code.dialect.java.MethodRef;
36
37 import java.lang.invoke.MethodHandle;
38 import java.lang.invoke.MethodHandles;
39 import java.lang.reflect.Method;
40 import java.util.ArrayList;
41 import java.util.List;
42 import java.util.Optional;
43 import java.util.stream.Stream;
44
45 import static jdk.incubator.code.CodeTransformer.LOWERING_TRANSFORMER;
46
47 /**
48 * Simple example of how to use the code reflection API.
49 *
50 * <p>
51 * This example replaces a math function Math.pow with an optimized function using code transforms
52 * from the code-reflection API. The optimized function can be applied only under certain conditions.
53 * </p>
54 *
55 * <p>
56 * Optimizations:
57 * 1) Replace Pow(x, y) when x == 2 to (1 << y), if only if the parameter y is an integer.
58 * 2) Replace Pow(x, y) when y == 2 to (x * x).
59 * </p>
60 *
61 * <p>
62 * Babylon repository: {@see <a href="https://github.com/openjdk/babylon/tree/code-reflection">link</a>}
63 * </p>
64 *
65 * <p>
66 * How to run?
67 * <code>
68 * java --add-modules jdk.incubator.code -cp target/crsamples-1.0-SNAPSHOT.jar oracle.code.samples.MathOptimizer
69 * </code>
70 * </p>:
71 */
72 public class MathOptimizer {
73
74 @Reflect
75 private static double myFunction(int value) {
76 return Math.pow(2, value);
77 }
78
79 // if pow(2, x), then substitute for this function
80 // We could apply this function if, at runtime, user pass int values to the pow function
81 // Thus, we narrow the result type from 8 bytes (double) to 4 bytes (INT).
82 private static int functionShift(int val) {
83 return 1 << val;
84 }
85
86 // if pow(x, 2) then substitute for this function
87 private static double functionMult(double x) {
88 return x * x;
89 }
90
91 private static final MethodRef MY_SHIFT_FUNCTION = MethodRef.method(MathOptimizer.class, "functionShift", int.class, int.class);
92
93 private static final MethodRef MY_MULT_FUNCTION = MethodRef.method(MathOptimizer.class, "functionMult", double.class, double.class);
94
95 // Analyze type methods: taken from example of String Concat Transformer to traverse the tree.
96 static boolean analyseType(JavaOp.ConvOp convOp, JavaType typeToMatch) {
97 return analyseType(convOp.operands().get(0), typeToMatch);
98 }
99
100 static boolean analyseType(Value v, JavaType typeToMatch) {
101 // Maybe there is a utility already to do tree traversal
102 if (v instanceof Op.Result r && r.op() instanceof JavaOp.ConvOp convOp) {
103 // Node of tree, recursively traverse the operands
104 return analyseType(convOp, typeToMatch);
105 } else {
106 // Leaf of tree: analyze type
107 CodeType type = v.type();
108 return type.equals(typeToMatch);
109 }
110 }
111
112 static void main() {
113
114 Optional<Method> myFunction = Stream.of(MathOptimizer.class.getDeclaredMethods())
115 .filter(m -> m.getName().equals("myFunction"))
116 .findFirst();
117
118 Method myMathMethod = myFunction.get();
119
120 // Obtain the code model for the annotated method
121 CoreOp.FuncOp codeModel = Op.ofMethod(myMathMethod).get();
122 IO.println(codeModel.toText());
123
124 IO.println("\nLet's transform the code");
125 codeModel = codeModel.transform((blockBuilder, op) -> {
126 switch (op) {
127 case JavaOp.InvokeOp invokeOp when whenIsMathPowFunction(invokeOp) -> {
128 // The idea here is to create a new JavaOp.invoke with the optimization and replace it.
129 List<Value> operands = blockBuilder.context().getValues(op.operands());
130
131 // Analyse second operand of the Math.pow(x, y).
132 // if the x == 2, and both are integers, then we can optimize the function using bitwise operations
133 // pow(2, y) replace with (1 << y)
134 Value operand = operands.getFirst(); // obtain the first parameter
135 // inspect if the base (as in pow(base, exp) is value 2
136 boolean canApplyBitShift = inspectParameterRecursive(operand, 2);
137 if (canApplyBitShift) {
138 // We also need to inspect types. We can apply this optimization
139 // if the exp type is also an integer.
140 boolean isIntType = analyseType(operands.get(1), JavaType.INT);
141 if (!isIntType) {
142 canApplyBitShift = false;
143 }
144 }
145
146 // If the conditions to apply the first optimization failed, we try the second optimization
147 // if types are not int, and base is not 2.
148 // pow(x, 2) => replace with x * x
149 boolean canApplyMultiplication = false;
150 if (!canApplyBitShift) {
151 // inspect if exp (as in pow(base, exp) is value 2
152 canApplyMultiplication = inspectParameterRecursive(operands.get(1), 2);
153 }
154
155 if (canApplyBitShift) {
156 // Narrow type from DOUBLE to INT for the input parameter of the new function.
157 Op.Result op2 = blockBuilder.op(JavaOp.conv(JavaType.INT, operands.get(1)));
158 List<Value> newOperandList = new ArrayList<>();
159 newOperandList.add(op2);
160
161 // Create a new invoke with the optimised method
162 JavaOp.InvokeOp newInvoke = JavaOp.invoke(MY_SHIFT_FUNCTION, newOperandList);
163 // Copy the original location info to the new invoke
164 newInvoke.setLocation(invokeOp.location());
165
166 // Replace the invoke node with the new optimized invoke
167 Op.Result newResult = blockBuilder.op(newInvoke);
168 // Apply type conversion to double
169 newResult = blockBuilder.op(JavaOp.conv(JavaType.DOUBLE, newResult));
170 // Propagate the new result
171 blockBuilder.context().mapValue(invokeOp.result(), newResult);
172
173 } else if (canApplyMultiplication) {
174 // Adapt the parameters to the new function. We only need the first
175 // parameter from the initial parameter list - pow(x, 2) -
176 // Create a new invoke function with the optimised method
177 JavaOp.InvokeOp newInvoke = JavaOp.invoke(MY_MULT_FUNCTION, operands.get(0));
178 // Copy the location info to the new invoke
179 newInvoke.setLocation(invokeOp.location());
180
181 // Replace the invoke node with the new optimized invoke
182 Op.Result newResult = blockBuilder.op(newInvoke);
183 blockBuilder.context().mapValue(invokeOp.result(), newResult);
184
185 } else {
186 // ignore the transformation
187 blockBuilder.op(op);
188 }
189 }
190 default -> blockBuilder.op(op);
191 }
192 return blockBuilder;
193 });
194
195 IO.println("AFTER TRANSFORM: ");
196 IO.println(codeModel.toText());
197 codeModel = codeModel.transform(LOWERING_TRANSFORMER);
198 IO.println("After Lowering: ");
199 IO.println(codeModel.toText());
200
201 // Select invocation calls and display the lines
202 IO.println("\nPlaying with Traverse");
203 codeModel.elements().forEach(e -> {
204 if (e instanceof JavaOp.InvokeOp invokeOp) {
205 IO.println("Function Name: " + invokeOp.invokeReference().name());
206
207 // Maybe Location should throw a new exception instead of the NPE,
208 // since it is possible we don't have a location after a transformation has been done.
209 Op.Location location = invokeOp.location();
210 if (location != null) {
211 int line = location.line();
212 IO.println("Line " + line);
213 IO.println("Class: " + invokeOp.getClass());
214 // Detect Math::pow
215 boolean contains = invokeOp.invokeReference().equals(JAVA_LANG_MATH_POW);
216 if (contains) {
217 System.out.println("Method: " + invokeOp.invokeReference().name());
218 }
219 } else {
220 IO.println("[WARNING] Location is null");
221 }
222 }
223 });
224
225 // In addition, we can generate bytecodes from a new code model that
226 // has been transformed.
227 MethodHandle methodHandle = BytecodeGenerator.generate(MethodHandles.lookup(), codeModel);
228 // And invoke the method handle result
229 try {
230 double result = (double) methodHandle.invoke(10);
231 double checkResult = myFunction(10);
232 IO.println("Result after BC generation: " + result);
233 IO.println("Is correct? " + (checkResult == result));
234 } catch (Throwable e) {
235 throw new RuntimeException(e);
236 }
237 }
238
239 // Inspect a value for a parameter
240 static boolean inspectParameterRecursive(JavaOp.ConvOp convOp, int valToMatch) {
241 return inspectParameterRecursive(convOp.operands().get(0), valToMatch);
242 }
243
244 static boolean inspectParameterRecursive(Value v, int valToMatch) {
245 if (v instanceof Op.Result r && r.op() instanceof JavaOp.ConvOp convOp) {
246 return inspectParameterRecursive(convOp, valToMatch);
247 } else {
248 // Leaf of tree - we want to analyse the value
249 if (v instanceof CoreOp.Result r && r.op() instanceof CoreOp.ConstantOp constant) {
250 return constant.value().equals(valToMatch);
251 }
252 return false;
253 }
254 }
255
256 static final MethodRef JAVA_LANG_MATH_POW = MethodRef.method(Math.class, "pow", double.class, double.class, double.class);
257
258 private static boolean whenIsMathPowFunction(JavaOp.InvokeOp invokeOp) {
259 return invokeOp.invokeReference().equals(JAVA_LANG_MATH_POW);
260 }
261 }