1 /* 2 * Copyright (c) 2024 Intel Corporation. All rights reserved. 3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. 4 * 5 * This code is free software; you can redistribute it and/or modify it 6 * under the terms of the GNU General Public License version 2 only, as 7 * published by the Free Software Foundation. Oracle designates this 8 * particular file as subject to the "Classpath" exception as provided 9 * by Oracle in the LICENSE file that accompanied this code. 10 * 11 * This code is distributed in the hope that it will be useful, but WITHOUT 12 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or 13 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License 14 * version 2 for more details (a copy is included in the LICENSE file that 15 * accompanied this code). 16 * 17 * You should have received a copy of the GNU General Public License version 18 * 2 along with this work; if not, write to the Free Software Foundation, 19 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. 20 * 21 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA 22 * or visit www.oracle.com if you need additional information or have any 23 * questions. 24 */ 25 26 package intel.code.spirv; 27 28 import java.util.List; 29 import java.util.ArrayList; 30 import java.util.Map; 31 import java.util.HashMap; 32 import java.lang.reflect.code.Block; 33 import java.lang.reflect.code.Body; 34 import java.lang.reflect.code.op.CoreOp; 35 import java.lang.reflect.code.Op; 36 import java.lang.reflect.code.Value; 37 import java.lang.reflect.code.TypeElement; 38 import java.lang.reflect.code.type.JavaType; 39 40 public class TranslateToSpirvModel { 41 private Map<Block, Block.Builder> blockMap; // Java block to spirv block builder 42 private Map<Value, Value> valueMap; // Java model Value to Spirv model Value 43 44 public static SpirvOps.FuncOp translateFunction(CoreOp.FuncOp func) { 45 CoreOp.FuncOp lowFunc = lowerMethod(func); 46 TranslateToSpirvModel instance = new TranslateToSpirvModel(); 47 Body.Builder bodyBuilder = instance.translateBody(lowFunc.body(), lowFunc, null); 48 return new SpirvOps.FuncOp(lowFunc.funcName(), lowFunc.invokableType(), bodyBuilder); 49 } 50 51 private TranslateToSpirvModel() { 52 blockMap = new HashMap<>(); 53 valueMap = new HashMap<>(); 54 } 55 56 private Body.Builder translateBody(Body body, Op parentOp, Body.Builder parentBody) { 57 Body.Builder bodyBuilder = Body.Builder.of(parentBody, body.bodyType()); 58 Block.Builder spirvBlock = bodyBuilder.entryBlock(); 59 blockMap.put(body.entryBlock(), spirvBlock); 60 List<Block> blocks = body.blocks(); 61 // map Java blocks to spirv blocks 62 for (Block b : blocks.subList(1, blocks.size())) { 63 Block.Builder loweredBlock = spirvBlock.block(b.parameterTypes()); 64 blockMap.put(b, loweredBlock); 65 spirvBlock = loweredBlock; 66 } 67 // map entry block parameters to spirv function parameter 68 spirvBlock = bodyBuilder.entryBlock(); 69 List<SpirvOp> paramOps = new ArrayList<>(); 70 List<SpirvOps.VariableOp> varOps = new ArrayList<>(); 71 Block entryBlock = body.entryBlock(); 72 int paramCount = entryBlock.parameters().size(); 73 for (int i = 0; i < paramCount; i++) { 74 Block.Parameter bp = entryBlock.parameters().get(i); 75 assert entryBlock.ops().get(i) instanceof CoreOp.VarOp; 76 SpirvOp funcParam = new SpirvOps.FunctionParameterOp(bp.type(), List.of()); 77 spirvBlock.op(funcParam); 78 valueMap.put(bp, funcParam.result()); 79 paramOps.add(funcParam); 80 } 81 // SPIR-V Variable ops must be the first ops in a function's entry block and do not include initialization. 82 // Emit all SPIR-V Variable ops first and emit initializing stores afterward, at the CR model VarOp position. 83 for (int i = 0; i < paramCount; i++) { 84 CoreOp.VarOp jvop = (CoreOp.VarOp)entryBlock.ops().get(i); 85 TypeElement resultType = new PointerType(jvop.varType(), StorageType.CROSSWORKGROUP); 86 SpirvOps.VariableOp svop = new SpirvOps.VariableOp((String)jvop.attributes().get(""), resultType, jvop.varType()); 87 spirvBlock.op(svop); 88 valueMap.put(jvop.result(), svop.result()); 89 varOps.add(svop); 90 } 91 // add non-function-parameter variables 92 for (int bi = 0; bi < body.blocks().size(); bi++) { 93 Block block = body.blocks().get(bi); 94 spirvBlock = blockMap.get(block); 95 List<Op> ops = block.ops(); 96 for (int i = (bi == 0 ? paramCount : 0); i < ops.size(); i++) { 97 if (bi > 0) spirvBlock = blockMap.get(block); 98 Op op = ops.get(i); 99 if (op instanceof CoreOp.VarOp jvop) { 100 TypeElement resultType = new PointerType(jvop.varType(), StorageType.CROSSWORKGROUP); 101 SpirvOps.VariableOp svop = new SpirvOps.VariableOp((String)jvop.attributes().get(""), resultType, jvop.varType()); 102 bodyBuilder.entryBlock().op(svop); 103 valueMap.put(jvop.result(), svop.result()); 104 varOps.add(svop); 105 } 106 } 107 } 108 for (int bi = 0; bi < body.blocks().size(); bi++) { 109 Block block = body.blocks().get(bi); 110 spirvBlock = blockMap.get(block); 111 for (Op op : block.ops()) { 112 switch (op) { 113 case CoreOp.ReturnOp rop -> { 114 spirvBlock.op(new SpirvOps.ReturnOp(rop.resultType(), mapOperands(rop))); 115 } 116 case CoreOp.VarOp vop -> { 117 Value dest = valueMap.get(vop.result()); 118 Value value = valueMap.get(vop.operands().get(0)); 119 // init variable here; declaration has been moved to top of function 120 spirvBlock.op(new SpirvOps.StoreOp(dest, value)); 121 } 122 case CoreOp.VarAccessOp.VarLoadOp vlo -> { 123 List<Value> operands = mapOperands(vlo); 124 SpirvOps.LoadOp load = new SpirvOps.LoadOp(vlo.resultType(), operands); 125 spirvBlock.op(load); 126 valueMap.put(vlo.result(), load.result()); 127 } 128 case CoreOp.VarAccessOp.VarStoreOp vso -> { 129 Value dest = valueMap.get(vso.varOp().result()); 130 Value value = valueMap.get(vso.operands().get(1)); 131 spirvBlock.op(new SpirvOps.StoreOp(dest, value)); 132 } 133 case CoreOp.ArrayAccessOp.ArrayLoadOp alo -> { 134 Value array = valueMap.get(alo.operands().get(0)); 135 Value index = valueMap.get(alo.operands().get(1)); 136 TypeElement arrayType = array.type(); 137 SpirvOps.ConvertOp convert = new SpirvOps.ConvertOp(JavaType.type(long.class), List.of(index)); 138 spirvBlock.op(new SpirvOps.LoadOp(arrayType, List.of(array))); 139 spirvBlock.op(convert); 140 SpirvOp ibac = new SpirvOps.InBoundAccessChainOp(arrayType, List.of(array, convert.result())); 141 spirvBlock.op(ibac); 142 SpirvOp load = new SpirvOps.LoadOp(alo.resultType(), List.of(ibac.result())); 143 spirvBlock.op(load); 144 valueMap.put(alo.result(), load.result()); 145 } 146 case CoreOp.ArrayAccessOp.ArrayStoreOp aso -> { 147 Value array = valueMap.get(aso.operands().get(0)); 148 Value index = valueMap.get(aso.operands().get(1)); 149 TypeElement arrayType = array.type(); 150 SpirvOp ibac = new SpirvOps.InBoundAccessChainOp(arrayType, List.of(array, index)); 151 spirvBlock.op(ibac); 152 spirvBlock.op(new SpirvOps.StoreOp(ibac.result(), valueMap.get(aso.operands().get(2)))); 153 } 154 case CoreOp.ArrayLengthOp alo -> { 155 Op len = new SpirvOps.ArrayLengthOp(JavaType.INT, List.of(valueMap.get(alo.operands().get(0)))); 156 spirvBlock.op(len); 157 valueMap.put(alo.result(), len.result()); 158 } 159 case CoreOp.AddOp aop -> { 160 TypeElement type = aop.operands().get(0).type(); 161 List<Value> operands = mapOperands(aop); 162 SpirvOp addOp; 163 if (isIntegerType(type)) addOp = new SpirvOps.IAddOp(type, operands); 164 else if (isFloatType(type)) addOp = new SpirvOps.FAddOp(type, operands); 165 else throw unsupported("type", type); 166 spirvBlock.op(addOp); 167 valueMap.put(aop.result(), addOp.result()); 168 } 169 case CoreOp.SubOp sop -> { 170 TypeElement type = sop.operands().get(0).type(); 171 List<Value> operands = mapOperands(sop); 172 SpirvOp subOp; 173 if (isIntegerType(type)) subOp = new SpirvOps.ISubOp(type, operands); 174 else if (isFloatType(type)) subOp = new SpirvOps.FSubOp(type, operands); 175 else throw unsupported("type", type); 176 spirvBlock.op(subOp); 177 valueMap.put(sop.result(), subOp.result()); 178 } 179 case CoreOp.MulOp mop -> { 180 TypeElement type = mop.operands().get(0).type(); 181 List<Value> operands = mapOperands(mop); 182 SpirvOp mulOp; 183 if (isIntegerType(type)) mulOp = new SpirvOps.IMulOp(type, operands); 184 else if (isFloatType(type)) mulOp = new SpirvOps.FMulOp(type, operands); 185 else throw unsupported("type", type); 186 spirvBlock.op(mulOp); 187 valueMap.put(mop.result(), mulOp.result()); 188 } 189 case CoreOp.DivOp dop -> { 190 TypeElement type = dop.operands().get(0).type(); 191 List<Value> operands = mapOperands(dop); 192 SpirvOp divOp; 193 if (isIntegerType(type)) divOp = new SpirvOps.IDivOp(type, operands); 194 else if (isFloatType(type)) divOp = new SpirvOps.FDivOp(type, operands); 195 else throw unsupported("type", type); 196 spirvBlock.op(divOp); 197 valueMap.put(dop.result(), divOp.result()); 198 } 199 case CoreOp.ModOp mop -> { 200 TypeElement type = mop.operands().get(0).type(); 201 List<Value> operands = mapOperands(mop); 202 SpirvOp modOp = new SpirvOps.ModOp(type, operands); 203 spirvBlock.op(modOp); 204 valueMap.put(mop.result(), modOp.result()); 205 } 206 case CoreOp.EqOp eqop -> { 207 TypeElement type = eqop.operands().get(0).type(); 208 List<Value> operands = mapOperands(eqop); 209 SpirvOp seqop; 210 if (isIntegerType(type)) seqop = new SpirvOps.IEqualOp(type, operands); 211 else if (isFloatType(type)) seqop = new SpirvOps.FEqualOp(type, operands); 212 else throw unsupported("type", type); 213 spirvBlock.op(seqop); 214 valueMap.put(eqop.result(), seqop.result()); 215 } 216 case CoreOp.NeqOp neqop -> { 217 TypeElement type = neqop.operands().get(0).type(); 218 List<Value> operands = mapOperands(neqop); 219 SpirvOp sneqop; 220 if (isIntegerType(type)) sneqop = new SpirvOps.INotEqualOp(type, operands); 221 else if (isFloatType(type)) sneqop = new SpirvOps.FNotEqualOp(type, operands); 222 else throw unsupported("type", type); 223 spirvBlock.op(sneqop); 224 valueMap.put(neqop.result(), sneqop.result()); 225 } 226 case CoreOp.LtOp ltop -> { 227 TypeElement type = ltop.operands().get(0).type(); 228 List<Value> operands = mapOperands(ltop); 229 SpirvOp sltop = new SpirvOps.LtOp(type, operands); 230 spirvBlock.op(sltop); 231 valueMap.put(ltop.result(), sltop.result()); 232 } 233 case CoreOp.InvokeOp inv -> { 234 List<Value> operands = mapOperands(inv); 235 SpirvOp spirvCall = new SpirvOps.CallOp(inv.invokeDescriptor(), operands); 236 spirvBlock.op(spirvCall); 237 valueMap.put(inv.result(), spirvCall.result()); 238 } 239 case CoreOp.ConstantOp cop -> { 240 SpirvOp scop = new SpirvOps.ConstantOp(cop.resultType(), cop.value()); 241 spirvBlock.op(scop); 242 valueMap.put(cop.result(), scop.result()); 243 } 244 case CoreOp.ConvOp cop -> { 245 List<Value> operands = mapOperands(cop); 246 SpirvOp scop = new SpirvOps.ConvertOp(cop.resultType(), operands); 247 spirvBlock.op(scop); 248 valueMap.put(cop.result(), scop.result()); 249 } 250 case CoreOp.FieldAccessOp.FieldLoadOp flo -> { 251 SpirvOp load = new SpirvOps.FieldLoadOp(flo.resultType(), flo.fieldDescriptor(), mapOperands(flo)); 252 spirvBlock.op(load); 253 valueMap.put(flo.result(), load.result()); 254 } 255 case CoreOp.BranchOp bop -> { 256 Block.Reference successor = blockMap.get(bop.branch().targetBlock()).successor(); 257 spirvBlock.op(new SpirvOps.BranchOp(successor)); 258 } 259 case CoreOp.ConditionalBranchOp cbop -> { 260 Block trueBlock = cbop.trueBranch().targetBlock(); 261 Block falseBlock = cbop.falseBranch().targetBlock(); 262 Block.Reference spvTrueBlock = blockMap.get(trueBlock).successor(); 263 Block.Reference spvFalseBlock = blockMap.get(falseBlock).successor(); 264 spirvBlock.op(new SpirvOps.ConditionalBranchOp(spvTrueBlock, spvFalseBlock, mapOperands(cbop))); 265 } 266 default -> unsupported("op", op.getClass()); 267 } 268 } //ops 269 } // blocks 270 return bodyBuilder; 271 } 272 273 private RuntimeException unsupported(String message, Object value) { 274 return new RuntimeException("Unsupported " + message + ": " + value); 275 } 276 277 private static CoreOp.FuncOp lowerMethod(CoreOp.FuncOp fop) { 278 CoreOp.FuncOp lfop = fop.transform((block, op) -> { 279 if (op instanceof Op.Lowerable lop) { 280 return lop.lower(block); 281 } 282 else { 283 block.op(op); 284 return block; 285 } 286 }); 287 return lfop; 288 } 289 290 private List<Value> mapOperands(Op op) { 291 List<Value> operands = new ArrayList<>(); 292 for (Value javaValue : op.operands()) { 293 Value spirvValue = valueMap.get(javaValue); 294 assert spirvValue != null : "no value mapping from %s" + javaValue; 295 operands.add(spirvValue); 296 } 297 return operands; 298 } 299 300 private boolean isIntegerType(TypeElement type) { 301 return type.equals(JavaType.INT) || type.equals(JavaType.LONG); 302 } 303 304 private boolean isFloatType(TypeElement type) { 305 return type.equals(JavaType.FLOAT) || type.equals(JavaType.DOUBLE); 306 } 307 }