1 /* 2 * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved. 3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. 4 * 5 * This code is free software; you can redistribute it and/or modify it 6 * under the terms of the GNU General Public License version 2 only, as 7 * published by the Free Software Foundation. Oracle designates this 8 * particular file as subject to the "Classpath" exception as provided 9 * by Oracle in the LICENSE file that accompanied this code. 10 * 11 * This code is distributed in the hope that it will be useful, but WITHOUT 12 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or 13 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License 14 * version 2 for more details (a copy is included in the LICENSE file that 15 * accompanied this code). 16 * 17 * You should have received a copy of the GNU General Public License version 18 * 2 along with this work; if not, write to the Free Software Foundation, 19 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. 20 * 21 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA 22 * or visit www.oracle.com if you need additional information or have any 23 * questions. 24 */ 25 package experiments; 26 27 import jdk.incubator.code.*; 28 import jdk.incubator.code.dialect.core.CoreOp; 29 import jdk.incubator.code.dialect.java.JavaOp; 30 31 import java.lang.reflect.Method; 32 import java.util.HashMap; 33 import java.util.Map; 34 import java.util.Optional; 35 import java.util.function.BiConsumer; 36 import java.util.function.BiFunction; 37 import java.util.stream.Stream; 38 39 public class TransformState { 40 41 @CodeReflection 42 static int threeSum(int a, int b, int c) { 43 return a + b * c; 44 } 45 46 //@Test 47 public static void testOpToOp() { 48 CoreOp.FuncOp threeSumFuncOp = getFuncOp("threeSum"); 49 Map<Op, Op> oldOpToNewOpMap = new HashMap<>(); 50 OpTransformer opTracker = (block, op) -> { 51 if (op instanceof JavaOp.AddOp) { 52 CopyContext cc = block.context(); 53 var newSubOp = JavaOp.sub(cc.getValue(op.operands().get(0)), cc.getValue(op.operands().get(1))); 54 Op.Result result = block.op(newSubOp); 55 cc.mapValue(op.result(), result); 56 oldOpToNewOpMap.put(op,newSubOp);// <-- this maps op -> new subOp 57 } else { 58 var result = block.op(op); 59 oldOpToNewOpMap.put(op,result.op()); //<-- this maps op -> op' 60 } 61 return block; 62 }; 63 64 System.out.println(threeSumFuncOp.toText()); 65 CoreOp.FuncOp threeSumFuncOp1 = threeSumFuncOp.transform(opTracker); 66 System.out.println(threeSumFuncOp1.toText()); 67 } 68 69 //@Test 70 public static void testDelegate() { 71 CoreOp.FuncOp threeSumFuncOp = getFuncOp("threeSum"); 72 JavaOp.AddOp addOp = (JavaOp.AddOp) threeSumFuncOp.elements().filter(e -> e instanceof JavaOp.AddOp).findFirst().orElseThrow(); 73 JavaOp.MulOp mulOp = (JavaOp.MulOp) threeSumFuncOp.elements().filter(e -> e instanceof JavaOp.MulOp).findFirst().orElseThrow(); 74 Map<Value, String> mapState = new HashMap<>(); 75 mapState.put(addOp.result(), "STATE 1"); 76 mapState.put(mulOp.result(), "STATE 2"); 77 78 Map<Value, String> transformedMapState = new HashMap<>(); 79 Map<Op, Op> transformedOpMapState = new HashMap<>(); 80 81 OpTransformer opTracker = (block, op) -> { 82 if (op instanceof JavaOp.AddOp) { 83 CopyContext cc = block.context(); 84 var newSubOp = JavaOp.sub(cc.getValue(op.operands().get(0)), cc.getValue(op.operands().get(1))); 85 Op.Result result = block.op(newSubOp); 86 cc.mapValue(op.result(), result); 87 transformedOpMapState.put(op,newSubOp);// <-- this maps op -> new subOp 88 } else { 89 var result = block.op(op); 90 transformedOpMapState.put(op,result.op()); //<-- this maps op -> op' 91 } 92 return block; 93 }; 94 OpTransformer t = trackingValueDelegatingTransformer( 95 (block, op) -> { 96 if (op instanceof JavaOp.AddOp) { 97 CopyContext cc = block.context(); 98 var newSubOp = JavaOp.sub(cc.getValue(op.operands().get(0)), cc.getValue(op.operands().get(1))); 99 Op.Result r = block.op(newSubOp); 100 cc.mapValue(op.result(), r); 101 transformedOpMapState.put(op,newSubOp); 102 } else { 103 var r = block.op(op); 104 transformedOpMapState.put(op,r.op()); 105 } 106 return block; 107 }, 108 (vIn, vOut) -> { 109 if (mapState.containsKey(vIn) && vOut != null) { 110 transformedMapState.put(vOut, mapState.get(vIn)); 111 } 112 }); 113 114 System.out.println(threeSumFuncOp.toText()); 115 print(mapState); 116 CoreOp.FuncOp threeSumFuncOp1 = threeSumFuncOp.transform(opTracker); 117 System.out.println(threeSumFuncOp1.toText()); 118 print(transformedMapState); 119 } 120 121 static void print(Map<Value, String> mappedStated) { 122 mappedStated.forEach((v, s) -> { 123 System.out.println(v + "[" + ((v instanceof Op.Result r ? r.op() : "") + "] -> " + s)); 124 }); 125 } 126 127 static OpTransformer trackingValueDelegatingTransformer( 128 BiFunction<Block.Builder, Op, Block.Builder> t, 129 BiConsumer<Value, Value> mapAction) { 130 return (block, op) -> { 131 try { 132 return t.apply(block, op); 133 } finally { 134 Value in = op.result(); 135 Value out = block.context().getValueOrDefault(in, null); 136 mapAction.accept(in, out); 137 } 138 }; 139 } 140 141 142 //@Test 143 static public void testAndThen() { 144 CoreOp.FuncOp threeSumFuncOp = getFuncOp("threeSum"); 145 JavaOp.AddOp addOp = (JavaOp.AddOp) threeSumFuncOp.elements().filter(e -> e instanceof JavaOp.AddOp).findFirst().orElseThrow(); 146 JavaOp.MulOp mulOp = (JavaOp.MulOp) threeSumFuncOp.elements().filter(e -> e instanceof JavaOp.MulOp).findFirst().orElseThrow(); 147 Map<Value, String> mapState = new HashMap<>(); 148 mapState.put(addOp.result(), "STATE 1"); 149 mapState.put(mulOp.result(), "STATE 2"); 150 151 Map<Value, String> transformedMapState = new HashMap<>(); 152 153 OpTransformer t = trackingValueAndThenTransformer( 154 (block, op) -> { 155 if (op instanceof JavaOp.AddOp) { 156 CopyContext cc = block.context(); 157 Op.Result r = block.op(JavaOp.sub( 158 cc.getValue(op.operands().get(0)), 159 cc.getValue(op.operands().get(1)))); 160 cc.mapValue(op.result(), r); 161 } else { 162 block.op(op); 163 } 164 return block; 165 }, 166 (vIn, vOut) -> { 167 if (mapState.containsKey(vIn) && vOut != null) { 168 transformedMapState.put(vOut, mapState.get(vIn)); 169 } 170 }); 171 172 System.out.println(threeSumFuncOp.toText()); 173 print(mapState); 174 CoreOp.FuncOp threeSumFuncOp1 = threeSumFuncOp.transform(t); 175 System.out.println(threeSumFuncOp1.toText()); 176 print(transformedMapState); 177 } 178 179 static OpTransformer trackingValueAndThenTransformer( 180 OpTransformer t, 181 BiConsumer<Value, Value> mapAction) { 182 return OpTransformer.andThen(t, (block, op) -> { 183 Value in = op.result(); 184 Value out = block.context().getValueOrDefault(in, null); 185 mapAction.accept(in, out); 186 return block; 187 }); 188 } 189 190 191 static CoreOp.FuncOp getFuncOp(String name) { 192 Optional<Method> om = Stream.of(TransformState.class.getDeclaredMethods()) 193 .filter(m -> m.getName().equals(name)) 194 .findFirst(); 195 196 Method m = om.get(); 197 return Op.ofMethod(m).get(); 198 } 199 200 static public void main(String[] args) { 201 testOpToOp(); 202 //testDelegate(); 203 //testAndThen(); 204 } 205 }