1 /*
  2  * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 package experiments;
 26 
 27 import jdk.incubator.code.*;
 28 import jdk.incubator.code.dialect.core.CoreOp;
 29 import jdk.incubator.code.dialect.java.JavaOp;
 30 
 31 import java.lang.reflect.Method;
 32 import java.util.HashMap;
 33 import java.util.Map;
 34 import java.util.Optional;
 35 import java.util.function.BiConsumer;
 36 import java.util.function.BiFunction;
 37 import java.util.stream.Stream;
 38 
 39 public class TransformState {
 40 
 41     @CodeReflection
 42     static int threeSum(int a, int b, int c) {
 43         return a + b * c;
 44     }
 45 
 46     //@Test
 47     public static void testOpToOp() {
 48         CoreOp.FuncOp threeSumFuncOp = getFuncOp("threeSum");
 49         Map<Op, Op> oldOpToNewOpMap = new HashMap<>();
 50         OpTransformer opTracker = (block, op) -> {
 51             if (op instanceof JavaOp.AddOp) {
 52                 CopyContext cc = block.context();
 53                 var newSubOp = JavaOp.sub(cc.getValue(op.operands().get(0)), cc.getValue(op.operands().get(1)));
 54                 Op.Result result = block.op(newSubOp);
 55                 cc.mapValue(op.result(), result);
 56                 oldOpToNewOpMap.put(op,newSubOp);// <-- this maps op -> new subOp
 57             } else {
 58                 var result = block.op(op);
 59                 oldOpToNewOpMap.put(op,result.op()); //<-- this maps op ->  op'
 60             }
 61             return block;
 62         };
 63 
 64         System.out.println(threeSumFuncOp.toText());
 65         CoreOp.FuncOp threeSumFuncOp1 = threeSumFuncOp.transform(opTracker);
 66         System.out.println(threeSumFuncOp1.toText());
 67     }
 68 
 69     //@Test
 70     public static void testDelegate() {
 71         CoreOp.FuncOp threeSumFuncOp = getFuncOp("threeSum");
 72         JavaOp.AddOp addOp = (JavaOp.AddOp) threeSumFuncOp.elements().filter(e -> e instanceof JavaOp.AddOp).findFirst().orElseThrow();
 73         JavaOp.MulOp mulOp = (JavaOp.MulOp) threeSumFuncOp.elements().filter(e -> e instanceof JavaOp.MulOp).findFirst().orElseThrow();
 74         Map<Value, String> mapState = new HashMap<>();
 75         mapState.put(addOp.result(), "STATE 1");
 76         mapState.put(mulOp.result(), "STATE 2");
 77 
 78         Map<Value, String> transformedMapState = new HashMap<>();
 79         Map<Op, Op> transformedOpMapState = new HashMap<>();
 80 
 81         OpTransformer opTracker = (block, op) -> {
 82             if (op instanceof JavaOp.AddOp) {
 83                 CopyContext cc = block.context();
 84                 var newSubOp = JavaOp.sub(cc.getValue(op.operands().get(0)), cc.getValue(op.operands().get(1)));
 85                 Op.Result result = block.op(newSubOp);
 86                 cc.mapValue(op.result(), result);
 87                 transformedOpMapState.put(op,newSubOp);// <-- this maps op -> new subOp
 88             } else {
 89                 var result = block.op(op);
 90                 transformedOpMapState.put(op,result.op()); //<-- this maps op ->  op'
 91             }
 92             return block;
 93         };
 94         OpTransformer t = trackingValueDelegatingTransformer(
 95                 (block, op) -> {
 96                     if (op instanceof JavaOp.AddOp) {
 97                         CopyContext cc = block.context();
 98                         var newSubOp = JavaOp.sub(cc.getValue(op.operands().get(0)), cc.getValue(op.operands().get(1)));
 99                         Op.Result r = block.op(newSubOp);
100                         cc.mapValue(op.result(), r);
101                         transformedOpMapState.put(op,newSubOp);
102                     } else {
103                         var r = block.op(op);
104                         transformedOpMapState.put(op,r.op());
105                     }
106                     return block;
107                 },
108                 (vIn, vOut) -> {
109                     if (mapState.containsKey(vIn) && vOut != null) {
110                         transformedMapState.put(vOut, mapState.get(vIn));
111                     }
112                 });
113 
114         System.out.println(threeSumFuncOp.toText());
115         print(mapState);
116         CoreOp.FuncOp threeSumFuncOp1 = threeSumFuncOp.transform(opTracker);
117         System.out.println(threeSumFuncOp1.toText());
118         print(transformedMapState);
119     }
120 
121     static void print(Map<Value, String> mappedStated) {
122         mappedStated.forEach((v, s) -> {
123             System.out.println(v + "[" + ((v instanceof Op.Result r ? r.op() : "") + "] -> " + s));
124         });
125     }
126 
127     static OpTransformer trackingValueDelegatingTransformer(
128             BiFunction<Block.Builder, Op, Block.Builder> t,
129             BiConsumer<Value, Value> mapAction) {
130         return (block, op) -> {
131             try {
132                 return t.apply(block, op);
133             } finally {
134                 Value in = op.result();
135                 Value out = block.context().getValueOrDefault(in, null);
136                 mapAction.accept(in, out);
137             }
138         };
139     }
140 
141 
142     //@Test
143     static public void testAndThen() {
144         CoreOp.FuncOp threeSumFuncOp = getFuncOp("threeSum");
145         JavaOp.AddOp addOp = (JavaOp.AddOp) threeSumFuncOp.elements().filter(e -> e instanceof JavaOp.AddOp).findFirst().orElseThrow();
146         JavaOp.MulOp mulOp = (JavaOp.MulOp) threeSumFuncOp.elements().filter(e -> e instanceof JavaOp.MulOp).findFirst().orElseThrow();
147         Map<Value, String> mapState = new HashMap<>();
148         mapState.put(addOp.result(), "STATE 1");
149         mapState.put(mulOp.result(), "STATE 2");
150 
151         Map<Value, String> transformedMapState = new HashMap<>();
152 
153         OpTransformer t = trackingValueAndThenTransformer(
154                 (block, op) -> {
155                     if (op instanceof JavaOp.AddOp) {
156                         CopyContext cc = block.context();
157                         Op.Result r = block.op(JavaOp.sub(
158                                 cc.getValue(op.operands().get(0)),
159                                 cc.getValue(op.operands().get(1))));
160                         cc.mapValue(op.result(), r);
161                     } else {
162                         block.op(op);
163                     }
164                     return block;
165                 },
166                 (vIn, vOut) -> {
167                     if (mapState.containsKey(vIn) && vOut != null) {
168                         transformedMapState.put(vOut, mapState.get(vIn));
169                     }
170                 });
171 
172         System.out.println(threeSumFuncOp.toText());
173         print(mapState);
174         CoreOp.FuncOp threeSumFuncOp1 = threeSumFuncOp.transform(t);
175         System.out.println(threeSumFuncOp1.toText());
176         print(transformedMapState);
177     }
178 
179     static OpTransformer trackingValueAndThenTransformer(
180             OpTransformer t,
181             BiConsumer<Value, Value> mapAction) {
182         return OpTransformer.andThen(t, (block, op) -> {
183             Value in = op.result();
184             Value out = block.context().getValueOrDefault(in, null);
185             mapAction.accept(in, out);
186             return block;
187         });
188     }
189 
190 
191     static CoreOp.FuncOp getFuncOp(String name) {
192         Optional<Method> om = Stream.of(TransformState.class.getDeclaredMethods())
193                 .filter(m -> m.getName().equals(name))
194                 .findFirst();
195 
196         Method m = om.get();
197         return Op.ofMethod(m).get();
198     }
199 
200     static public  void main(String[] args) {
201         testOpToOp();
202         //testDelegate();
203         //testAndThen();
204     }
205 }