1 /*
  2  * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 package hat.codebuilders;
 26 
 27 import hat.dialect.HATF16VarOp;
 28 import hat.dialect.HATMemoryOp;
 29 import hat.dialect.HATVectorVarOp;
 30 import hat.optools.FuncOpParams;
 31 import jdk.incubator.code.Block;
 32 import jdk.incubator.code.Op;
 33 import jdk.incubator.code.Value;
 34 import jdk.incubator.code.dialect.core.CoreOp;
 35 import jdk.incubator.code.dialect.java.JavaOp;
 36 
 37 import java.lang.invoke.MethodHandles;
 38 import java.util.HashMap;
 39 import java.util.Map;
 40 
 41 public class ScopedCodeBuilderContext extends CodeBuilderContext {
 42 
 43     public static class Scope<O extends Op> {
 44         final Scope<?> parent;
 45         final O op;
 46 
 47         public Scope(Scope<?> parent, O op) {
 48             this.parent = parent;
 49             this.op = op;
 50         }
 51 
 52         public Op resolve(Value value) {
 53             if (value instanceof Op.Result result && result.op() instanceof CoreOp.VarOp varOp) {
 54                 return varOp;
 55             }
 56 
 57             if (value instanceof Op.Result result && result.op() instanceof HATMemoryOp hatMemoryOp) {
 58                 return hatMemoryOp;
 59             }
 60 
 61             if (value instanceof Op.Result result && result.op() instanceof HATVectorVarOp hatVectorVarOp) {
 62                 return hatVectorVarOp;
 63             }
 64 
 65             if (value instanceof Op.Result result && result.op() instanceof HATF16VarOp hatf16VarOp) {
 66                 return hatf16VarOp;
 67             }
 68 
 69             if (parent != null) {
 70                 return parent.resolve(value);
 71             }
 72             throw new IllegalStateException("failed to resolve VarOp for value " + value);
 73         }
 74     }
 75 
 76     public static class FuncScope extends Scope<CoreOp.FuncOp> {
 77         final FuncOpParams paramTable;
 78         FuncScope(Scope<?> parent, CoreOp.FuncOp funcOp) {
 79             super(parent, funcOp);
 80             paramTable = new FuncOpParams(funcOp);
 81         }
 82 
 83         @Override
 84         public Op resolve(Value value) {
 85             if (value instanceof Block.Parameter blockParameter) {
 86                 if (paramTable.parameterVarOpMap.containsKey(blockParameter)) {
 87                     return paramTable.parameterVarOpMap.get(blockParameter);
 88                 } else {
 89                     throw new IllegalStateException("what ?");
 90                 }
 91             } else {
 92                 return super.resolve(value);
 93             }
 94         }
 95     }
 96 
 97     public static class ForScope extends Scope<JavaOp.ForOp> {
 98         Map<Block.Parameter, CoreOp.VarOp> blockParamToVarOpMap = new HashMap<>();
 99         ForScope(Scope<?> parent, JavaOp.ForOp forOp) {
100             super(parent, forOp);
101             var loopParams = forOp.loopBody().entryBlock().parameters().toArray(new Block.Parameter[0]);
102             var updateParams = forOp.update().entryBlock().parameters().toArray(new Block.Parameter[0]);
103             var condParams = forOp.cond().entryBlock().parameters().toArray(new Block.Parameter[0]);
104             var lastInitOp = forOp.init().entryBlock().ops().getLast();
105             var lastInitOpOperand0Result = (Op.Result) lastInitOp.operands().getFirst();
106             var lastInitOpOperand0ResultOp = lastInitOpOperand0Result.op();
107             CoreOp.VarOp[] varOps;
108             if (lastInitOpOperand0ResultOp instanceof CoreOp.TupleOp tupleOp) {
109                  /*
110                  for (int j = 1, i=2, k=3; j < size; k+=1,i+=2,j+=3) {
111                     float sum = k+i+j;
112                  }
113                  java.for
114                  ()Tuple<Var<int>, Var<int>, Var<int>> -> {
115                      %0 : int = constant @"1";
116                      %1 : Var<int> = var %0 @"j";
117                      %2 : int = constant @"2";
118                      %3 : Var<int> = var %2 @"i";
119                      %4 : int = constant @"3";
120                      %5 : Var<int> = var %4 @"k";
121                      %6 : Tuple<Var<int>, Var<int>, Var<int>> = tuple %1 %3 %5;
122                      yield %6;
123                  }
124                  (%7 : Var<int>, %8 : Var<int>, %9 : Var<int>)boolean -> {
125                      %10 : int = var.load %7;
126                      %11 : int = var.load %12;
127                      %13 : boolean = lt %10 %11;
128                      yield %13;
129                  }
130                  (%14 : Var<int>, %15 : Var<int>, %16 : Var<int>)void -> {
131                      %17 : int = var.load %16;
132                      %18 : int = constant @"1";
133                      %19 : int = add %17 %18;
134                      var.store %16 %19;
135                      %20 : int = var.load %15;
136                      %21 : int = constant @"2";
137                      %22 : int = add %20 %21;
138                      var.store %15 %22;
139                      %23 : int = var.load %14;
140                      %24 : int = constant @"3";
141                      %25 : int = add %23 %24;
142                      var.store %14 %25;
143                      yield;
144                  }
145                  (%26 : Var<int>, %27 : Var<int>, %28 : Var<int>)void -> {
146                      %29 : int = var.load %28;
147                      %30 : int = var.load %27;
148                      %31 : int = add %29 %30;
149                      %32 : int = var.load %26;
150                      %33 : int = add %31 %32;
151                      %34 : float = conv %33;
152                      %35 : Var<float> = var %34 @"sum";
153                      java.continue;
154                  };
155                  */
156                 varOps = tupleOp.operands().stream().map(operand -> (CoreOp.VarOp) (((Op.Result) operand).op())).toList().toArray(new CoreOp.VarOp[0]);
157             } else {
158                  /*
159                  for (int j = 0; j < size; j+=1) {
160                     float sum = j;
161                  }
162                  java.for
163                     ()Var<int> -> {
164                         %0 : int = constant @"0";
165                         %1 : Var<int> = var %0 @"j";
166                         yield %1;
167                     }
168                     (%2 : Var<int>)boolean -> {
169                         %3 : int = var.load %2;
170                         %4 : int = var.load %5;
171                         %6 : boolean = lt %3 %4;
172                         yield %6;
173                     }
174                     (%7 : Var<int>)void -> {
175                         %8 : int = var.load %7;
176                         %9 : int = constant @"1";
177                         %10 : int = add %8 %9;
178                         var.store %7 %10;
179                         yield;
180                     }
181                     (%11 : Var<int>)void -> {
182                         %12 : int = var.load %11;
183                         %13 : float = conv %12;
184                         %14 : Var<float> = var %13 @"sum";
185                         java.continue;
186                     };
187 
188                  */
189                 varOps = new CoreOp.VarOp[]{(CoreOp.VarOp) lastInitOpOperand0ResultOp};
190             }
191             for (int i = 0; i < varOps.length; i++) {
192                 blockParamToVarOpMap.put(condParams[i], varOps[i]);
193                 blockParamToVarOpMap.put(updateParams[i], varOps[i]);
194                 blockParamToVarOpMap.put(loopParams[i], varOps[i]);
195             }
196         }
197 
198         @Override
199         public Op resolve(Value value) {
200             if (value instanceof Block.Parameter blockParameter) {
201                 CoreOp.VarOp varOp = this.blockParamToVarOpMap.get(blockParameter);
202                 if (varOp != null) {
203                     return varOp;
204                 }
205             }
206             return super.resolve(value);
207         }
208     }
209 
210     public static class IfScope extends Scope<JavaOp.IfOp> {
211         IfScope(Scope<?> parent, JavaOp.IfOp op) {
212             super(parent, op);
213         }
214     }
215 
216     private void popScope() {
217         scope = scope.parent;
218     }
219 
220     public  void ifScope(JavaOp.IfOp ifOp, Runnable r) {
221         scope = new IfScope(scope, ifOp);
222         r.run();
223         popScope();
224     }
225     public  void funcScope(CoreOp.FuncOp funcOp, Runnable r) {
226        scope = new FuncScope(scope,funcOp);
227         r.run();
228         popScope();
229     }
230     public  void forScope(JavaOp.ForOp forOp, Runnable r) {
231         scope = new ForScope(scope,forOp);
232         r.run();
233         popScope();
234     }
235 
236     public Scope<?> scope = null;
237 
238     public ScopedCodeBuilderContext(MethodHandles.Lookup lookup, CoreOp.FuncOp funcOp) {
239         super(lookup,funcOp);
240     }
241 
242     private Map<Op.Result, CoreOp.VarOp> finalVarOps = new HashMap<>();
243 
244     public void setFinals(Map<Op.Result, CoreOp.VarOp> finalVars) {
245         this.finalVarOps = finalVars;
246     }
247 
248     public boolean isVarOpFinal(CoreOp.VarOp varOp) {
249         return finalVarOps.containsKey(varOp.result());
250     }
251 }