1 /* 2 * Copyright (c) 2024 Intel Corporation. All rights reserved. 3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. 4 * 5 * This code is free software; you can redistribute it and/or modify it 6 * under the terms of the GNU General Public License version 2 only, as 7 * published by the Free Software Foundation. Oracle designates this 8 * particular file as subject to the "Classpath" exception as provided 9 * by Oracle in the LICENSE file that accompanied this code. 10 * 11 * This code is distributed in the hope that it will be useful, but WITHOUT 12 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or 13 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License 14 * version 2 for more details (a copy is included in the LICENSE file that 15 * accompanied this code). 16 * 17 * You should have received a copy of the GNU General Public License version 18 * 2 along with this work; if not, write to the Free Software Foundation, 19 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. 20 * 21 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA 22 * or visit www.oracle.com if you need additional information or have any 23 * questions. 24 */ 25 26 package intel.code.spirv; 27 28 import java.util.List; 29 import java.util.Arrays; 30 import java.util.HashMap; 31 import java.util.Set; 32 import java.util.HashSet; 33 import java.util.function.Function; 34 import java.io.IOException; 35 import java.io.File; 36 import java.io.FileOutputStream; 37 import java.io.ByteArrayInputStream; 38 import java.io.ByteArrayOutputStream; 39 import java.io.PrintStream; 40 import java.nio.ByteBuffer; 41 import java.nio.ByteOrder; 42 import java.nio.channels.FileChannel; 43 import java.math.BigInteger; 44 import jdk.incubator.vector.VectorSpecies; 45 import jdk.incubator.vector.VectorOperators; 46 import jdk.incubator.vector.Vector; 47 import jdk.incubator.vector.IntVector; 48 import jdk.incubator.vector.FloatVector; 49 import java.lang.foreign.MemorySegment; 50 import java.lang.foreign.ValueLayout; 51 import java.lang.reflect.code.Block; 52 import java.lang.reflect.code.Body; 53 import java.lang.reflect.code.Op; 54 import java.lang.reflect.code.Value; 55 import java.lang.reflect.code.op.CoreOp; 56 import java.lang.reflect.code.TypeElement; 57 import java.lang.reflect.code.type.MethodRef; 58 import java.lang.reflect.code.type.ClassType; 59 import java.lang.reflect.code.type.JavaType; 60 import uk.ac.manchester.beehivespirvtoolkit.lib.SPIRVHeader; 61 import uk.ac.manchester.beehivespirvtoolkit.lib.SPIRVModule; 62 import uk.ac.manchester.beehivespirvtoolkit.lib.SPIRVFunction; 63 import uk.ac.manchester.beehivespirvtoolkit.lib.SPIRVBlock; 64 import uk.ac.manchester.beehivespirvtoolkit.lib.instructions.*; 65 import uk.ac.manchester.beehivespirvtoolkit.lib.instructions.operands.*; 66 import uk.ac.manchester.beehivespirvtoolkit.lib.disassembler.Disassembler; 67 import uk.ac.manchester.beehivespirvtoolkit.lib.disassembler.SPIRVDisassemblerOptions; 68 import uk.ac.manchester.beehivespirvtoolkit.lib.disassembler.SPVByteStreamReader; 69 70 public class SpirvModuleGenerator { 71 public static MemorySegment generateModule(String moduleName, CoreOp.FuncOp func) { 72 SpirvOps.FuncOp spirvFunc = TranslateToSpirvModel.translateFunction(func); 73 MemorySegment module = SpirvModuleGenerator.generateModule(moduleName, spirvFunc); 74 return module; 75 } 76 77 public static MemorySegment generateModule(String moduleName, SpirvOps.FuncOp func) { 78 return new SpirvModuleGenerator().generateModuleInternal(moduleName, func); 79 } 80 81 public static void writeModuleToFile(MemorySegment module, String filepath) { 82 ByteBuffer buffer = module.asByteBuffer(); 83 File out = new File(filepath); 84 try (FileChannel channel = new FileOutputStream(out, false).getChannel()) { 85 channel.write(buffer); 86 } 87 catch (IOException e) { 88 throw new RuntimeException(e); 89 } 90 } 91 92 public static String disassembleModule(MemorySegment module) { 93 SPVByteStreamReader input = new SPVByteStreamReader(new ByteArrayInputStream(module.toArray(ValueLayout.JAVA_BYTE))); 94 ByteArrayOutputStream out = new ByteArrayOutputStream(); 95 try (PrintStream ps = new PrintStream(out)) { 96 SPIRVDisassemblerOptions options = new SPIRVDisassemblerOptions(false, false, false, false, true); 97 Disassembler dis = new Disassembler(input, ps, options); 98 dis.run(); 99 } 100 catch (Exception e) { 101 throw new RuntimeException(e); 102 } 103 return new String(out.toByteArray()); 104 } 105 106 private record SpirvResult(SPIRVId type, SPIRVId address, SPIRVId value) {} 107 108 private final SPIRVModule module; 109 private final Symbols symbols; 110 111 private SpirvModuleGenerator() { 112 this.module = new SPIRVModule(new SPIRVHeader(1, 2, 32, 0, 0)); 113 this.symbols = new Symbols(); 114 } 115 116 private MemorySegment generateModuleInternal(String moduleName, SpirvOps.FuncOp func) { 117 initModule(); 118 generateFunction(moduleName, moduleName, func); 119 ByteBuffer buffer = ByteBuffer.allocateDirect(module.getByteCount()); 120 buffer.order(ByteOrder.LITTLE_ENDIAN); 121 module.close().write(buffer); 122 buffer.flip(); 123 return MemorySegment.ofBuffer(buffer); 124 } 125 126 private void generateFunction(String moduleName, String fnName, SpirvOps.FuncOp func) { 127 TypeElement returnType = func.invokableType().returnType(); 128 SPIRVId functionID = nextId(fnName); 129 String signature = func.invokableType().returnType().toString(); 130 List<TypeElement> paramTypes = func.invokableType().parameterTypes(); 131 // build signature string 132 for (int i = 0; i < paramTypes.size(); i++) { 133 signature += "_" + paramTypes.get(i).toString(); 134 } 135 // declare function type if not already present 136 SPIRVId functionSig = getIdOrNull(signature); 137 if (functionSig == null) { 138 SPIRVId[] typeIdsArray = new SPIRVId[paramTypes.size()]; 139 for (int i = 0; i < paramTypes.size(); i++) { 140 typeIdsArray[i] = spirvType(paramTypes.get(i).toString()); 141 } 142 functionSig = nextId(fnName + "Signature"); 143 module.add(new SPIRVOpTypeFunction(functionSig, spirvType(returnType.toString()), new SPIRVMultipleOperands<>(typeIdsArray))); 144 addId(signature, functionSig); 145 } 146 // declare function as modeule entry point 147 SPIRVId spirvReturnType = spirvType(returnType.toString()); 148 SPIRVFunction function = (SPIRVFunction)module.add(new SPIRVOpFunction(spirvReturnType, functionID, SPIRVFunctionControl.DontInline(), functionSig)); 149 SPIRVOpLabel entryPoint = new SPIRVOpLabel(nextId()); 150 SPIRVBlock entryBlock = (SPIRVBlock)function.add(entryPoint); 151 SPIRVMultipleOperands<SPIRVId> operands = new SPIRVMultipleOperands<>(getId("globalInvocationId"), getId("globalSize"), getId("subgroupSize"), getId("subgroupId")); 152 module.add(new SPIRVOpEntryPoint(SPIRVExecutionModel.Kernel(), functionID, new SPIRVLiteralString(fnName), operands)); 153 154 translateBody(func.body(), function, entryBlock); 155 function.add(new SPIRVOpFunctionEnd()); 156 } 157 158 private void translateBody(Body body, SPIRVFunction function, SPIRVBlock entryBlock) { 159 int labelNumber = 0; 160 SPIRVBlock spirvBlock = entryBlock; 161 for (int bi = 1; bi < body.blocks().size(); bi++) { 162 Block block = body.blocks().get(bi); 163 String blockName = String.valueOf(block.hashCode()); 164 SPIRVOpLabel blockLabel = new SPIRVOpLabel(nextId()); 165 SPIRVBlock newBlock = (SPIRVBlock)function.add(blockLabel); 166 symbols.putBlock(block, newBlock); 167 symbols.putLabel(block, blockLabel); 168 } 169 for (Value param : body.entryBlock().parameters()) { 170 SPIRVId paramId = nextId(); 171 addResult(param, new SpirvResult(spirvType(param.type().toString()), null, paramId)); 172 } 173 for (int bi = 0; bi < body.blocks().size(); bi++) { 174 Block block = body.blocks().get(bi); 175 if (bi > 0) { 176 spirvBlock = symbols.getBlock(block); 177 } 178 List<Op> ops = block.ops(); 179 for (Op op : block.ops()) { 180 // debug("---------- spirv op = %s", op.toText()); 181 switch (op) { 182 case SpirvOps.VariableOp vop -> { 183 String typeName = vop.varType().toString(); 184 SPIRVId type = spirvType(typeName); 185 SPIRVId varType = spirvVariableType(type); 186 SPIRVId var = nextId(vop.varName()); 187 spirvBlock.add(new SPIRVOpVariable(varType, var, SPIRVStorageClass.Function(), new SPIRVOptionalOperand<>())); 188 addResult(vop.result(), new SpirvResult(varType, var, null)); 189 } 190 case SpirvOps.FunctionParameterOp fpo -> { 191 SPIRVId result = nextId(); 192 SPIRVId type = spirvType(fpo.resultType().toString()); 193 function.add(new SPIRVOpFunctionParameter(type, result)); 194 addResult(fpo.result(), new SpirvResult(type, null, result)); 195 } 196 case SpirvOps.LoadOp lo -> { 197 if (((JavaType)lo.resultType()).equals(JavaType.type(VectorSpecies.class))) { 198 addResult(lo.result(), new SpirvResult(getType("int"), null, getConst("int_EIGHT"))); 199 } 200 else { 201 SPIRVId type = spirvType(lo.resultType().toString()); 202 SpirvResult toLoad = getResult(lo.operands().get(0)); 203 SPIRVId varAddr = toLoad.address(); 204 SPIRVId result = nextId(); 205 spirvBlock.add(new SPIRVOpLoad(type, result, varAddr, align(type.getName()))); 206 addResult(lo.result(), new SpirvResult(type, varAddr, result)); 207 } 208 } 209 case SpirvOps.StoreOp so -> { 210 SpirvResult var = getResult(so.operands().get(0)); 211 SPIRVId varAddr = var.address(); 212 SPIRVId value = getResult(so.operands().get(1)).value(); 213 spirvBlock.add(new SPIRVOpStore(varAddr, value, align(var.type().getName()))); 214 } 215 case SpirvOps.IAddOp _, SpirvOps.FAddOp _ -> { 216 SPIRVId intType = getType("int"); 217 SPIRVId longType = getType("long"); 218 SPIRVId floatType = getType("float"); 219 SPIRVId lhs = getResult(op.operands().get(0)).value(); 220 SPIRVId rhs = getResult(op.operands().get(1)).value(); 221 SPIRVId lhsType = spirvType(op.resultType().toString()); 222 SPIRVId ans = nextId(); 223 if (lhsType == intType) spirvBlock.add(new SPIRVOpIAdd(intType, ans, lhs, rhs)); 224 else if (lhsType == longType) spirvBlock.add(new SPIRVOpIAdd(longType, ans, lhs, rhs)); 225 else if (lhsType == floatType) spirvBlock.add(new SPIRVOpFAdd(floatType, ans, lhs, rhs)); 226 else unsupported("type", lhsType.getName()); 227 addResult(op.result(), new SpirvResult(lhsType, null, ans)); 228 } 229 case SpirvOps.IMulOp _, SpirvOps.FMulOp _, SpirvOps.IDivOp _, SpirvOps.FDivOp _ -> { 230 SPIRVId intType = getType("int"); 231 SPIRVId longType = getType("long"); 232 SPIRVId floatType = getType("float"); 233 SPIRVId lhs = getResult(op.operands().get(0)).value(); 234 SPIRVId rhs = getResult(op.operands().get(1)).value(); 235 SPIRVId lhsType = spirvType(op.resultType().toString()); 236 SPIRVId rhsType = getResult(op.operands().get(1)).type(); 237 SPIRVId ans = nextId(); 238 if (lhsType == intType) { 239 if (op instanceof SpirvOps.IMulOp) spirvBlock.add(new SPIRVOpIMul(intType, ans, lhs, rhs)); 240 else if (op instanceof SpirvOps.IDivOp) spirvBlock.add(new SPIRVOpSDiv(intType, ans, lhs, rhs)); 241 } 242 else if (lhsType == longType) { 243 SPIRVId rhsId = rhsType == intType ? nextId() : rhs; 244 if (rhsType == intType) spirvBlock.add(new SPIRVOpSConvert(longType, rhsId, rhs)); 245 if (op instanceof SpirvOps.IMulOp) spirvBlock.add(new SPIRVOpIMul(longType, ans, lhs, rhsId)); 246 else if (op instanceof SpirvOps.IDivOp) spirvBlock.add(new SPIRVOpSDiv(longType, ans, lhs, rhs)); 247 } 248 else if (lhsType == floatType) { 249 if (op instanceof SpirvOps.FMulOp) spirvBlock.add(new SPIRVOpFMul(floatType, ans, lhs, rhs)); 250 else if (op instanceof SpirvOps.FDivOp) spirvBlock.add(new SPIRVOpFDiv(floatType, ans, lhs, rhs)); 251 } 252 else unsupported("type", lhsType); 253 addResult(op.result(), new SpirvResult(lhsType, null, ans)); 254 } 255 case SpirvOps.ModOp mop -> { 256 SPIRVId type = getType(mop.operands().get(0).type().toString()); 257 SPIRVId lhs = getResult(mop.operands().get(0)).value(); 258 SPIRVId rhs = getResult(mop.operands().get(1)).value(); 259 SPIRVId result = nextId(); 260 spirvBlock.add(new SPIRVOpUMod(type, result, lhs, rhs)); 261 addResult(mop.result(), new SpirvResult(type, null, result)); 262 } 263 case SpirvOps.IEqualOp eqop -> { 264 SPIRVId boolType = getType("bool"); 265 SPIRVId intType = getType("int"); 266 SPIRVId longType = getType("long"); 267 SPIRVId floatType = getType("float"); 268 SPIRVId lhs = getResult(op.operands().get(0)).value(); 269 SPIRVId rhs = getResult(op.operands().get(1)).value(); 270 SPIRVId lhsType = spirvType(op.resultType().toString()); 271 SPIRVId ans = nextId(); 272 if (lhsType == intType) spirvBlock.add(new SPIRVOpIEqual(boolType, ans, lhs, rhs)); 273 else if (lhsType == longType) spirvBlock.add(new SPIRVOpIEqual(boolType, ans, lhs, rhs)); 274 else unsupported("type", lhsType.getName()); 275 addResult(op.result(), new SpirvResult(lhsType, null, ans)); 276 } 277 case SpirvOps.CallOp call -> { 278 if (call.callDescriptor().equals(MethodRef.ofString("spirvdemo.IntArray::get(long)int")) || 279 call.callDescriptor().equals(MethodRef.ofString("spirvdemo.FloatArray::get(long)float"))) { 280 SPIRVId longType = getType("long"); 281 String arrayTypeName = call.operands().get(0).type().toString(); 282 SpirvResult arrayResult = getResult(call.operands().get(0)); 283 SPIRVId arrayAddr = arrayResult.address(); 284 SPIRVId arrayType = spirvType(arrayTypeName); 285 SPIRVId elementType = spirvElementType(arrayTypeName); 286 int nIndexes = call.operands().size() - 1; 287 SPIRVId index = getResult(call.operands().get(1)).value(); 288 SPIRVId array = nextId(); 289 spirvBlock.add(new SPIRVOpLoad(arrayType, array, arrayAddr, align(arrayType.getName()))); 290 SPIRVId resultAddr = nextId(); 291 spirvBlock.add(new SPIRVOpInBoundsPtrAccessChain(arrayType, resultAddr, array, index, new SPIRVMultipleOperands<>())); 292 SPIRVId result = nextId(); 293 spirvBlock.add(new SPIRVOpLoad(elementType, result, resultAddr, align(elementType.getName()))); 294 addResult(call.result(), new SpirvResult(elementType, resultAddr, result)); 295 } 296 else if (call.callDescriptor().equals(MethodRef.ofString("spirvdemo.IntArray::set(long, int)void")) || 297 call.callDescriptor().equals(MethodRef.ofString("spirvdemo.FloatArray::set(long, float)void"))) { 298 SPIRVId longType = getType("long"); 299 String arrayTypeName = call.operands().get(0).type().toString(); 300 SpirvResult arrayResult = getResult(call.operands().get(0)); 301 SPIRVId arrayAddr = arrayResult.address(); 302 SPIRVId arrayType = spirvType(arrayTypeName); 303 SPIRVId elementType = spirvElementType(arrayTypeName); 304 int nIndexes = call.operands().size() - 2; 305 int valueIndex = nIndexes + 1; 306 SPIRVId index = getResult(call.operands().get(1)).value(); 307 SPIRVId array = nextId(); 308 spirvBlock.add(new SPIRVOpLoad(arrayType, array, arrayAddr, align(arrayType.getName()))); 309 SPIRVId dest = nextId(); 310 spirvBlock.add(new SPIRVOpInBoundsPtrAccessChain(arrayType, dest, array, index, new SPIRVMultipleOperands<>())); 311 SPIRVId value = getResult(call.operands().get(valueIndex)).value(); 312 spirvBlock.add(new SPIRVOpStore(dest, value, align(elementType.getName()))); 313 } 314 else if (call.callDescriptor().equals(MethodRef.method(IntVector.class, "fromArray", IntVector.class, VectorSpecies.class, int[].class, int.class)) 315 || call.callDescriptor().equals(MethodRef.method(FloatVector.class, "fromArray", FloatVector.class, VectorSpecies.class, float[].class, int.class))) { 316 SPIRVId oclExtension = getId("oclExtension"); 317 SpirvResult speciesResult = getResult(call.operands().get(0)); 318 SpirvResult arrayResult = getResult(call.operands().get(1)); 319 String arrayType = arrayResult.type().getName(); 320 int laneCount = 8; //TODO: remove hard code, instruction below needs a literal 321 String vTypeName = ((ClassType)call.callDescriptor().refType()).toClassName(); 322 SPIRVId vType = spirvVectorType(vTypeName, laneCount); 323 SPIRVId array = arrayResult.value(); 324 SPIRVId index = getResult(call.operands().get(2)).value(); 325 SPIRVId vectorIndex = nextId(); 326 spirvBlock.add(new SPIRVOpSDiv(getType("int"), vectorIndex, index, speciesResult.value())); 327 SPIRVId longIndex = nextId(); 328 spirvBlock.add(new SPIRVOpSConvert(getType("long"), longIndex, vectorIndex)); 329 SPIRVId vector = nextId(); 330 SPIRVMultipleOperands<SPIRVId> operands = new SPIRVMultipleOperands<>(longIndex, array, new SPIRVId(laneCount)); // TODO: lanes must be a literal 331 spirvBlock.add(new SPIRVOpExtInst(vType, vector, oclExtension, new SPIRVLiteralExtInstInteger(171, "vloadn"), operands)); 332 addResult(call.result(), new SpirvResult(vType, null, vector)); 333 } 334 else if (call.callDescriptor().equals(MethodRef.method(IntVector.class, "fromMemorySegment", IntVector.class, VectorSpecies.class, MemorySegment.class, long.class, ByteOrder.class)) 335 || call.callDescriptor().equals(MethodRef.method(FloatVector.class, "fromMemorySegment", FloatVector.class, VectorSpecies.class, MemorySegment.class, long.class, ByteOrder.class))) { 336 SPIRVId oclExtension = getId("oclExtension"); 337 SPIRVId species = getResult(call.operands().get(0)).value(); 338 SPIRVId lanesLong = nextId(); 339 spirvBlock.add(new SPIRVOpSConvert(getType("long"), lanesLong, species)); 340 int laneCount = 8; //TODO: remove hard code, vloadn instruction below needs a literal lane count, get value from env 341 SPIRVId segment = getResult(call.operands().get(1)).value(); 342 String vTypeName = ((ClassType)call.callDescriptor().refType()).toClassName(); 343 SPIRVId vType = spirvVectorType(vTypeName, laneCount); 344 SPIRVId temp = nextId(); 345 spirvBlock.add(new SPIRVOpConvertPtrToU(getType("long"), temp, segment)); 346 SPIRVId typedSegment = nextId(); 347 SPIRVId pointerType = (SPIRVId)map(x -> x.equals(vTypeName), "jdk.incubator.vector.IntVector", "jdk.incubator.vector.FloatVector", getType("ptrInt"), getType("ptrFloat")); 348 spirvBlock.add(new SPIRVOpConvertUToPtr(pointerType, typedSegment, temp)); 349 SPIRVId offset = getResult(call.operands().get(2)).value(); 350 SPIRVId vectorIndex = nextId(); 351 spirvBlock.add(new SPIRVOpSDiv(getType("long"), vectorIndex, offset, lanesLong)); // divide by lane count 352 SPIRVId finalIndex = nextId(); 353 SPIRVId vector = nextId(); 354 SPIRVMultipleOperands<SPIRVId> operands = new SPIRVMultipleOperands<>(vectorIndex, typedSegment, new SPIRVId(laneCount)); // TODO: lanes must be a literal 355 spirvBlock.add(new SPIRVOpExtInst(vType, vector, oclExtension, new SPIRVLiteralExtInstInteger(171, "vloadn"), operands)); 356 addResult(call.result(), new SpirvResult(vType, null, vector)); 357 } 358 else if (call.callDescriptor().equals(MethodRef.method(IntVector.class, "intoArray", void.class, int[].class, int.class)) 359 || call.callDescriptor().equals(MethodRef.method(FloatVector.class, "intoArray", void.class, float[].class, int.class))) { 360 SPIRVId oclExtension = getId("oclExtension"); 361 SpirvResult vectorResult = getResult(call.operands().get(0)); 362 SPIRVId vector = vectorResult.value(); 363 SPIRVId vectorType = vectorResult.type(); 364 SpirvResult arrayResult = getResult(call.operands().get(1)); 365 SPIRVId array = arrayResult.value(); 366 SPIRVId index = getResult(call.operands().get(2)).value(); 367 SPIRVId vectorIndex = nextId(); 368 spirvBlock.add(new SPIRVOpShiftRightArithmetic(getType("int"), vectorIndex, index, vectorExponent(vectorType.getName()))); 369 SPIRVId longIndex = nextId(); 370 spirvBlock.add(new SPIRVOpSConvert(getType("long"), longIndex, vectorIndex)); 371 SPIRVMultipleOperands<SPIRVId> operandsR = new SPIRVMultipleOperands<>(vector, longIndex, array); 372 spirvBlock.add(new SPIRVOpExtInst(getType("void"), nextId(), oclExtension, new SPIRVLiteralExtInstInteger(172, "vstoren"), operandsR)); 373 } 374 else if (call.callDescriptor().equals(MethodRef.method(IntVector.class, "intoMemorySegment", void.class, MemorySegment.class, long.class, ByteOrder.class)) 375 || call.callDescriptor().equals(MethodRef.method(FloatVector.class, "intoMemorySegment", void.class, MemorySegment.class, long.class, ByteOrder.class))) { 376 SPIRVId oclExtension = getId("oclExtension"); 377 SpirvResult vectorResult = getResult(call.operands().get(0)); 378 SPIRVId vector = vectorResult.value(); 379 SPIRVId vectorType = vectorResult.type(); 380 SpirvResult segmentResult = getResult(call.operands().get(1));; 381 SPIRVId segment = segmentResult.value(); 382 SPIRVId temp = nextId(); 383 spirvBlock.add(new SPIRVOpConvertPtrToU(getType("long"), temp, segment)); 384 SPIRVId typedSegment = nextId(); 385 String vectorElementType = vectorElementType(vectorType).getName(); 386 SPIRVId pointerType = (SPIRVId)map(x -> x.equals(vectorElementType), "int", "float", getType("ptrInt"), getType("ptrFloat")); 387 spirvBlock.add(new SPIRVOpConvertUToPtr(pointerType, typedSegment, temp)); 388 SPIRVId offset = getResult(call.operands().get(2)).value(); 389 SPIRVId vectorIndex = nextId(); 390 int laneCount = laneCount(vectorType.getName()); 391 spirvBlock.add(new SPIRVOpShiftRightArithmetic(getType("long"), vectorIndex, offset, vectorExponent(vectorType.getName()))); 392 SPIRVMultipleOperands<SPIRVId> operandsR = new SPIRVMultipleOperands<>(vector, vectorIndex, typedSegment); 393 spirvBlock.add(new SPIRVOpExtInst(getId("void"), nextId(), oclExtension, new SPIRVLiteralExtInstInteger(172, "vstoren"), operandsR)); 394 } 395 else if (call.callDescriptor().equals(MethodRef.method(IntVector.class, "reduceLanes", int.class, VectorOperators.Associative.class)) 396 || call.callDescriptor().equals(MethodRef.method(FloatVector.class, "reduceLanes", float.class, VectorOperators.Associative.class))) { 397 SpirvResult vectorResult = getResult(call.operands().get(0)); 398 SPIRVId vectorType = vectorResult.type(); 399 SPIRVId vector = vectorResult.value(); 400 String vTypeName = vectorType.getName(); 401 SPIRVId elementType = vectorElementType(vectorType); 402 Op reduceOp = ((Op.Result)call.operands().get(1)).op(); 403 if (reduceOp instanceof SpirvOps.FieldLoadOp flo) { 404 assert flo.fieldDescriptor().refType().equals(JavaType.type(VectorOperators.class)); 405 assert flo.fieldDescriptor().name().equals("ADD"); 406 String operation = flo.fieldDescriptor().name(); 407 } 408 else unsupported("operation expression", reduceOp.toText()); 409 String tempTag = nextTempTag(); 410 SPIRVId temp_0 = nextId(tempTag + 0); 411 spirvBlock.add(new SPIRVOpCompositeExtract(elementType, temp_0, vector, new SPIRVMultipleOperands<>(new SPIRVLiteralInteger(0)))); 412 for (int lane = 1; lane < laneCount(vectorType.getName()); lane++) { 413 SPIRVId temp = nextId(tempTag + lane); 414 SPIRVId element = nextId(); 415 spirvBlock.add(new SPIRVOpCompositeExtract(elementType, element, vector, new SPIRVMultipleOperands<>(new SPIRVLiteralInteger(lane)))); 416 if (elementType == getType("int")) { 417 spirvBlock.add(new SPIRVOpIAdd(elementType, temp, getId(tempTag + (lane - 1)), element)); 418 } 419 else if (elementType == getType("float")) { 420 spirvBlock.add(new SPIRVOpFAdd(elementType, temp, getId(tempTag + (lane - 1)), element)); 421 } 422 else unsupported("type", elementType.getName()); 423 } 424 addResult(call.result(), new SpirvResult(elementType, null, getId(tempTag + (laneCount(vectorType.getName()) - 1)))); 425 } 426 else if (call.callDescriptor().equals(MethodRef.method(IntVector.class, "add", IntVector.class, Vector.class)) 427 || call.callDescriptor().equals(MethodRef.method(IntVector.class, "mul", IntVector.class, Vector.class)) 428 || call.callDescriptor().equals(MethodRef.method(FloatVector.class, "add", FloatVector.class, Vector.class)) 429 || call.callDescriptor().equals(MethodRef.method(FloatVector.class, "mul", FloatVector.class, Vector.class))) { 430 SPIRVId oclExtension = getId("oclExtension"); 431 SpirvResult lhsResult = getResult(call.operands().get(0)); 432 SPIRVId lhsType = lhsResult.type(); 433 SPIRVId lhs = lhsResult.value(); 434 SPIRVId rhs = getResult(call.operands().get(1)).value(); 435 SPIRVId add = nextId(); 436 if (call.callDescriptor().name().equals("add")) { 437 spirvBlock.add(lhsType.getName().endsWith("int") ? new SPIRVOpIAdd(lhsType, add, lhs, rhs) : new SPIRVOpFAdd(lhsType, add, lhs, rhs)); 438 } 439 else if (call.callDescriptor().name().equals("mul")) { 440 spirvBlock.add(lhsType.getName().endsWith("int") ? new SPIRVOpIMul(lhsType, add, lhs, rhs) : new SPIRVOpFMul(lhsType, add, lhs, rhs)); 441 } 442 addResult(call.result(), new SpirvResult(lhsType, null, add)); 443 } 444 else if (call.callDescriptor().equals(MethodRef.method(FloatVector.class, "fma", FloatVector.class, Vector.class, Vector.class))) { 445 SPIRVId oclExtension = getId("oclExtension"); 446 SpirvResult aResult = getResult(call.operands().get(0)); 447 SPIRVId vType = aResult.type(); 448 SPIRVId a = aResult.value(); 449 SPIRVId b = getResult(call.operands().get(1)).value(); 450 SPIRVId c = getResult(call.operands().get(2)).value(); 451 String vTypeStr = vType.getName(); 452 assert vTypeStr.endsWith("float"); 453 SPIRVId result = nextId(); 454 SPIRVMultipleOperands<SPIRVId> operands = new SPIRVMultipleOperands<>(a, b, c); 455 spirvBlock.add(new SPIRVOpExtInst(vType, result, oclExtension, new SPIRVLiteralExtInstInteger(26, "fma"), operands)); 456 addResult(call.result(), new SpirvResult(vType, null, result)); 457 } 458 else if (call.callDescriptor().equals(MethodRef.method(IntVector.class, "zero", IntVector.class, VectorSpecies.class)) 459 || call.callDescriptor().equals(MethodRef.method(FloatVector.class, "zero", FloatVector.class, VectorSpecies.class))) { 460 SpirvResult speciesResult = getResult(call.operands().get(0)); 461 SPIRVId vType = spirvType(((ClassType)call.callDescriptor().refType()).toClassName()); 462 String elementType = vectorElementType(vType).getName(); 463 SPIRVId value = getId(elementType + "_ZERO"); 464 int laneCount = laneCount(vType.getName()); 465 assert laneCount == 8 || laneCount == 16; 466 SPIRVId vector = nextId(); 467 SPIRVMultipleOperands<SPIRVId> operands = spirvOperands(value, laneCount); 468 spirvBlock.add(new SPIRVOpCompositeConstruct(vType, vector, operands)); 469 addResult(call.result(), new SpirvResult(vType, null, vector)); 470 } 471 else if (call.callDescriptor().equals(MethodRef.method(IntVector.class, "lane", int.class, int.class)) 472 || call.callDescriptor().equals(MethodRef.method(FloatVector.class, "lane", float.class, int.class))) { 473 SpirvResult lhsResult = getResult(call.operands().get(0)); 474 SPIRVId lhsType = lhsResult.type(); 475 SPIRVId lhs = lhsResult.value(); 476 String vTypeStr = lhsType.getName(); 477 SPIRVId vType = lhsResult.type(); 478 SPIRVId elementType = vectorElementType(vType); 479 SPIRVId result = nextId(); 480 Op laneOp = ((Op.Result)call.operands().get(1)).op(); 481 assert laneOp instanceof SpirvOps.ConstantOp; 482 int lane = (int)((SpirvOps.ConstantOp)laneOp).value(); 483 spirvBlock.add(new SPIRVOpCompositeExtract(elementType, result, lhsResult.value(), new SPIRVMultipleOperands<>(new SPIRVLiteralInteger(lane)))); 484 addResult(call.result(), new SpirvResult(elementType, null, result)); 485 } 486 else if (call.callDescriptor().equals(MethodRef.method(VectorSpecies.class, "length", int.class))) { 487 addResult(call.result(), new SpirvResult(getType("int"), null, getConst("int_EIGHT"))); // TODO: remove hardcode 488 } 489 else unsupported("method", call.callDescriptor()); 490 } 491 case SpirvOps.ConstantOp cop -> { 492 SPIRVId type = spirvType(cop.resultType().toString()); 493 SPIRVId result = nextId(); 494 Object value = cop.value(); 495 if (type == getType("int")) { 496 module.add(new SPIRVOpConstant(type, result, new SPIRVContextDependentInt(new BigInteger(String.valueOf(value))))); 497 } 498 else if (type == getType("long")) { 499 module.add(new SPIRVOpConstant(type, result, new SPIRVContextDependentLong(new BigInteger(String.valueOf(value))))); 500 } 501 else if (type == getType("float")) { 502 module.add(new SPIRVOpConstant(type, result, new SPIRVContextDependentFloat((float)value))); 503 } 504 else unsupported("type", cop.resultType()); 505 addResult(cop.result(), new SpirvResult(type, null, result)); 506 } 507 case SpirvOps.ConvertOp scop -> { 508 SPIRVId toType = spirvType(scop.resultType().toString()); 509 SPIRVId to = nextId(); 510 SpirvResult valueResult = getResult(scop.operands().get(0)); 511 SPIRVId from = valueResult.value(); 512 SPIRVId fromType = valueResult.type(); 513 if (isIntegerType(fromType)) { 514 if (isIntegerType(toType)) { 515 spirvBlock.add(new SPIRVOpSConvert(toType, to, from)); 516 } 517 else if (isFloatType(toType)) { 518 spirvBlock.add(new SPIRVOpConvertSToF(toType, to, from)); 519 } 520 else unsupported("conversion type", scop.resultType()); 521 } 522 else unsupported("conversion type", scop.operands().get(0)); 523 addResult(scop.result(), new SpirvResult(toType, null, to)); 524 } 525 case SpirvOps.InBoundAccessChainOp iacop -> { 526 SPIRVId type = spirvType(iacop.resultType().toString()); 527 SPIRVId result = nextId(); 528 SPIRVId object = getResult(iacop.operands().get(0)).value(); 529 SPIRVId index = getResult(iacop.operands().get(1)).value(); 530 spirvBlock.add(new SPIRVOpInBoundsPtrAccessChain(type, result, object, index, new SPIRVMultipleOperands<>())); 531 addResult(iacop.result(), new SpirvResult(type, result, null)); 532 } 533 case SpirvOps.FieldLoadOp flo -> { 534 if (flo.operands().size() > 0 && flo.operands().get(0).type().equals(JavaType.ofString("spirvdemo.GPU$Index"))) { 535 SpirvResult result; 536 int group = -1; 537 int index = -1; 538 String fieldName = flo.fieldDescriptor().name(); 539 switch(fieldName) { 540 case "x": group = 0; index = 0; break; 541 case "y": group = 0; index = 1; break; 542 case "z": group = 0; index = 2; break; 543 case "w": group = 1; index = 0; break; 544 case "h": group = 1; index = 1; break; 545 case "d": group = 1; index = 2; break; 546 } 547 switch (group) { 548 case 0: result = globalId(index, spirvBlock); break; 549 case 1: result = globalSize(index, spirvBlock); break; 550 default: throw new RuntimeException("Unknown Index field: " + fieldName); 551 } 552 addResult(flo.result(), result); 553 } 554 else if (((JavaType)flo.resultType()).equals(JavaType.type(VectorSpecies.class))) { 555 addResult(flo.result(), new SpirvResult(getType("int"), null, getConst("int_EIGHT"))); 556 } 557 else if (flo.fieldDescriptor().refType().equals(JavaType.type(VectorOperators.class))) { 558 // currently ignored 559 } 560 else if (flo.fieldDescriptor().refType().equals(JavaType.type(ByteOrder.class))) { 561 // currently ignored 562 } 563 else unsupported("field load", ((ClassType)flo.fieldDescriptor().refType()).toClassName() + "." + flo.fieldDescriptor().name()); 564 } 565 case SpirvOps.BranchOp bop -> { 566 SPIRVId trueLabel = symbols.getLabel(bop.branch()).getResultId(); 567 spirvBlock.add(new SPIRVOpBranch(trueLabel)); 568 } 569 case SpirvOps.ConditionalBranchOp cbop -> { 570 SPIRVId test = getResult(cbop.operands().get(0)).value(); 571 SPIRVId trueLabel = symbols.getLabel(cbop.trueBranch()).getResultId(); 572 SPIRVId falseLabel = symbols.getLabel(cbop.falseBranch()).getResultId(); 573 spirvBlock.add(new SPIRVOpBranchConditional(test, trueLabel, falseLabel, new SPIRVMultipleOperands<SPIRVLiteralInteger>())); 574 } 575 case SpirvOps.LtOp ltop -> { 576 SPIRVId lhs = getResult(ltop.operands().get(0)).value(); 577 SPIRVId rhs = getResult(ltop.operands().get(1)).value(); 578 SPIRVId boolType = getType("bool"); 579 SPIRVId result = nextId(); 580 spirvBlock.add(new SPIRVOpSLessThan(boolType, result, lhs, rhs)); 581 addResult(ltop.result(), new SpirvResult(boolType, null, result)); 582 } 583 case SpirvOps.ReturnOp rop -> { 584 if (rop.operands().size() == 0) { 585 spirvBlock.add(new SPIRVOpReturn()); 586 } 587 else { 588 SPIRVId returnValue = getResult(rop.operands().get(0)).value(); 589 spirvBlock.add(new SPIRVOpReturnValue(returnValue)); 590 } 591 } 592 default -> unsupported("op", op.getClass()); 593 } 594 } 595 } 596 } 597 598 private void initModule() { 599 module.add(new SPIRVOpCapability(SPIRVCapability.Addresses())); 600 module.add(new SPIRVOpCapability(SPIRVCapability.Linkage())); 601 module.add(new SPIRVOpCapability(SPIRVCapability.Kernel())); 602 module.add(new SPIRVOpCapability(SPIRVCapability.Int8())); 603 module.add(new SPIRVOpCapability(SPIRVCapability.Int16())); 604 module.add(new SPIRVOpCapability(SPIRVCapability.Int64())); 605 module.add(new SPIRVOpCapability(SPIRVCapability.Vector16())); 606 module.add(new SPIRVOpCapability(SPIRVCapability.Float16())); 607 module.add(new SPIRVOpMemoryModel(SPIRVAddressingModel.Physical64(), SPIRVMemoryModel.OpenCL())); 608 609 // OpenCL extension provides built-in variables suitable for kernel programming 610 // Import extention and declare fourn variables 611 SPIRVId oclExtension = nextId("oclExtension"); 612 module.add(new SPIRVOpExtInstImport(oclExtension, new SPIRVLiteralString("OpenCL.std"))); 613 614 SPIRVId globalInvocationId = nextId("globalInvocationId"); 615 SPIRVId globalSize = nextId("globalSize"); 616 SPIRVId subgroupSize = nextId("subgroupSize"); 617 SPIRVId subgroupId = nextId("subgroupId"); 618 619 module.add(new SPIRVOpDecorate(globalInvocationId, SPIRVDecoration.BuiltIn(SPIRVBuiltIn.GlobalInvocationId()))); 620 module.add(new SPIRVOpDecorate(globalInvocationId, SPIRVDecoration.Constant())); 621 module.add(new SPIRVOpDecorate(globalInvocationId, SPIRVDecoration.LinkageAttributes(new SPIRVLiteralString("spirv_BuiltInGlobalInvocationId"), SPIRVLinkageType.Import()))); 622 module.add(new SPIRVOpDecorate(globalSize, SPIRVDecoration.BuiltIn(SPIRVBuiltIn.GlobalSize()))); 623 module.add(new SPIRVOpDecorate(globalSize, SPIRVDecoration.Constant())); 624 module.add(new SPIRVOpDecorate(globalSize, SPIRVDecoration.LinkageAttributes(new SPIRVLiteralString("spirv_BuiltInGlobalSize"), SPIRVLinkageType.Import()))); 625 module.add(new SPIRVOpDecorate(subgroupSize, SPIRVDecoration.BuiltIn(SPIRVBuiltIn.SubgroupSize()))); 626 module.add(new SPIRVOpDecorate(subgroupSize, SPIRVDecoration.Constant())); 627 module.add(new SPIRVOpDecorate(subgroupSize, SPIRVDecoration.LinkageAttributes(new SPIRVLiteralString("spirv_BuiltInSubgroupSize"), SPIRVLinkageType.Import()))); 628 module.add(new SPIRVOpDecorate(subgroupId, SPIRVDecoration.BuiltIn(SPIRVBuiltIn.SubgroupId()))); 629 module.add(new SPIRVOpDecorate(subgroupId, SPIRVDecoration.Constant())); 630 module.add(new SPIRVOpDecorate(subgroupId, SPIRVDecoration.LinkageAttributes(new SPIRVLiteralString("spirv_BuiltInSubgroupId"), SPIRVLinkageType.Import()))); 631 632 module.add(new SPIRVOpVariable(getType("ptrV3long"), globalInvocationId, SPIRVStorageClass.Input(), new SPIRVOptionalOperand<>())); 633 module.add(new SPIRVOpVariable(getType("ptrV3long"), globalSize, SPIRVStorageClass.Input(), new SPIRVOptionalOperand<>())); 634 module.add(new SPIRVOpVariable(getType("ptrV3long"), subgroupSize, SPIRVStorageClass.Input(), new SPIRVOptionalOperand<>())); 635 module.add(new SPIRVOpVariable(getType("ptrV3long"), subgroupId, SPIRVStorageClass.Input(), new SPIRVOptionalOperand<>())); 636 } 637 638 private SPIRVId spirvType(String javaType) { 639 SPIRVId ans = switch(javaType) { 640 case "byte" -> getType("byte"); 641 case "short" -> getType("short"); 642 case "int" -> getType("int"); 643 case "long" -> getType("long"); 644 case "float" -> getType("float"); 645 case "double" -> getType("double"); 646 case "int[]" -> getType("int[]"); 647 case "float[]" -> getType("float[]"); 648 case "double[]" -> getType("double[]"); 649 case "long[]" -> getType("long[]"); 650 case "bool" -> getType("bool"); 651 case "spirvdemo.IntArray" -> getType("int[]"); 652 case "spirvdemo.FloatArray" -> getType("float[]"); 653 case "jdk.incubator.vector.IntVector" -> spirvVectorType("IntVector", 8); 654 case "jdk.incubator.vector.FloatVector" -> spirvVectorType("FloatVector", 8); 655 case "jdk.incubator.vector.VectorSpecies<java.lang.Integer>" -> getType("int"); 656 case "jdk.incubator.vector.VectorSpecies<java.lang.Long>" -> getType("long"); 657 case "jdk.incubator.vector.VectorSpecies<java.lang.Float>" -> getType("int"); 658 case "VectorSpecies" -> getType("int"); 659 case "void" -> getType("void"); 660 case "spirvdemo.GPU$Index" -> getType("ptrGPUIndex"); 661 case "java.lang.foreign.MemorySegment" -> getType("ptrByte"); 662 default -> null; 663 }; 664 if (ans == null) unsupported("type", javaType); 665 return ans; 666 } 667 668 private SPIRVId spirvElementType(String javaType) { 669 SPIRVId ans = switch(javaType) { 670 case "byte[]" -> getType("byte"); 671 case "short[]" -> getType("short"); 672 case "int[]" -> getType("int"); 673 case "long[]" -> getType("long"); 674 case "float[]" -> getType("float"); 675 case "double[]" -> getType("double"); 676 case "boolean[]" -> getType("bool"); 677 case "spirvdemo.IntArray" -> getType("int"); 678 case "spirvdemo.FloatArray" -> getType("float"); 679 case "jdk.incubator.vector.LongVector" -> getType("long"); 680 case "jdk.incubator.vector.FloatVector" -> getType("float"); 681 case "IntVector" -> getType("int"); 682 case "LongVector" -> getType("long"); 683 case "FloatVector" -> getType("float"); 684 case "java.lang.foreign.MemorySegment" -> getType("byte"); 685 default -> null; 686 }; 687 if (ans == null) unsupported("type", javaType); 688 return ans; 689 } 690 691 private SPIRVId vectorElementType(SPIRVId type) { 692 SPIRVId ans = switch(type.getName()) { 693 case "v8int" -> getType("int"); 694 case "v16int" -> getType("int"); 695 case "v8long" -> getType("long"); 696 case "v8float" -> getType("float"); 697 case "v16float" -> getType("float"); 698 default -> null; 699 }; 700 if (ans == null) unsupported("type", type.getName()); 701 return ans; 702 } 703 704 private SPIRVId spirvVariableType(SPIRVId spirvType) { 705 SPIRVId ans = switch(spirvType.getName()) { 706 case "byte" -> getType("ptrByte"); 707 case "short" -> getType("ptrShort"); 708 case "int" -> getType("ptrInt"); 709 case "long" -> getType("ptrLong"); 710 case "float" -> getType("ptrFloat"); 711 case "double" -> getType("ptrDouble"); 712 case "boolean" -> getType("ptrBool"); 713 case "int[]" -> getType("ptrInt[]"); 714 case "long[]" -> getType("ptrLong[]"); 715 case "float[]" -> getType("ptrFloat[]"); 716 case "double[]" -> getType("ptrDouble[]"); 717 case "v8int" -> getType("ptrV8int"); 718 case "v16int" -> getType("ptrV16int"); 719 case "v8long" -> getType("ptrV8long"); 720 case "v8float" -> getType("ptrV8float"); 721 case "v16float" -> getType("ptrV16float"); 722 case "ptrGPUIndex" -> getType("ptrPtrGPUIndex"); 723 case "ptrByte" -> getType("ptrPtrByte"); 724 default -> null; 725 }; 726 if (ans == null) unsupported("type", spirvType.getName()); 727 return ans; 728 } 729 730 private SPIRVId spirvVectorType(String javaVectorType, int vectorLength) { 731 String prefix = "v" + vectorLength; 732 String elementType = spirvElementType(javaVectorType).getName(); 733 return getType(prefix + elementType); 734 } 735 736 private int alignment(String spirvType) { 737 int ans = switch(spirvType) { 738 case "byte" -> 1; 739 case "short" -> 2; 740 case "int" -> 4; 741 case "long" -> 8; 742 case "float" -> 4; 743 case "double" -> 8; 744 case "boolean" -> 1; 745 case "v8int" -> 32; 746 case "v16int" -> 64; 747 case "v8long" -> 64; 748 case "v8float" -> 32; 749 case "v16float" -> 64; 750 case "ptrGPUIndex" -> 32; 751 case "int[]" -> 8; 752 case "long[]" -> 8; 753 case "float[]" -> 8; 754 case "double[]" -> 8; 755 case "ptrByte" -> 8; 756 case "ptrInt" -> 8; 757 case "ptrInt[]" -> 8; 758 case "ptrLong" -> 8; 759 case "ptrLong[]" -> 8; 760 case "ptrFloat" -> 8; 761 case "ptrFloat[]" -> 8; 762 case "ptrV8int" -> 8; 763 case "ptrV8float" -> 8; 764 case "ptrPtrGPUIndex" -> 8; 765 default -> 0; 766 }; 767 if (ans == 0) unsupported("type", spirvType); 768 return ans; 769 } 770 771 private int laneCount(String vectorType) { 772 int ans = switch(vectorType) { 773 case "v8int" -> 8; 774 case "v8long" -> 8; 775 case "v8float" -> 8; 776 case "v16int" -> 16; 777 case "v16float" -> 16; 778 default -> 0; 779 }; 780 if (ans == 0) unsupported("type", vectorType); 781 return ans; 782 } 783 784 private SPIRVId vectorExponent(String vectorType) { 785 SPIRVId ans = null; 786 switch(vectorType) { 787 case "v8int" -> ans = getId("int_THREE"); 788 case "v8long" -> ans = getId("int_THREE"); 789 case "v8float" -> ans = getId("int_THREE"); 790 case "v16int" -> ans = getId("int_FOUR"); 791 case "v16float" -> ans = getId("int_FOUR"); 792 default -> unsupported("type", vectorType); 793 }; 794 return ans; 795 } 796 797 private Set<String> moduleTypes = new HashSet<>(); 798 799 private SPIRVId getType(String name) { 800 if (!moduleTypes.contains(name)) { 801 switch (name) { 802 case "void" -> module.add(new SPIRVOpTypeVoid(nextId(name))); 803 case "bool" -> module.add(new SPIRVOpTypeBool(nextId(name))); 804 case "byte" -> module.add(new SPIRVOpTypeInt(nextId(name), new SPIRVLiteralInteger(8), new SPIRVLiteralInteger(0))); 805 case "short" -> module.add(new SPIRVOpTypeInt(nextId(name), new SPIRVLiteralInteger(16), new SPIRVLiteralInteger(0))); 806 case "int" -> module.add(new SPIRVOpTypeInt(nextId(name), new SPIRVLiteralInteger(32), new SPIRVLiteralInteger(0))); 807 case "long" -> module.add(new SPIRVOpTypeInt(nextId(name), new SPIRVLiteralInteger(64), new SPIRVLiteralInteger(0))); 808 case "float" -> module.add(new SPIRVOpTypeFloat(nextId(name), new SPIRVLiteralInteger(32))); 809 case "double" -> module.add(new SPIRVOpTypeFloat(nextId(name), new SPIRVLiteralInteger(64))); 810 case "ptrByte" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.CrossWorkgroup(), getType("byte"))); 811 case "ptrInt" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("int"))); 812 case "ptrLong" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("long"))); 813 case "ptrFloat" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("float"))); 814 case "short[]" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.CrossWorkgroup(), getType("short"))); 815 case "int[]" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.CrossWorkgroup(), getType("int"))); 816 case "long[]" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.CrossWorkgroup(), getType("long"))); 817 case "float[]" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.CrossWorkgroup(), getType("float"))); 818 case "double[]" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.CrossWorkgroup(), getType("double"))); 819 case "boolean[]" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.CrossWorkgroup(), getType("boolean"))); 820 case "ptrInt[]" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("int[]"))); 821 case "ptrLong[]" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("long[]"))); 822 case "ptrFloat[]" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("float[]"))); 823 case "spirvdemo.GPUIndex" -> module.add(new SPIRVOpTypeStruct(nextId(name), new SPIRVMultipleOperands<>(getType("long"), getType("long"), getType("long")))); 824 case "ptrGPUIndex" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("spirvdemo.GPUIndex"))); 825 case "ptrCrossGroupByte"-> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.CrossWorkgroup(), getType("byte"))); 826 case "ptrPtrGPUIndex" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("ptrGPUIndex"))); 827 case "ptrPtrByte" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("ptrByte"))); 828 case "v3long" -> module.add(new SPIRVOpTypeVector(nextId(name), getType("long"), new SPIRVLiteralInteger(3))); 829 case "v8int" -> module.add(new SPIRVOpTypeVector(nextId(name), getType("int"), new SPIRVLiteralInteger(8))); 830 case "v8long" -> module.add(new SPIRVOpTypeVector(nextId(name), getType("long"), new SPIRVLiteralInteger(8))); 831 case "v16int" -> module.add(new SPIRVOpTypeVector(nextId(name), getType("int"), new SPIRVLiteralInteger(16))); 832 case "v8float" -> module.add(new SPIRVOpTypeVector(nextId(name), getType("float"), new SPIRVLiteralInteger(8))); 833 case "v16float" -> module.add(new SPIRVOpTypeVector(nextId(name), getType("float"), new SPIRVLiteralInteger(16))); 834 case "ptrV3long" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Input(), getType("v3long"))); 835 case "ptrV8long" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("v8long"))); 836 case "ptrV8int" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("v8int"))); 837 case "ptrV16int" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("v16int"))); 838 case "ptrV8float" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("v8float"))); 839 case "ptrV16float" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("v16float"))); 840 case "ptrPtrV8int" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("ptrV8int"))); 841 case "ptrPtrV16int" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("ptrV16int"))); 842 case "ptrPtrV8float" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("ptrV8float"))); 843 case "ptrPtrV16float" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("ptrV16float"))); 844 default -> unsupported("type", name); 845 } 846 moduleTypes.add(name); 847 } 848 return getId(name); 849 } 850 851 private Set<String> moduleConstants = new HashSet<>(); 852 853 private SPIRVId getConst(String name) { 854 if (!moduleConstants.contains(name)) { 855 switch (name) { 856 case "int_ZERO" -> module.add(new SPIRVOpConstant(getType("int"), nextId("int_ZERO"), new SPIRVContextDependentInt(new BigInteger("0")))); 857 case "int_ONE" -> module.add(new SPIRVOpConstant(getType("int"), nextId("int_ONE"), new SPIRVContextDependentInt(new BigInteger("1")))); 858 case "int_TWO" -> module.add(new SPIRVOpConstant(getType("int"), nextId("int_TWO"), new SPIRVContextDependentInt(new BigInteger("2")))); 859 case "int_EIGHT" -> module.add(new SPIRVOpConstant(getType("int"), nextId("int_EIGHT"), new SPIRVContextDependentInt(new BigInteger("8")))); 860 default -> unsupported("constant", name); 861 } 862 moduleConstants.add(name); 863 } 864 return getId(name); 865 } 866 867 private SPIRVOptionalOperand<SPIRVMemoryAccess> align(int align) { 868 return new SPIRVOptionalOperand<>(SPIRVMemoryAccess.Aligned(new SPIRVLiteralInteger(align))); 869 } 870 871 private SPIRVOptionalOperand<SPIRVMemoryAccess> align(String type) { 872 return align(alignment(type)); 873 } 874 875 private SPIRVMultipleOperands<SPIRVId> spirvOperands(SPIRVId value, int count) { 876 SPIRVId[] values = new SPIRVId[count]; 877 Arrays.fill(values, value); 878 return new SPIRVMultipleOperands<>(values); 879 } 880 881 private SPIRVOptionalOperand<SPIRVMemoryAccess> none() { 882 return new SPIRVOptionalOperand<>(); 883 } 884 885 private SpirvResult globalSize(int index, SPIRVBlock spirvBlock) { 886 SPIRVId longType = getType("long"); 887 SPIRVId v3long = getId("v3long"); 888 SPIRVId globalSizeId = getId("globalSize"); 889 SPIRVId globalSizes = nextId(); 890 spirvBlock.add(new SPIRVOpLoad(v3long, globalSizes, globalSizeId, align(32))); 891 SPIRVId globalSize = nextId(); 892 spirvBlock.add(new SPIRVOpCompositeExtract(longType, globalSize, globalSizes, new SPIRVMultipleOperands<>(new SPIRVLiteralInteger(index)))); 893 return new SpirvResult(longType, null, globalSize); 894 } 895 896 private SpirvResult globalId(int index, SPIRVBlock spirvBlock) { 897 SPIRVId longType = getType("long"); 898 SPIRVId v3long = getId("v3long"); 899 SPIRVId globalInvocationId = getId("globalInvocationId"); 900 SPIRVId globalIds = nextId(); 901 spirvBlock.add(new SPIRVOpLoad(v3long, globalIds, globalInvocationId, align(32))); 902 SPIRVId globalIndex = nextId(); 903 spirvBlock.add(new SPIRVOpCompositeExtract(longType, globalIndex, globalIds, new SPIRVMultipleOperands<>(new SPIRVLiteralInteger(index)))); 904 return new SpirvResult(longType, null, globalIndex); 905 } 906 907 private SPIRVId nextId() { 908 return module.getNextId(); 909 } 910 911 private SPIRVId nextId(String name) { 912 SPIRVId ans = nextId(); 913 ans.setName(name); 914 symbols.putId(name, ans); 915 module.add(new SPIRVOpName(ans, new SPIRVLiteralString(name))); 916 return ans; 917 } 918 919 private static int counter = 0; 920 921 private String nextTempTag() { 922 counter++; 923 return "temp_" + counter + "_"; 924 } 925 926 private boolean isIntegerType(SPIRVId type) { 927 String name = type.getName(); 928 return name.equals("short") || name.equals("int") || name.equals("long"); 929 } 930 931 private boolean isFloatType(SPIRVId type) { 932 String name = type.getName(); 933 return name.equals("float") || name.equals("double"); 934 } 935 936 private boolean isVectorSpecies(String javaType) { 937 return javaType.equals("VectorSpecies"); 938 } 939 940 private boolean isVectorType(String javaType) { 941 return javaType.equals("IntVector") || javaType.equals("FloatVector"); 942 } 943 944 private void addId(String name, SPIRVId id) { 945 symbols.putId(name, id); 946 } 947 948 private SPIRVId getId(String name) { 949 SPIRVId ans = symbols.getId(name); 950 assert ans != null : name + " not found"; 951 return ans; 952 } 953 954 private SPIRVId getIdOrNull(String name) { 955 return symbols.getId(name); 956 } 957 958 private static Object map(Function<Object, Boolean> test, Object... args) { 959 int len = args.length; 960 assert len >= 2 && len % 2 == 0; 961 int pairs = len / 2; 962 for (int i = 0; i < pairs; i++) { 963 if (test.apply(args[i])) return args[i + pairs]; 964 } 965 throw new RuntimeException("No match: " + args[0]); 966 } 967 968 private void unsupported(String message, Object value) { 969 throw new RuntimeException("Unsupported " + message + ": " + value); 970 } 971 972 private void addResult(Value value, SpirvResult result) { 973 assert symbols.getResult(value) == null : "result already present"; 974 symbols.putResult(value, result); 975 } 976 977 private SpirvResult getResult(Value value) { 978 return symbols.getResult(value); 979 } 980 981 private static class Symbols { 982 private final HashMap<Value, SpirvResult> results; 983 private final HashMap<String, SPIRVId> ids; 984 private final HashMap<Block, SPIRVBlock> blocks; 985 private final HashMap<Block, SPIRVOpLabel> labels; 986 987 public Symbols() { 988 this.results = new HashMap<>(); 989 this.ids = new HashMap<>(); 990 this.blocks = new HashMap<>(); 991 this.labels = new HashMap<>(); 992 } 993 994 public void putId(String name, SPIRVId id) { 995 ids.put(name, id); 996 } 997 998 public SPIRVId getId(String name) { 999 return ids.get(name); 1000 } 1001 1002 public void putBlock(Block block, SPIRVBlock spirvBlock) { 1003 blocks.put(block, spirvBlock); 1004 } 1005 1006 public SPIRVBlock getBlock(Block block) { 1007 return blocks.get(block); 1008 } 1009 1010 public void putLabel(Block block, SPIRVOpLabel spirvLabel) { 1011 labels.put(block, spirvLabel); 1012 } 1013 1014 public SPIRVOpLabel getLabel(Block block) { 1015 return labels.get(block); 1016 } 1017 1018 public void putResult(Value value, SpirvResult result) { 1019 results.put(value, result); 1020 } 1021 1022 public SpirvResult getResult(Value value) { 1023 return results.get(value); 1024 } 1025 1026 public String toString() { 1027 return String.format("results %s\n\nids %s\n\nblocks %s\nlabels %s\n", results.keySet(), ids.keySet(), blocks.keySet(), labels.keySet()); 1028 } 1029 } 1030 }