1 /*
  2  * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 package experiments;
 26 
 27 import jdk.incubator.code.*;
 28 import jdk.incubator.code.dialect.core.CoreOp;
 29 import jdk.incubator.code.dialect.java.JavaOp;
 30 
 31 import java.lang.reflect.Method;
 32 import java.util.HashMap;
 33 import java.util.Map;
 34 import java.util.Objects;
 35 import java.util.Optional;
 36 import java.util.function.BiConsumer;
 37 import java.util.function.BiFunction;
 38 import java.util.stream.Stream;
 39 
 40 public class TransformState {
 41 
 42     @Reflect
 43     static int threeSum(int a, int b, int c) {
 44         return a + b * c;
 45     }
 46 
 47     //@Test
 48     public static void testOpToOp() {
 49         CoreOp.FuncOp threeSumFuncOp = getFuncOp("threeSum");
 50         Map<Op, Op> oldOpToNewOpMap = new HashMap<>();
 51         CodeTransformer opTracker = (block, op) -> {
 52             if (op instanceof JavaOp.AddOp) {
 53                 CodeContext cc = block.context();
 54                 var newSubOp = JavaOp.sub(cc.getValue(op.operands().get(0)), cc.getValue(op.operands().get(1)));
 55                 Op.Result result = block.op(newSubOp);
 56                 cc.mapValue(op.result(), result);
 57                 oldOpToNewOpMap.put(op,newSubOp);// <-- this maps op -> new subOp
 58             } else {
 59                 var result = block.op(op);
 60                 oldOpToNewOpMap.put(op,result.op()); //<-- this maps op ->  op'
 61             }
 62             return block;
 63         };
 64 
 65         System.out.println(threeSumFuncOp.toText());
 66         CoreOp.FuncOp threeSumFuncOp1 = threeSumFuncOp.transform(opTracker);
 67         System.out.println(threeSumFuncOp1.toText());
 68     }
 69 
 70     //@Test
 71     public static void testDelegate() {
 72         CoreOp.FuncOp threeSumFuncOp = getFuncOp("threeSum");
 73         JavaOp.AddOp addOp = (JavaOp.AddOp) threeSumFuncOp.elements().filter(e -> e instanceof JavaOp.AddOp).findFirst().orElseThrow();
 74         JavaOp.MulOp mulOp = (JavaOp.MulOp) threeSumFuncOp.elements().filter(e -> e instanceof JavaOp.MulOp).findFirst().orElseThrow();
 75         Map<Value, String> mapState = new HashMap<>();
 76         mapState.put(addOp.result(), "STATE 1");
 77         mapState.put(mulOp.result(), "STATE 2");
 78 
 79         Map<Value, String> transformedMapState = new HashMap<>();
 80         Map<Op, Op> transformedOpMapState = new HashMap<>();
 81 
 82         CodeTransformer opTracker = (block, op) -> {
 83             if (op instanceof JavaOp.AddOp) {
 84                 CodeContext cc = block.context();
 85                 var newSubOp = JavaOp.sub(cc.getValue(op.operands().get(0)), cc.getValue(op.operands().get(1)));
 86                 Op.Result result = block.op(newSubOp);
 87                 cc.mapValue(op.result(), result);
 88                 transformedOpMapState.put(op,newSubOp);// <-- this maps op -> new subOp
 89             } else {
 90                 var result = block.op(op);
 91                 transformedOpMapState.put(op,result.op()); //<-- this maps op ->  op'
 92             }
 93             return block;
 94         };
 95         CodeTransformer t = trackingValueDelegatingTransformer(
 96                 (block, op) -> {
 97                     if (op instanceof JavaOp.AddOp) {
 98                         CodeContext cc = block.context();
 99                         var newSubOp = JavaOp.sub(cc.getValue(op.operands().get(0)), cc.getValue(op.operands().get(1)));
100                         Op.Result r = block.op(newSubOp);
101                         cc.mapValue(op.result(), r);
102                         transformedOpMapState.put(op,newSubOp);
103                     } else {
104                         var r = block.op(op);
105                         transformedOpMapState.put(op,r.op());
106                     }
107                     return block;
108                 },
109                 (vIn, vOut) -> {
110                     if (mapState.containsKey(vIn) && vOut != null) {
111                         transformedMapState.put(vOut, mapState.get(vIn));
112                     }
113                 });
114 
115         System.out.println(threeSumFuncOp.toText());
116         print(mapState);
117         CoreOp.FuncOp threeSumFuncOp1 = threeSumFuncOp.transform(opTracker);
118         System.out.println(threeSumFuncOp1.toText());
119         print(transformedMapState);
120     }
121 
122     static void print(Map<Value, String> mappedStated) {
123         mappedStated.forEach((v, s) -> {
124             System.out.println(v + "[" + ((v instanceof Op.Result r ? r.op() : "") + "] -> " + s));
125         });
126     }
127 
128     static CodeTransformer trackingValueDelegatingTransformer(
129             BiFunction<Block.Builder, Op, Block.Builder> t,
130             BiConsumer<Value, Value> mapAction) {
131         return (block, op) -> {
132             try {
133                 return t.apply(block, op);
134             } finally {
135                 Value in = op.result();
136                 Value out = block.context().queryValue(in).orElse(null);
137                 mapAction.accept(in, out);
138             }
139         };
140     }
141 
142 
143     //@Test
144     static public void testAndThen() {
145         CoreOp.FuncOp threeSumFuncOp = getFuncOp("threeSum");
146         JavaOp.AddOp addOp = (JavaOp.AddOp) threeSumFuncOp.elements().filter(e -> e instanceof JavaOp.AddOp).findFirst().orElseThrow();
147         JavaOp.MulOp mulOp = (JavaOp.MulOp) threeSumFuncOp.elements().filter(e -> e instanceof JavaOp.MulOp).findFirst().orElseThrow();
148         Map<Value, String> mapState = new HashMap<>();
149         mapState.put(addOp.result(), "STATE 1");
150         mapState.put(mulOp.result(), "STATE 2");
151 
152         Map<Value, String> transformedMapState = new HashMap<>();
153 
154         CodeTransformer t = trackingValueAndThenTransformer(
155                 (block, op) -> {
156                     if (op instanceof JavaOp.AddOp) {
157                         CodeContext cc = block.context();
158                         Op.Result r = block.op(JavaOp.sub(
159                                 cc.getValue(op.operands().get(0)),
160                                 cc.getValue(op.operands().get(1))));
161                         cc.mapValue(op.result(), r);
162                     } else {
163                         block.op(op);
164                     }
165                     return block;
166                 },
167                 (vIn, vOut) -> {
168                     if (mapState.containsKey(vIn) && vOut != null) {
169                         transformedMapState.put(vOut, mapState.get(vIn));
170                     }
171                 });
172 
173         System.out.println(threeSumFuncOp.toText());
174         print(mapState);
175         CoreOp.FuncOp threeSumFuncOp1 = threeSumFuncOp.transform(t);
176         System.out.println(threeSumFuncOp1.toText());
177         print(transformedMapState);
178     }
179 
180     static CodeTransformer trackingValueAndThenTransformer(
181             CodeTransformer t,
182             BiConsumer<Value, Value> mapAction) {
183 
184         // This only composes CodeTransformer.acceptOp.
185         // If the given code transformer overrides acceptBody or acceptBlock,
186         // that behavior is not preserved
187         return ((builder, op) -> {
188             builder = t.acceptOp(builder, op);
189 
190             Value in = op.result();
191             Value out = builder.context().queryValue(in).orElse(null);
192             mapAction.accept(in, out);
193 
194             return builder;
195         });
196     }
197 
198     static CoreOp.FuncOp getFuncOp(String name) {
199         Optional<Method> om = Stream.of(TransformState.class.getDeclaredMethods())
200                 .filter(m -> m.getName().equals(name))
201                 .findFirst();
202 
203         Method m = om.get();
204         return Op.ofMethod(m).get();
205     }
206 
207     static public  void main(String[] args) {
208         testOpToOp();
209         //testDelegate();
210         //testAndThen();
211     }
212 }