1 /*
2 * Copyright (c) 2024 Intel Corporation. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation. Oracle designates this
8 * particular file as subject to the "Classpath" exception as provided
9 * by Oracle in the LICENSE file that accompanied this code.
10 *
11 * This code is distributed in the hope that it will be useful, but WITHOUT
12 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
13 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
14 * version 2 for more details (a copy is included in the LICENSE file that
15 * accompanied this code).
16 *
17 * You should have received a copy of the GNU General Public License version
18 * 2 along with this work; if not, write to the Free Software Foundation,
19 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
20 *
21 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
22 * or visit www.oracle.com if you need additional information or have any
23 * questions.
24 */
25
26 package intel.code.spirv;
27
28 import java.util.List;
29 import java.util.ArrayList;
30 import java.util.Map;
31 import java.util.HashMap;
32 import jdk.incubator.code.Block;
33 import jdk.incubator.code.Body;
34 import jdk.incubator.code.CodeTransformer;
35 import jdk.incubator.code.Op;
36 import jdk.incubator.code.Value;
37 import jdk.incubator.code.CodeType;
38 import jdk.incubator.code.dialect.core.CoreOp;
39 import jdk.incubator.code.dialect.java.JavaOp;
40 import jdk.incubator.code.dialect.java.JavaType;
41
42 public class TranslateToSpirvModel {
43 private Map<Block, Block.Builder> blockMap; // Java block to spirv block builder
44 private Map<Value, Value> valueMap; // Java model Value to Spirv model Value
45
46 public static SpirvOps.FuncOp translateFunction(CoreOp.FuncOp func) {
47 CoreOp.FuncOp lowFunc = lowerMethod(func);
48 TranslateToSpirvModel instance = new TranslateToSpirvModel();
49 Body.Builder bodyBuilder = instance.translateBody(lowFunc.body(), lowFunc, null);
50 return new SpirvOps.FuncOp(lowFunc.funcName(), lowFunc.invokableSignature(), bodyBuilder);
51 }
52
53 public TranslateToSpirvModel() {
54 blockMap = new HashMap<>();
55 valueMap = new HashMap<>();
56 }
57
58 private Body.Builder translateBody(Body body, Op parentOp, Body.Builder parentBody) {
59 Body.Builder bodyBuilder = Body.Builder.of(parentBody, body.bodySignature());
60 Block.Builder spirvBlock = bodyBuilder.entryBlock();
61 blockMap.put(body.entryBlock(), spirvBlock);
62 List<Block> blocks = body.blocks();
63 // map Java blocks to spirv blocks
64 for (Block b : blocks.subList(1, blocks.size())) {
65 Block.Builder loweredBlock = spirvBlock.block(b.parameterTypes());
66 blockMap.put(b, loweredBlock);
67 spirvBlock = loweredBlock;
68 }
69 // map entry block parameters to spirv function parameter
70 spirvBlock = bodyBuilder.entryBlock();
71 List<SpirvOp> paramOps = new ArrayList<>();
72 List<SpirvOps.VariableOp> varOps = new ArrayList<>();
73 Block entryBlock = body.entryBlock();
74 int paramCount = entryBlock.parameters().size();
75 for (int i = 0; i < paramCount; i++) {
76 Block.Parameter bp = entryBlock.parameters().get(i);
77 assert entryBlock.ops().get(i) instanceof CoreOp.VarOp;
78 SpirvOp funcParam = new SpirvOps.FunctionParameterOp(bp.type(), List.of());
79 spirvBlock.op(funcParam);
80 valueMap.put(bp, funcParam.result());
81 paramOps.add(funcParam);
82 }
83 // SPIR-V Variable ops must be the first ops in a function's entry block and do not include initialization.
84 // Emit all SPIR-V Variable ops first and emit initializing stores afterward, at the CR model VarOp position.
85 for (int i = 0; i < paramCount; i++) {
86 CoreOp.VarOp jvop = (CoreOp.VarOp)entryBlock.ops().get(i);
87 CodeType resultType = new PointerType(jvop.varValueType(), StorageType.CROSSWORKGROUP);
88 SpirvOps.VariableOp svop = new SpirvOps.VariableOp((String)jvop.externalize().get(""), resultType, jvop.varValueType());
89 spirvBlock.op(svop);
90 valueMap.put(jvop.result(), svop.result());
91 varOps.add(svop);
92 }
93 // add non-function-parameter variables
94 for (int bi = 0; bi < body.blocks().size(); bi++) {
95 Block block = body.blocks().get(bi);
96 spirvBlock = blockMap.get(block);
97 List<Op> ops = block.ops();
98 for (int i = (bi == 0 ? paramCount : 0); i < ops.size(); i++) {
99 if (bi > 0) spirvBlock = blockMap.get(block);
100 Op op = ops.get(i);
101 if (op instanceof CoreOp.VarOp jvop) {
102 CodeType resultType = new PointerType(jvop.varValueType(), StorageType.CROSSWORKGROUP);
103 SpirvOps.VariableOp svop = new SpirvOps.VariableOp((String)jvop.externalize().get(""), resultType, jvop.varValueType());
104 bodyBuilder.entryBlock().op(svop);
105 valueMap.put(jvop.result(), svop.result());
106 varOps.add(svop);
107 }
108 }
109 }
110 for (int bi = 0; bi < body.blocks().size(); bi++) {
111 Block block = body.blocks().get(bi);
112 spirvBlock = blockMap.get(block);
113 for (Op op : block.ops()) {
114 switch (op) {
115 case CoreOp.ReturnOp rop -> {
116 spirvBlock.op(new SpirvOps.ReturnOp(rop.resultType(), mapOperands(rop)));
117 }
118 case CoreOp.VarOp vop -> {
119 Value dest = valueMap.get(vop.result());
120 Value value = valueMap.get(vop.operands().get(0));
121 // init variable here; declaration has been moved to top of function
122 spirvBlock.op(new SpirvOps.StoreOp(dest, value));
123 }
124 case CoreOp.VarAccessOp.VarLoadOp vlo -> {
125 List<Value> operands = mapOperands(vlo);
126 SpirvOps.LoadOp load = new SpirvOps.LoadOp(vlo.resultType(), operands);
127 spirvBlock.op(load);
128 valueMap.put(vlo.result(), load.result());
129 }
130 case CoreOp.VarAccessOp.VarStoreOp vso -> {
131 Value dest = valueMap.get(vso.varOp().result());
132 Value value = valueMap.get(vso.operands().get(1));
133 spirvBlock.op(new SpirvOps.StoreOp(dest, value));
134 }
135 case JavaOp.ArrayAccessOp.ArrayLoadOp alo -> {
136 Value array = valueMap.get(alo.operands().get(0));
137 Value index = valueMap.get(alo.operands().get(1));
138 CodeType arrayType = array.type();
139 SpirvOps.ConvertOp convert = new SpirvOps.ConvertOp(JavaType.type(long.class), List.of(index));
140 spirvBlock.op(new SpirvOps.LoadOp(arrayType, List.of(array)));
141 spirvBlock.op(convert);
142 SpirvOp ibac = new SpirvOps.InBoundAccessChainOp(arrayType, List.of(array, convert.result()));
143 spirvBlock.op(ibac);
144 SpirvOp load = new SpirvOps.LoadOp(alo.resultType(), List.of(ibac.result()));
145 spirvBlock.op(load);
146 valueMap.put(alo.result(), load.result());
147 }
148 case JavaOp.ArrayAccessOp.ArrayStoreOp aso -> {
149 Value array = valueMap.get(aso.operands().get(0));
150 Value index = valueMap.get(aso.operands().get(1));
151 CodeType arrayType = array.type();
152 SpirvOp ibac = new SpirvOps.InBoundAccessChainOp(arrayType, List.of(array, index));
153 spirvBlock.op(ibac);
154 spirvBlock.op(new SpirvOps.StoreOp(ibac.result(), valueMap.get(aso.operands().get(2))));
155 }
156 case JavaOp.ArrayLengthOp alo -> {
157 Op len = new SpirvOps.ArrayLengthOp(JavaType.INT, List.of(valueMap.get(alo.operands().get(0))));
158 spirvBlock.op(len);
159 valueMap.put(alo.result(), len.result());
160 }
161 case JavaOp.AddOp aop -> {
162 CodeType type = aop.operands().get(0).type();
163 List<Value> operands = mapOperands(aop);
164 SpirvOp addOp;
165 if (isIntegerType(type)) addOp = new SpirvOps.IAddOp(type, operands);
166 else if (isFloatType(type)) addOp = new SpirvOps.FAddOp(type, operands);
167 else throw unsupported("type", type);
168 spirvBlock.op(addOp);
169 valueMap.put(aop.result(), addOp.result());
170 }
171 case JavaOp.SubOp sop -> {
172 CodeType type = sop.operands().get(0).type();
173 List<Value> operands = mapOperands(sop);
174 SpirvOp subOp;
175 if (isIntegerType(type)) subOp = new SpirvOps.ISubOp(type, operands);
176 else if (isFloatType(type)) subOp = new SpirvOps.FSubOp(type, operands);
177 else throw unsupported("type", type);
178 spirvBlock.op(subOp);
179 valueMap.put(sop.result(), subOp.result());
180 }
181 case JavaOp.MulOp mop -> {
182 CodeType type = mop.operands().get(0).type();
183 List<Value> operands = mapOperands(mop);
184 SpirvOp mulOp;
185 if (isIntegerType(type)) mulOp = new SpirvOps.IMulOp(type, operands);
186 else if (isFloatType(type)) mulOp = new SpirvOps.FMulOp(type, operands);
187 else throw unsupported("type", type);
188 spirvBlock.op(mulOp);
189 valueMap.put(mop.result(), mulOp.result());
190 }
191 case JavaOp.DivOp dop -> {
192 CodeType type = dop.operands().get(0).type();
193 List<Value> operands = mapOperands(dop);
194 SpirvOp divOp;
195 if (isIntegerType(type)) divOp = new SpirvOps.IDivOp(type, operands);
196 else if (isFloatType(type)) divOp = new SpirvOps.FDivOp(type, operands);
197 else throw unsupported("type", type);
198 spirvBlock.op(divOp);
199 valueMap.put(dop.result(), divOp.result());
200 }
201 case JavaOp.ModOp mop -> {
202 CodeType type = mop.operands().get(0).type();
203 List<Value> operands = mapOperands(mop);
204 SpirvOp modOp = new SpirvOps.ModOp(type, operands);
205 spirvBlock.op(modOp);
206 valueMap.put(mop.result(), modOp.result());
207 }
208 case JavaOp.EqOp eqop -> {
209 CodeType type = eqop.operands().get(0).type();
210 List<Value> operands = mapOperands(eqop);
211 SpirvOp seqop;
212 if (isIntegerType(type)) seqop = new SpirvOps.IEqualOp(type, operands);
213 else if (isFloatType(type)) seqop = new SpirvOps.FEqualOp(type, operands);
214 else throw unsupported("type", type);
215 spirvBlock.op(seqop);
216 valueMap.put(eqop.result(), seqop.result());
217 }
218 case JavaOp.NeqOp neqop -> {
219 CodeType type = neqop.operands().get(0).type();
220 List<Value> operands = mapOperands(neqop);
221 SpirvOp sneqop;
222 if (isIntegerType(type)) sneqop = new SpirvOps.INotEqualOp(type, operands);
223 else if (isFloatType(type)) sneqop = new SpirvOps.FNotEqualOp(type, operands);
224 else throw unsupported("type", type);
225 spirvBlock.op(sneqop);
226 valueMap.put(neqop.result(), sneqop.result());
227 }
228 case JavaOp.LtOp ltop -> {
229 CodeType type = ltop.operands().get(0).type();
230 List<Value> operands = mapOperands(ltop);
231 SpirvOp sltop = new SpirvOps.LtOp(type, operands);
232 spirvBlock.op(sltop);
233 valueMap.put(ltop.result(), sltop.result());
234 }
235 case JavaOp.InvokeOp inv -> {
236 List<Value> operands = mapOperands(inv);
237 SpirvOp spirvCall = new SpirvOps.CallOp(inv.invokeReference(), operands);
238 spirvBlock.op(spirvCall);
239 valueMap.put(inv.result(), spirvCall.result());
240 }
241 case CoreOp.ConstantOp cop -> {
242 SpirvOp scop = new SpirvOps.ConstantOp(cop.resultType(), cop.value());
243 spirvBlock.op(scop);
244 valueMap.put(cop.result(), scop.result());
245 }
246 case JavaOp.ConvOp cop -> {
247 List<Value> operands = mapOperands(cop);
248 SpirvOp scop = new SpirvOps.ConvertOp(cop.resultType(), operands);
249 spirvBlock.op(scop);
250 valueMap.put(cop.result(), scop.result());
251 }
252 case JavaOp.FieldAccessOp.FieldLoadOp flo -> {
253 SpirvOp load = new SpirvOps.FieldLoadOp(flo.resultType(), flo.fieldReference(), mapOperands(flo));
254 spirvBlock.op(load);
255 valueMap.put(flo.result(), load.result());
256 }
257 case CoreOp.BranchOp bop -> {
258 Block.Reference successor = blockMap.get(bop.branch().targetBlock()).reference();
259 spirvBlock.op(new SpirvOps.BranchOp(successor));
260 }
261 case CoreOp.ConditionalBranchOp cbop -> {
262 Block trueBlock = cbop.trueBranch().targetBlock();
263 Block falseBlock = cbop.falseBranch().targetBlock();
264 Block.Reference spvTrueBlock = blockMap.get(trueBlock).reference();
265 Block.Reference spvFalseBlock = blockMap.get(falseBlock).reference();
266 spirvBlock.op(new SpirvOps.ConditionalBranchOp(spvTrueBlock, spvFalseBlock, mapOperands(cbop)));
267 }
268 default -> unsupported("op", op.getClass());
269 }
270 } //ops
271 } // blocks
272 return bodyBuilder;
273 }
274
275 private RuntimeException unsupported(String message, Object value) {
276 return new RuntimeException("Unsupported " + message + ": " + value);
277 }
278
279 private static CoreOp.FuncOp lowerMethod(CoreOp.FuncOp fop) {
280 CoreOp.FuncOp lfop = fop.transform(CodeTransformer.LOWERING_TRANSFORMER);
281 return lfop;
282 }
283
284 private List<Value> mapOperands(Op op) {
285 List<Value> operands = new ArrayList<>();
286 for (Value javaValue : op.operands()) {
287 Value spirvValue = valueMap.get(javaValue);
288 assert spirvValue != null : "no value mapping from %s" + javaValue;
289 operands.add(spirvValue);
290 }
291 return operands;
292 }
293
294 private boolean isIntegerType(CodeType type) {
295 return type.equals(JavaType.INT) || type.equals(JavaType.LONG);
296 }
297
298 private boolean isFloatType(CodeType type) {
299 return type.equals(JavaType.FLOAT) || type.equals(JavaType.DOUBLE);
300 }
301 }