1 /* 2 * Copyright (c) 2024 Intel Corporation. All rights reserved. 3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. 4 * 5 * This code is free software; you can redistribute it and/or modify it 6 * under the terms of the GNU General Public License version 2 only, as 7 * published by the Free Software Foundation. Oracle designates this 8 * particular file as subject to the "Classpath" exception as provided 9 * by Oracle in the LICENSE file that accompanied this code. 10 * 11 * This code is distributed in the hope that it will be useful, but WITHOUT 12 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or 13 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License 14 * version 2 for more details (a copy is included in the LICENSE file that 15 * accompanied this code). 16 * 17 * You should have received a copy of the GNU General Public License version 18 * 2 along with this work; if not, write to the Free Software Foundation, 19 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. 20 * 21 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA 22 * or visit www.oracle.com if you need additional information or have any 23 * questions. 24 */ 25 26 package intel.code.spirv; 27 28 import java.util.List; 29 import java.util.ArrayList; 30 import java.util.Arrays; 31 import java.util.Map; 32 import java.util.HashMap; 33 import java.util.Set; 34 import java.util.HashSet; 35 import java.util.Optional; 36 import java.util.function.Function; 37 import java.io.IOException; 38 import java.io.File; 39 import java.io.FileOutputStream; 40 import java.io.ByteArrayInputStream; 41 import java.io.ByteArrayOutputStream; 42 import java.io.PrintStream; 43 import java.nio.ByteBuffer; 44 import java.nio.ByteOrder; 45 import java.nio.channels.FileChannel; 46 import java.math.BigInteger; 47 import java.lang.invoke.MethodHandles; 48 import java.lang.foreign.MemorySegment; 49 import java.lang.foreign.ValueLayout; 50 import jdk.incubator.code.Block; 51 import jdk.incubator.code.Body; 52 import jdk.incubator.code.Op; 53 import jdk.incubator.code.Value; 54 import jdk.incubator.code.op.CoreOp; 55 import jdk.incubator.code.TypeElement; 56 import jdk.incubator.code.type.MethodRef; 57 import jdk.incubator.code.type.ClassType; 58 import jdk.incubator.code.type.JavaType; 59 import jdk.incubator.code.type.FunctionType; 60 import hat.callgraph.CallGraph; 61 import hat.callgraph.KernelCallGraph; 62 import hat.callgraph.KernelEntrypoint; 63 import uk.ac.manchester.beehivespirvtoolkit.lib.SPIRVHeader; 64 import uk.ac.manchester.beehivespirvtoolkit.lib.SPIRVModule; 65 import uk.ac.manchester.beehivespirvtoolkit.lib.SPIRVFunction; 66 import uk.ac.manchester.beehivespirvtoolkit.lib.SPIRVBlock; 67 import uk.ac.manchester.beehivespirvtoolkit.lib.instructions.*; 68 import uk.ac.manchester.beehivespirvtoolkit.lib.instructions.operands.*; 69 import uk.ac.manchester.beehivespirvtoolkit.lib.disassembler.Disassembler; 70 import uk.ac.manchester.beehivespirvtoolkit.lib.disassembler.SPIRVDisassemblerOptions; 71 import uk.ac.manchester.beehivespirvtoolkit.lib.disassembler.SPVByteStreamReader; 72 import intel.code.spirv.SpirvOp.PhiOp; 73 74 public class SpirvModuleGenerator { 75 private final String moduleName; 76 private final SPIRVModule module; 77 private final Symbols symbols; 78 79 public static SpirvModuleGenerator create(String moduleName) { 80 return new SpirvModuleGenerator(moduleName); 81 } 82 83 public static MemorySegment generateModule(String moduleName, KernelCallGraph callGraph) { 84 SpirvModuleGenerator generator = SpirvModuleGenerator.create(moduleName); 85 for (CallGraph.MethodCall call : callGraph.calls) { 86 if (call.targetMethodRef != null) { 87 try { 88 Optional<CoreOp.FuncOp> ofo = call.targetMethodRef.codeModel(MethodHandles.lookup()); 89 if (ofo.isPresent()) { 90 CoreOp.FuncOp fo = ofo.get(); 91 SpirvOp.FuncOp spirvFunc = TranslateToSpirvModel.translateFunction(fo); 92 } 93 } catch (Exception e) { 94 throw new RuntimeException(e); 95 } 96 } 97 } 98 KernelEntrypoint kernelEntrypoint = callGraph.entrypoint; 99 CoreOp.FuncOp funcOp = kernelEntrypoint.funcOpWrapper().op(); 100 String kernelName = funcOp.funcName(); 101 SpirvOp.FuncOp spirvFunc = TranslateToSpirvModel.translateFunction(funcOp); 102 generator.generateFunction(funcOp.funcName(), spirvFunc, true); 103 return generator.finalizeModule(); 104 } 105 106 public static MemorySegment generateModule(String moduleName, CoreOp.FuncOp func) { 107 SpirvOp.FuncOp spirvFunc = TranslateToSpirvModel.translateFunction(func); 108 MemorySegment moduleSegment = SpirvModuleGenerator.generateModule(moduleName, spirvFunc); 109 return moduleSegment; 110 } 111 112 public static MemorySegment generateModule(String moduleName, SpirvOp.FuncOp func) { 113 SpirvModuleGenerator generator = new SpirvModuleGenerator(moduleName); 114 MemorySegment moduleSegment = generator.generateModuleInternal(func); 115 return moduleSegment; 116 } 117 118 public static void writeModuleToFile(MemorySegment module, String filepath) { 119 ByteBuffer buffer = module.asByteBuffer(); 120 File out = new File(filepath); 121 try (FileChannel channel = new FileOutputStream(out, false).getChannel()) { 122 channel.write(buffer); 123 channel.close(); 124 } 125 catch (IOException e) { 126 throw new RuntimeException(e); 127 } 128 } 129 130 private static void writeModuleToFile(SPIRVModule module, String filepath) 131 { 132 ByteBuffer buffer = ByteBuffer.allocate(module.getByteCount()); 133 buffer.order(ByteOrder.LITTLE_ENDIAN); 134 module.close().write(buffer); 135 buffer.flip(); 136 File out = new File(filepath); 137 try (FileChannel channel = new FileOutputStream(out, false).getChannel()) 138 { 139 channel.write(buffer); 140 } 141 catch (IOException e) 142 { 143 throw new RuntimeException(e); 144 } 145 } 146 147 public static String disassembleModule(MemorySegment module) { 148 SPVByteStreamReader input = new SPVByteStreamReader(new ByteArrayInputStream(module.toArray(ValueLayout.JAVA_BYTE))); 149 ByteArrayOutputStream out = new ByteArrayOutputStream(); 150 try (PrintStream ps = new PrintStream(out)) { 151 SPIRVDisassemblerOptions options = new SPIRVDisassemblerOptions(false, false, false, false, true); 152 Disassembler dis = new Disassembler(input, ps, options); 153 dis.run(); 154 } 155 catch (Exception e) { 156 throw new RuntimeException(e); 157 } 158 return new String(out.toByteArray()); 159 } 160 161 public MemorySegment finalizeModule() { 162 ByteBuffer buffer = ByteBuffer.allocateDirect(module.getByteCount()); 163 buffer.order(ByteOrder.LITTLE_ENDIAN); 164 module.close().write(buffer); 165 buffer.flip(); 166 return MemorySegment.ofBuffer(buffer); 167 } 168 169 private record SpirvResult(SPIRVId type, SPIRVId address, SPIRVId value) {} 170 171 private SpirvModuleGenerator(String moduleName) { 172 this.moduleName = moduleName; 173 this.module = new SPIRVModule(new SPIRVHeader(1, 2, 32, 0, 0)); 174 this.symbols = new Symbols(); 175 initModule(); 176 } 177 178 private MemorySegment generateModuleInternal(SpirvOp.FuncOp func) { 179 generateFunction(moduleName, func, true); 180 return finalizeModule(); 181 } 182 183 private SPIRVId generateFunction(String fnName, SpirvOp.FuncOp func, boolean isEntryPoint) { 184 TypeElement returnType = func.invokableType().returnType(); 185 SPIRVId functionId = nextId(fnName); 186 String signature = func.invokableType().returnType().toString(); 187 List<TypeElement> paramTypes = func.invokableType().parameterTypes(); 188 // build signature string 189 for (int i = 0; i < paramTypes.size(); i++) { 190 signature += "_" + paramTypes.get(i).toString(); 191 } 192 // declare function type if not already present 193 SPIRVId functionSig = getIdOrNull(signature); 194 if (functionSig == null) { 195 SPIRVId[] typeIdsArray = new SPIRVId[paramTypes.size()]; 196 for (int i = 0; i < paramTypes.size(); i++) { 197 typeIdsArray[i] = spirvType(paramTypes.get(i).toString()); 198 } 199 functionSig = nextId(fnName + "Signature"); 200 module.add(new SPIRVOpTypeFunction(functionSig, spirvType(returnType.toString()), new SPIRVMultipleOperands<>(typeIdsArray))); 201 addId(signature, functionSig); 202 } 203 SPIRVId spirvReturnType = spirvType(returnType.toString()); 204 SPIRVFunction function = (SPIRVFunction)module.add(new SPIRVOpFunction(spirvReturnType, functionId, SPIRVFunctionControl.DontInline(), functionSig)); 205 SPIRVOpLabel entryLabel = new SPIRVOpLabel(nextId()); 206 SPIRVBlock entryBlock = (SPIRVBlock)function.add(entryLabel); 207 SPIRVMultipleOperands<SPIRVId> operands = new SPIRVMultipleOperands<>(getId("globalInvocationId"), getId("globalSize"), getId("subgroupSize"), getId("subgroupId")); 208 if (isEntryPoint) { 209 module.add(new SPIRVOpEntryPoint(SPIRVExecutionModel.Kernel(), functionId, new SPIRVLiteralString(fnName), operands)); 210 } 211 translateBody(func.body(), function, entryBlock); 212 function.add(new SPIRVOpFunctionEnd()); 213 return functionId; 214 } 215 216 private void translateBody(Body body, SPIRVFunction function, SPIRVBlock entryBlock) { 217 int labelNumber = 0; 218 SPIRVBlock spirvBlock = entryBlock; 219 for (int bi = 1; bi < body.blocks().size(); bi++) { 220 Block block = body.blocks().get(bi); 221 SPIRVOpLabel blockLabel = new SPIRVOpLabel(nextId()); 222 SPIRVBlock newBlock = (SPIRVBlock)function.add(blockLabel); 223 symbols.putBlock(block, newBlock); 224 symbols.putLabel(block, blockLabel); 225 for (int i = 0; i < block.parameters().size(); i++) { 226 Value param = block.parameters().get(i); 227 SPIRVId paramId = nextId(); 228 addResult(param, new SpirvResult(spirvType(param.type().toString()), null, paramId)); 229 } 230 } 231 for (int bi = 0; bi < body.blocks().size(); bi++) { 232 Block block = body.blocks().get(bi); 233 if (bi > 0) { 234 spirvBlock = symbols.getBlock(block); 235 } 236 for (Op op : block.ops()) { 237 switch (op) { 238 case SpirvOp.PhiOp phop -> { 239 List<PhiOp.Predecessor> inPredecessors = phop.predecessors(); 240 SPIRVPairIdRefIdRef[] outPredecessors = new SPIRVPairIdRefIdRef[inPredecessors.size()]; 241 for (int i = 0; i < inPredecessors.size(); i++) { 242 PhiOp.Predecessor predecessor = inPredecessors.get(i); 243 SPIRVId label = symbols.getLabel(predecessor.block().targetBlock()).getResultId(); 244 SPIRVId value = getResult(predecessor.value()).value(); 245 outPredecessors[i] = new SPIRVPairIdRefIdRef(value, label); 246 } 247 SPIRVId result = nextId(); 248 SPIRVId type = spirvType(phop.resultType().toString()); 249 SPIRVOpPhi phiOp = new SPIRVOpPhi(spirvType(phop.resultType().toString()), result, new SPIRVMultipleOperands<>(outPredecessors)); 250 spirvBlock.add(phiOp); 251 addResult(phop.result(), new SpirvResult(type, null, result)); 252 } 253 case SpirvOp.VariableOp vop -> { 254 String typeName = vop.varType().toString(); 255 SPIRVId type = spirvType(typeName); 256 SPIRVId varType = spirvVariableType(type); 257 SPIRVId var = nextId(vop.varName()); 258 spirvBlock.add(new SPIRVOpVariable(varType, var, SPIRVStorageClass.Function(), new SPIRVOptionalOperand<>())); 259 addResult(vop.result(), new SpirvResult(varType, var, null)); 260 } 261 case SpirvOp.FunctionParameterOp fpo -> { 262 SPIRVId result = nextId(); 263 SPIRVId type = spirvType(fpo.resultType().toString()); 264 function.add(new SPIRVOpFunctionParameter(type, result)); 265 module.add(new SPIRVOpDecorate(result, SPIRVDecoration.Alignment(new SPIRVLiteralInteger(8)))); 266 addResult(fpo.result(), new SpirvResult(type, null, result)); 267 } 268 case SpirvOp.LoadOp lo -> { 269 SPIRVId type = spirvType(lo.resultType().toString()); 270 SpirvResult toLoad = getResult(lo.operands().get(0)); 271 SPIRVId varAddr = toLoad.address(); 272 SPIRVId result = nextId(); 273 spirvBlock.add(new SPIRVOpLoad(type, result, varAddr, align(type.getName()))); 274 addResult(lo.result(), new SpirvResult(type, varAddr, result)); 275 } 276 case SpirvOp.StoreOp so -> { 277 SpirvResult var = getResult(so.operands().get(0)); 278 SPIRVId varAddr = var.address(); 279 SPIRVId value = getResult(so.operands().get(1)).value(); 280 spirvBlock.add(new SPIRVOpStore(varAddr, value, align(var.type().getName()))); 281 } 282 case SpirvOp.IAddOp _, SpirvOp.FAddOp _ -> { 283 SPIRVId intType = getType("int"); 284 SPIRVId longType = getType("long"); 285 SPIRVId floatType = getType("float"); 286 SPIRVId lhs = getResult(op.operands().get(0)).value(); 287 SPIRVId rhs = getResult(op.operands().get(1)).value(); 288 SPIRVId lhsType = spirvType(op.resultType().toString()); 289 SPIRVId ans = nextId(); 290 if (lhsType == intType) spirvBlock.add(new SPIRVOpIAdd(intType, ans, lhs, rhs)); 291 else if (lhsType == longType) spirvBlock.add(new SPIRVOpIAdd(longType, ans, lhs, rhs)); 292 else if (lhsType == floatType) spirvBlock.add(new SPIRVOpFAdd(floatType, ans, lhs, rhs)); 293 else unsupported("type", lhsType.getName()); 294 addResult(op.result(), new SpirvResult(lhsType, null, ans)); 295 } 296 case SpirvOp.ISubOp _, SpirvOp.FSubOp _ -> { 297 SPIRVId intType = getType("int"); 298 SPIRVId longType = getType("long"); 299 SPIRVId floatType = getType("float"); 300 SPIRVId lhs = getResult(op.operands().get(0)).value(); 301 SPIRVId rhs = getResult(op.operands().get(1)).value(); 302 SPIRVId lhsType = spirvType(op.resultType().toString()); 303 SPIRVId ans = nextId(); 304 if (lhsType == intType) spirvBlock.add(new SPIRVOpISub(intType, ans, lhs, rhs)); 305 else if (lhsType == longType) spirvBlock.add(new SPIRVOpISub(longType, ans, lhs, rhs)); 306 else if (lhsType == floatType) spirvBlock.add(new SPIRVOpFSub(floatType, ans, lhs, rhs)); 307 else unsupported("type", lhsType.getName()); 308 addResult(op.result(), new SpirvResult(lhsType, null, ans)); 309 } 310 case SpirvOp.IMulOp _, SpirvOp.FMulOp _, SpirvOp.IDivOp _, SpirvOp.FDivOp _ -> { 311 SPIRVId intType = getType("int"); 312 SPIRVId longType = getType("long"); 313 SPIRVId floatType = getType("float"); 314 SPIRVId lhs = getResult(op.operands().get(0)).value(); 315 SPIRVId rhs = getResult(op.operands().get(1)).value(); 316 SPIRVId lhsType = spirvType(op.resultType().toString()); 317 SPIRVId rhsType = getResult(op.operands().get(1)).type(); 318 SPIRVId ans = nextId(); 319 if (lhsType == intType) { 320 if (op instanceof SpirvOp.IMulOp) spirvBlock.add(new SPIRVOpIMul(intType, ans, lhs, rhs)); 321 else if (op instanceof SpirvOp.IDivOp) spirvBlock.add(new SPIRVOpSDiv(intType, ans, lhs, rhs)); 322 } 323 else if (lhsType == longType) { 324 SPIRVId rhsId = rhsType == intType ? nextId() : rhs; 325 if (rhsType == intType) spirvBlock.add(new SPIRVOpSConvert(longType, rhsId, rhs)); 326 if (op instanceof SpirvOp.IMulOp) spirvBlock.add(new SPIRVOpIMul(longType, ans, lhs, rhsId)); 327 else if (op instanceof SpirvOp.IDivOp) spirvBlock.add(new SPIRVOpSDiv(longType, ans, lhs, rhs)); 328 } 329 else if (lhsType == floatType) { 330 if (op instanceof SpirvOp.FMulOp) spirvBlock.add(new SPIRVOpFMul(floatType, ans, lhs, rhs)); 331 else if (op instanceof SpirvOp.FDivOp) spirvBlock.add(new SPIRVOpFDiv(floatType, ans, lhs, rhs)); 332 } 333 else unsupported("type", lhsType); 334 addResult(op.result(), new SpirvResult(lhsType, null, ans)); 335 } 336 case SpirvOp.ModOp mop -> { 337 SPIRVId type = getType(mop.operands().get(0).type().toString()); 338 SPIRVId lhs = getResult(mop.operands().get(0)).value(); 339 SPIRVId rhs = getResult(mop.operands().get(1)).value(); 340 SPIRVId result = nextId(); 341 spirvBlock.add(new SPIRVOpUMod(type, result, lhs, rhs)); 342 addResult(mop.result(), new SpirvResult(type, null, result)); 343 } 344 case SpirvOp.IEqualOp eqop -> { 345 SPIRVId boolType = getType("bool"); 346 SPIRVId intType = getType("int"); 347 SPIRVId longType = getType("long"); 348 SPIRVId floatType = getType("float"); 349 SPIRVId lhs = getResult(op.operands().get(0)).value(); 350 SPIRVId rhs = getResult(op.operands().get(1)).value(); 351 SPIRVId lhsType = spirvType(op.resultType().toString()); 352 SPIRVId ans = nextId(); 353 if (lhsType == intType) spirvBlock.add(new SPIRVOpIEqual(boolType, ans, lhs, rhs)); 354 else if (lhsType == longType) spirvBlock.add(new SPIRVOpIEqual(boolType, ans, lhs, rhs)); 355 else unsupported("type", lhsType.getName()); 356 addResult(op.result(), new SpirvResult(lhsType, null, ans)); 357 } 358 case SpirvOp.GeOp eqop -> { 359 SPIRVId boolType = getType("bool"); 360 SPIRVId lhs = getResult(op.operands().get(0)).value(); 361 SPIRVId rhs = getResult(op.operands().get(1)).value(); 362 SPIRVId lhsType = spirvType(op.resultType().toString()); 363 SPIRVId ans = nextId(); 364 spirvBlock.add(new SPIRVOpSGreaterThanEqual(boolType, ans, lhs, rhs)); 365 addResult(op.result(), new SpirvResult(lhsType, null, ans)); 366 } 367 case SpirvOp.PtrNotEqualOp neqop -> { 368 SPIRVId boolType = getType("bool"); 369 SPIRVId longType = getType("long"); 370 SPIRVId lhs = getResult(neqop.operands().get(0)).value(); 371 SPIRVId rhs = getResult(neqop.operands().get(1)).value(); 372 SPIRVId ans = nextId(); 373 SPIRVId lhsLong = nextId(); 374 SPIRVId rhsLong = nextId(); 375 spirvBlock.add(new SPIRVOpConvertPtrToU(longType, lhsLong, lhs)); 376 spirvBlock.add(new SPIRVOpConvertPtrToU(longType, rhsLong, rhs)); 377 spirvBlock.add(new SPIRVOpIEqual(boolType, ans, lhsLong, rhsLong)); 378 addResult(op.result(), new SpirvResult(boolType, null, ans)); 379 } 380 case SpirvOp.CallOp call -> { 381 MethodRef methodRef = call.callDescriptor(); 382 if (methodRef.equals(MethodRef.ofString("hat.buffer.S32Array::array(long)int"))) 383 { 384 SPIRVId longType = getType("long"); 385 String arrayTypeName = call.operands().get(0).type().toString(); 386 SpirvResult arrayResult = getResult(call.operands().get(0)); 387 SPIRVId arrayAddr = arrayResult.address(); 388 SPIRVId arrayType = spirvType(arrayTypeName); 389 SPIRVId elementType = spirvElementType(arrayTypeName); 390 int nIndexes = call.operands().size() - 1; 391 SPIRVId indexX = getResult(call.operands().get(1)).value(); 392 SPIRVId array = nextId(); 393 spirvBlock.add(new SPIRVOpLoad(arrayType, array, arrayAddr, align(arrayType.getName()))); 394 SPIRVId temp1 = nextId(); 395 SPIRVId temp2 = nextId(); 396 spirvBlock.add(new SPIRVOpConvertPtrToU(longType, temp1, array)); 397 spirvBlock.add(new SPIRVOpIAdd(longType, temp2, temp1, getConst("long", 4))); 398 SPIRVId elementBase = nextId(); 399 spirvBlock.add(new SPIRVOpConvertUToPtr(arrayType, elementBase, temp2)); 400 SPIRVId resultAddr = nextId(); 401 spirvBlock.add(new SPIRVOpInBoundsPtrAccessChain(arrayType, resultAddr, elementBase, indexX, new SPIRVMultipleOperands<>())); 402 SPIRVId result = nextId(); 403 spirvBlock.add(new SPIRVOpLoad(elementType, result, resultAddr, align(elementType.getName()))); 404 addResult(call.result(), new SpirvResult(elementType, resultAddr, result)); 405 } 406 else if (methodRef.equals(MethodRef.ofString("hat.buffer.S32Array2D::array(long)int")) || 407 methodRef.equals(MethodRef.ofString("hat.buffer.F32Array2D::array(long)float"))) 408 { 409 SPIRVId longType = getType("long"); 410 String arrayTypeName = call.operands().get(0).type().toString(); 411 SpirvResult arrayResult = getResult(call.operands().get(0)); 412 SPIRVId arrayAddr = arrayResult.address(); 413 SPIRVId arrayType = spirvType(arrayTypeName); 414 SPIRVId elementType = spirvElementType(arrayTypeName); 415 int nIndexes = call.operands().size() - 1; 416 SPIRVId indexX = getResult(call.operands().get(1)).value(); 417 SPIRVId array = nextId(); 418 SPIRVId temp1 = nextId(); 419 SPIRVId temp2 = nextId(); 420 spirvBlock.add(new SPIRVOpConvertPtrToU(longType, temp1, array)); 421 spirvBlock.add(new SPIRVOpIAdd(longType, temp2, temp1, getConst("long", 4))); 422 SPIRVId elementBase = nextId(); 423 spirvBlock.add(new SPIRVOpConvertUToPtr(arrayType, elementBase, temp2)); 424 spirvBlock.add(new SPIRVOpLoad(arrayType, array, arrayAddr, align(arrayType.getName()))); 425 SPIRVId resultAddr = nextId(); 426 spirvBlock.add(new SPIRVOpInBoundsPtrAccessChain(arrayType, resultAddr, elementBase, indexX, new SPIRVMultipleOperands<>())); 427 SPIRVId result = nextId(); 428 spirvBlock.add(new SPIRVOpLoad(elementType, result, resultAddr, align(elementType.getName()))); 429 addResult(call.result(), new SpirvResult(elementType, resultAddr, result)); 430 } 431 else if (methodRef.equals(MethodRef.ofString("hat.buffer.S32Array::array(long, int)void"))) { 432 SPIRVId longType = getType("long"); 433 String arrayTypeName = call.operands().get(0).type().toString(); 434 SpirvResult arrayResult = getResult(call.operands().get(0)); 435 SPIRVId arrayAddr = arrayResult.address(); 436 SPIRVId arrayType = spirvType(arrayTypeName); 437 SPIRVId elementType = spirvElementType(arrayTypeName); 438 int nIndexes = call.operands().size() - 2; 439 int valueIndex = nIndexes + 1; 440 SPIRVId indexX = getResult(call.operands().get(1)).value(); 441 SPIRVId array = nextId(); 442 spirvBlock.add(new SPIRVOpLoad(arrayType, array, arrayAddr, align(arrayType.getName()))); 443 SPIRVId temp1 = nextId(); 444 SPIRVId temp2 = nextId(); 445 spirvBlock.add(new SPIRVOpConvertPtrToU(longType, temp1, array)); 446 spirvBlock.add(new SPIRVOpIAdd(longType, temp2, temp1, getConst("long", 4))); 447 SPIRVId elementBase = nextId(); 448 spirvBlock.add(new SPIRVOpConvertUToPtr(arrayType, elementBase, temp2)); 449 SPIRVId dest = nextId(); 450 spirvBlock.add(new SPIRVOpInBoundsPtrAccessChain(arrayType, dest, elementBase, indexX, new SPIRVMultipleOperands<>())); 451 SPIRVId value = getResult(call.operands().get(valueIndex)).value(); 452 spirvBlock.add(new SPIRVOpStore(dest, value, align(elementType.getName()))); 453 } 454 else if (methodRef.equals(MethodRef.ofString("hat.buffer.S32Array2D::array(long, int)void")) || 455 methodRef.equals(MethodRef.ofString("hat.buffer.F32Array2D::array(long, float)void"))) { 456 SPIRVId longType = getType("long"); 457 String arrayTypeName = call.operands().get(0).type().toString(); 458 SpirvResult arrayResult = getResult(call.operands().get(0)); 459 SPIRVId arrayAddr = arrayResult.address(); 460 SPIRVId arrayType = spirvType(arrayTypeName); 461 SPIRVId elementType = spirvElementType(arrayTypeName); 462 int nIndexes = call.operands().size() - 2; 463 int valueIndex = nIndexes + 1; 464 SPIRVId indexX = getResult(call.operands().get(1)).value(); 465 SPIRVId array = nextId(); 466 spirvBlock.add(new SPIRVOpLoad(arrayType, array, arrayAddr, align(arrayType.getName()))); 467 SPIRVId temp1 = nextId(); 468 SPIRVId temp2 = nextId(); 469 spirvBlock.add(new SPIRVOpConvertPtrToU(longType, temp1, array)); 470 spirvBlock.add(new SPIRVOpIAdd(longType, temp2, temp1, getConst("long", 8))); 471 SPIRVId elementBase = nextId(); 472 spirvBlock.add(new SPIRVOpConvertUToPtr(arrayType, elementBase, temp2)); 473 SPIRVId dest = nextId(); 474 spirvBlock.add(new SPIRVOpInBoundsPtrAccessChain(arrayType, dest, elementBase, indexX, new SPIRVMultipleOperands<>())); 475 SPIRVId value = getResult(call.operands().get(valueIndex)).value(); 476 spirvBlock.add(new SPIRVOpStore(dest, value, align(elementType.getName()))); 477 } 478 else if (methodRef.equals(MethodRef.ofString("hat.buffer.S32Array::length()int")) || 479 methodRef.equals(MethodRef.ofString("hat.buffer.S32Array2D::width()int"))|| 480 methodRef.equals(MethodRef.ofString("hat.buffer.F32Array2D::width()int"))) { 481 SPIRVId intType = getType("int"); 482 String arrayTypeName = call.operands().get(0).type().toString(); 483 SpirvResult arrayResult = getResult(call.operands().get(0)); 484 SPIRVId arrayAddr = arrayResult.address(); 485 SPIRVId arrayType = spirvType(arrayTypeName); 486 SPIRVId array = nextId(); 487 spirvBlock.add(new SPIRVOpLoad(arrayType, array, arrayAddr, align(arrayType.getName()))); 488 SPIRVId resultAddr = nextId(); 489 spirvBlock.add(new SPIRVOpInBoundsPtrAccessChain(arrayType, resultAddr, array, getConst("int", 0), new SPIRVMultipleOperands<>())); 490 SPIRVId result = nextId(); 491 spirvBlock.add(new SPIRVOpLoad(intType, result, resultAddr, align(arrayType.getName()))); 492 addResult(call.result(), new SpirvResult(intType, resultAddr, result)); 493 } 494 else if (methodRef.equals(MethodRef.ofString("hat.buffer.S32Array2D::height()int")) || 495 methodRef.equals(MethodRef.ofString("hat.buffer.F32Array2D::height()int"))) { 496 SPIRVId intType = getType("int"); 497 String arrayTypeName = call.operands().get(0).type().toString(); 498 SpirvResult arrayResult = getResult(call.operands().get(0)); 499 SPIRVId arrayAddr = arrayResult.address(); 500 SPIRVId arrayType = spirvType(arrayTypeName); 501 SPIRVId array = nextId(); 502 spirvBlock.add(new SPIRVOpLoad(arrayType, array, arrayAddr, align(arrayType.getName()))); 503 SPIRVId resultAddr = nextId(); 504 spirvBlock.add(new SPIRVOpInBoundsPtrAccessChain(arrayType, resultAddr, array, getConst("int", 1), new SPIRVMultipleOperands<>())); 505 SPIRVId result = nextId(); 506 spirvBlock.add(new SPIRVOpLoad(intType, result, resultAddr, align(arrayType.getName()))); 507 addResult(call.result(), new SpirvResult(intType, resultAddr, result)); 508 } 509 else if (methodRef.equals(MethodRef.ofString("java.lang.Math::sqrt(double)double"))) { 510 SPIRVId floatType = getType("float"); 511 SPIRVId result = nextId(); 512 SPIRVId operand = getResult(call.operands().get(0)).value(); 513 spirvBlock.add(new SPIRVOpExtInst(floatType, result, getId("oclExtension"), new SPIRVLiteralExtInstInteger(61, "sqrt"), new SPIRVMultipleOperands<>(operand))); 514 addResult(call.result(), new SpirvResult(floatType, null, result)); 515 } 516 else { 517 SPIRVId fnId = getFunctionId(methodRef); 518 if (fnId == null) { 519 unsupported("method", methodRef); 520 } 521 else { 522 FunctionType fnType = methodRef.type(); 523 SPIRVId[] args = new SPIRVId[call.operands().size()]; 524 for (int i = 0; i < args.length; i++) { 525 SPIRVId argId = getResult(call.operands().get(i)).value(); 526 args[i] = argId; 527 } 528 SPIRVId returnType = spirvType(fnType.returnType().toString()); 529 SPIRVId callResult = nextId(); 530 spirvBlock.add(new SPIRVOpFunctionCall(returnType, callResult, fnId, new SPIRVMultipleOperands<>(args))); 531 addResult(call.result(), new SpirvResult(returnType, null, callResult)); 532 } 533 } 534 } 535 case SpirvOp.ConstantOp cop -> { 536 SPIRVId type = spirvType(cop.resultType().toString()); 537 SPIRVId result = nextId(); 538 Object value = cop.value(); 539 if (type == getType("int")) { 540 module.add(new SPIRVOpConstant(type, result, new SPIRVContextDependentInt(new BigInteger(String.valueOf(value))))); 541 } 542 else if (type == getType("long")) { 543 module.add(new SPIRVOpConstant(type, result, new SPIRVContextDependentLong(new BigInteger(String.valueOf(value))))); 544 } 545 else if (type == getType("float")) { 546 module.add(new SPIRVOpConstant(type, result, new SPIRVContextDependentFloat((float)value))); 547 } 548 else if (type == getType("bool")) { 549 module.add(((boolean)value) ? new SPIRVOpConstantTrue(type, result) : new SPIRVOpConstantFalse(type, result)); 550 } 551 else if (type == getType("java.lang.Object")) { 552 module.add(new SPIRVOpConstantNull(type, result)); 553 } 554 else if (type == getType("int[]")) { 555 module.add(new SPIRVOpConstantNull(type, result)); 556 } 557 else unsupported("type", cop.resultType()); 558 addResult(cop.result(), new SpirvResult(type, null, result)); 559 } 560 case SpirvOp.ConvertOp scop -> { 561 SPIRVId toType = spirvType(scop.resultType().toString()); 562 SPIRVId to = nextId(); 563 SpirvResult valueResult = getResult(scop.operands().get(0)); 564 SPIRVId from = valueResult.value(); 565 SPIRVId fromType = valueResult.type(); 566 if (isIntegerType(fromType)) { 567 if (isIntegerType(toType)) { 568 spirvBlock.add(new SPIRVOpSConvert(toType, to, from)); 569 } 570 else if (isFloatType(toType)) { 571 spirvBlock.add(new SPIRVOpConvertSToF(toType, to, from)); 572 } 573 else unsupported("conversion type", scop.resultType()); 574 } 575 else if (isFloatType(fromType)) { 576 if (isIntegerType(toType)) { 577 spirvBlock.add(new SPIRVOpConvertFToS(toType, to, from)); 578 } 579 else if (isFloatType(toType)) { 580 spirvBlock.add(new SPIRVOpFConvert(toType, to, from)); 581 } 582 else unsupported("conversion type", scop.resultType()); 583 } 584 else unsupported("conversion type", scop.operands().get(0)); 585 addResult(scop.result(), new SpirvResult(toType, null, to)); 586 } 587 case SpirvOp.InBoundsAccessChainOp iacop -> { 588 SPIRVId type = spirvType(iacop.resultType().toString()); 589 SPIRVId result = nextId(); 590 SPIRVId object = getResult(iacop.operands().get(0)).value(); 591 SPIRVId index = getResult(iacop.operands().get(1)).value(); 592 spirvBlock.add(new SPIRVOpInBoundsPtrAccessChain(type, result, object, index, new SPIRVMultipleOperands<>())); 593 addResult(iacop.result(), new SpirvResult(type, result, null)); 594 } 595 case SpirvOp.FieldLoadOp flo -> { 596 if (flo.operands().size() > 0 && (flo.operands().get(0).type().equals(JavaType.ofString("hat.KernelContext")))) { 597 SpirvResult result; 598 int group = -1; 599 int index = -1; 600 String fieldName = flo.fieldDescriptor().name(); 601 switch (fieldName) { 602 case "x": group = 0; index = 0; break; 603 case "y": group = 0; index = 1; break; 604 case "z": group = 0; index = 2; break; 605 case "maxX": group = 1; index = 0; break; 606 case "maxY": group = 1; index = 1; break; 607 case "maxZ": group = 1; index = 2; break; 608 } 609 switch (group) { 610 case 0: result = globalId(index, spirvBlock); break; 611 case 1: result = globalSize(index, spirvBlock); break; 612 default: throw new RuntimeException("Unknown Index field: " + fieldName); 613 } 614 addResult(flo.result(), result); 615 } 616 else if (flo.operands().get(0).type().equals(JavaType.ofString("hat.KernelContext"))) { 617 String fieldName = flo.fieldDescriptor().name(); 618 SPIRVId fieldIndex = switch (fieldName) { 619 case "x" -> getConst("long", 0); 620 case "maxX" -> getConst("long", 1); 621 default -> throw new RuntimeException("Unknown field: " + fieldName); 622 }; 623 SPIRVId intType = getType("int"); 624 String contextTypeName = flo.operands().get(0).type().toString(); 625 SpirvResult kernalContext = getResult(flo.operands().get(0)); 626 SPIRVId contextAddr = kernalContext.address(); 627 SPIRVId contextType = spirvType(contextTypeName); 628 SPIRVId context = nextId(); 629 spirvBlock.add(new SPIRVOpLoad(contextType, context, contextAddr, align(contextType.getName()))); 630 SPIRVId fieldType = intType; 631 SPIRVId resultAddr = nextId(); 632 spirvBlock.add(new SPIRVOpInBoundsAccessChain(getType("ptrInt"), resultAddr, context, new SPIRVMultipleOperands<>(fieldIndex))); 633 SPIRVId result = nextId(); 634 spirvBlock.add(new SPIRVOpLoad(intType, result, resultAddr, align("int"))); 635 addResult(flo.result(), new SpirvResult(intType, resultAddr, result)); 636 } 637 else if (flo.fieldDescriptor().refType().equals(JavaType.type(ByteOrder.class))) { 638 // currently ignored 639 } 640 else unsupported("field load", ((ClassType)flo.fieldDescriptor().refType()).toClassName() + "." + flo.fieldDescriptor().name()); 641 } 642 case SpirvOp.BranchOp bop -> { 643 SPIRVId label = symbols.getLabel(bop.branch().targetBlock()).getResultId(); 644 Block.Reference target = bop.branch(); 645 spirvBlock.add(new SPIRVOpBranch(label)); 646 } 647 case SpirvOp.ConditionalBranchOp cbop -> { 648 SPIRVId test = getResult(cbop.operands().get(0)).value(); 649 SPIRVId trueLabel = symbols.getLabel(cbop.trueBranch().targetBlock()).getResultId(); 650 SPIRVId falseLabel = symbols.getLabel(cbop.falseBranch().targetBlock()).getResultId(); 651 spirvBlock.add(new SPIRVOpBranchConditional(test, trueLabel, falseLabel, new SPIRVMultipleOperands<SPIRVLiteralInteger>())); 652 } 653 case SpirvOp.LtOp ltop -> { 654 SpirvResult lhs = getResult(ltop.operands().get(0)); 655 SpirvResult rhs = getResult(ltop.operands().get(1)); 656 SPIRVId boolType = getType("bool"); 657 SPIRVId result = nextId(); 658 String operandType = lhs.type().getName(); 659 SPIRVInstruction sop = switch (operandType) { 660 case "float" -> new SPIRVOpFUnordLessThan(boolType, result, lhs.value(), rhs.value()); 661 case "int" -> new SPIRVOpSLessThan(boolType, result, lhs.value(), rhs.value()); 662 case "long" -> new SPIRVOpSLessThan(boolType, result, lhs.value(), rhs.value()); 663 default -> throw new RuntimeException("Unsupported type: " + lhs.type().getName()); 664 }; 665 spirvBlock.add(sop); 666 addResult(ltop.result(), new SpirvResult(boolType, null, result)); 667 } 668 case SpirvOp.ReturnOp rop -> { 669 spirvBlock.add(new SPIRVOpReturn()); 670 } 671 case SpirvOp.ReturnValueOp rop -> { 672 SPIRVId returnValue = getResult(rop.operands().get(0)).value(); 673 spirvBlock.add(new SPIRVOpReturnValue(returnValue)); 674 } 675 default -> unsupported("op", op.getClass()); 676 } 677 } 678 } // end bi 679 } 680 681 private SPIRVId getFunctionId(MethodRef methodRef) { 682 SPIRVId fnId = symbols.getId(methodRef.toString()); 683 if (fnId == null) { 684 try { 685 Optional<CoreOp.FuncOp> optJFuncOp = methodRef.codeModel(MethodHandles.lookup()); 686 if (optJFuncOp.isPresent()) { 687 CoreOp.FuncOp jFuncOp = optJFuncOp.get(); 688 SpirvOp.FuncOp sFuncOp = TranslateToSpirvModel.translateFunction(jFuncOp); 689 fnId = generateFunction(jFuncOp.funcName(), sFuncOp, false); 690 symbols.putId(methodRef.toString(), fnId); 691 } 692 } 693 catch (ReflectiveOperationException e) { 694 throw new RuntimeException(e); 695 } 696 } 697 return fnId; 698 } 699 700 private void initModule() { 701 module.add(new SPIRVOpCapability(SPIRVCapability.Addresses())); 702 module.add(new SPIRVOpCapability(SPIRVCapability.Linkage())); 703 module.add(new SPIRVOpCapability(SPIRVCapability.Kernel())); 704 module.add(new SPIRVOpCapability(SPIRVCapability.Int8())); 705 module.add(new SPIRVOpCapability(SPIRVCapability.Int64())); 706 module.add(new SPIRVOpMemoryModel(SPIRVAddressingModel.Physical64(), SPIRVMemoryModel.OpenCL())); 707 708 // OpenCL extension provides built-in variables suitable for kernel programming 709 // Import extension and declare four variables 710 SPIRVId oclExtension = nextId("oclExtension"); 711 module.add(new SPIRVOpExtInstImport(oclExtension, new SPIRVLiteralString("OpenCL.std"))); 712 symbols.putId("oclExtension", oclExtension); 713 714 SPIRVId globalInvocationId = nextId("globalInvocationId"); 715 SPIRVId globalSize = nextId("globalSize"); 716 SPIRVId subgroupSize = nextId("subgroupSize"); 717 SPIRVId subgroupId = nextId("subgroupId"); 718 719 module.add(new SPIRVOpDecorate(globalInvocationId, SPIRVDecoration.BuiltIn(SPIRVBuiltIn.GlobalInvocationId()))); 720 module.add(new SPIRVOpDecorate(globalInvocationId, SPIRVDecoration.Constant())); 721 module.add(new SPIRVOpDecorate(globalInvocationId, SPIRVDecoration.LinkageAttributes(new SPIRVLiteralString("spirv_BuiltInGlobalInvocationId"), SPIRVLinkageType.Import()))); 722 module.add(new SPIRVOpDecorate(globalSize, SPIRVDecoration.BuiltIn(SPIRVBuiltIn.GlobalSize()))); 723 module.add(new SPIRVOpDecorate(globalSize, SPIRVDecoration.Constant())); 724 module.add(new SPIRVOpDecorate(globalSize, SPIRVDecoration.LinkageAttributes(new SPIRVLiteralString("spirv_BuiltInGlobalSize"), SPIRVLinkageType.Import()))); 725 module.add(new SPIRVOpDecorate(subgroupSize, SPIRVDecoration.BuiltIn(SPIRVBuiltIn.SubgroupSize()))); 726 module.add(new SPIRVOpDecorate(subgroupSize, SPIRVDecoration.Constant())); 727 module.add(new SPIRVOpDecorate(subgroupSize, SPIRVDecoration.LinkageAttributes(new SPIRVLiteralString("spirv_BuiltInSubgroupSize"), SPIRVLinkageType.Import()))); 728 module.add(new SPIRVOpDecorate(subgroupId, SPIRVDecoration.BuiltIn(SPIRVBuiltIn.SubgroupId()))); 729 module.add(new SPIRVOpDecorate(subgroupId, SPIRVDecoration.Constant())); 730 module.add(new SPIRVOpDecorate(subgroupId, SPIRVDecoration.LinkageAttributes(new SPIRVLiteralString("spirv_BuiltInSubgroupId"), SPIRVLinkageType.Import()))); 731 732 module.add(new SPIRVOpVariable(getType("ptrV3long"), globalInvocationId, SPIRVStorageClass.Input(), new SPIRVOptionalOperand<>())); 733 module.add(new SPIRVOpVariable(getType("ptrV3long"), globalSize, SPIRVStorageClass.Input(), new SPIRVOptionalOperand<>())); 734 module.add(new SPIRVOpVariable(getType("ptrV3long"), subgroupSize, SPIRVStorageClass.Input(), new SPIRVOptionalOperand<>())); 735 module.add(new SPIRVOpVariable(getType("ptrV3long"), subgroupId, SPIRVStorageClass.Input(), new SPIRVOptionalOperand<>())); 736 } 737 738 private SPIRVId spirvType(String inputType) { 739 String javaType = inputType.replaceAll("\\$", "."); 740 SPIRVId ans = switch(javaType) { 741 case "byte" -> getType("byte"); 742 case "short" -> getType("short"); 743 case "int" -> getType("int"); 744 case "long" -> getType("long"); 745 case "float" -> getType("float"); 746 case "double" -> getType("double"); 747 case "byte[]" -> getType("byte[]"); 748 case "int[]" -> getType("int[]"); 749 case "float[]" -> getType("float[]"); 750 case "double[]" -> getType("double[]"); 751 case "long[]" -> getType("long[]"); 752 case "bool" -> getType("bool"); 753 case "boolean" -> getType("bool"); 754 case "java.lang.Object" -> getType("java.lang.Object"); 755 case "hat.buffer.S32Array" -> getType("int[]"); 756 case "hat.buffer.S32Array2D" -> getType("int[]"); 757 case "hat.buffer.F32Array2D" -> getType("float[]"); 758 case "void" -> getType("void"); 759 case "hat.KernelContext" -> getType("ptrKernelContext"); 760 case "java.lang.foreign.MemorySegment" -> getType("ptrByte"); 761 default -> null; 762 }; 763 if (ans == null) unsupported("type", javaType); 764 return ans; 765 } 766 767 private SPIRVId spirvElementType(String inputType) { 768 String javaType = inputType.replaceAll("\\$", "."); 769 SPIRVId ans = switch(javaType) { 770 case "byte[]" -> getType("byte"); 771 case "short[]" -> getType("short"); 772 case "int[]" -> getType("int"); 773 case "long[]" -> getType("long"); 774 case "float[]" -> getType("float"); 775 case "double[]" -> getType("double"); 776 case "boolean[]" -> getType("bool"); 777 case "hat.buffer.S32Array" -> getType("int"); 778 case "hat.buffer.S32Array2D" -> getType("int"); 779 case "hat.buffer.F32Array2D" -> getType("float"); 780 case "java.lang.foreign.MemorySegment" -> getType("byte"); 781 default -> null; 782 }; 783 if (ans == null) unsupported("type", javaType); 784 return ans; 785 } 786 787 private SPIRVId vectorElementType(SPIRVId type) { 788 SPIRVId ans = switch(type.getName()) { 789 case "v8int" -> getType("int"); 790 case "v16int" -> getType("int"); 791 case "v8long" -> getType("long"); 792 case "v8float" -> getType("float"); 793 case "v16float" -> getType("float"); 794 default -> null; 795 }; 796 if (ans == null) unsupported("type", type.getName()); 797 return ans; 798 } 799 800 private SPIRVId spirvVariableType(SPIRVId spirvType) { 801 SPIRVId ans = switch(spirvType.getName()) { 802 case "bool" -> getType("ptrBool"); 803 case "byte" -> getType("ptrByte"); 804 case "short" -> getType("ptrShort"); 805 case "int" -> getType("ptrInt"); 806 case "long" -> getType("ptrLong"); 807 case "float" -> getType("ptrFloat"); 808 case "double" -> getType("ptrDouble"); 809 case "boolean" -> getType("ptrBool"); 810 case "byte[]" -> getType("ptrByte[]"); 811 case "int[]" -> getType("ptrInt[]"); 812 case "long[]" -> getType("ptrLong[]"); 813 case "float[]" -> getType("ptrFloat[]"); 814 case "double[]" -> getType("ptrDouble[]"); 815 case "v8int" -> getType("ptrV8int"); 816 case "v16int" -> getType("ptrV16int"); 817 case "v8long" -> getType("ptrV8long"); 818 case "v8float" -> getType("ptrV8float"); 819 case "v16float" -> getType("ptrV16float"); 820 case "ptrKernelContext" -> getType("ptrPtrKernelContext"); 821 case "hat.KernelContext" -> getType("ptrKernelContext"); 822 case "ptrByte" -> getType("ptrPtrByte"); 823 default -> null; 824 }; 825 if (ans == null) unsupported("type", spirvType.getName()); 826 return ans; 827 } 828 829 private SPIRVId spirvVectorType(String javaVectorType, int vectorLength) { 830 String prefix = "v" + vectorLength; 831 String elementType = spirvElementType(javaVectorType).getName(); 832 return getType(prefix + elementType); 833 } 834 835 private int alignment(String inputType) { 836 String spirvType = inputType.replaceAll("\\$", "."); 837 int ans = switch(spirvType) { 838 case "bool" -> 1; 839 case "byte" -> 1; 840 case "short" -> 2; 841 case "int" -> 4; 842 case "long" -> 8; 843 case "float" -> 4; 844 case "double" -> 8; 845 case "boolean" -> 1; 846 case "v8int" -> 32; 847 case "v16int" -> 64; 848 case "v8long" -> 64; 849 case "v8float" -> 32; 850 case "v16float" -> 64; 851 case "hat.KernelContext" -> 32; 852 case "ptrKernelContext" -> 32; 853 case "byte[]" -> 8; 854 case "int[]" -> 8; 855 case "long[]" -> 8; 856 case "float[]" -> 8; 857 case "double[]" -> 8; 858 case "ptrBool" -> 8; 859 case "ptrByte" -> 8; 860 case "ptrInt" -> 8; 861 case "ptrByte[]" -> 8; 862 case "ptrInt[]" -> 8; 863 case "ptrLong" -> 8; 864 case "ptrLong[]" -> 8; 865 case "ptrFloat" -> 8; 866 case "ptrFloat[]" -> 8; 867 case "ptrV8int" -> 8; 868 case "ptrV8float" -> 8; 869 case "ptrPtrKernelContext" -> 8; 870 default -> 0; 871 }; 872 if (ans == 0) unsupported("type", spirvType); 873 return ans; 874 } 875 876 private Set<String> moduleTypes = new HashSet<>(); 877 878 private SPIRVId getType(String inputName) { 879 String name = inputName.replaceAll("\\$", "."); 880 if (!moduleTypes.contains(name)) { 881 switch (name) { 882 case "void" -> module.add(new SPIRVOpTypeVoid(nextId(name))); 883 case "bool" -> module.add(new SPIRVOpTypeBool(nextId(name))); 884 case "boolean" -> module.add(new SPIRVOpTypeBool(nextId(name))); 885 case "byte" -> module.add(new SPIRVOpTypeInt(nextId(name), new SPIRVLiteralInteger(8), new SPIRVLiteralInteger(0))); 886 case "short" -> module.add(new SPIRVOpTypeInt(nextId(name), new SPIRVLiteralInteger(16), new SPIRVLiteralInteger(0))); 887 case "int" -> module.add(new SPIRVOpTypeInt(nextId(name), new SPIRVLiteralInteger(32), new SPIRVLiteralInteger(0))); 888 case "long" -> module.add(new SPIRVOpTypeInt(nextId(name), new SPIRVLiteralInteger(64), new SPIRVLiteralInteger(0))); 889 case "float" -> module.add(new SPIRVOpTypeFloat(nextId(name), new SPIRVLiteralInteger(32))); 890 case "double" -> module.add(new SPIRVOpTypeFloat(nextId(name), new SPIRVLiteralInteger(64))); 891 case "ptrBool" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.CrossWorkgroup(), getType("bool"))); 892 case "ptrByte" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.CrossWorkgroup(), getType("byte"))); 893 case "ptrInt" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("int"))); 894 case "ptrLong" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("long"))); 895 case "ptrFloat" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("float"))); 896 case "byte[]" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.CrossWorkgroup(), getType("byte"))); 897 case "short[]" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.CrossWorkgroup(), getType("short"))); 898 case "int[]" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.CrossWorkgroup(), getType("int"))); 899 case "long[]" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.CrossWorkgroup(), getType("long"))); 900 case "float[]" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.CrossWorkgroup(), getType("float"))); 901 case "double[]" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.CrossWorkgroup(), getType("double"))); 902 case "boolean[]" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.CrossWorkgroup(), getType("boolean"))); 903 case "ptrByte[]" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("byte[]"))); 904 case "ptrInt[]" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("int[]"))); 905 case "ptrLong[]" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("long[]"))); 906 case "ptrFloat[]" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("float[]"))); 907 case "java.lang.Object" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("void"))); 908 case "hat.KernelContext" -> module.add(new SPIRVOpTypeStruct(nextId(name), new SPIRVMultipleOperands<>(getType("int"), getType("int")))); 909 case "ptrKernelContext" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("hat.KernelContext"))); 910 case "ptrCrossGroupByte"-> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.CrossWorkgroup(), getType("byte"))); 911 case "ptrPtrKernelContext" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("ptrKernelContext"))); 912 case "ptrPtrByte" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("ptrByte"))); 913 case "ptrPtrInt" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("ptrInt"))); 914 case "ptrPtrFloat" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("ptrFloat"))); 915 case "v3long" -> module.add(new SPIRVOpTypeVector(nextId(name), getType("long"), new SPIRVLiteralInteger(3))); 916 case "v8int" -> module.add(new SPIRVOpTypeVector(nextId(name), getType("int"), new SPIRVLiteralInteger(8))); 917 case "v8long" -> module.add(new SPIRVOpTypeVector(nextId(name), getType("long"), new SPIRVLiteralInteger(8))); 918 case "v16int" -> module.add(new SPIRVOpTypeVector(nextId(name), getType("int"), new SPIRVLiteralInteger(16))); 919 case "v8float" -> module.add(new SPIRVOpTypeVector(nextId(name), getType("float"), new SPIRVLiteralInteger(8))); 920 case "v16float" -> module.add(new SPIRVOpTypeVector(nextId(name), getType("float"), new SPIRVLiteralInteger(16))); 921 case "ptrV3long" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Input(), getType("v3long"))); 922 case "ptrV8long" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("v8long"))); 923 case "ptrV8int" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("v8int"))); 924 case "ptrV16int" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("v16int"))); 925 case "ptrV8float" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("v8float"))); 926 case "ptrV16float" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("v16float"))); 927 case "ptrPtrV8int" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("ptrV8int"))); 928 case "ptrPtrV16int" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("ptrV16int"))); 929 case "ptrPtrV8float" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("ptrV8float"))); 930 case "ptrPtrV16float" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("ptrV16float"))); 931 default -> unsupported("type", name); 932 } 933 moduleTypes.add(name); 934 } 935 return getId(name); 936 } 937 938 private Set<String> moduleConstants = new HashSet<>(); 939 940 private SPIRVId getConst(String typeName, long value) { 941 String name = typeName + "_" + value; 942 if (!moduleConstants.contains(name)) { 943 String valueStr = String.valueOf(value); 944 switch (typeName) { 945 case "int" -> module.add(new SPIRVOpConstant(getType(typeName), nextId(name), new SPIRVContextDependentInt(new BigInteger(valueStr)))); 946 case "long" -> module.add(new SPIRVOpConstant(getType(typeName), nextId(name), new SPIRVContextDependentLong(new BigInteger(valueStr)))); 947 case "boolean" -> module.add(value == 0 ? new SPIRVOpConstantFalse(getType(typeName), nextId(name)) : new SPIRVOpConstantTrue(getType(typeName), nextId(name))); 948 default -> unsupported("constant", name); 949 }; 950 moduleConstants.add(name); 951 } 952 return getId(name); 953 } 954 955 private SPIRVOptionalOperand<SPIRVMemoryAccess> align(int align) { 956 return new SPIRVOptionalOperand<>(SPIRVMemoryAccess.Aligned(new SPIRVLiteralInteger(align))); 957 } 958 959 private SPIRVOptionalOperand<SPIRVMemoryAccess> align(String type) { 960 return align(alignment(type)); 961 } 962 963 private SPIRVMultipleOperands<SPIRVId> spirvOperands(SPIRVId value, int count) { 964 SPIRVId[] values = new SPIRVId[count]; 965 Arrays.fill(values, value); 966 return new SPIRVMultipleOperands<>(values); 967 } 968 969 private SPIRVOptionalOperand<SPIRVMemoryAccess> none() { 970 return new SPIRVOptionalOperand<>(); 971 } 972 973 private SpirvResult globalSize(int index, SPIRVBlock spirvBlock) { 974 SPIRVId intType = getType("int"); 975 SPIRVId longType = getType("long"); 976 SPIRVId v3long = getId("v3long"); 977 SPIRVId globalSizeId = getId("globalSize"); 978 SPIRVId globalSizes = nextId(); 979 spirvBlock.add(new SPIRVOpLoad(v3long, globalSizes, globalSizeId, align(32))); 980 SPIRVId longSize = nextId(); 981 SPIRVId globalSize = nextId(); 982 spirvBlock.add(new SPIRVOpCompositeExtract(longType, longSize, globalSizes, new SPIRVMultipleOperands<>(new SPIRVLiteralInteger(index)))); 983 spirvBlock.add(new SPIRVOpSConvert(intType, globalSize, longSize)); 984 return new SpirvResult(intType, null, globalSize); 985 } 986 987 private SpirvResult globalId(int index, SPIRVBlock spirvBlock) { 988 SPIRVId intType = getType("int"); 989 SPIRVId longType = getType("long"); 990 SPIRVId v3long = getId("v3long"); 991 SPIRVId globalInvocationId = getId("globalInvocationId"); 992 SPIRVId globalIds = nextId(); 993 spirvBlock.add(new SPIRVOpLoad(v3long, globalIds, globalInvocationId, align(32))); 994 SPIRVId longIndex = nextId(); 995 SPIRVId globalIndex = nextId(); 996 spirvBlock.add(new SPIRVOpCompositeExtract(longType, longIndex, globalIds, new SPIRVMultipleOperands<>(new SPIRVLiteralInteger(index)))); 997 spirvBlock.add(new SPIRVOpSConvert(intType, globalIndex, longIndex)); 998 return new SpirvResult(intType, null, globalIndex); 999 } 1000 1001 private SpirvResult flatIndex(SPIRVId sizeX, SPIRVId sizeY, SPIRVId sizeZ, SPIRVId indexX, SPIRVId indexY, SPIRVId indexZ, SPIRVBlock spirvBlock) 1002 { 1003 SPIRVId longType = getType("long"); 1004 SPIRVId xTerm0 = nextId(); 1005 SPIRVId xTerm1 = nextId(); 1006 SPIRVId yTerm = nextId(); 1007 SPIRVId flat0 = nextId(); 1008 SPIRVId flat1 = nextId(); 1009 spirvBlock.add(new SPIRVOpIMul(longType, xTerm0, sizeY, sizeZ)); 1010 spirvBlock.add(new SPIRVOpIMul(longType, xTerm1, xTerm0, indexX)); 1011 spirvBlock.add(new SPIRVOpIMul(longType, yTerm, sizeZ, indexY)); 1012 spirvBlock.add(new SPIRVOpIAdd(longType, flat0, xTerm1, yTerm)); 1013 spirvBlock.add(new SPIRVOpIAdd(longType, flat1, flat0, indexZ)); 1014 return new SpirvResult(longType, null, flat1); 1015 } 1016 1017 private SPIRVId nextId() { 1018 return module.getNextId(); 1019 } 1020 1021 private SPIRVId nextId(String name) { 1022 SPIRVId ans = nextId(); 1023 ans.setName(name); 1024 symbols.putId(name, ans); 1025 module.add(new SPIRVOpName(ans, new SPIRVLiteralString(name))); 1026 return ans; 1027 } 1028 1029 private static int counter = 0; 1030 1031 private String nextTempTag() { 1032 counter++; 1033 return "temp_" + counter + "_"; 1034 } 1035 1036 private boolean isIntegerType(SPIRVId type) { 1037 String name = type.getName(); 1038 return name.equals("byte") || name.equals("short") || name.equals("int") || name.equals("long"); 1039 } 1040 1041 private boolean isFloatType(SPIRVId type) { 1042 String name = type.getName(); 1043 return name.equals("float") || name.equals("double"); 1044 } 1045 1046 private boolean isVectorSpecies(String javaType) { 1047 return javaType.equals("VectorSpecies"); 1048 } 1049 1050 private boolean isVectorType(String javaType) { 1051 return javaType.equals("IntVector") || javaType.equals("FloatVector"); 1052 } 1053 1054 private void addId(String name, SPIRVId id) { 1055 symbols.putId(name, id); 1056 } 1057 1058 private SPIRVId getId(String name) { 1059 SPIRVId ans = symbols.getId(name); 1060 assert ans != null : name + " not found"; 1061 return ans; 1062 } 1063 1064 private SPIRVId getIdOrNull(String name) { 1065 return symbols.getId(name); 1066 } 1067 1068 private static Object map(Function<Object, Boolean> test, Object... args) { 1069 int len = args.length; 1070 assert len >= 2 && len % 2 == 0; 1071 int pairs = len / 2; 1072 for (int i = 0; i < pairs; i++) { 1073 if (test.apply(args[i])) return args[i + pairs]; 1074 } 1075 throw new RuntimeException("No match: " + args[0]); 1076 } 1077 1078 private void unsupported(String message, Object value) { 1079 throw new RuntimeException("Unsupported " + message + ": " + value); 1080 } 1081 1082 private void addResult(Value value, SpirvResult result) { 1083 assert symbols.getResult(value) == null : "result already present"; 1084 symbols.putResult(value, result); 1085 } 1086 1087 private SpirvResult getResult(Value value) { 1088 return symbols.getResult(value); 1089 } 1090 1091 private static class Symbols { 1092 private final HashMap<Value, SpirvResult> results; 1093 private final HashMap<String, SPIRVId> ids; 1094 private final HashMap<Block, SPIRVBlock> blocks; 1095 private final HashMap<Block, SPIRVOpLabel> labels; 1096 1097 public Symbols() { 1098 this.results = new HashMap<>(); 1099 this.ids = new HashMap<>(); 1100 this.blocks = new HashMap<>(); 1101 this.labels = new HashMap<>(); 1102 } 1103 1104 public void putId(String name, SPIRVId id) { 1105 ids.put(name, id); 1106 } 1107 1108 public SPIRVId getId(String name) { 1109 return ids.get(name); 1110 } 1111 1112 public void putBlock(Block block, SPIRVBlock spirvBlock) { 1113 blocks.put(block, spirvBlock); 1114 } 1115 1116 public SPIRVBlock getBlock(Block block) { 1117 return blocks.get(block); 1118 } 1119 1120 public void putLabel(Block block, SPIRVOpLabel spirvLabel) { 1121 labels.put(block, spirvLabel); 1122 } 1123 1124 public SPIRVOpLabel getLabel(Block block) { 1125 return labels.get(block); 1126 } 1127 1128 public void putResult(Value value, SpirvResult result) { 1129 results.put(value, result); 1130 } 1131 1132 public SpirvResult getResult(Value value) { 1133 return results.get(value); 1134 } 1135 1136 public String toString() { 1137 return String.format("results %s\n\nids %s\n\nblocks %s\nlabels %s\n", results.keySet(), ids.keySet(), blocks.keySet(), labels.keySet()); 1138 } 1139 } 1140 1141 public static void debug(String message, Object... args) { 1142 System.out.println(String.format(message, args)); 1143 } 1144 }