1 /*
2 * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation. Oracle designates this
8 * particular file as subject to the "Classpath" exception as provided
9 * by Oracle in the LICENSE file that accompanied this code.
10 *
11 * This code is distributed in the hope that it will be useful, but WITHOUT
12 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
13 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
14 * version 2 for more details (a copy is included in the LICENSE file that
15 * accompanied this code).
16 *
17 * You should have received a copy of the GNU General Public License version
18 * 2 along with this work; if not, write to the Free Software Foundation,
19 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
20 *
21 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
22 * or visit www.oracle.com if you need additional information or have any
23 * questions.
24 */
25 package experiments;
26
27 import jdk.incubator.code.*;
28 import jdk.incubator.code.dialect.core.CoreOp;
29 import jdk.incubator.code.dialect.java.JavaOp;
30
31 import java.lang.reflect.Method;
32 import java.util.HashMap;
33 import java.util.Map;
34 import java.util.Optional;
35 import java.util.function.BiConsumer;
36 import java.util.function.BiFunction;
37 import java.util.stream.Stream;
38
39 public class TransformState {
40
41 @Reflect
42 static int threeSum(int a, int b, int c) {
43 return a + b * c;
44 }
45
46 //@Test
47 public static void testOpToOp() {
48 CoreOp.FuncOp threeSumFuncOp = getFuncOp("threeSum");
49 Map<Op, Op> oldOpToNewOpMap = new HashMap<>();
50 CodeTransformer opTracker = (block, op) -> {
51 if (op instanceof JavaOp.AddOp) {
52 CodeContext cc = block.context();
53 var newSubOp = JavaOp.sub(cc.getValue(op.operands().get(0)), cc.getValue(op.operands().get(1)));
54 Op.Result result = block.op(newSubOp);
55 cc.mapValue(op.result(), result);
56 oldOpToNewOpMap.put(op,newSubOp);// <-- this maps op -> new subOp
57 } else {
58 var result = block.op(op);
59 oldOpToNewOpMap.put(op,result.op()); //<-- this maps op -> op'
60 }
61 return block;
62 };
63
64 System.out.println(threeSumFuncOp.toText());
65 CoreOp.FuncOp threeSumFuncOp1 = threeSumFuncOp.transform(opTracker);
66 System.out.println(threeSumFuncOp1.toText());
67 }
68
69 //@Test
70 public static void testDelegate() {
71 CoreOp.FuncOp threeSumFuncOp = getFuncOp("threeSum");
72 JavaOp.AddOp addOp = (JavaOp.AddOp) threeSumFuncOp.elements().filter(e -> e instanceof JavaOp.AddOp).findFirst().orElseThrow();
73 JavaOp.MulOp mulOp = (JavaOp.MulOp) threeSumFuncOp.elements().filter(e -> e instanceof JavaOp.MulOp).findFirst().orElseThrow();
74 Map<Value, String> mapState = new HashMap<>();
75 mapState.put(addOp.result(), "STATE 1");
76 mapState.put(mulOp.result(), "STATE 2");
77
78 Map<Value, String> transformedMapState = new HashMap<>();
79 Map<Op, Op> transformedOpMapState = new HashMap<>();
80
81 CodeTransformer opTracker = (block, op) -> {
82 if (op instanceof JavaOp.AddOp) {
83 CodeContext cc = block.context();
84 var newSubOp = JavaOp.sub(cc.getValue(op.operands().get(0)), cc.getValue(op.operands().get(1)));
85 Op.Result result = block.op(newSubOp);
86 cc.mapValue(op.result(), result);
87 transformedOpMapState.put(op,newSubOp);// <-- this maps op -> new subOp
88 } else {
89 var result = block.op(op);
90 transformedOpMapState.put(op,result.op()); //<-- this maps op -> op'
91 }
92 return block;
93 };
94 CodeTransformer t = trackingValueDelegatingTransformer(
95 (block, op) -> {
96 if (op instanceof JavaOp.AddOp) {
97 CodeContext cc = block.context();
98 var newSubOp = JavaOp.sub(cc.getValue(op.operands().get(0)), cc.getValue(op.operands().get(1)));
99 Op.Result r = block.op(newSubOp);
100 cc.mapValue(op.result(), r);
101 transformedOpMapState.put(op,newSubOp);
102 } else {
103 var r = block.op(op);
104 transformedOpMapState.put(op,r.op());
105 }
106 return block;
107 },
108 (vIn, vOut) -> {
109 if (mapState.containsKey(vIn) && vOut != null) {
110 transformedMapState.put(vOut, mapState.get(vIn));
111 }
112 });
113
114 System.out.println(threeSumFuncOp.toText());
115 print(mapState);
116 CoreOp.FuncOp threeSumFuncOp1 = threeSumFuncOp.transform(opTracker);
117 System.out.println(threeSumFuncOp1.toText());
118 print(transformedMapState);
119 }
120
121 static void print(Map<Value, String> mappedStated) {
122 mappedStated.forEach((v, s) -> {
123 System.out.println(v + "[" + ((v instanceof Op.Result r ? r.op() : "") + "] -> " + s));
124 });
125 }
126
127 static CodeTransformer trackingValueDelegatingTransformer(
128 BiFunction<Block.Builder, Op, Block.Builder> t,
129 BiConsumer<Value, Value> mapAction) {
130 return (block, op) -> {
131 try {
132 return t.apply(block, op);
133 } finally {
134 Value in = op.result();
135 Value out = block.context().getValueOrDefault(in, null);
136 mapAction.accept(in, out);
137 }
138 };
139 }
140
141
142 //@Test
143 static public void testAndThen() {
144 CoreOp.FuncOp threeSumFuncOp = getFuncOp("threeSum");
145 JavaOp.AddOp addOp = (JavaOp.AddOp) threeSumFuncOp.elements().filter(e -> e instanceof JavaOp.AddOp).findFirst().orElseThrow();
146 JavaOp.MulOp mulOp = (JavaOp.MulOp) threeSumFuncOp.elements().filter(e -> e instanceof JavaOp.MulOp).findFirst().orElseThrow();
147 Map<Value, String> mapState = new HashMap<>();
148 mapState.put(addOp.result(), "STATE 1");
149 mapState.put(mulOp.result(), "STATE 2");
150
151 Map<Value, String> transformedMapState = new HashMap<>();
152
153 CodeTransformer t = trackingValueAndThenTransformer(
154 (block, op) -> {
155 if (op instanceof JavaOp.AddOp) {
156 CodeContext cc = block.context();
157 Op.Result r = block.op(JavaOp.sub(
158 cc.getValue(op.operands().get(0)),
159 cc.getValue(op.operands().get(1))));
160 cc.mapValue(op.result(), r);
161 } else {
162 block.op(op);
163 }
164 return block;
165 },
166 (vIn, vOut) -> {
167 if (mapState.containsKey(vIn) && vOut != null) {
168 transformedMapState.put(vOut, mapState.get(vIn));
169 }
170 });
171
172 System.out.println(threeSumFuncOp.toText());
173 print(mapState);
174 CoreOp.FuncOp threeSumFuncOp1 = threeSumFuncOp.transform(t);
175 System.out.println(threeSumFuncOp1.toText());
176 print(transformedMapState);
177 }
178
179 static CodeTransformer trackingValueAndThenTransformer(
180 CodeTransformer t,
181 BiConsumer<Value, Value> mapAction) {
182 return CodeTransformer.andThen(t, (block, op) -> {
183 Value in = op.result();
184 Value out = block.context().getValueOrDefault(in, null);
185 mapAction.accept(in, out);
186 return block;
187 });
188 }
189
190
191 static CoreOp.FuncOp getFuncOp(String name) {
192 Optional<Method> om = Stream.of(TransformState.class.getDeclaredMethods())
193 .filter(m -> m.getName().equals(name))
194 .findFirst();
195
196 Method m = om.get();
197 return Op.ofMethod(m).get();
198 }
199
200 static public void main(String[] args) {
201 testOpToOp();
202 //testDelegate();
203 //testAndThen();
204 }
205 }