1 /*
2 * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation. Oracle designates this
8 * particular file as subject to the "Classpath" exception as provided
9 * by Oracle in the LICENSE file that accompanied this code.
10 *
11 * This code is distributed in the hope that it will be useful, but WITHOUT
12 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
13 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
14 * version 2 for more details (a copy is included in the LICENSE file that
15 * accompanied this code).
16 *
17 * You should have received a copy of the GNU General Public License version
18 * 2 along with this work; if not, write to the Free Software Foundation,
19 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
20 *
21 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
22 * or visit www.oracle.com if you need additional information or have any
23 * questions.
24 */
25 package hat.codebuilders;
26
27 import hat.dialect.HATF16VarOp;
28 import hat.dialect.HATMemoryOp;
29 import hat.dialect.HATVectorVarOp;
30 import hat.optools.FuncOpParams;
31 import jdk.incubator.code.Block;
32 import jdk.incubator.code.Op;
33 import jdk.incubator.code.Value;
34 import jdk.incubator.code.dialect.core.CoreOp;
35 import jdk.incubator.code.dialect.java.JavaOp;
36
37 import java.lang.invoke.MethodHandles;
38 import java.util.HashMap;
39 import java.util.Map;
40
41 public class ScopedCodeBuilderContext extends CodeBuilderContext {
42
43 public static class Scope<O extends Op> {
44 final Scope<?> parent;
45 final O op;
46
47 public Scope(Scope<?> parent, O op) {
48 this.parent = parent;
49 this.op = op;
50 }
51
52 public Op resolve(Value value) {
53 if (value instanceof Op.Result result && result.op() instanceof CoreOp.VarOp varOp) {
54 return varOp;
55 }
56
57 if (value instanceof Op.Result result && result.op() instanceof HATMemoryOp hatMemoryOp) {
58 return hatMemoryOp;
59 }
60
61 if (value instanceof Op.Result result && result.op() instanceof HATVectorVarOp hatVectorVarOp) {
62 return hatVectorVarOp;
63 }
64
65 if (value instanceof Op.Result result && result.op() instanceof HATF16VarOp hatf16VarOp) {
66 return hatf16VarOp;
67 }
68
69 if (parent != null) {
70 return parent.resolve(value);
71 }
72 throw new IllegalStateException("failed to resolve VarOp for value " + value);
73 }
74 }
75
76 public static class FuncScope extends Scope<CoreOp.FuncOp> {
77 final FuncOpParams paramTable;
78 FuncScope(Scope<?> parent, CoreOp.FuncOp funcOp) {
79 super(parent, funcOp);
80 paramTable = new FuncOpParams(funcOp);
81 }
82
83 @Override
84 public Op resolve(Value value) {
85 if (value instanceof Block.Parameter blockParameter) {
86 if (paramTable.parameterVarOpMap.containsKey(blockParameter)) {
87 return paramTable.parameterVarOpMap.get(blockParameter);
88 } else {
89 throw new IllegalStateException("what ?");
90 }
91 } else {
92 return super.resolve(value);
93 }
94 }
95 }
96
97 public static class ForScope extends Scope<JavaOp.ForOp> {
98 Map<Block.Parameter, CoreOp.VarOp> blockParamToVarOpMap = new HashMap<>();
99 ForScope(Scope<?> parent, JavaOp.ForOp forOp) {
100 super(parent, forOp);
101 var loopParams = forOp.loopBody().entryBlock().parameters().toArray(new Block.Parameter[0]);
102 var updateParams = forOp.update().entryBlock().parameters().toArray(new Block.Parameter[0]);
103 var condParams = forOp.cond().entryBlock().parameters().toArray(new Block.Parameter[0]);
104 var lastInitOp = forOp.init().entryBlock().ops().getLast();
105 var lastInitOpOperand0Result = (Op.Result) lastInitOp.operands().getFirst();
106 var lastInitOpOperand0ResultOp = lastInitOpOperand0Result.op();
107 CoreOp.VarOp[] varOps;
108 if (lastInitOpOperand0ResultOp instanceof CoreOp.TupleOp tupleOp) {
109 /*
110 for (int j = 1, i=2, k=3; j < size; k+=1,i+=2,j+=3) {
111 float sum = k+i+j;
112 }
113 java.for
114 ()Tuple<Var<int>, Var<int>, Var<int>> -> {
115 %0 : int = constant @"1";
116 %1 : Var<int> = var %0 @"j";
117 %2 : int = constant @"2";
118 %3 : Var<int> = var %2 @"i";
119 %4 : int = constant @"3";
120 %5 : Var<int> = var %4 @"k";
121 %6 : Tuple<Var<int>, Var<int>, Var<int>> = tuple %1 %3 %5;
122 yield %6;
123 }
124 (%7 : Var<int>, %8 : Var<int>, %9 : Var<int>)boolean -> {
125 %10 : int = var.load %7;
126 %11 : int = var.load %12;
127 %13 : boolean = lt %10 %11;
128 yield %13;
129 }
130 (%14 : Var<int>, %15 : Var<int>, %16 : Var<int>)void -> {
131 %17 : int = var.load %16;
132 %18 : int = constant @"1";
133 %19 : int = add %17 %18;
134 var.store %16 %19;
135 %20 : int = var.load %15;
136 %21 : int = constant @"2";
137 %22 : int = add %20 %21;
138 var.store %15 %22;
139 %23 : int = var.load %14;
140 %24 : int = constant @"3";
141 %25 : int = add %23 %24;
142 var.store %14 %25;
143 yield;
144 }
145 (%26 : Var<int>, %27 : Var<int>, %28 : Var<int>)void -> {
146 %29 : int = var.load %28;
147 %30 : int = var.load %27;
148 %31 : int = add %29 %30;
149 %32 : int = var.load %26;
150 %33 : int = add %31 %32;
151 %34 : float = conv %33;
152 %35 : Var<float> = var %34 @"sum";
153 java.continue;
154 };
155 */
156 varOps = tupleOp.operands().stream().map(operand -> (CoreOp.VarOp) (((Op.Result) operand).op())).toList().toArray(new CoreOp.VarOp[0]);
157 } else {
158 /*
159 for (int j = 0; j < size; j+=1) {
160 float sum = j;
161 }
162 java.for
163 ()Var<int> -> {
164 %0 : int = constant @"0";
165 %1 : Var<int> = var %0 @"j";
166 yield %1;
167 }
168 (%2 : Var<int>)boolean -> {
169 %3 : int = var.load %2;
170 %4 : int = var.load %5;
171 %6 : boolean = lt %3 %4;
172 yield %6;
173 }
174 (%7 : Var<int>)void -> {
175 %8 : int = var.load %7;
176 %9 : int = constant @"1";
177 %10 : int = add %8 %9;
178 var.store %7 %10;
179 yield;
180 }
181 (%11 : Var<int>)void -> {
182 %12 : int = var.load %11;
183 %13 : float = conv %12;
184 %14 : Var<float> = var %13 @"sum";
185 java.continue;
186 };
187
188 */
189 varOps = new CoreOp.VarOp[]{(CoreOp.VarOp) lastInitOpOperand0ResultOp};
190 }
191 for (int i = 0; i < varOps.length; i++) {
192 blockParamToVarOpMap.put(condParams[i], varOps[i]);
193 blockParamToVarOpMap.put(updateParams[i], varOps[i]);
194 blockParamToVarOpMap.put(loopParams[i], varOps[i]);
195 }
196 }
197
198 @Override
199 public Op resolve(Value value) {
200 if (value instanceof Block.Parameter blockParameter) {
201 CoreOp.VarOp varOp = this.blockParamToVarOpMap.get(blockParameter);
202 if (varOp != null) {
203 return varOp;
204 }
205 }
206 return super.resolve(value);
207 }
208 }
209
210 public static class IfScope extends Scope<JavaOp.IfOp> {
211 IfScope(Scope<?> parent, JavaOp.IfOp op) {
212 super(parent, op);
213 }
214 }
215
216 private void popScope() {
217 scope = scope.parent;
218 }
219
220 public void ifScope(JavaOp.IfOp ifOp, Runnable r) {
221 scope = new IfScope(scope, ifOp);
222 r.run();
223 popScope();
224 }
225 public void funcScope(CoreOp.FuncOp funcOp, Runnable r) {
226 scope = new FuncScope(scope,funcOp);
227 r.run();
228 popScope();
229 }
230 public void forScope(JavaOp.ForOp forOp, Runnable r) {
231 scope = new ForScope(scope,forOp);
232 r.run();
233 popScope();
234 }
235
236 public Scope<?> scope = null;
237
238 public ScopedCodeBuilderContext(MethodHandles.Lookup lookup, CoreOp.FuncOp funcOp) {
239 super(lookup,funcOp);
240 }
241
242 private Map<Op.Result, CoreOp.VarOp> finalVarOps = new HashMap<>();
243
244 public void setFinals(Map<Op.Result, CoreOp.VarOp> finalVars) {
245 this.finalVarOps = finalVars;
246 }
247
248 public boolean isVarOpFinal(CoreOp.VarOp varOp) {
249 return finalVarOps.containsKey(varOp.result());
250 }
251 }