1 /*
2 * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation. Oracle designates this
8 * particular file as subject to the "Classpath" exception as provided
9 * by Oracle in the LICENSE file that accompanied this code.
10 *
11 * This code is distributed in the hope that it will be useful, but WITHOUT
12 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
13 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
14 * version 2 for more details (a copy is included in the LICENSE file that
15 * accompanied this code).
16 *
17 * You should have received a copy of the GNU General Public License version
18 * 2 along with this work; if not, write to the Free Software Foundation,
19 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
20 *
21 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
22 * or visit www.oracle.com if you need additional information or have any
23 * questions.
24 */
25 package experiments;
26
27 import jdk.incubator.code.*;
28 import jdk.incubator.code.dialect.core.CoreOp;
29 import jdk.incubator.code.dialect.java.JavaOp;
30
31 import java.lang.reflect.Method;
32 import java.util.HashMap;
33 import java.util.Map;
34 import java.util.Objects;
35 import java.util.Optional;
36 import java.util.function.BiConsumer;
37 import java.util.function.BiFunction;
38 import java.util.stream.Stream;
39
40 public class TransformState {
41
42 @Reflect
43 static int threeSum(int a, int b, int c) {
44 return a + b * c;
45 }
46
47 //@Test
48 public static void testOpToOp() {
49 CoreOp.FuncOp threeSumFuncOp = getFuncOp("threeSum");
50 Map<Op, Op> oldOpToNewOpMap = new HashMap<>();
51 CodeTransformer opTracker = (block, op) -> {
52 if (op instanceof JavaOp.AddOp) {
53 CodeContext cc = block.context();
54 var newSubOp = JavaOp.sub(cc.getValue(op.operands().get(0)), cc.getValue(op.operands().get(1)));
55 Op.Result result = block.op(newSubOp);
56 cc.mapValue(op.result(), result);
57 oldOpToNewOpMap.put(op,newSubOp);// <-- this maps op -> new subOp
58 } else {
59 var result = block.op(op);
60 oldOpToNewOpMap.put(op,result.op()); //<-- this maps op -> op'
61 }
62 return block;
63 };
64
65 System.out.println(threeSumFuncOp.toText());
66 CoreOp.FuncOp threeSumFuncOp1 = threeSumFuncOp.transform(opTracker);
67 System.out.println(threeSumFuncOp1.toText());
68 }
69
70 //@Test
71 public static void testDelegate() {
72 CoreOp.FuncOp threeSumFuncOp = getFuncOp("threeSum");
73 JavaOp.AddOp addOp = (JavaOp.AddOp) threeSumFuncOp.elements().filter(e -> e instanceof JavaOp.AddOp).findFirst().orElseThrow();
74 JavaOp.MulOp mulOp = (JavaOp.MulOp) threeSumFuncOp.elements().filter(e -> e instanceof JavaOp.MulOp).findFirst().orElseThrow();
75 Map<Value, String> mapState = new HashMap<>();
76 mapState.put(addOp.result(), "STATE 1");
77 mapState.put(mulOp.result(), "STATE 2");
78
79 Map<Value, String> transformedMapState = new HashMap<>();
80 Map<Op, Op> transformedOpMapState = new HashMap<>();
81
82 CodeTransformer opTracker = (block, op) -> {
83 if (op instanceof JavaOp.AddOp) {
84 CodeContext cc = block.context();
85 var newSubOp = JavaOp.sub(cc.getValue(op.operands().get(0)), cc.getValue(op.operands().get(1)));
86 Op.Result result = block.op(newSubOp);
87 cc.mapValue(op.result(), result);
88 transformedOpMapState.put(op,newSubOp);// <-- this maps op -> new subOp
89 } else {
90 var result = block.op(op);
91 transformedOpMapState.put(op,result.op()); //<-- this maps op -> op'
92 }
93 return block;
94 };
95 CodeTransformer t = trackingValueDelegatingTransformer(
96 (block, op) -> {
97 if (op instanceof JavaOp.AddOp) {
98 CodeContext cc = block.context();
99 var newSubOp = JavaOp.sub(cc.getValue(op.operands().get(0)), cc.getValue(op.operands().get(1)));
100 Op.Result r = block.op(newSubOp);
101 cc.mapValue(op.result(), r);
102 transformedOpMapState.put(op,newSubOp);
103 } else {
104 var r = block.op(op);
105 transformedOpMapState.put(op,r.op());
106 }
107 return block;
108 },
109 (vIn, vOut) -> {
110 if (mapState.containsKey(vIn) && vOut != null) {
111 transformedMapState.put(vOut, mapState.get(vIn));
112 }
113 });
114
115 System.out.println(threeSumFuncOp.toText());
116 print(mapState);
117 CoreOp.FuncOp threeSumFuncOp1 = threeSumFuncOp.transform(opTracker);
118 System.out.println(threeSumFuncOp1.toText());
119 print(transformedMapState);
120 }
121
122 static void print(Map<Value, String> mappedStated) {
123 mappedStated.forEach((v, s) -> {
124 System.out.println(v + "[" + ((v instanceof Op.Result r ? r.op() : "") + "] -> " + s));
125 });
126 }
127
128 static CodeTransformer trackingValueDelegatingTransformer(
129 BiFunction<Block.Builder, Op, Block.Builder> t,
130 BiConsumer<Value, Value> mapAction) {
131 return (block, op) -> {
132 try {
133 return t.apply(block, op);
134 } finally {
135 Value in = op.result();
136 Value out = block.context().queryValue(in).orElse(null);
137 mapAction.accept(in, out);
138 }
139 };
140 }
141
142
143 //@Test
144 static public void testAndThen() {
145 CoreOp.FuncOp threeSumFuncOp = getFuncOp("threeSum");
146 JavaOp.AddOp addOp = (JavaOp.AddOp) threeSumFuncOp.elements().filter(e -> e instanceof JavaOp.AddOp).findFirst().orElseThrow();
147 JavaOp.MulOp mulOp = (JavaOp.MulOp) threeSumFuncOp.elements().filter(e -> e instanceof JavaOp.MulOp).findFirst().orElseThrow();
148 Map<Value, String> mapState = new HashMap<>();
149 mapState.put(addOp.result(), "STATE 1");
150 mapState.put(mulOp.result(), "STATE 2");
151
152 Map<Value, String> transformedMapState = new HashMap<>();
153
154 CodeTransformer t = trackingValueAndThenTransformer(
155 (block, op) -> {
156 if (op instanceof JavaOp.AddOp) {
157 CodeContext cc = block.context();
158 Op.Result r = block.op(JavaOp.sub(
159 cc.getValue(op.operands().get(0)),
160 cc.getValue(op.operands().get(1))));
161 cc.mapValue(op.result(), r);
162 } else {
163 block.op(op);
164 }
165 return block;
166 },
167 (vIn, vOut) -> {
168 if (mapState.containsKey(vIn) && vOut != null) {
169 transformedMapState.put(vOut, mapState.get(vIn));
170 }
171 });
172
173 System.out.println(threeSumFuncOp.toText());
174 print(mapState);
175 CoreOp.FuncOp threeSumFuncOp1 = threeSumFuncOp.transform(t);
176 System.out.println(threeSumFuncOp1.toText());
177 print(transformedMapState);
178 }
179
180 static CodeTransformer trackingValueAndThenTransformer(
181 CodeTransformer t,
182 BiConsumer<Value, Value> mapAction) {
183
184 // This only composes CodeTransformer.acceptOp.
185 // If the given code transformer overrides acceptBody or acceptBlock,
186 // that behavior is not preserved
187 return ((builder, op) -> {
188 builder = t.acceptOp(builder, op);
189
190 Value in = op.result();
191 Value out = builder.context().queryValue(in).orElse(null);
192 mapAction.accept(in, out);
193
194 return builder;
195 });
196 }
197
198 static CoreOp.FuncOp getFuncOp(String name) {
199 Optional<Method> om = Stream.of(TransformState.class.getDeclaredMethods())
200 .filter(m -> m.getName().equals(name))
201 .findFirst();
202
203 Method m = om.get();
204 return Op.ofMethod(m).get();
205 }
206
207 static public void main(String[] args) {
208 testOpToOp();
209 //testDelegate();
210 //testAndThen();
211 }
212 }