1 /*
2 * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation. Oracle designates this
8 * particular file as subject to the "Classpath" exception as provided
9 * by Oracle in the LICENSE file that accompanied this code.
10 *
11 * This code is distributed in the hope that it will be useful, but WITHOUT
12 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
13 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
14 * version 2 for more details (a copy is included in the LICENSE file that
15 * accompanied this code).
16 *
17 * You should have received a copy of the GNU General Public License version
18 * 2 along with this work; if not, write to the Free Software Foundation,
19 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
20 *
21 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
22 * or visit www.oracle.com if you need additional information or have any
23 * questions.
24 */
25 package oracle.code.samples;
26
27 import jdk.incubator.code.CodeTransformer;
28 import jdk.incubator.code.Reflect;
29 import jdk.incubator.code.Op;
30 import jdk.incubator.code.TypeElement;
31 import jdk.incubator.code.Value;
32 import jdk.incubator.code.bytecode.BytecodeGenerator;
33 import jdk.incubator.code.dialect.core.CoreOp;
34 import jdk.incubator.code.dialect.core.Inliner;
35 import jdk.incubator.code.dialect.java.JavaOp;
36 import jdk.incubator.code.dialect.java.JavaType;
37 import jdk.incubator.code.dialect.java.MethodRef;
38
39 import java.lang.invoke.MethodHandle;
40 import java.lang.invoke.MethodHandles;
41 import java.lang.reflect.Method;
42 import java.util.ArrayList;
43 import java.util.List;
44 import java.util.Objects;
45 import java.util.Optional;
46 import java.util.concurrent.atomic.AtomicReference;
47 import java.util.stream.Stream;
48
49 /**
50 * Simple example of how to use the code reflection API.
51 *
52 * <p>
53 * This example is based on the {@link MathOptimizer} example to add inlining for the new replaced
54 * methods. In this example, we focus on explaining in detail the inlining component.
55 * </p>
56 *
57 * <p>
58 * Optimizations:
59 * <ol>Replace Pow(x, y) when x == 2 to (1 << y), if only if the parameter y is an integer.</ol>
60 * <ol>Replace Pow(x, y) when y == 2 to (x * x).</ol>
61 * <ol>After a replacement has been done, we inline the new invoke into the main code model.</ol>
62 * </p>
63 *
64 * <p>
65 * In a nutshell, we apply a second transform to perform the inlining. Note that the inlining could be done
66 * also in within the first transform.
67 * To be able to inline, we need to also annotate the new invoke nodes that were replaced during the first
68 * transform with the <code>@Reflect</code> annotation. In this way, we can build the code models for
69 * each of the methods and apply the inlining directly using the <code>Inliner.inline</code> from the code
70 * reflection API.
71 * </p>
72 *
73 * <p>
74 * Babylon repository: {@see <a href="https://github.com/openjdk/babylon/tree/code-reflection">link</a>}
75 * </p>
76 *
77 * <p>
78 * How to run?
79 * <code>
80 * java --add-modules jdk.incubator.code -cp target/crsamples-1.0-SNAPSHOT.jar oracle.code.samples.MathOptimizerWithInlining
81 * </code>
82 * </p>:
83 */
84 public class MathOptimizerWithInlining {
85
86 // New functions are also annotated with code reflection
87 @Reflect
88 private static double myFunction(int value) {
89 return Math.pow(2, value);
90 }
91
92 @Reflect
93 private static int functionShift(int val) {
94 return 1 << val;
95 }
96
97 @Reflect
98 private static double functionMult(double x) {
99 return x * x;
100 }
101
102 private static final MethodRef MY_SHIFT_FUNCTION = MethodRef.method(MathOptimizerWithInlining.class, "functionShift", int.class, int.class);
103
104 private static final MethodRef MY_MULT_FUNCTION = MethodRef.method(MathOptimizerWithInlining.class, "functionMult", double.class, double.class);
105
106 private static CoreOp.FuncOp buildCodeModelForMethod(Class<?> klass, String methodName) {
107 Optional<Method> function = Stream.of(klass.getDeclaredMethods())
108 .filter(m -> m.getName().equals(methodName))
109 .findFirst();
110 Method method = function.get();
111 CoreOp.FuncOp funcOp = Op.ofMethod(method).get();
112 return funcOp;
113 }
114
115 static void main() {
116
117 // Obtain the code model for the annotated method
118 CoreOp.FuncOp codeModel = buildCodeModelForMethod(MathOptimizerWithInlining.class, "myFunction");
119 System.out.println(codeModel.toText());
120
121 enum FunctionToUse {
122 SHIFT,
123 MULT,
124 GENERIC;
125 }
126
127 AtomicReference<FunctionToUse> replace = new AtomicReference<>(FunctionToUse.GENERIC);
128
129 codeModel = codeModel.transform((blockBuilder, op) -> {
130 // The idea here is to create a new JavaOp.invoke with the optimization and replace it.
131 if (Objects.requireNonNull(op) instanceof JavaOp.InvokeOp invokeOp && whenIsMathPowFunction(invokeOp)) {
132 List<Value> operands = blockBuilder.context().getValues(op.operands());
133
134 // Analyse second operand of the Math.pow(x, y).
135 // if the x == 2, and both are integers, then we can optimize the function using bitwise operations
136 // pow(2, y) replace with (1 << y)
137 Value operand = operands.getFirst(); // obtain the first parameter
138 // inspect if the base (as in pow(base, exp) is value 2
139 boolean canApplyBitShift = inspectParameterRecursive(operand, 2);
140 if (canApplyBitShift) {
141 // We also need to inspect types. We can apply this optimization
142 // if the exp type is also an integer.
143 boolean isIntType = analyseType(operands.get(1), JavaType.INT);
144 if (!isIntType) {
145 canApplyBitShift = false;
146 }
147 }
148
149 // If the conditions to apply the first optimization failed, we try the second optimization
150 // if types are not int, and base is not 2.
151 // pow(x, 2) => replace with x * x
152 boolean canApplyMultiplication = false;
153 if (!canApplyBitShift) {
154 // inspect if exp (as in pow(base, exp) is value 2
155 canApplyMultiplication = inspectParameterRecursive(operands.get(1), 2);
156 }
157
158 if (canApplyBitShift) {
159 // Narrow type from DOUBLE to INT for the input parameter of the new function.
160 Op.Result op2 = blockBuilder.op(JavaOp.conv(JavaType.INT, operands.get(1)));
161 List<Value> newOperandList = new ArrayList<>();
162 newOperandList.add(op2);
163
164 // Create a new invoke with the optimised method
165 JavaOp.InvokeOp newInvoke = JavaOp.invoke(MY_SHIFT_FUNCTION, newOperandList);
166 // Copy the original location info to the new invoke
167 newInvoke.setLocation(invokeOp.location());
168
169 // Replace the invoke node with the new optimized invoke
170 Op.Result newResult = blockBuilder.op(newInvoke);
171 // Type conversion to double
172 newResult = blockBuilder.op(JavaOp.conv(JavaType.DOUBLE, newResult));
173 blockBuilder.context().mapValue(invokeOp.result(), newResult);
174
175 replace.set(FunctionToUse.SHIFT);
176
177 } else if (canApplyMultiplication) {
178 // Adapt the parameters to the new function. We only need the first
179 // parameter from the initial parameter list - pow(x, 2) -
180 // Create a new invoke function with the optimised method
181 JavaOp.InvokeOp newInvoke = JavaOp.invoke(MY_MULT_FUNCTION, operands.get(0));
182 // Copy the location info to the new invoke
183 newInvoke.setLocation(invokeOp.location());
184
185 // Replace the invoke node with the new optimized invoke
186 Op.Result newResult = blockBuilder.op(newInvoke);
187 blockBuilder.context().mapValue(invokeOp.result(), newResult);
188 replace.set(FunctionToUse.MULT);
189
190 } else {
191 // ignore the transformation
192 blockBuilder.op(op);
193 }
194 } else {
195 blockBuilder.op(op);
196 }
197 return blockBuilder;
198 });
199
200 System.out.println("Code Model after the first transform (replace with a new method): ");
201 System.out.println(codeModel.toText());
202
203 // Let's now apply a second transformation
204 // We want to inline the functions. Note that we can apply this transformation if any of the new functions
205 // have been replaced.
206 System.out.println("Second transform: apply inlining for the new methods into the main code model");
207 if (replace.get() != FunctionToUse.GENERIC) {
208
209 // Build code model for the functions we want to inline.
210 // Since we apply two replacements (depending on the values of the input code), we can apply two different
211 // inline functions.
212 CoreOp.FuncOp shiftCodeModel = buildCodeModelForMethod(MathOptimizerWithInlining.class, "functionShift");
213 CoreOp.FuncOp multCodeModel = buildCodeModelForMethod(MathOptimizerWithInlining.class, "functionMult");
214
215 // Apply inlining
216 codeModel = codeModel.transform(codeModel.funcName(),
217 (blockBuilder, op) -> {
218
219 if (op instanceof JavaOp.InvokeOp invokeOp && isMethodWeWantToInline(invokeOp)) {
220 // Let's inline the function
221
222 // 1. Select the function we want to inline.
223 // Since we have two possible replacements, depending on the input code, we need to
224 // apply the corresponding replacement function
225 CoreOp.FuncOp codeModelToInline = isShiftFunction(invokeOp) ? shiftCodeModel : multCodeModel;
226
227 // 2. Apply the inlining
228 Inliner.inline(
229 blockBuilder, // the current block builder
230 codeModelToInline, // the method to inline which we obtained using code reflection too
231 blockBuilder.context().getValues(invokeOp.operands()), // operands to this call. Since we already replace the function,
232 // we can use the same operands as the invoke call
233 (builder, val) -> blockBuilder.context().mapValue(invokeOp.result(), val)); // Propagate the new result
234 } else {
235 // copy the op into the builder if it is not the invoke node we are looking for
236 blockBuilder.op(op);
237 }
238
239 // return new transformed block builder
240 return blockBuilder;
241 });
242
243 System.out.println("After inlining: " + codeModel.toText());
244 }
245
246 codeModel = codeModel.transform(CodeTransformer.LOWERING_TRANSFORMER);
247 System.out.println("After Lowering: ");
248 System.out.println(codeModel.toText());
249
250 System.out.println("\nGenerate bytecode and execute");
251 MethodHandle methodHandle = BytecodeGenerator.generate(MethodHandles.lookup(), codeModel);
252 try {
253 var result = methodHandle.invoke(10);
254 System.out.println(result);
255 } catch (Throwable e) {
256 throw new RuntimeException(e);
257 }
258
259 }
260
261 // Utility methods
262
263 // Analyze type methods: taken from example of String Concat Transformer to traverse the tree.
264 static boolean analyseType(JavaOp.ConvOp convOp, JavaType typeToMatch) {
265 return analyseType(convOp.operands().get(0), typeToMatch);
266 }
267
268 static boolean analyseType(Value v, JavaType typeToMatch) {
269 // Maybe there is a utility already to do tree traversal
270 if (v instanceof Op.Result r && r.op() instanceof JavaOp.ConvOp convOp) {
271 // Node of tree, recursively traverse the operands
272 return analyseType(convOp, typeToMatch);
273 } else {
274 // Leaf of tree: analyze type
275 TypeElement type = v.type();
276 return type.equals(typeToMatch);
277 }
278 }
279
280 // Inspect a value for a parameter
281 static boolean inspectParameterRecursive(JavaOp.ConvOp convOp, int valToMatch) {
282 return inspectParameterRecursive(convOp.operands().get(0), valToMatch);
283 }
284
285 static boolean inspectParameterRecursive(Value v, int valToMatch) {
286 if (v instanceof Op.Result r && r.op() instanceof JavaOp.ConvOp convOp) {
287 return inspectParameterRecursive(convOp, valToMatch);
288 } else {
289 // Leaf of tree - we want to obtain the actual value of the parameter and check
290 if (v instanceof CoreOp.Result r && r.op() instanceof CoreOp.ConstantOp constant) {
291 return constant.value().equals(valToMatch);
292 }
293 return false;
294 }
295 }
296
297 static final MethodRef JAVA_LANG_MATH_POW = MethodRef.method(Math.class, "pow", double.class, double.class, double.class);
298
299 private static boolean whenIsMathPowFunction(JavaOp.InvokeOp invokeOp) {
300 return invokeOp.invokeDescriptor().equals(JAVA_LANG_MATH_POW);
301 }
302
303 private static boolean isMethodWeWantToInline(JavaOp.InvokeOp invokeOp) {
304 return (invokeOp.invokeDescriptor().toString().startsWith("oracle.code.samples.MathOptimizerWithInlining::functionShift")
305 || invokeOp.invokeDescriptor().toString().startsWith("oracle.code.samples.MathOptimizerWithInlining::functionMult"));
306 }
307
308 private static boolean isShiftFunction(JavaOp.InvokeOp invokeOp) {
309 return invokeOp.invokeDescriptor().toString().contains("functionShift");
310 }
311 }