1 import jdk.incubator.code.*;
2 import jdk.incubator.code.dialect.core.CoreOp;
3 import jdk.incubator.code.dialect.java.JavaOp;
4 import jdk.incubator.code.dialect.java.JavaType;
5 import jdk.incubator.code.dialect.java.MethodRef;
6 import jdk.incubator.code.extern.OpWriter;
7 import org.junit.jupiter.api.Assertions;
8 import org.junit.jupiter.api.Test;
9
10 import java.io.StringWriter;
11 import java.lang.reflect.Method;
12 import java.util.List;
13 import java.util.Set;
14 import java.util.function.IntBinaryOperator;
15
16 /*
17 * @test
18 * @modules jdk.incubator.code
19 * @modules java.base/java.lang.invoke:open
20 * @run junit TestTransform
21 */
22 public class TestTransform {
23
24 @Reflect
25 static int f() {
26 IntBinaryOperator o = (a, b) -> a + b;
27 int sum = 0;
28 for (int i = 0; i < 10; i++) {
29 sum += o.applyAsInt(i, i);
30 }
31 return sum + 42;
32 }
33
34 static int add(int a, int b) {
35 return a + b;
36 }
37
38 @Reflect
39 static int fWithAddMethod() {
40 IntBinaryOperator o = (a, b) -> add(a, b);
41 int sum = 0;
42 for (int i = 0; i < 10; i = add(i, 1)) {
43 sum = add(sum, o.applyAsInt(i, i));
44 }
45 return add(sum, 42);
46 }
47
48 @Test
49 public void testOpTransformer_fWithAddMethod() throws Exception {
50 Method fAdd = this.getClass().getDeclaredMethod("add", int.class, int.class);
51 var codeTransformer = CodeTransformer.opTransformer((builder, op, operands) -> {
52 switch (op) {
53 case JavaOp.AddOp _ ->
54 builder.apply(JavaOp.invoke(MethodRef.method(fAdd), operands));
55 default ->
56 builder.apply(op);
57 }
58 });
59
60 testTransformer("f", "fWithAddMethod", codeTransformer);
61 }
62
63 @Test
64 public void testCodeTransformer_fWithAddMethod() throws Exception {
65 Method fAdd = this.getClass().getDeclaredMethod("add", int.class, int.class);
66 CodeTransformer codeTransformer = (builder, op) -> {
67 switch (op) {
68 case JavaOp.AddOp _ -> {
69 List<Value> values = builder.context().getValues(op.operands());
70 Op.Result r = builder.op(JavaOp.invoke(MethodRef.method(fAdd), values));
71 builder.context().mapValue(op.result(), r);
72 }
73 default ->
74 builder.op(op);
75 }
76 return builder;
77 };
78
79 testTransformer("f", "fWithAddMethod", codeTransformer);
80 }
81
82
83 @Reflect
84 static int fAddToSubNeg() {
85 IntBinaryOperator o = (a, b) -> a - -b;
86 int sum = 0;
87 for (int i = 0; i < 10; i = i - -1) {
88 sum = sum - -o.applyAsInt(i, i);
89 }
90 return sum - -42;
91 }
92
93 @Test
94 public void testOpTransformer_fAddToSubNeg() throws Exception {
95 var codeTransformer = CodeTransformer.opTransformer((builder, op, operands) -> {
96 switch (op) {
97 case CoreOp.ConstantOp _ -> {
98 Set<Op.Result> uses = op.result().uses();
99 // add(x, constant(C))
100 if (uses.size() == 1 && uses.iterator().next().op() instanceof JavaOp.AddOp) {
101 // Drop, this will be replaced later
102 } else {
103 builder.apply(op);
104 }
105 }
106 case JavaOp.AddOp _ -> {
107 // add(x, constant(C)) -> sub(x, constant(-C))
108 // add(x, y) -> sub(x, neg(y))
109 Op.Result rhs;
110 if (op.operands().get(1) instanceof Op.Result r && r.op() instanceof CoreOp.ConstantOp cop) {
111 // There is no mapping to the second operand, since it was associated
112 // with the constant op which was dropped
113 Assertions.assertNull(operands.get(1));
114
115 rhs = builder.apply(CoreOp.constant(JavaType.INT, -(int) cop.value()));
116 } else {
117 Assertions.assertNotNull(operands.get(1));
118
119 rhs = builder.apply(JavaOp.neg(operands.get(1)));
120 }
121 builder.apply(JavaOp.sub(operands.get(0), rhs));
122 }
123 default ->
124 builder.apply(op);
125 }
126 });
127
128 testTransformer("f", "fAddToSubNeg", codeTransformer);
129 }
130
131 @Test
132 public void testCodeTransformer_fAddToSubNeg() throws Exception {
133 CodeTransformer codeTransformer = (builder, op) -> {
134 switch (op) {
135 case CoreOp.ConstantOp _ -> {
136 Set<Op.Result> uses = op.result().uses();
137 // add(x, constant(C))
138 if (uses.size() == 1 && uses.iterator().next().op() instanceof JavaOp.AddOp) {
139 // Drop, this will be replaced later
140 } else {
141 builder.op(op);
142 }
143 }
144 case JavaOp.AddOp _ -> {
145 // add(x, constant(C)) -> sub(x, constant(-C))
146 // add(x, y) -> sub(x, neg(y))
147 List<Value> operands = op.operands().stream()
148 .map(v -> builder.context().getValueOrDefault(v, null)).toList();
149 Op.Result rhs;
150 if (op.operands().get(1) instanceof Op.Result r && r.op() instanceof CoreOp.ConstantOp cop) {
151 // There is no mapping to the second operand, since it was associated
152 // with the constant op which was dropped
153 Assertions.assertNull(operands.get(1));
154
155 rhs = builder.op(CoreOp.constant(JavaType.INT, - (int) cop.value()));
156 } else {
157 Assertions.assertNotNull(operands.get(1));
158
159 rhs = builder.op(JavaOp.neg(operands.get(1)));
160 }
161 Op.Result result = builder.op(JavaOp.sub(operands.get(0), rhs));
162 builder.context().mapValue(op.result(), result);
163 }
164 default ->
165 builder.op(op);
166 }
167 return builder;
168 };
169
170 testTransformer("f", "fAddToSubNeg", codeTransformer);
171 }
172
173
174 void testTransformer(String methodName, String transformedMethodName, CodeTransformer codeTransformer) throws Exception {
175 Method fMethod = this.getClass().getDeclaredMethod(methodName);
176 var fModel = Op.ofMethod(fMethod).orElseThrow();
177
178 var fTransformed = fModel.transform(codeTransformer);
179
180 Method fTransformedMethod = this.getClass().getDeclaredMethod(transformedMethodName);
181 var fTransformedModel = Op.ofMethod(fTransformedMethod).orElseThrow();
182
183 assertEqual(fTransformedModel, fTransformed, methodName, transformedMethodName);
184 }
185
186 static void assertEqual(Op expected, Op actual,
187 String methodName, String transformedMethodName) {
188 Assertions.assertEquals(serialize(expected).replace(transformedMethodName, methodName), serialize(actual));
189 }
190
191 static String serialize(Op o) {
192 StringWriter w = new StringWriter();
193 OpWriter.writeTo(w, o, OpWriter.LocationOption.DROP_LOCATION);
194 return w.toString();
195 }
196
197 }