1 /*
  2  * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 package optkl.codebuilders;
 26 
 27 
 28 import optkl.util.carriers.LookupCarrier;
 29 import optkl.util.ops.VarLikeOp;
 30 
 31 import optkl.FuncOpParams;
 32 import jdk.incubator.code.Block;
 33 import jdk.incubator.code.Op;
 34 import jdk.incubator.code.Value;
 35 import jdk.incubator.code.dialect.core.CoreOp;
 36 import jdk.incubator.code.dialect.java.JavaOp;
 37 
 38 import java.lang.invoke.MethodHandles;
 39 import java.util.HashMap;
 40 import java.util.Map;
 41 
 42 public class ScopedCodeBuilderContext implements LookupCarrier {
 43     final public FuncOpParams paramTable;
 44 
 45     public boolean isInFor() {
 46         return (scope != null && scope.parent instanceof ForScope);
 47     }
 48 
 49     public static sealed abstract class Scope<O extends Op> permits ForScope, FuncScope, IfScope, LambdaScope, RootScope {
 50         public final Scope<?> parent;
 51         final O op;
 52 
 53         public Scope(Scope<?> parent, O op) {
 54             this.parent = parent;
 55             this.op = op;
 56         }
 57 
 58         public Op resolve(Value value) {
 59             if (value instanceof Op.Result result && result.op() instanceof CoreOp.VarOp varOp) {
 60                 return varOp;
 61             }
 62             if (value instanceof Op.Result result && result.op() instanceof VarLikeOp varOp) {
 63                 return (Op) varOp;
 64             }
 65             if (parent != null) {
 66                 return parent.resolve(value);
 67             }
 68             if (value instanceof Block.Parameter parameter){
 69                 return parameter.uses().iterator().next().op();
 70             }
 71             throw new IllegalStateException("failed to resolve VarOp for value " + value);
 72         }
 73     }
 74 
 75     public static final class FuncScope extends Scope<CoreOp.FuncOp> {
 76         final FuncOpParams paramTable;
 77         FuncScope(Scope<?> parent, CoreOp.FuncOp funcOp) {
 78             super(parent, funcOp);
 79             paramTable = new FuncOpParams(funcOp);
 80         }
 81 
 82         @Override
 83         public Op resolve(Value value) {
 84             if (value instanceof Block.Parameter blockParameter) {
 85                 if (paramTable.parameterVarOpMap.containsFrom(blockParameter)) {
 86                     return paramTable.parameterVarOpMap.getTo(blockParameter);
 87                 } else {
 88                     return super.resolve(value);
 89                 }
 90             } else {
 91                 return super.resolve(value);
 92             }
 93         }
 94     }
 95 
 96     public static final class ForScope extends Scope<JavaOp.ForOp> {
 97         Map<Block.Parameter, CoreOp.VarOp> blockParamToVarOpMap = new HashMap<>();
 98         ForScope(Scope<?> parent, JavaOp.ForOp forOp) {
 99             super(parent, forOp);
100             var loopParams = forOp.loopBody().entryBlock().parameters().toArray(new Block.Parameter[0]);
101             var updateParams = forOp.update().entryBlock().parameters().toArray(new Block.Parameter[0]);
102             var condParams = forOp.cond().entryBlock().parameters().toArray(new Block.Parameter[0]);
103             var lastInitOp = forOp.init().entryBlock().ops().getLast();
104             var lastInitOpOperand0Result = (Op.Result) lastInitOp.operands().getFirst();
105             var lastInitOpOperand0ResultOp = lastInitOpOperand0Result.op();
106             CoreOp.VarOp[] varOps;
107             if (lastInitOpOperand0ResultOp instanceof CoreOp.TupleOp tupleOp) {
108                  /*
109                  for (int j = 1, i=2, k=3; j < size; k+=1,i+=2,j+=3) {
110                     float sum = k+i+j;
111                  }
112                  java.for
113                  ()Tuple<Var<int>, Var<int>, Var<int>> -> {
114                      %0 : int = constant @"1";
115                      %1 : Var<int> = var %0 @"j";
116                      %2 : int = constant @"2";
117                      %3 : Var<int> = var %2 @"i";
118                      %4 : int = constant @"3";
119                      %5 : Var<int> = var %4 @"k";
120                      %6 : Tuple<Var<int>, Var<int>, Var<int>> = tuple %1 %3 %5;
121                      yield %6;
122                  }
123                  (%7 : Var<int>, %8 : Var<int>, %9 : Var<int>)boolean -> {
124                      %10 : int = var.load %7;
125                      %11 : int = var.load %12;
126                      %13 : boolean = lt %10 %11;
127                      yield %13;
128                  }
129                  (%14 : Var<int>, %15 : Var<int>, %16 : Var<int>)void -> {
130                      %17 : int = var.load %16;
131                      %18 : int = constant @"1";
132                      %19 : int = add %17 %18;
133                      var.store %16 %19;
134                      %20 : int = var.load %15;
135                      %21 : int = constant @"2";
136                      %22 : int = add %20 %21;
137                      var.store %15 %22;
138                      %23 : int = var.load %14;
139                      %24 : int = constant @"3";
140                      %25 : int = add %23 %24;
141                      var.store %14 %25;
142                      yield;
143                  }
144                  (%26 : Var<int>, %27 : Var<int>, %28 : Var<int>)void -> {
145                      %29 : int = var.load %28;
146                      %30 : int = var.load %27;
147                      %31 : int = add %29 %30;
148                      %32 : int = var.load %26;
149                      %33 : int = add %31 %32;
150                      %34 : float = conv %33;
151                      %35 : Var<float> = var %34 @"sum";
152                      java.continue;
153                  };
154                  */
155                 varOps = tupleOp.operands().stream().map(operand -> (CoreOp.VarOp) (((Op.Result) operand).op())).toList().toArray(new CoreOp.VarOp[0]);
156             } else {
157                  /*
158                  for (int j = 0; j < size; j+=1) {
159                     float sum = j;
160                  }
161                  java.for
162                     ()Var<int> -> {
163                         %0 : int = constant @"0";
164                         %1 : Var<int> = var %0 @"j";
165                         yield %1;
166                     }
167                     (%2 : Var<int>)boolean -> {
168                         %3 : int = var.load %2;
169                         %4 : int = var.load %5;
170                         %6 : boolean = lt %3 %4;
171                         yield %6;
172                     }
173                     (%7 : Var<int>)void -> {
174                         %8 : int = var.load %7;
175                         %9 : int = constant @"1";
176                         %10 : int = add %8 %9;
177                         var.store %7 %10;
178                         yield;
179                     }
180                     (%11 : Var<int>)void -> {
181                         %12 : int = var.load %11;
182                         %13 : float = conv %12;
183                         %14 : Var<float> = var %13 @"sum";
184                         java.continue;
185                     };
186 
187                  */
188                 varOps = new CoreOp.VarOp[]{(CoreOp.VarOp) lastInitOpOperand0ResultOp};
189             }
190             for (int i = 0; i < varOps.length; i++) {
191                 blockParamToVarOpMap.put(condParams[i], varOps[i]);
192                 blockParamToVarOpMap.put(updateParams[i], varOps[i]);
193                 blockParamToVarOpMap.put(loopParams[i], varOps[i]);
194             }
195         }
196 
197         @Override
198         public Op resolve(Value value) {
199             if (value instanceof Block.Parameter blockParameter) {
200                 CoreOp.VarOp varOp = this.blockParamToVarOpMap.get(blockParameter);
201                 if (varOp != null) {
202                     return varOp;
203                 }
204             }
205             return super.resolve(value);
206         }
207     }
208 
209     public static final class IfScope extends Scope<JavaOp.IfOp> {
210         IfScope(Scope<?> parent, JavaOp.IfOp op) {
211             super(parent, op);
212         }
213     }
214 
215     public static final class RootScope extends Scope<Op> {
216         RootScope() {
217             super(null,null);
218         }
219     }
220     public static final class LambdaScope extends Scope<JavaOp.LambdaOp> {
221         LambdaScope(Scope<?> parent, JavaOp.LambdaOp lambdaOp) {
222             super(parent,lambdaOp);
223         }
224         @Override public Op resolve(Value value){
225             return super.resolve(value);
226         }
227     }
228 
229     private void popScope() {
230         scope = scope.parent;
231     }
232 
233     public  void ifScope(JavaOp.IfOp ifOp, Runnable r) {
234         scope = new IfScope(scope, ifOp);
235         r.run();
236         popScope();
237     }
238     public  void lambdaScope(JavaOp.LambdaOp lambdaOp, Runnable r) {
239         scope = new LambdaScope(scope, lambdaOp);
240         r.run();
241         popScope();
242     }
243 
244     public  void funcScope(CoreOp.FuncOp funcOp, Runnable r) {
245        scope = new FuncScope(scope,funcOp);
246         r.run();
247         popScope();
248     }
249 
250     public  void forScope(JavaOp.ForOp forOp, Runnable r) {
251         scope = new ForScope(scope,forOp);
252         r.run();
253         popScope();
254     }
255 
256     private final  MethodHandles.Lookup lookup;
257     private final CoreOp.FuncOp funcOp;
258     private  Scope<?> scope = new RootScope();
259     @Override public MethodHandles.Lookup lookup(){
260         return lookup;
261     }
262 
263     public Op resolve(Value value){
264         return scope.resolve(value);
265     }
266     public CoreOp.FuncOp funcOp(){
267         return funcOp;
268     }
269     public ScopedCodeBuilderContext(MethodHandles.Lookup lookup, CoreOp.FuncOp funcOp) {
270         this.lookup = lookup;
271         this.funcOp= funcOp;
272         this.paramTable = new FuncOpParams(funcOp);
273     }
274 
275     private Map<Op.Result, CoreOp.VarOp> finalVarOps = new HashMap<>();
276 
277     public void setFinals(Map<Op.Result, CoreOp.VarOp> finalVars) {
278         this.finalVarOps = finalVars;
279     }
280 
281     public boolean isVarOpFinal(CoreOp.VarOp varOp) {
282         return finalVarOps.containsKey(varOp.result());
283     }
284 }