1 /*
  2  * Copyright (c) 2024 Intel Corporation. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 
 26 package intel.code.spirv;
 27 
 28 import java.util.List;
 29 import java.util.ArrayList;
 30 import java.util.Map;
 31 import java.util.HashMap;
 32 import jdk.incubator.code.Block;
 33 import jdk.incubator.code.Body;
 34 import jdk.incubator.code.CodeTransformer;
 35 import jdk.incubator.code.Op;
 36 import jdk.incubator.code.Value;
 37 import jdk.incubator.code.CodeType;
 38 import jdk.incubator.code.dialect.core.CoreOp;
 39 import jdk.incubator.code.dialect.java.JavaOp;
 40 import jdk.incubator.code.dialect.java.JavaType;
 41 
 42 public class TranslateToSpirvModel  {
 43     private Map<Block, Block.Builder> blockMap;    // Java block to spirv block builder
 44     private Map<Value, Value> valueMap;            // Java model Value to Spirv model Value
 45 
 46     public static SpirvOps.FuncOp translateFunction(CoreOp.FuncOp func) {
 47         CoreOp.FuncOp lowFunc = lowerMethod(func);
 48         TranslateToSpirvModel instance = new TranslateToSpirvModel();
 49         Body.Builder bodyBuilder = instance.translateBody(lowFunc.body(), lowFunc, null);
 50         return new SpirvOps.FuncOp(lowFunc.funcName(), lowFunc.invokableSignature(), bodyBuilder);
 51     }
 52 
 53     public TranslateToSpirvModel() {
 54         blockMap = new HashMap<>();
 55         valueMap = new HashMap<>();
 56     }
 57 
 58     private Body.Builder translateBody(Body body, Op parentOp, Body.Builder parentBody) {
 59         Body.Builder bodyBuilder = Body.Builder.of(parentBody, body.bodySignature());
 60         Block.Builder spirvBlock = bodyBuilder.entryBlock();
 61         blockMap.put(body.entryBlock(), spirvBlock);
 62         List<Block> blocks = body.blocks();
 63         // map Java blocks to spirv blocks
 64         for (Block b : blocks.subList(1, blocks.size()))  {
 65             Block.Builder loweredBlock = spirvBlock.block(b.parameterTypes());
 66             blockMap.put(b, loweredBlock);
 67             spirvBlock = loweredBlock;
 68         }
 69         // map entry block parameters to spirv function parameter
 70         spirvBlock = bodyBuilder.entryBlock();
 71         List<SpirvOp> paramOps = new ArrayList<>();
 72         List<SpirvOps.VariableOp> varOps = new ArrayList<>();
 73         Block entryBlock = body.entryBlock();
 74         int paramCount = entryBlock.parameters().size();
 75         for (int i = 0; i < paramCount; i++) {
 76             Block.Parameter bp = entryBlock.parameters().get(i);
 77             assert entryBlock.ops().get(i) instanceof CoreOp.VarOp;
 78             SpirvOp funcParam = new SpirvOps.FunctionParameterOp(bp.type(), List.of());
 79             spirvBlock.op(funcParam);
 80             valueMap.put(bp, funcParam.result());
 81             paramOps.add(funcParam);
 82         }
 83         // SPIR-V Variable ops must be the first ops in a function's entry block and do not include initialization.
 84         // Emit all SPIR-V Variable ops first and emit initializing stores afterward, at the CR model VarOp position.
 85         for (int i = 0; i < paramCount; i++) {
 86             CoreOp.VarOp jvop = (CoreOp.VarOp)entryBlock.ops().get(i);
 87             CodeType resultType = new PointerType(jvop.varValueType(), StorageType.CROSSWORKGROUP);
 88             SpirvOps.VariableOp svop = new SpirvOps.VariableOp((String)jvop.externalize().get(""), resultType, jvop.varValueType());
 89             spirvBlock.op(svop);
 90             valueMap.put(jvop.result(), svop.result());
 91             varOps.add(svop);
 92         }
 93         // add non-function-parameter variables
 94         for (int bi = 0; bi < body.blocks().size(); bi++)  {
 95             Block block = body.blocks().get(bi);
 96             spirvBlock = blockMap.get(block);
 97             List<Op> ops = block.ops();
 98             for (int i = (bi == 0 ? paramCount : 0); i < ops.size(); i++) {
 99                 if (bi > 0) spirvBlock = blockMap.get(block);
100                 Op op = ops.get(i);
101                 if (op instanceof CoreOp.VarOp jvop) {
102                     CodeType resultType = new PointerType(jvop.varValueType(), StorageType.CROSSWORKGROUP);
103                     SpirvOps.VariableOp svop = new SpirvOps.VariableOp((String)jvop.externalize().get(""), resultType, jvop.varValueType());
104                     bodyBuilder.entryBlock().op(svop);
105                     valueMap.put(jvop.result(), svop.result());
106                     varOps.add(svop);
107                 }
108             }
109         }
110         for (int bi = 0; bi < body.blocks().size(); bi++)  {
111             Block block = body.blocks().get(bi);
112             spirvBlock = blockMap.get(block);
113             for (Op op : block.ops()) {
114                 switch (op) {
115                     case CoreOp.ReturnOp rop -> {
116                         spirvBlock.op(new SpirvOps.ReturnOp(rop.resultType(), mapOperands(rop)));
117                     }
118                     case CoreOp.VarOp vop -> {
119                         Value dest = valueMap.get(vop.result());
120                         Value value = valueMap.get(vop.operands().get(0));
121                         // init variable here; declaration has been moved to top of function
122                         spirvBlock.op(new SpirvOps.StoreOp(dest, value));
123                     }
124                     case CoreOp.VarAccessOp.VarLoadOp vlo -> {
125                         List<Value> operands = mapOperands(vlo);
126                         SpirvOps.LoadOp load = new SpirvOps.LoadOp(vlo.resultType(), operands);
127                         spirvBlock.op(load);
128                         valueMap.put(vlo.result(), load.result());
129                     }
130                     case CoreOp.VarAccessOp.VarStoreOp vso -> {
131                         Value dest = valueMap.get(vso.varOp().result());
132                         Value value = valueMap.get(vso.operands().get(1));
133                         spirvBlock.op(new SpirvOps.StoreOp(dest, value));
134                     }
135                     case JavaOp.ArrayAccessOp.ArrayLoadOp alo -> {
136                         Value array = valueMap.get(alo.operands().get(0));
137                         Value index = valueMap.get(alo.operands().get(1));
138                         CodeType arrayType = array.type();
139                         SpirvOps.ConvertOp convert = new SpirvOps.ConvertOp(JavaType.type(long.class), List.of(index));
140                         spirvBlock.op(new SpirvOps.LoadOp(arrayType, List.of(array)));
141                         spirvBlock.op(convert);
142                         SpirvOp ibac = new SpirvOps.InBoundAccessChainOp(arrayType, List.of(array, convert.result()));
143                         spirvBlock.op(ibac);
144                         SpirvOp load = new SpirvOps.LoadOp(alo.resultType(), List.of(ibac.result()));
145                         spirvBlock.op(load);
146                         valueMap.put(alo.result(), load.result());
147                     }
148                     case JavaOp.ArrayAccessOp.ArrayStoreOp aso -> {
149                         Value array = valueMap.get(aso.operands().get(0));
150                         Value index = valueMap.get(aso.operands().get(1));
151                         CodeType arrayType = array.type();
152                         SpirvOp ibac = new SpirvOps.InBoundAccessChainOp(arrayType, List.of(array, index));
153                         spirvBlock.op(ibac);
154                         spirvBlock.op(new SpirvOps.StoreOp(ibac.result(), valueMap.get(aso.operands().get(2))));
155                     }
156                     case JavaOp.ArrayLengthOp alo -> {
157                         Op len = new SpirvOps.ArrayLengthOp(JavaType.INT, List.of(valueMap.get(alo.operands().get(0))));
158                         spirvBlock.op(len);
159                         valueMap.put(alo.result(), len.result());
160                     }
161                     case JavaOp.AddOp aop -> {
162                         CodeType type = aop.operands().get(0).type();
163                         List<Value> operands = mapOperands(aop);
164                         SpirvOp addOp;
165                         if (isIntegerType(type)) addOp = new SpirvOps.IAddOp(type, operands);
166                         else if (isFloatType(type)) addOp = new SpirvOps.FAddOp(type, operands);
167                         else throw unsupported("type", type);
168                         spirvBlock.op(addOp);
169                         valueMap.put(aop.result(), addOp.result());
170                      }
171                     case JavaOp.SubOp sop -> {
172                         CodeType  type = sop.operands().get(0).type();
173                         List<Value> operands = mapOperands(sop);
174                         SpirvOp subOp;
175                         if (isIntegerType(type)) subOp = new SpirvOps.ISubOp(type, operands);
176                         else if (isFloatType(type)) subOp = new SpirvOps.FSubOp(type, operands);
177                         else throw unsupported("type", type);
178                         spirvBlock.op(subOp);
179                         valueMap.put(sop.result(), subOp.result());
180                      }
181                     case JavaOp.MulOp mop -> {
182                         CodeType type = mop.operands().get(0).type();
183                         List<Value> operands = mapOperands(mop);
184                         SpirvOp mulOp;
185                         if (isIntegerType(type)) mulOp = new SpirvOps.IMulOp(type, operands);
186                         else if (isFloatType(type)) mulOp = new SpirvOps.FMulOp(type, operands);
187                         else throw unsupported("type", type);
188                         spirvBlock.op(mulOp);
189                         valueMap.put(mop.result(), mulOp.result());
190                     }
191                     case JavaOp.DivOp dop -> {
192                         CodeType type = dop.operands().get(0).type();
193                         List<Value> operands = mapOperands(dop);
194                         SpirvOp divOp;
195                         if (isIntegerType(type)) divOp = new SpirvOps.IDivOp(type, operands);
196                         else if (isFloatType(type)) divOp = new SpirvOps.FDivOp(type, operands);
197                         else throw unsupported("type", type);
198                         spirvBlock.op(divOp);
199                         valueMap.put(dop.result(), divOp.result());
200                     }
201                     case JavaOp.ModOp mop -> {
202                         CodeType type = mop.operands().get(0).type();
203                         List<Value> operands = mapOperands(mop);
204                         SpirvOp modOp = new SpirvOps.ModOp(type, operands);
205                         spirvBlock.op(modOp);
206                         valueMap.put(mop.result(), modOp.result());
207                     }
208                     case JavaOp.EqOp eqop -> {
209                         CodeType type = eqop.operands().get(0).type();
210                         List<Value> operands = mapOperands(eqop);
211                         SpirvOp seqop;
212                         if (isIntegerType(type)) seqop = new SpirvOps.IEqualOp(type, operands);
213                         else if (isFloatType(type)) seqop = new SpirvOps.FEqualOp(type, operands);
214                         else throw unsupported("type", type);
215                         spirvBlock.op(seqop);
216                         valueMap.put(eqop.result(), seqop.result());
217                     }
218                     case JavaOp.NeqOp neqop -> {
219                         CodeType type = neqop.operands().get(0).type();
220                         List<Value> operands = mapOperands(neqop);
221                         SpirvOp sneqop;
222                         if (isIntegerType(type)) sneqop = new SpirvOps.INotEqualOp(type, operands);
223                         else if (isFloatType(type)) sneqop = new SpirvOps.FNotEqualOp(type, operands);
224                         else throw unsupported("type", type);
225                         spirvBlock.op(sneqop);
226                         valueMap.put(neqop.result(), sneqop.result());
227                     }
228                     case JavaOp.LtOp ltop -> {
229                         CodeType type = ltop.operands().get(0).type();
230                         List<Value> operands = mapOperands(ltop);
231                         SpirvOp sltop = new SpirvOps.LtOp(type, operands);
232                         spirvBlock.op(sltop);
233                         valueMap.put(ltop.result(), sltop.result());
234                     }
235                     case JavaOp.InvokeOp inv -> {
236                         List<Value> operands = mapOperands(inv);
237                         SpirvOp spirvCall = new SpirvOps.CallOp(inv.invokeReference(), operands);
238                         spirvBlock.op(spirvCall);
239                         valueMap.put(inv.result(), spirvCall.result());
240                     }
241                     case CoreOp.ConstantOp cop -> {
242                         SpirvOp scop = new SpirvOps.ConstantOp(cop.resultType(), cop.value());
243                         spirvBlock.op(scop);
244                         valueMap.put(cop.result(), scop.result());
245                     }
246                     case JavaOp.ConvOp cop -> {
247                         List<Value> operands = mapOperands(cop);
248                         SpirvOp scop = new SpirvOps.ConvertOp(cop.resultType(), operands);
249                         spirvBlock.op(scop);
250                         valueMap.put(cop.result(), scop.result());
251                     }
252                     case JavaOp.FieldAccessOp.FieldLoadOp flo -> {
253                         SpirvOp load = new SpirvOps.FieldLoadOp(flo.resultType(), flo.fieldReference(), mapOperands(flo));
254                         spirvBlock.op(load);
255                         valueMap.put(flo.result(), load.result());
256                     }
257                     case CoreOp.BranchOp bop -> {
258                         Block.Reference successor = blockMap.get(bop.branch().targetBlock()).reference();
259                         spirvBlock.op(new SpirvOps.BranchOp(successor));
260                     }
261                     case CoreOp.ConditionalBranchOp cbop -> {
262                         Block trueBlock = cbop.trueBranch().targetBlock();
263                         Block falseBlock = cbop.falseBranch().targetBlock();
264                         Block.Reference spvTrueBlock = blockMap.get(trueBlock).reference();
265                         Block.Reference spvFalseBlock = blockMap.get(falseBlock).reference();
266                         spirvBlock.op(new SpirvOps.ConditionalBranchOp(spvTrueBlock, spvFalseBlock, mapOperands(cbop)));
267                     }
268                     default -> unsupported("op", op.getClass());
269                 }
270             } //ops
271         } // blocks
272         return bodyBuilder;
273     }
274 
275     private RuntimeException unsupported(String message, Object value) {
276         return new RuntimeException("Unsupported " + message + ": " + value);
277     }
278 
279     private static CoreOp.FuncOp lowerMethod(CoreOp.FuncOp fop) {
280         CoreOp.FuncOp lfop = fop.transform(CodeTransformer.LOWERING_TRANSFORMER);
281         return lfop;
282     }
283 
284     private List<Value> mapOperands(Op op) {
285         List<Value> operands = new ArrayList<>();
286         for (Value javaValue : op.operands()) {
287             Value spirvValue = valueMap.get(javaValue);
288             assert spirvValue != null : "no value mapping from %s" + javaValue;
289             operands.add(spirvValue);
290         }
291         return operands;
292     }
293 
294     private boolean isIntegerType(CodeType type) {
295         return type.equals(JavaType.INT) || type.equals(JavaType.LONG);
296     }
297 
298     private boolean isFloatType(CodeType type) {
299         return type.equals(JavaType.FLOAT) || type.equals(JavaType.DOUBLE);
300     }
301 }