1 /*
2 * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation. Oracle designates this
8 * particular file as subject to the "Classpath" exception as provided
9 * by Oracle in the LICENSE file that accompanied this code.
10 *
11 * This code is distributed in the hope that it will be useful, but WITHOUT
12 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
13 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
14 * version 2 for more details (a copy is included in the LICENSE file that
15 * accompanied this code).
16 *
17 * You should have received a copy of the GNU General Public License version
18 * 2 along with this work; if not, write to the Free Software Foundation,
19 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
20 *
21 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
22 * or visit www.oracle.com if you need additional information or have any
23 * questions.
24 */
25 package oracle.code.samples;
26
27 import jdk.incubator.code.CodeReflection;
28 import jdk.incubator.code.CopyContext;
29 import jdk.incubator.code.Op;
30 import jdk.incubator.code.OpTransformer;
31 import jdk.incubator.code.TypeElement;
32 import jdk.incubator.code.Value;
33 import jdk.incubator.code.analysis.Inliner;
34 import jdk.incubator.code.dialect.core.CoreOp;
35 import jdk.incubator.code.dialect.java.JavaOp;
36 import jdk.incubator.code.dialect.java.JavaType;
37 import jdk.incubator.code.dialect.java.MethodRef;
38 import jdk.incubator.code.extern.OpWriter;
39 import jdk.incubator.code.interpreter.Interpreter;
40
41 import java.lang.invoke.MethodHandles;
42 import java.lang.reflect.Method;
43 import java.util.ArrayList;
44 import java.util.List;
45 import java.util.Objects;
46 import java.util.Optional;
47 import java.util.concurrent.atomic.AtomicReference;
48 import java.util.stream.Stream;
49
50 /**
51 * Simple example of how to use the code reflection API.
52 *
53 * <p>
54 * This example is based on the {@link MathOptimizer} example to add inlining for the new replaced
55 * methods. In this example, we focus on explaining in detail the inlining component.
56 * </p>
57 *
58 * <p>
59 * Optimizations:
60 * <ol>Replace Pow(x, y) when x == 2 to (1 << y), if only if the parameter y is an integer.</ol>
61 * <ol>Replace Pow(x, y) when y == 2 to (x * x).</ol>
62 * <ol>After a replacement has been done, we inline the new invoke into the main code model.</ol>
63 * </p>
64 *
65 * <p>
66 * In a nutshell, we apply a second transform to perform the inlining. Note that the inlining could be done
67 * also in within the first transform.
68 * To be able to inline, we need to also annotate the new invoke nodes that were replaced during the first
69 * transform with the <code>@CodeReflection</code> annotation. In this way, we can build the code models for
70 * each of the methods and apply the inlining directly using the <code>Inliner.inline</code> from the code
71 * reflection API.
72 * </p>
73 *
74 * <p>
75 * Babylon repository: {@see <a href="https://github.com/openjdk/babylon/tree/code-reflection">link</a>}
76 * </p>
77 *
78 * <p>
79 * How to run?
80 * <code>
81 * java --add-modules jdk.incubator.code -cp target/crsamples-1.0-SNAPSHOT.jar oracle.code.samples.MathOptimizerWithInlining
82 * </code>
83 * </p>:
84 */
85 public class MathOptimizerWithInlining {
86
87 // New functions are also annotated with code reflection
88 @CodeReflection
89 private static double myFunction(int value) {
90 return Math.pow(2, value);
91 }
92
93 @CodeReflection
94 private static int functionShift(int val) {
95 return 1 << val;
96 }
97
98 @CodeReflection
99 private static double functionMult(double x) {
100 return x * x;
101 }
102
103 private static final MethodRef MY_SHIFT_FUNCTION = MethodRef.method(MathOptimizerWithInlining.class, "functionShift", int.class, int.class);
104
105 private static final MethodRef MY_MULT_FUNCTION = MethodRef.method(MathOptimizerWithInlining.class, "functionMult", double.class, double.class);
106
107 private static CoreOp.FuncOp buildCodeModelForMethod(Class<?> klass, String methodName) {
108 Optional<Method> function = Stream.of(klass.getDeclaredMethods())
109 .filter(m -> m.getName().equals(methodName))
110 .findFirst();
111 Method method = function.get();
112 CoreOp.FuncOp funcOp = Op.ofMethod(method).get();
113 return funcOp;
114 }
115
116 static void main(String[] args) {
117
118 // Obtain the code model for the annotated method
119 CoreOp.FuncOp codeModel = buildCodeModelForMethod(MathOptimizerWithInlining.class, "myFunction");
120 System.out.println(codeModel.toText());
121
122 enum FunctionToUse {
123 SHIFT,
124 MULT,
125 GENERIC;
126 }
127
128 AtomicReference<FunctionToUse> replace = new AtomicReference<>(FunctionToUse.GENERIC);
129
130 codeModel = codeModel.transform(CopyContext.create(), (blockBuilder, op) -> {
131 // The idea here is to create a new JavaOp.invoke with the optimization and replace it.
132 if (Objects.requireNonNull(op) instanceof JavaOp.InvokeOp invokeOp && whenIsMathPowFunction(invokeOp)) {
133 List<Value> operands = blockBuilder.context().getValues(op.operands());
134
135 // Analyse second operand of the Math.pow(x, y).
136 // if the x == 2, and both are integers, then we can optimize the function using bitwise operations
137 // pow(2, y) replace with (1 << y)
138 Value operand = operands.getFirst(); // obtain the first parameter
139 // inspect if the base (as in pow(base, exp) is value 2
140 boolean canApplyBitShift = inspectParameterRecursive(operand, 2);
141 if (canApplyBitShift) {
142 // We also need to inspect types. We can apply this optimization
143 // if the exp type is also an integer.
144 boolean isIntType = analyseType(operands.get(1), JavaType.INT);
145 if (!isIntType) {
146 canApplyBitShift = false;
147 }
148 }
149
150 // If the conditions to apply the first optimization failed, we try the second optimization
151 // if types are not int, and base is not 2.
152 // pow(x, 2) => replace with x * x
153 boolean canApplyMultiplication = false;
154 if (!canApplyBitShift) {
155 // inspect if exp (as in pow(base, exp) is value 2
156 canApplyMultiplication = inspectParameterRecursive(operands.get(1), 2);
157 }
158
159 if (canApplyBitShift) {
160 // Narrow type from DOUBLE to INT for the input parameter of the new function.
161 Op.Result op2 = blockBuilder.op(JavaOp.conv(JavaType.INT, operands.get(1)));
162 List<Value> newOperandList = new ArrayList<>();
163 newOperandList.add(op2);
164
165 // Create a new invoke with the optimised method
166 JavaOp.InvokeOp newInvoke = JavaOp.invoke(MY_SHIFT_FUNCTION, newOperandList);
167 // Copy the original location info to the new invoke
168 newInvoke.setLocation(invokeOp.location());
169
170 // Replace the invoke node with the new optimized invoke
171 Op.Result newResult = blockBuilder.op(newInvoke);
172 // Type conversion to double
173 newResult = blockBuilder.op(JavaOp.conv(JavaType.DOUBLE, newResult));
174 blockBuilder.context().mapValue(invokeOp.result(), newResult);
175
176 replace.set(FunctionToUse.SHIFT);
177
178 } else if (canApplyMultiplication) {
179 // Adapt the parameters to the new function. We only need the first
180 // parameter from the initial parameter list - pow(x, 2) -
181 // Create a new invoke function with the optimised method
182 JavaOp.InvokeOp newInvoke = JavaOp.invoke(MY_MULT_FUNCTION, operands.get(0));
183 // Copy the location info to the new invoke
184 newInvoke.setLocation(invokeOp.location());
185
186 // Replace the invoke node with the new optimized invoke
187 Op.Result newResult = blockBuilder.op(newInvoke);
188 blockBuilder.context().mapValue(invokeOp.result(), newResult);
189 replace.set(FunctionToUse.MULT);
190
191 } else {
192 // ignore the transformation
193 blockBuilder.op(op);
194 }
195 } else {
196 blockBuilder.op(op);
197 }
198 return blockBuilder;
199 });
200
201 System.out.println("Code Model after the first transform (replace with a new method): ");
202 System.out.println(codeModel.toText());
203
204 // Let's now apply a second transformation
205 // We want to inline the functions. Note that we can apply this transformation if any of the new functions
206 // have been replaced.
207 System.out.println("Second transform: apply inlining for the new methods into the main code model");
208 if (replace.get() != FunctionToUse.GENERIC) {
209
210 // Build code model for the functions we want to inline.
211 // Since we apply two replacements (depending on the values of the input code), we can apply two different
212 // inline functions.
213 CoreOp.FuncOp shiftCodeModel = buildCodeModelForMethod(MathOptimizerWithInlining.class, "functionShift");
214 CoreOp.FuncOp multCodeModel = buildCodeModelForMethod(MathOptimizerWithInlining.class, "functionMult");
215
216 // Apply inlining
217 codeModel = codeModel.transform(codeModel.funcName(),
218 (blockBuilder, op) -> {
219
220 if (op instanceof JavaOp.InvokeOp invokeOp && isMethodWeWantToInline(invokeOp)) {
221 // Let's inline the function
222
223 // 1. Select the function we want to inline.
224 // Since we have two possible replacements, depending on the input code, we need to
225 // apply the corresponding replacement function
226 CoreOp.FuncOp codeModelToInline = isShiftFunction(invokeOp) ? shiftCodeModel : multCodeModel;
227
228 // 2. Apply the inlining
229 Inliner.inline(
230 blockBuilder, // the current block builder
231 codeModelToInline, // the method to inline which we obtained using code reflection too
232 blockBuilder.context().getValues(invokeOp.operands()), // operands to this call. Since we already replace the function,
233 // we can use the same operands as the invoke call
234 (builder, val) -> blockBuilder.context().mapValue(invokeOp.result(), val)); // Propagate the new result
235 } else {
236 // copy the op into the builder if it is not the invoke node we are looking for
237 blockBuilder.op(op);
238 }
239
240 // return new transformed block builder
241 return blockBuilder;
242 });
243
244 System.out.println("After inlining: " + codeModel.toText());
245 }
246
247 codeModel = codeModel.transform(OpTransformer.LOWERING_TRANSFORMER);
248 System.out.println("After Lowering: ");
249 System.out.println(codeModel.toText());
250
251 System.out.println("\nEvaluate with Interpreter.invoke");
252 // The Interpreter Invoke should launch new exceptions
253 var result = Interpreter.invoke(MethodHandles.lookup(), codeModel, 10);
254 System.out.println(result);
255 }
256
257 // Utility methods
258
259 // Analyze type methods: taken from example of String Concat Transformer to traverse the tree.
260 static boolean analyseType(JavaOp.ConvOp convOp, JavaType typeToMatch) {
261 return analyseType(convOp.operands().get(0), typeToMatch);
262 }
263
264 static boolean analyseType(Value v, JavaType typeToMatch) {
265 // Maybe there is a utility already to do tree traversal
266 if (v instanceof Op.Result r && r.op() instanceof JavaOp.ConvOp convOp) {
267 // Node of tree, recursively traverse the operands
268 return analyseType(convOp, typeToMatch);
269 } else {
270 // Leaf of tree: analyze type
271 TypeElement type = v.type();
272 return type.equals(typeToMatch);
273 }
274 }
275
276 // Inspect a value for a parameter
277 static boolean inspectParameterRecursive(JavaOp.ConvOp convOp, int valToMatch) {
278 return inspectParameterRecursive(convOp.operands().get(0), valToMatch);
279 }
280
281 static boolean inspectParameterRecursive(Value v, int valToMatch) {
282 if (v instanceof Op.Result r && r.op() instanceof JavaOp.ConvOp convOp) {
283 return inspectParameterRecursive(convOp, valToMatch);
284 } else {
285 // Leaf of tree - we want to obtain the actual value of the parameter and check
286 if (v instanceof CoreOp.Result r && r.op() instanceof CoreOp.ConstantOp constant) {
287 return constant.value().equals(valToMatch);
288 }
289 return false;
290 }
291 }
292
293 static final MethodRef JAVA_LANG_MATH_POW = MethodRef.method(Math.class, "pow", double.class, double.class, double.class);
294
295 private static boolean whenIsMathPowFunction(JavaOp.InvokeOp invokeOp) {
296 return invokeOp.invokeDescriptor().equals(JAVA_LANG_MATH_POW);
297 }
298
299 private static boolean isMethodWeWantToInline(JavaOp.InvokeOp invokeOp) {
300 return (invokeOp.invokeDescriptor().toString().startsWith("oracle.code.samples.MathOptimizerWithInlining::functionShift")
301 || invokeOp.invokeDescriptor().toString().startsWith("oracle.code.samples.MathOptimizerWithInlining::functionMult"));
302 }
303
304 private static boolean isShiftFunction(JavaOp.InvokeOp invokeOp) {
305 return invokeOp.invokeDescriptor().toString().contains("functionShift");
306 }
307 }