1 /*
  2  * Copyright (c) 2024, 2025, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.
  8  *
  9  * This code is distributed in the hope that it will be useful, but WITHOUT
 10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 12  * version 2 for more details (a copy is included in the LICENSE file that
 13  * accompanied this code).
 14  *
 15  * You should have received a copy of the GNU General Public License version
 16  * 2 along with this work; if not, write to the Free Software Foundation,
 17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 18  *
 19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 20  * or visit www.oracle.com if you need additional information or have any
 21  * questions.
 22  */
 23 
 24 import jdk.incubator.code.*;
 25 import jdk.incubator.code.dialect.core.CoreOp;
 26 import jdk.incubator.code.dialect.core.CoreType;
 27 
 28 import java.util.*;
 29 import java.util.function.Function;
 30 
 31 public class AnfTransformer {
 32     final CoreOp.FuncOp sourceOp;
 33     final Map<Block, Function<Body.Builder, AnfDialect.AnfFuncOp>> fBuilders = new HashMap<>();
 34     final Body.Builder outerBodyBuilder;
 35     final ImmediateDominatorMap idomMap;
 36     final Map<Block, Value> funMap = new HashMap<>();
 37     final Map<Block, Value> funMap2 = new HashMap<>();
 38 
 39     public AnfTransformer(CoreOp.FuncOp funcOp) {
 40         sourceOp = funcOp;
 41         outerBodyBuilder = Body.Builder.of(null, CoreType.functionType(funcOp.body().yieldType()));
 42         idomMap = new ImmediateDominatorMap(funcOp.body());
 43     }
 44 
 45     public AnfDialect.AnfFuncOp transform() {
 46         return transformOuterBody(sourceOp.body());
 47     }
 48 
 49     //Outer body corresponds to outermost letrec
 50     public AnfDialect.AnfFuncOp transformOuterBody(Body b) {
 51         var entry = b.entryBlock();
 52 
 53         var builderEntry = outerBodyBuilder.entryBlock();
 54 
 55         var selfRefP = builderEntry.parameter(((CoreOp.FuncOp) b.parentOp()).invokableType());
 56         funMap.put(entry, selfRefP);
 57 
 58         for (Block.Parameter p : entry.parameters()) {
 59             var newP = builderEntry.parameter(p.type());
 60             builderEntry.context().mapValue(p,newP);
 61         }
 62 
 63         var outerLetRecBody = Body.Builder.of(outerBodyBuilder, CoreType.functionType(b.yieldType(), List.of()), CopyContext.create(builderEntry.context()));
 64 
 65         List<Block> dominatedBlocks = idomMap.idominates(entry);
 66         List<AnfDialect.AnfFuncOp> funs = dominatedBlocks.stream().map(block -> transformBlock(block, outerLetRecBody)).toList();
 67 
 68         var res = transformBlock(entry, outerLetRecBody);
 69         return res;
 70 
 71     }
 72 
 73     public AnfDialect.AnfFuncOp transformBlock(Block b, Body.Builder bodyBuilder) {
 74         if (idomMap.idominates(b).isEmpty()) {
 75             return transformLeafBlock(b, bodyBuilder);
 76         }
 77         return transformDomBlock(b, bodyBuilder);
 78     }
 79 
 80     //"Leaf" in this case is a leaf of the dominator tree
 81     public AnfDialect.AnfFuncOp transformLeafBlock(Block b, Body.Builder ancestorBodyBuilder) {
 82         var blockReturnType = getBlockReturnType(b);
 83         var blockFType = CoreType.functionType(blockReturnType);
 84 
 85         List<TypeElement> synthParamTypes = new ArrayList<>();
 86         synthParamTypes.add(blockFType);
 87 
 88         var blockFTypeSynth = CoreType.functionType(blockReturnType, synthParamTypes);
 89 
 90         Body.Builder newBodyBuilder = Body.Builder.of(ancestorBodyBuilder, blockFTypeSynth, CopyContext.create(ancestorBodyBuilder.entryBlock().context()));
 91 
 92         var selfRefParam = newBodyBuilder.entryBlock().parameters().get(0);
 93         funMap.put(b, selfRefParam);
 94 
 95         for (Block.Parameter param : b.parameters()) {
 96             var p = newBodyBuilder.entryBlock().parameter(param.type());
 97             newBodyBuilder.entryBlock().context().mapValue(param, p);
 98         }
 99 
100         var letBody = Body.Builder.of(newBodyBuilder, CoreType.functionType(blockReturnType, List.of()), CopyContext.create(newBodyBuilder.entryBlock().context()));
101 
102         AnfDialect.AnfLetOp let = transformOps(b, letBody);
103         newBodyBuilder.entryBlock().op(let);
104         return AnfDialect.func(b.toString(), newBodyBuilder);
105     }
106 
107     //Non leaf nodes of the dominator tree
108     public AnfDialect.AnfFuncOp transformDomBlock(Block b, Body.Builder ancestorBodyBuilder) {
109         var blockReturnType = getBlockReturnType(b);
110         var blockFType = CoreType.functionType(blockReturnType);
111 
112         List<TypeElement> synthParamTypes = new ArrayList<>();
113         synthParamTypes.add(blockFType);
114 
115         var blockFTypeSynth = CoreType.functionType(blockReturnType, synthParamTypes);
116 
117         //Function body contains letrec and its bodies
118         Body.Builder funcBodyBuilder = Body.Builder.of(ancestorBodyBuilder, blockFTypeSynth, CopyContext.create(ancestorBodyBuilder.entryBlock().context()));
119 
120         //Self param
121         var selfRefParam = funcBodyBuilder.entryBlock().parameters().get(0);
122         funMap.put(b, selfRefParam);
123 
124         for (Block.Parameter param : b.parameters()) {
125             var p = funcBodyBuilder.entryBlock().parameter(param.type());
126             funcBodyBuilder.entryBlock().context().mapValue(param, p);
127         }
128 
129         //letrec inner body
130         Body.Builder letrecBody = Body.Builder.of(funcBodyBuilder, CoreType.functionType(blockReturnType, List.of()), CopyContext.create(funcBodyBuilder.entryBlock().context()));
131 
132         List<Block> dominates = idomMap.idominates(b);
133         for (Block dblock : dominates) {
134             var res = transformDomBlock(dblock, letrecBody);
135             var fval = letrecBody.entryBlock().op(res);
136             funMap2.put(dblock, fval);
137         }
138 
139         var letBody = Body.Builder.of(letrecBody, letrecBody.bodyType(), CopyContext.create(letrecBody.entryBlock().context()));
140         transformBlockOps(b, letBody.entryBlock());
141         var let = AnfDialect.let(letBody);
142 
143         letrecBody.entryBlock().op(let);
144 
145         var letrec = AnfDialect.letrec(letrecBody);
146         funcBodyBuilder.entryBlock().op(letrec);
147         return AnfDialect.func(b.toString(), funcBodyBuilder);
148 
149     }
150 
151     private TypeElement getBlockReturnType(Block b) {
152         var op = b.ops().getLast();
153         if (op instanceof Op.Terminating) {
154             List<Block.Reference> destBlocks = new ArrayList<>();
155             if (op instanceof CoreOp.ReturnOp ro) {
156                 return ro.returnValue().type();
157             } else if (op instanceof CoreOp.YieldOp yo) {
158                 return yo.yieldValue().type();
159             } else if (op instanceof CoreOp.BranchOp bop) {
160                 destBlocks.addAll(bop.successors());
161             } else if (op instanceof CoreOp.ConditionalBranchOp cbop) {
162                 destBlocks.addAll(cbop.successors());
163             }
164             //Traverse until we find a yield or return type, TODO: not going to try to unify types
165 
166             Set<Block> visitedBlocks = new HashSet<>();
167             visitedBlocks.add(b);
168 
169             while (!destBlocks.isEmpty()) {
170                 var block = destBlocks.removeFirst().targetBlock();
171                 if (visitedBlocks.contains(block)) {
172                     continue;
173                 }
174 
175                 //Discovered a terminator with a return value, use its type
176                 if (block.successors().isEmpty()) {
177                     var o = block.ops().getLast();
178                     if (o instanceof CoreOp.ReturnOp ro) {
179                         return ro.returnValue().type();
180                     } else if (o instanceof CoreOp.YieldOp yo) {
181                         return yo.yieldValue().type();
182                     } else {
183                         throw new UnsupportedOperationException("Unsupported terminator encountered: " + o.opName());
184                     }
185                 } else {
186                     visitedBlocks.add(block);
187                     var newDests = block.successors().stream().filter((s) -> !visitedBlocks.contains(s.targetBlock())).toList();
188                     destBlocks.addAll(newDests);
189                 }
190             }
191 
192         }
193 
194         throw new RuntimeException("Encountered Block with no return " + op.opName());
195     }
196 
197     private Block.Builder transformEndOp(Block.Builder b, Op op) {
198         if (op instanceof Op.Terminating t) {
199             switch (t) {
200                 case CoreOp.ConditionalBranchOp c -> {
201                     var tbranch_args = c.trueBranch().arguments();
202                     tbranch_args = tbranch_args.stream().map(b.context()::getValue).toList();
203                     var fbranch_args = c.falseBranch().arguments();
204                     fbranch_args = fbranch_args.stream().map(b.context()::getValue).toList();
205 
206                     List<Value> trueArgs = new ArrayList<>();
207                     trueArgs.addAll(tbranch_args);
208 
209                     List<Value> falseArgs = new ArrayList<>();
210                     falseArgs.addAll(fbranch_args);
211 
212 
213                     var ifExp = AnfDialect.if_(b.parentBody(),
214                                     getBlockReturnType(c.trueBranch().targetBlock()),
215                                     b.context().getValue(c.predicate()))
216                             .if_((bodyBuilder) -> bindFunApp(bodyBuilder, trueArgs, c.trueBranch().targetBlock()))
217                             .else_((bodyBuilder) -> bindFunApp(bodyBuilder, falseArgs, c.falseBranch().targetBlock()));
218 
219                     b.op(ifExp);
220 
221                     return b;
222                 }
223                 case CoreOp.BranchOp br -> {
224                     var args = br.branch().arguments();
225                     args = args.stream().map(b.context()::getValue).toList();
226 
227                     List<Value> funcArgs = new ArrayList<>();
228                     funcArgs.addAll(args);
229                     bindFunApp(b, funcArgs, br.branch().targetBlock());
230 
231                     return b;
232                 }
233                 case CoreOp.ReturnOp ro -> {
234                     var rval = b.context().getValue(ro.returnValue());
235                     b.op(CoreOp.core_yield(rval));
236                     return b;
237                 }
238                 case CoreOp.YieldOp y ->  {
239                     var rval = b.context().getValue(y.yieldValue());
240                     b.op(CoreOp.core_yield(rval));
241                     return b;
242                 }
243                 default -> {
244                     throw new UnsupportedOperationException("Unsupported terminating op encountered: " + op);
245                 }
246             }
247         } else {
248             b.op(op);
249             return b;
250         }
251     }
252 
253 
254     private void bindFunApp(Block.Builder b, List<Value> args, Block target) {
255 
256         List<Value> synthArgs = new ArrayList<>();
257         synthArgs.addAll(args);
258         synthArgs.addFirst(funMap.get(target));
259         try {
260             b.op(AnfDialect.apply(synthArgs));
261             return;
262         } catch (IllegalStateException e) {}
263 
264         synthArgs.removeFirst();
265         synthArgs.addFirst(funMap2.get(target));
266 
267         try {
268             b.op(AnfDialect.apply(synthArgs));
269         } catch (IllegalStateException e) {
270             throw new IllegalStateException("No valid mapping to FuncOp for apply");
271         }
272 
273     }
274 
275 
276     public AnfDialect.AnfLetOp transformOps(Block b, Body.Builder bodyBuilder) {
277         Block.Builder blockb = bodyBuilder.entryBlock();
278         return transformOps(b, blockb);
279     }
280 
281     public AnfDialect.AnfLetOp transformOps(Block b, Block.Builder blockBuilder) {
282         transformBlockOps(b, blockBuilder);
283         return AnfDialect.let(blockBuilder.parentBody());
284     }
285 
286     public void transformBlockOps(Block b, Block.Builder blockBuilder) {
287         for (var op : b.ops()) {
288             transformEndOp(blockBuilder, op);
289         }
290     }
291 
292     static class ImmediateDominatorMap {
293 
294         private final Map <Block, List<Block>> dominatesMap;
295         private final Map <Block, Block> dominatorsMap;
296 
297         public ImmediateDominatorMap(Body b) {
298             dominatorsMap = b.immediateDominators();
299             dominatesMap = new HashMap<>();
300 
301             //Reverse the idom relation
302             b.immediateDominators().forEach((dominated, dominator) -> {
303                 if (!dominated.equals(dominator)) {
304                     dominatesMap.compute(dominator, (k, v) -> {
305                         if (v == null) {
306                             var newList = new ArrayList<Block>();
307                             newList.add(dominated);
308                             return newList;
309                         } else {
310                             v.add(dominated);
311                             return v;
312                         }
313                     });
314                 }
315             });
316 
317         }
318 
319         //Looks "down" the dominator tree toward leaves
320         public List<Block> idominates(Block b) {
321             return dominatesMap.getOrDefault(b, List.of());
322         }
323 
324         //Looks "up" the dominator tree toward start node
325         public Block idominatedBy(Block b) {
326             return dominatorsMap.get(b);
327         }
328     }
329 }