1 /*
  2  * Copyright (c) 2024 Intel Corporation. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 
 26 package intel.code.spirv;
 27 
 28 import java.util.List;
 29 import java.util.ArrayList;
 30 import java.util.Map;
 31 import java.util.HashMap;
 32 import java.lang.reflect.code.Block;
 33 import java.lang.reflect.code.Body;
 34 import java.lang.reflect.code.op.CoreOp;
 35 import java.lang.reflect.code.Op;
 36 import java.lang.reflect.code.Value;
 37 import java.lang.reflect.code.TypeElement;
 38 import java.lang.reflect.code.type.JavaType;
 39 
 40 public class TranslateToSpirvModel  {
 41     private Map<Block, Block.Builder> blockMap;    // Java block to spirv block builder
 42     private Map<Value, Value> valueMap;            // Java model Value to Spirv model Value
 43 
 44     public static SpirvOps.FuncOp translateFunction(CoreOp.FuncOp func) {
 45         CoreOp.FuncOp lowFunc = lowerMethod(func);
 46         TranslateToSpirvModel instance = new TranslateToSpirvModel();
 47         Body.Builder bodyBuilder = instance.translateBody(lowFunc.body(), lowFunc, null);
 48         return new SpirvOps.FuncOp(lowFunc.funcName(), lowFunc.invokableType(), bodyBuilder);
 49     }
 50 
 51     private TranslateToSpirvModel() {
 52         blockMap = new HashMap<>();
 53         valueMap = new HashMap<>();
 54     }
 55 
 56     private Body.Builder translateBody(Body body, Op parentOp, Body.Builder parentBody) {
 57         Body.Builder bodyBuilder = Body.Builder.of(parentBody, body.bodyType());
 58         Block.Builder spirvBlock = bodyBuilder.entryBlock();
 59         blockMap.put(body.entryBlock(), spirvBlock);
 60         List<Block> blocks = body.blocks();
 61         // map Java blocks to spirv blocks
 62         for (Block b : blocks.subList(1, blocks.size()))  {
 63             Block.Builder loweredBlock = spirvBlock.block(b.parameterTypes());
 64             blockMap.put(b, loweredBlock);
 65             spirvBlock = loweredBlock;
 66         }
 67         // map entry block parameters to spirv function parameter
 68         spirvBlock = bodyBuilder.entryBlock();
 69         List<SpirvOp> paramOps = new ArrayList<>();
 70         List<SpirvOps.VariableOp> varOps = new ArrayList<>();
 71         Block entryBlock = body.entryBlock();
 72         int paramCount = entryBlock.parameters().size();
 73         for (int i = 0; i < paramCount; i++) {
 74             Block.Parameter bp = entryBlock.parameters().get(i);
 75             assert entryBlock.ops().get(i) instanceof CoreOp.VarOp;
 76             SpirvOp funcParam = new SpirvOps.FunctionParameterOp(bp.type(), List.of());
 77             spirvBlock.op(funcParam);
 78             valueMap.put(bp, funcParam.result());
 79             paramOps.add(funcParam);
 80         }
 81         // SPIR-V Variable ops must be the first ops in a function's entry block and do not include initialization.
 82         // Emit all SPIR-V Variable ops first and emit initializing stores afterward, at the CR model VarOp position.
 83         for (int i = 0; i < paramCount; i++) {
 84             CoreOp.VarOp jvop = (CoreOp.VarOp)entryBlock.ops().get(i);
 85             TypeElement resultType = new PointerType(jvop.varType(), StorageType.CROSSWORKGROUP);
 86             SpirvOps.VariableOp svop = new SpirvOps.VariableOp((String)jvop.attributes().get(""), resultType, jvop.varType());
 87             spirvBlock.op(svop);
 88             valueMap.put(jvop.result(), svop.result());
 89             varOps.add(svop);
 90         }
 91         // add non-function-parameter variables
 92         for (int bi = 0; bi < body.blocks().size(); bi++)  {
 93             Block block = body.blocks().get(bi);
 94             spirvBlock = blockMap.get(block);
 95             List<Op> ops = block.ops();
 96             for (int i = (bi == 0 ? paramCount : 0); i < ops.size(); i++) {
 97                 if (bi > 0) spirvBlock = blockMap.get(block);
 98                 Op op = ops.get(i);
 99                 if (op instanceof CoreOp.VarOp jvop) {
100                     TypeElement resultType = new PointerType(jvop.varType(), StorageType.CROSSWORKGROUP);
101                     SpirvOps.VariableOp svop = new SpirvOps.VariableOp((String)jvop.attributes().get(""), resultType, jvop.varType());
102                     bodyBuilder.entryBlock().op(svop);
103                     valueMap.put(jvop.result(), svop.result());
104                     varOps.add(svop);
105                 }
106             }
107         }
108         for (int bi = 0; bi < body.blocks().size(); bi++)  {
109             Block block = body.blocks().get(bi);
110             spirvBlock = blockMap.get(block);
111             for (Op op : block.ops()) {
112                 switch (op) {
113                     case CoreOp.ReturnOp rop -> {
114                         spirvBlock.op(new SpirvOps.ReturnOp(rop.resultType(), mapOperands(rop)));
115                     }
116                     case CoreOp.VarOp vop -> {
117                         Value dest = valueMap.get(vop.result());
118                         Value value = valueMap.get(vop.operands().get(0));
119                         // init variable here; declaration has been moved to top of function
120                         spirvBlock.op(new SpirvOps.StoreOp(dest, value));
121                     }
122                     case CoreOp.VarAccessOp.VarLoadOp vlo -> {
123                         List<Value> operands = mapOperands(vlo);
124                         SpirvOps.LoadOp load = new SpirvOps.LoadOp(vlo.resultType(), operands);
125                         spirvBlock.op(load);
126                         valueMap.put(vlo.result(), load.result());
127                     }
128                     case CoreOp.VarAccessOp.VarStoreOp vso -> {
129                         Value dest = valueMap.get(vso.varOp().result());
130                         Value value = valueMap.get(vso.operands().get(1));
131                         spirvBlock.op(new SpirvOps.StoreOp(dest, value));
132                     }
133                     case CoreOp.ArrayAccessOp.ArrayLoadOp alo -> {
134                         Value array = valueMap.get(alo.operands().get(0));
135                         Value index = valueMap.get(alo.operands().get(1));
136                         TypeElement arrayType = array.type();
137                         SpirvOps.ConvertOp convert = new SpirvOps.ConvertOp(JavaType.type(long.class), List.of(index));
138                         spirvBlock.op(new SpirvOps.LoadOp(arrayType, List.of(array)));
139                         spirvBlock.op(convert);
140                         SpirvOp ibac = new SpirvOps.InBoundAccessChainOp(arrayType, List.of(array, convert.result()));
141                         spirvBlock.op(ibac);
142                         SpirvOp load = new SpirvOps.LoadOp(alo.resultType(), List.of(ibac.result()));
143                         spirvBlock.op(load);
144                         valueMap.put(alo.result(), load.result());
145                     }
146                     case CoreOp.ArrayAccessOp.ArrayStoreOp aso -> {
147                         Value array = valueMap.get(aso.operands().get(0));
148                         Value index = valueMap.get(aso.operands().get(1));
149                         TypeElement arrayType = array.type();
150                         SpirvOp ibac = new SpirvOps.InBoundAccessChainOp(arrayType, List.of(array, index));
151                         spirvBlock.op(ibac);
152                         spirvBlock.op(new SpirvOps.StoreOp(ibac.result(), valueMap.get(aso.operands().get(2))));
153                     }
154                     case CoreOp.ArrayLengthOp alo -> {
155                         Op len = new SpirvOps.ArrayLengthOp(JavaType.INT, List.of(valueMap.get(alo.operands().get(0))));
156                         spirvBlock.op(len);
157                         valueMap.put(alo.result(), len.result());
158                     }
159                     case CoreOp.AddOp aop -> {
160                         TypeElement type = aop.operands().get(0).type();
161                         List<Value> operands = mapOperands(aop);
162                         SpirvOp addOp;
163                         if (isIntegerType(type)) addOp = new SpirvOps.IAddOp(type, operands);
164                         else if (isFloatType(type)) addOp = new SpirvOps.FAddOp(type, operands);
165                         else throw unsupported("type", type);
166                         spirvBlock.op(addOp);
167                         valueMap.put(aop.result(), addOp.result());
168                      }
169                     case CoreOp.SubOp sop -> {
170                         TypeElement  type = sop.operands().get(0).type();
171                         List<Value> operands = mapOperands(sop);
172                         SpirvOp subOp;
173                         if (isIntegerType(type)) subOp = new SpirvOps.ISubOp(type, operands);
174                         else if (isFloatType(type)) subOp = new SpirvOps.FSubOp(type, operands);
175                         else throw unsupported("type", type);
176                         spirvBlock.op(subOp);
177                         valueMap.put(sop.result(), subOp.result());
178                      }
179                     case CoreOp.MulOp mop -> {
180                         TypeElement type = mop.operands().get(0).type();
181                         List<Value> operands = mapOperands(mop);
182                         SpirvOp mulOp;
183                         if (isIntegerType(type)) mulOp = new SpirvOps.IMulOp(type, operands);
184                         else if (isFloatType(type)) mulOp = new SpirvOps.FMulOp(type, operands);
185                         else throw unsupported("type", type);
186                         spirvBlock.op(mulOp);
187                         valueMap.put(mop.result(), mulOp.result());
188                     }
189                     case CoreOp.DivOp dop -> {
190                         TypeElement type = dop.operands().get(0).type();
191                         List<Value> operands = mapOperands(dop);
192                         SpirvOp divOp;
193                         if (isIntegerType(type)) divOp = new SpirvOps.IDivOp(type, operands);
194                         else if (isFloatType(type)) divOp = new SpirvOps.FDivOp(type, operands);
195                         else throw unsupported("type", type);
196                         spirvBlock.op(divOp);
197                         valueMap.put(dop.result(), divOp.result());
198                     }
199                     case CoreOp.ModOp mop -> {
200                         TypeElement type = mop.operands().get(0).type();
201                         List<Value> operands = mapOperands(mop);
202                         SpirvOp modOp = new SpirvOps.ModOp(type, operands);
203                         spirvBlock.op(modOp);
204                         valueMap.put(mop.result(), modOp.result());
205                     }
206                     case CoreOp.EqOp eqop -> {
207                         TypeElement type = eqop.operands().get(0).type();
208                         List<Value> operands = mapOperands(eqop);
209                         SpirvOp seqop;
210                         if (isIntegerType(type)) seqop = new SpirvOps.IEqualOp(type, operands);
211                         else if (isFloatType(type)) seqop = new SpirvOps.FEqualOp(type, operands);
212                         else throw unsupported("type", type);
213                         spirvBlock.op(seqop);
214                         valueMap.put(eqop.result(), seqop.result());
215                     }
216                     case CoreOp.NeqOp neqop -> {
217                         TypeElement type = neqop.operands().get(0).type();
218                         List<Value> operands = mapOperands(neqop);
219                         SpirvOp sneqop;
220                         if (isIntegerType(type)) sneqop = new SpirvOps.INotEqualOp(type, operands);
221                         else if (isFloatType(type)) sneqop = new SpirvOps.FNotEqualOp(type, operands);
222                         else throw unsupported("type", type);
223                         spirvBlock.op(sneqop);
224                         valueMap.put(neqop.result(), sneqop.result());
225                     }
226                     case CoreOp.LtOp ltop -> {
227                         TypeElement type = ltop.operands().get(0).type();
228                         List<Value> operands = mapOperands(ltop);
229                         SpirvOp sltop = new SpirvOps.LtOp(type, operands);
230                         spirvBlock.op(sltop);
231                         valueMap.put(ltop.result(), sltop.result());
232                     }
233                     case CoreOp.InvokeOp inv -> {
234                         List<Value> operands = mapOperands(inv);
235                         SpirvOp spirvCall = new SpirvOps.CallOp(inv.invokeDescriptor(), operands);
236                         spirvBlock.op(spirvCall);
237                         valueMap.put(inv.result(), spirvCall.result());
238                     }
239                     case CoreOp.ConstantOp cop -> {
240                         SpirvOp scop = new SpirvOps.ConstantOp(cop.resultType(), cop.value());
241                         spirvBlock.op(scop);
242                         valueMap.put(cop.result(), scop.result());
243                     }
244                     case CoreOp.ConvOp cop -> {
245                         List<Value> operands = mapOperands(cop);
246                         SpirvOp scop = new SpirvOps.ConvertOp(cop.resultType(), operands);
247                         spirvBlock.op(scop);
248                         valueMap.put(cop.result(), scop.result());
249                     }
250                     case CoreOp.FieldAccessOp.FieldLoadOp flo -> {
251                         SpirvOp load = new SpirvOps.FieldLoadOp(flo.resultType(), flo.fieldDescriptor(), mapOperands(flo));
252                         spirvBlock.op(load);
253                         valueMap.put(flo.result(), load.result());
254                     }
255                     case CoreOp.BranchOp bop -> {
256                         Block.Reference successor = blockMap.get(bop.branch().targetBlock()).successor();
257                         spirvBlock.op(new SpirvOps.BranchOp(successor));
258                     }
259                     case CoreOp.ConditionalBranchOp cbop -> {
260                         Block trueBlock = cbop.trueBranch().targetBlock();
261                         Block falseBlock = cbop.falseBranch().targetBlock();
262                         Block.Reference spvTrueBlock = blockMap.get(trueBlock).successor();
263                         Block.Reference spvFalseBlock = blockMap.get(falseBlock).successor();
264                         spirvBlock.op(new SpirvOps.ConditionalBranchOp(spvTrueBlock, spvFalseBlock, mapOperands(cbop)));
265                     }
266                     default -> unsupported("op", op.getClass());
267                 }
268             } //ops
269         } // blocks
270         return bodyBuilder;
271     }
272 
273     private RuntimeException unsupported(String message, Object value) {
274         return new RuntimeException("Unsupported " + message + ": " + value);
275     }
276 
277     private static CoreOp.FuncOp lowerMethod(CoreOp.FuncOp fop) {
278         CoreOp.FuncOp lfop = fop.transform((block, op) -> {
279             if (op instanceof Op.Lowerable lop)  {
280                 return lop.lower(block);
281             }
282             else {
283                 block.op(op);
284                 return block;
285             }
286         });
287         return lfop;
288     }
289 
290     private List<Value> mapOperands(Op op) {
291         List<Value> operands = new ArrayList<>();
292         for (Value javaValue : op.operands()) {
293             Value spirvValue = valueMap.get(javaValue);
294             assert spirvValue != null : "no value mapping from %s" + javaValue;
295             operands.add(spirvValue);
296         }
297         return operands;
298     }
299 
300     private boolean isIntegerType(TypeElement type) {
301         return type.equals(JavaType.INT) || type.equals(JavaType.LONG);
302     }
303 
304     private boolean isFloatType(TypeElement type) {
305         return type.equals(JavaType.FLOAT) || type.equals(JavaType.DOUBLE);
306     }
307 }