1 /*
   2  * Copyright (c) 2024 Intel Corporation. All rights reserved.
   3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
   4  *
   5  * This code is free software; you can redistribute it and/or modify it
   6  * under the terms of the GNU General Public License version 2 only, as
   7  * published by the Free Software Foundation.  Oracle designates this
   8  * particular file as subject to the "Classpath" exception as provided
   9  * by Oracle in the LICENSE file that accompanied this code.
  10  *
  11  * This code is distributed in the hope that it will be useful, but WITHOUT
  12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
  13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
  14  * version 2 for more details (a copy is included in the LICENSE file that
  15  * accompanied this code).
  16  *
  17  * You should have received a copy of the GNU General Public License version
  18  * 2 along with this work; if not, write to the Free Software Foundation,
  19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
  20  *
  21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
  22  * or visit www.oracle.com if you need additional information or have any
  23  * questions.
  24  */
  25 
  26 package intel.code.spirv;
  27 
  28 import java.util.List;
  29 import java.util.Arrays;
  30 import java.util.HashMap;
  31 import java.util.Set;
  32 import java.util.HashSet;
  33 import java.util.function.Function;
  34 import java.io.IOException;
  35 import java.io.File;
  36 import java.io.FileOutputStream;
  37 import java.io.ByteArrayInputStream;
  38 import java.io.ByteArrayOutputStream;
  39 import java.io.PrintStream;
  40 import java.lang.constant.ClassDesc;
  41 import java.nio.ByteBuffer;
  42 import java.nio.ByteOrder;
  43 import java.nio.channels.FileChannel;
  44 import java.math.BigInteger;
  45 
  46 import jdk.incubator.code.dialect.core.CoreOp;
  47 import jdk.incubator.code.dialect.java.ClassType;
  48 import jdk.incubator.code.dialect.java.JavaType;
  49 import jdk.incubator.code.dialect.java.MethodRef;
  50 import jdk.incubator.vector.VectorSpecies;
  51 import jdk.incubator.vector.VectorOperators;
  52 import jdk.incubator.vector.Vector;
  53 import jdk.incubator.vector.IntVector;
  54 import jdk.incubator.vector.FloatVector;
  55 import java.lang.foreign.MemorySegment;
  56 import java.lang.foreign.ValueLayout;
  57 import jdk.incubator.code.Block;
  58 import jdk.incubator.code.Body;
  59 import jdk.incubator.code.Op;
  60 import jdk.incubator.code.Value;
  61 import jdk.incubator.code.TypeElement;
  62 import uk.ac.manchester.beehivespirvtoolkit.lib.SPIRVHeader;
  63 import uk.ac.manchester.beehivespirvtoolkit.lib.SPIRVModule;
  64 import uk.ac.manchester.beehivespirvtoolkit.lib.SPIRVFunction;
  65 import uk.ac.manchester.beehivespirvtoolkit.lib.SPIRVBlock;
  66 import uk.ac.manchester.beehivespirvtoolkit.lib.instructions.*;
  67 import uk.ac.manchester.beehivespirvtoolkit.lib.instructions.operands.*;
  68 import uk.ac.manchester.beehivespirvtoolkit.lib.disassembler.Disassembler;
  69 import uk.ac.manchester.beehivespirvtoolkit.lib.disassembler.SPIRVDisassemblerOptions;
  70 import uk.ac.manchester.beehivespirvtoolkit.lib.disassembler.SPVByteStreamReader;
  71 
  72 public class SpirvModuleGenerator {
  73     public static MemorySegment generateModule(String moduleName, CoreOp.FuncOp func) {
  74         SpirvOps.FuncOp spirvFunc = TranslateToSpirvModel.translateFunction(func);
  75         MemorySegment module = SpirvModuleGenerator.generateModule(moduleName, spirvFunc);
  76         return module;
  77     }
  78 
  79     public static MemorySegment generateModule(String moduleName, SpirvOps.FuncOp func) {
  80         return new SpirvModuleGenerator().generateModuleInternal(moduleName, func);
  81     }
  82 
  83     public static void writeModuleToFile(MemorySegment module, String filepath)  {
  84         ByteBuffer buffer = module.asByteBuffer();
  85         File out = new File(filepath);
  86         try (FileChannel channel = new FileOutputStream(out, false).getChannel()) {
  87             channel.write(buffer);
  88         }
  89         catch (IOException e)  {
  90             throw new RuntimeException(e);
  91         }
  92     }
  93 
  94     public static String disassembleModule(MemorySegment module) {
  95         SPVByteStreamReader input = new SPVByteStreamReader(new ByteArrayInputStream(module.toArray(ValueLayout.JAVA_BYTE)));
  96         ByteArrayOutputStream out = new ByteArrayOutputStream();
  97         try (PrintStream ps = new PrintStream(out))  {
  98             SPIRVDisassemblerOptions options = new SPIRVDisassemblerOptions(false, false, false, false, true);
  99             Disassembler dis = new Disassembler(input, ps, options);
 100             dis.run();
 101         }
 102         catch (Exception e) {
 103             throw new RuntimeException(e);
 104         }
 105         return new String(out.toByteArray());
 106     }
 107 
 108     private record SpirvResult(SPIRVId type, SPIRVId address, SPIRVId value) {}
 109 
 110     private final SPIRVModule module;
 111     private final Symbols symbols;
 112 
 113     private SpirvModuleGenerator() {
 114         this.module = new SPIRVModule(new SPIRVHeader(1, 2, 32, 0, 0));
 115         this.symbols = new Symbols();
 116     }
 117 
 118     private MemorySegment generateModuleInternal(String moduleName, SpirvOps.FuncOp func) {
 119         initModule();
 120         generateFunction(moduleName, moduleName, func);
 121         ByteBuffer buffer = ByteBuffer.allocateDirect(module.getByteCount());
 122         buffer.order(ByteOrder.LITTLE_ENDIAN);
 123         module.close().write(buffer);
 124         buffer.flip();
 125         return MemorySegment.ofBuffer(buffer);
 126     }
 127 
 128     private void generateFunction(String moduleName, String fnName, SpirvOps.FuncOp func) {
 129         TypeElement returnType = func.invokableType().returnType();
 130         SPIRVId functionID = nextId(fnName);
 131         String signature = func.invokableType().returnType().toString();
 132         List<TypeElement> paramTypes = func.invokableType().parameterTypes();
 133         // build signature string
 134         for (int i = 0; i < paramTypes.size(); i++) {
 135             signature += "_" + paramTypes.get(i).toString();
 136         }
 137         // declare function type if not already present
 138         SPIRVId functionSig = getIdOrNull(signature);
 139         if (functionSig == null) {
 140             SPIRVId[] typeIdsArray = new SPIRVId[paramTypes.size()];
 141             for (int i = 0; i < paramTypes.size(); i++) {
 142                 typeIdsArray[i] = spirvType(paramTypes.get(i).toString());
 143             }
 144             functionSig = nextId(fnName + "Signature");
 145             module.add(new SPIRVOpTypeFunction(functionSig, spirvType(returnType.toString()), new SPIRVMultipleOperands<>(typeIdsArray)));
 146             addId(signature, functionSig);
 147         }
 148         // declare function as modeule entry point
 149         SPIRVId spirvReturnType = spirvType(returnType.toString());
 150         SPIRVFunction function = (SPIRVFunction)module.add(new SPIRVOpFunction(spirvReturnType, functionID, SPIRVFunctionControl.DontInline(), functionSig));
 151         SPIRVOpLabel entryPoint = new SPIRVOpLabel(nextId());
 152         SPIRVBlock entryBlock = (SPIRVBlock)function.add(entryPoint);
 153         SPIRVMultipleOperands<SPIRVId> operands = new SPIRVMultipleOperands<>(getId("globalInvocationId"), getId("globalSize"), getId("subgroupSize"), getId("subgroupId"));
 154         module.add(new SPIRVOpEntryPoint(SPIRVExecutionModel.Kernel(), functionID, new SPIRVLiteralString(fnName), operands));
 155 
 156         translateBody(func.body(), function, entryBlock);
 157         function.add(new SPIRVOpFunctionEnd());
 158     }
 159 
 160     private void translateBody(Body body, SPIRVFunction function, SPIRVBlock entryBlock) {
 161         int labelNumber = 0;
 162         SPIRVBlock spirvBlock = entryBlock;
 163         for (int bi = 1; bi < body.blocks().size(); bi++)  {
 164             Block block = body.blocks().get(bi);
 165             String blockName = String.valueOf(block.hashCode());
 166             SPIRVOpLabel blockLabel = new SPIRVOpLabel(nextId());
 167             SPIRVBlock newBlock = (SPIRVBlock)function.add(blockLabel);
 168             symbols.putBlock(block, newBlock);
 169             symbols.putLabel(block, blockLabel);
 170         }
 171         for (Value param : body.entryBlock().parameters()) {
 172             SPIRVId paramId = nextId();
 173             addResult(param, new SpirvResult(spirvType(param.type().toString()), null, paramId));
 174         }
 175         for (int bi = 0; bi < body.blocks().size(); bi++)  {
 176             Block block = body.blocks().get(bi);
 177             if (bi > 0) {
 178                 spirvBlock = symbols.getBlock(block);
 179             }
 180             List<Op> ops = block.ops();
 181             for (Op op : block.ops()) {
 182                 // debug("---------- spirv op = %s", op.toText());
 183                 switch (op)  {
 184                     case SpirvOps.VariableOp vop -> {
 185                         String typeName = vop.varType().toString();
 186                         SPIRVId type = spirvType(typeName);
 187                         SPIRVId varType = spirvVariableType(type);
 188                         SPIRVId var = nextId(vop.varName());
 189                         spirvBlock.add(new SPIRVOpVariable(varType, var, SPIRVStorageClass.Function(), new SPIRVOptionalOperand<>()));
 190                         addResult(vop.result(), new SpirvResult(varType, var, null));
 191                     }
 192                     case SpirvOps.FunctionParameterOp fpo -> {
 193                         SPIRVId result = nextId();
 194                         SPIRVId type = spirvType(fpo.resultType().toString());
 195                         function.add(new SPIRVOpFunctionParameter(type, result));
 196                         addResult(fpo.result(), new SpirvResult(type, null, result));
 197                     }
 198                     case SpirvOps.LoadOp lo -> {
 199                         if (((JavaType)lo.resultType()).equals(JavaType.type(VectorSpecies.class))) {
 200                             addResult(lo.result(), new SpirvResult(getType("int"), null, getConst("int_EIGHT")));
 201                         }
 202                         else {
 203                             SPIRVId type = spirvType(lo.resultType().toString());
 204                             SpirvResult toLoad = getResult(lo.operands().get(0));
 205                             SPIRVId varAddr = toLoad.address();
 206                             SPIRVId result = nextId();
 207                             spirvBlock.add(new SPIRVOpLoad(type, result, varAddr, align(type.getName())));
 208                             addResult(lo.result(), new SpirvResult(type, varAddr, result));
 209                         }
 210                     }
 211                     case SpirvOps.StoreOp so -> {
 212                         SpirvResult var = getResult(so.operands().get(0));
 213                         SPIRVId varAddr = var.address();
 214                         SPIRVId value = getResult(so.operands().get(1)).value();
 215                         spirvBlock.add(new SPIRVOpStore(varAddr, value, align(var.type().getName())));
 216                     }
 217                     case SpirvOps.IAddOp _, SpirvOps.FAddOp _ -> {
 218                         SPIRVId intType = getType("int");
 219                         SPIRVId longType = getType("long");
 220                         SPIRVId floatType = getType("float");
 221                         SPIRVId lhs = getResult(op.operands().get(0)).value();
 222                         SPIRVId rhs = getResult(op.operands().get(1)).value();
 223                         SPIRVId lhsType = spirvType(op.resultType().toString());
 224                         SPIRVId ans = nextId();
 225                         if (lhsType == intType) spirvBlock.add(new SPIRVOpIAdd(intType, ans, lhs, rhs));
 226                         else if (lhsType == longType) spirvBlock.add(new SPIRVOpIAdd(longType, ans, lhs, rhs));
 227                         else if (lhsType == floatType) spirvBlock.add(new SPIRVOpFAdd(floatType, ans, lhs, rhs));
 228                         else unsupported("type", lhsType.getName());
 229                         addResult(op.result(), new SpirvResult(lhsType, null, ans));
 230                     }
 231                     case SpirvOps.IMulOp _, SpirvOps.FMulOp _, SpirvOps.IDivOp _, SpirvOps.FDivOp _ -> {
 232                         SPIRVId intType = getType("int");
 233                         SPIRVId longType = getType("long");
 234                         SPIRVId floatType = getType("float");
 235                         SPIRVId lhs = getResult(op.operands().get(0)).value();
 236                         SPIRVId rhs = getResult(op.operands().get(1)).value();
 237                         SPIRVId lhsType = spirvType(op.resultType().toString());
 238                         SPIRVId rhsType = getResult(op.operands().get(1)).type();
 239                         SPIRVId ans = nextId();
 240                         if (lhsType == intType) {
 241                             if (op instanceof SpirvOps.IMulOp) spirvBlock.add(new SPIRVOpIMul(intType, ans, lhs, rhs));
 242                             else if (op instanceof SpirvOps.IDivOp) spirvBlock.add(new SPIRVOpSDiv(intType, ans, lhs, rhs));
 243                         }
 244                         else if (lhsType == longType) {
 245                             SPIRVId rhsId = rhsType == intType ? nextId() : rhs;
 246                             if (rhsType == intType) spirvBlock.add(new SPIRVOpSConvert(longType, rhsId, rhs));
 247                             if (op instanceof SpirvOps.IMulOp) spirvBlock.add(new SPIRVOpIMul(longType, ans, lhs, rhsId));
 248                             else if (op instanceof SpirvOps.IDivOp) spirvBlock.add(new SPIRVOpSDiv(longType, ans, lhs, rhs));
 249                         }
 250                         else if (lhsType == floatType) {
 251                             if (op instanceof SpirvOps.FMulOp) spirvBlock.add(new SPIRVOpFMul(floatType, ans, lhs, rhs));
 252                             else if (op instanceof SpirvOps.FDivOp) spirvBlock.add(new SPIRVOpFDiv(floatType, ans, lhs, rhs));
 253                         }
 254                         else unsupported("type", lhsType);
 255                         addResult(op.result(), new SpirvResult(lhsType, null, ans));
 256                     }
 257                     case SpirvOps.ModOp mop -> {
 258                         SPIRVId type = getType(mop.operands().get(0).type().toString());
 259                         SPIRVId lhs = getResult(mop.operands().get(0)).value();
 260                         SPIRVId rhs = getResult(mop.operands().get(1)).value();
 261                         SPIRVId result = nextId();
 262                         spirvBlock.add(new SPIRVOpUMod(type, result, lhs, rhs));
 263                         addResult(mop.result(), new SpirvResult(type, null, result));
 264                     }
 265                     case SpirvOps.IEqualOp eqop -> {
 266                         SPIRVId boolType = getType("bool");
 267                         SPIRVId intType = getType("int");
 268                         SPIRVId longType = getType("long");
 269                         SPIRVId floatType = getType("float");
 270                         SPIRVId lhs = getResult(op.operands().get(0)).value();
 271                         SPIRVId rhs = getResult(op.operands().get(1)).value();
 272                         SPIRVId lhsType = spirvType(op.resultType().toString());
 273                         SPIRVId ans = nextId();
 274                         if (lhsType == intType) spirvBlock.add(new SPIRVOpIEqual(boolType, ans, lhs, rhs));
 275                         else if (lhsType == longType) spirvBlock.add(new SPIRVOpIEqual(boolType, ans, lhs, rhs));
 276                         else unsupported("type", lhsType.getName());
 277                         addResult(op.result(), new SpirvResult(lhsType, null, ans));
 278                     }
 279                     case SpirvOps.CallOp call -> {
 280                         if (call.callDescriptor().equals(MethodRef.method(JavaType.type(ClassDesc.of("spirvdemo.IntArray")), "get", JavaType.INT, JavaType.LONG)) ||
 281                             call.callDescriptor().equals(MethodRef.method(JavaType.type(ClassDesc.of("spirvdemo.FloatArray")), "get", JavaType.FLOAT, JavaType.LONG))) {
 282                             SPIRVId longType = getType("long");
 283                             String arrayTypeName = call.operands().get(0).type().toString();
 284                             SpirvResult arrayResult = getResult(call.operands().get(0));
 285                             SPIRVId arrayAddr = arrayResult.address();
 286                             SPIRVId arrayType = spirvType(arrayTypeName);
 287                             SPIRVId elementType = spirvElementType(arrayTypeName);
 288                             int nIndexes = call.operands().size() - 1;
 289                             SPIRVId index = getResult(call.operands().get(1)).value();
 290                             SPIRVId array = nextId();
 291                             spirvBlock.add(new SPIRVOpLoad(arrayType, array, arrayAddr, align(arrayType.getName())));
 292                             SPIRVId resultAddr = nextId();
 293                             spirvBlock.add(new SPIRVOpInBoundsPtrAccessChain(arrayType, resultAddr, array, index, new SPIRVMultipleOperands<>()));
 294                             SPIRVId result = nextId();
 295                             spirvBlock.add(new SPIRVOpLoad(elementType, result, resultAddr, align(elementType.getName())));
 296                             addResult(call.result(), new SpirvResult(elementType, resultAddr, result));
 297                         }
 298                         else if (call.callDescriptor().equals(MethodRef.method(JavaType.type(ClassDesc.of("spirvdemo.IntArray")), "set", JavaType.VOID, JavaType.LONG, JavaType.INT)) ||
 299                                 call.callDescriptor().equals(MethodRef.method(JavaType.type(ClassDesc.of("spirvdemo.FloatArray")), "set", JavaType.VOID, JavaType.LONG, JavaType.FLOAT))) {
 300                             SPIRVId longType = getType("long");
 301                             String arrayTypeName = call.operands().get(0).type().toString();
 302                             SpirvResult arrayResult = getResult(call.operands().get(0));
 303                             SPIRVId arrayAddr = arrayResult.address();
 304                             SPIRVId arrayType = spirvType(arrayTypeName);
 305                             SPIRVId elementType = spirvElementType(arrayTypeName);
 306                             int nIndexes = call.operands().size() - 2;
 307                             int valueIndex = nIndexes + 1;
 308                             SPIRVId index = getResult(call.operands().get(1)).value();
 309                             SPIRVId array = nextId();
 310                             spirvBlock.add(new SPIRVOpLoad(arrayType, array, arrayAddr, align(arrayType.getName())));
 311                             SPIRVId dest = nextId();
 312                             spirvBlock.add(new SPIRVOpInBoundsPtrAccessChain(arrayType, dest, array, index, new SPIRVMultipleOperands<>()));
 313                             SPIRVId value = getResult(call.operands().get(valueIndex)).value();
 314                             spirvBlock.add(new SPIRVOpStore(dest, value, align(elementType.getName())));
 315                         }
 316                         else if (call.callDescriptor().equals(MethodRef.method(IntVector.class, "fromArray", IntVector.class, VectorSpecies.class, int[].class, int.class))
 317                               || call.callDescriptor().equals(MethodRef.method(FloatVector.class, "fromArray", FloatVector.class, VectorSpecies.class, float[].class, int.class))) {
 318                             SPIRVId oclExtension = getId("oclExtension");
 319                             SpirvResult speciesResult = getResult(call.operands().get(0));
 320                             SpirvResult arrayResult = getResult(call.operands().get(1));
 321                             String arrayType = arrayResult.type().getName();
 322                             int laneCount = 8;  //TODO: remove hard code, instruction below needs a literal
 323                             String vTypeName = ((ClassType)call.callDescriptor().refType()).toClassName();
 324                             SPIRVId vType = spirvVectorType(vTypeName, laneCount);
 325                             SPIRVId array = arrayResult.value();
 326                             SPIRVId index = getResult(call.operands().get(2)).value();
 327                             SPIRVId vectorIndex = nextId();
 328                             spirvBlock.add(new SPIRVOpSDiv(getType("int"), vectorIndex, index, speciesResult.value()));
 329                             SPIRVId longIndex = nextId();
 330                             spirvBlock.add(new SPIRVOpSConvert(getType("long"), longIndex, vectorIndex));
 331                             SPIRVId vector = nextId();
 332                             SPIRVMultipleOperands<SPIRVId> operands = new SPIRVMultipleOperands<>(longIndex, array, new SPIRVId(laneCount)); // TODO: lanes must be a literal
 333                             spirvBlock.add(new SPIRVOpExtInst(vType, vector, oclExtension, new SPIRVLiteralExtInstInteger(171, "vloadn"), operands));
 334                             addResult(call.result(), new SpirvResult(vType, null, vector));
 335                         }
 336                         else if (call.callDescriptor().equals(MethodRef.method(IntVector.class, "fromMemorySegment", IntVector.class, VectorSpecies.class, MemorySegment.class, long.class, ByteOrder.class))
 337                               || call.callDescriptor().equals(MethodRef.method(FloatVector.class, "fromMemorySegment", FloatVector.class, VectorSpecies.class, MemorySegment.class, long.class, ByteOrder.class))) {
 338                             SPIRVId oclExtension = getId("oclExtension");
 339                             SPIRVId species = getResult(call.operands().get(0)).value();
 340                             SPIRVId lanesLong = nextId();
 341                             spirvBlock.add(new SPIRVOpSConvert(getType("long"), lanesLong, species));
 342                             int laneCount = 8; //TODO: remove hard code, vloadn instruction below needs a literal lane count, get value from env
 343                             SPIRVId segment = getResult(call.operands().get(1)).value();
 344                             String vTypeName = ((ClassType)call.callDescriptor().refType()).toClassName();
 345                             SPIRVId vType = spirvVectorType(vTypeName, laneCount);
 346                             SPIRVId temp = nextId();
 347                             spirvBlock.add(new SPIRVOpConvertPtrToU(getType("long"), temp, segment));
 348                             SPIRVId typedSegment = nextId();
 349                             SPIRVId pointerType = (SPIRVId)map(x -> x.equals(vTypeName), "jdk.incubator.vector.IntVector", "jdk.incubator.vector.FloatVector", getType("ptrInt"), getType("ptrFloat"));
 350                             spirvBlock.add(new SPIRVOpConvertUToPtr(pointerType, typedSegment, temp));
 351                             SPIRVId offset = getResult(call.operands().get(2)).value();
 352                             SPIRVId vectorIndex = nextId();
 353                             spirvBlock.add(new SPIRVOpSDiv(getType("long"), vectorIndex, offset, lanesLong)); // divide by lane count
 354                             SPIRVId finalIndex = nextId();
 355                             SPIRVId vector = nextId();
 356                             SPIRVMultipleOperands<SPIRVId> operands = new SPIRVMultipleOperands<>(vectorIndex, typedSegment, new SPIRVId(laneCount)); // TODO: lanes must be a literal
 357                             spirvBlock.add(new SPIRVOpExtInst(vType, vector, oclExtension, new SPIRVLiteralExtInstInteger(171, "vloadn"), operands));
 358                             addResult(call.result(), new SpirvResult(vType, null, vector));
 359                         }
 360                         else if (call.callDescriptor().equals(MethodRef.method(IntVector.class, "intoArray", void.class, int[].class, int.class))
 361                               || call.callDescriptor().equals(MethodRef.method(FloatVector.class, "intoArray", void.class, float[].class, int.class))) {
 362                             SPIRVId oclExtension = getId("oclExtension");
 363                             SpirvResult vectorResult = getResult(call.operands().get(0));
 364                             SPIRVId vector = vectorResult.value();
 365                             SPIRVId vectorType = vectorResult.type();
 366                             SpirvResult arrayResult = getResult(call.operands().get(1));
 367                             SPIRVId array = arrayResult.value();
 368                             SPIRVId index = getResult(call.operands().get(2)).value();
 369                             SPIRVId vectorIndex = nextId();
 370                             spirvBlock.add(new SPIRVOpShiftRightArithmetic(getType("int"), vectorIndex, index, vectorExponent(vectorType.getName())));
 371                             SPIRVId longIndex = nextId();
 372                             spirvBlock.add(new SPIRVOpSConvert(getType("long"), longIndex, vectorIndex));
 373                             SPIRVMultipleOperands<SPIRVId> operandsR = new SPIRVMultipleOperands<>(vector, longIndex, array);
 374                             spirvBlock.add(new SPIRVOpExtInst(getType("void"), nextId(), oclExtension, new SPIRVLiteralExtInstInteger(172, "vstoren"), operandsR));
 375                         }
 376                         else if (call.callDescriptor().equals(MethodRef.method(IntVector.class, "intoMemorySegment", void.class, MemorySegment.class, long.class, ByteOrder.class))
 377                               || call.callDescriptor().equals(MethodRef.method(FloatVector.class, "intoMemorySegment", void.class, MemorySegment.class, long.class, ByteOrder.class))) {
 378                             SPIRVId oclExtension = getId("oclExtension");
 379                             SpirvResult vectorResult = getResult(call.operands().get(0));
 380                             SPIRVId vector = vectorResult.value();
 381                             SPIRVId vectorType = vectorResult.type();
 382                             SpirvResult segmentResult = getResult(call.operands().get(1));;
 383                             SPIRVId segment = segmentResult.value();
 384                             SPIRVId temp = nextId();
 385                             spirvBlock.add(new SPIRVOpConvertPtrToU(getType("long"), temp, segment));
 386                             SPIRVId typedSegment = nextId();
 387                             String vectorElementType = vectorElementType(vectorType).getName();
 388                             SPIRVId pointerType = (SPIRVId)map(x -> x.equals(vectorElementType), "int", "float", getType("ptrInt"), getType("ptrFloat"));
 389                             spirvBlock.add(new SPIRVOpConvertUToPtr(pointerType, typedSegment, temp));
 390                             SPIRVId offset = getResult(call.operands().get(2)).value();
 391                             SPIRVId vectorIndex = nextId();
 392                             int laneCount = laneCount(vectorType.getName());
 393                             spirvBlock.add(new SPIRVOpShiftRightArithmetic(getType("long"), vectorIndex, offset, vectorExponent(vectorType.getName())));
 394                             SPIRVMultipleOperands<SPIRVId> operandsR = new SPIRVMultipleOperands<>(vector, vectorIndex, typedSegment);
 395                             spirvBlock.add(new SPIRVOpExtInst(getId("void"), nextId(), oclExtension, new SPIRVLiteralExtInstInteger(172, "vstoren"), operandsR));
 396                         }
 397                         else if (call.callDescriptor().equals(MethodRef.method(IntVector.class, "reduceLanes", int.class, VectorOperators.Associative.class))
 398                               || call.callDescriptor().equals(MethodRef.method(FloatVector.class, "reduceLanes", float.class, VectorOperators.Associative.class))) {
 399                             SpirvResult vectorResult = getResult(call.operands().get(0));
 400                             SPIRVId vectorType = vectorResult.type();
 401                             SPIRVId vector = vectorResult.value();
 402                             String vTypeName = vectorType.getName();
 403                             SPIRVId elementType = vectorElementType(vectorType);
 404                             Op reduceOp = ((Op.Result)call.operands().get(1)).op();
 405                             if (reduceOp instanceof SpirvOps.FieldLoadOp flo) {
 406                                 assert flo.fieldDescriptor().refType().equals(JavaType.type(VectorOperators.class));
 407                                 assert flo.fieldDescriptor().name().equals("ADD");
 408                                 String operation = flo.fieldDescriptor().name();
 409                             }
 410                             else unsupported("operation expression", reduceOp.toText());
 411                             String tempTag = nextTempTag();
 412                             SPIRVId temp_0 = nextId(tempTag + 0);
 413                             spirvBlock.add(new SPIRVOpCompositeExtract(elementType, temp_0, vector, new SPIRVMultipleOperands<>(new SPIRVLiteralInteger(0))));
 414                             for (int lane = 1; lane < laneCount(vectorType.getName()); lane++) {
 415                                 SPIRVId temp = nextId(tempTag + lane);
 416                                 SPIRVId element = nextId();
 417                                 spirvBlock.add(new SPIRVOpCompositeExtract(elementType, element, vector, new SPIRVMultipleOperands<>(new SPIRVLiteralInteger(lane))));
 418                                 if (elementType == getType("int")) {
 419                                     spirvBlock.add(new SPIRVOpIAdd(elementType, temp, getId(tempTag + (lane - 1)), element));
 420                                 }
 421                                 else if (elementType == getType("float")) {
 422                                     spirvBlock.add(new SPIRVOpFAdd(elementType, temp, getId(tempTag + (lane - 1)), element));
 423                                 }
 424                                 else unsupported("type", elementType.getName());
 425                             }
 426                             addResult(call.result(), new SpirvResult(elementType, null, getId(tempTag + (laneCount(vectorType.getName()) - 1))));
 427                         }
 428                         else if (call.callDescriptor().equals(MethodRef.method(IntVector.class, "add", IntVector.class, Vector.class))
 429                               || call.callDescriptor().equals(MethodRef.method(IntVector.class, "mul", IntVector.class, Vector.class))
 430                               || call.callDescriptor().equals(MethodRef.method(FloatVector.class, "add", FloatVector.class, Vector.class))
 431                               || call.callDescriptor().equals(MethodRef.method(FloatVector.class, "mul", FloatVector.class, Vector.class))) {
 432                             SPIRVId oclExtension = getId("oclExtension");
 433                             SpirvResult lhsResult = getResult(call.operands().get(0));
 434                             SPIRVId lhsType = lhsResult.type();
 435                             SPIRVId lhs = lhsResult.value();
 436                             SPIRVId rhs = getResult(call.operands().get(1)).value();
 437                             SPIRVId add = nextId();
 438                             if (call.callDescriptor().name().equals("add")) {
 439                                 spirvBlock.add(lhsType.getName().endsWith("int") ? new SPIRVOpIAdd(lhsType, add, lhs, rhs) : new SPIRVOpFAdd(lhsType, add, lhs, rhs));
 440                             }
 441                             else if (call.callDescriptor().name().equals("mul")) {
 442                                 spirvBlock.add(lhsType.getName().endsWith("int") ? new SPIRVOpIMul(lhsType, add, lhs, rhs) : new SPIRVOpFMul(lhsType, add, lhs, rhs));
 443                             }
 444                             addResult(call.result(), new SpirvResult(lhsType, null, add));
 445                         }
 446                         else if (call.callDescriptor().equals(MethodRef.method(FloatVector.class, "fma", FloatVector.class, Vector.class, Vector.class))) {
 447                             SPIRVId oclExtension = getId("oclExtension");
 448                             SpirvResult aResult = getResult(call.operands().get(0));
 449                             SPIRVId vType = aResult.type();
 450                             SPIRVId a = aResult.value();
 451                             SPIRVId b = getResult(call.operands().get(1)).value();
 452                             SPIRVId c = getResult(call.operands().get(2)).value();
 453                             String vTypeStr = vType.getName();
 454                             assert vTypeStr.endsWith("float");
 455                             SPIRVId result  = nextId();
 456                             SPIRVMultipleOperands<SPIRVId> operands = new SPIRVMultipleOperands<>(a, b, c);
 457                             spirvBlock.add(new SPIRVOpExtInst(vType, result, oclExtension, new SPIRVLiteralExtInstInteger(26, "fma"), operands));
 458                             addResult(call.result(), new SpirvResult(vType, null, result));
 459                         }
 460                         else if (call.callDescriptor().equals(MethodRef.method(IntVector.class, "zero", IntVector.class, VectorSpecies.class))
 461                              || call.callDescriptor().equals(MethodRef.method(FloatVector.class, "zero", FloatVector.class, VectorSpecies.class))) {
 462                             SpirvResult speciesResult = getResult(call.operands().get(0));
 463                             SPIRVId vType = spirvType(((ClassType)call.callDescriptor().refType()).toClassName());
 464                             String elementType = vectorElementType(vType).getName();
 465                             SPIRVId value = getId(elementType + "_ZERO");
 466                             int laneCount = laneCount(vType.getName());
 467                             assert laneCount == 8 || laneCount == 16;
 468                             SPIRVId vector = nextId();
 469                             SPIRVMultipleOperands<SPIRVId> operands = spirvOperands(value, laneCount);
 470                             spirvBlock.add(new SPIRVOpCompositeConstruct(vType, vector, operands));
 471                             addResult(call.result(), new SpirvResult(vType, null, vector));
 472                         }
 473                         else if (call.callDescriptor().equals(MethodRef.method(IntVector.class, "lane", int.class, int.class))
 474                               || call.callDescriptor().equals(MethodRef.method(FloatVector.class, "lane", float.class, int.class)))  {
 475                             SpirvResult lhsResult = getResult(call.operands().get(0));
 476                             SPIRVId lhsType = lhsResult.type();
 477                             SPIRVId lhs = lhsResult.value();
 478                             String vTypeStr = lhsType.getName();
 479                             SPIRVId vType = lhsResult.type();
 480                             SPIRVId elementType = vectorElementType(vType);
 481                             SPIRVId result = nextId();
 482                             Op laneOp = ((Op.Result)call.operands().get(1)).op();
 483                             assert laneOp instanceof SpirvOps.ConstantOp;
 484                             int lane = (int)((SpirvOps.ConstantOp)laneOp).value();
 485                             spirvBlock.add(new SPIRVOpCompositeExtract(elementType, result, lhsResult.value(), new SPIRVMultipleOperands<>(new SPIRVLiteralInteger(lane))));
 486                             addResult(call.result(), new SpirvResult(elementType, null, result));
 487                         }
 488                         else if (call.callDescriptor().equals(MethodRef.method(VectorSpecies.class, "length", int.class))) {
 489                             addResult(call.result(), new SpirvResult(getType("int"), null, getConst("int_EIGHT"))); // TODO: remove hardcode
 490                         }
 491                         else unsupported("method", call.callDescriptor());
 492                     }
 493                     case SpirvOps.ConstantOp cop -> {
 494                         SPIRVId type = spirvType(cop.resultType().toString());
 495                         SPIRVId result = nextId();
 496                         Object value = cop.value();
 497                         if (type == getType("int")) {
 498                             module.add(new SPIRVOpConstant(type, result, new SPIRVContextDependentInt(new BigInteger(String.valueOf(value)))));
 499                         }
 500                         else if (type == getType("long")) {
 501                             module.add(new SPIRVOpConstant(type, result, new SPIRVContextDependentLong(new BigInteger(String.valueOf(value)))));
 502                         }
 503                         else if (type == getType("float")) {
 504                             module.add(new SPIRVOpConstant(type, result, new SPIRVContextDependentFloat((float)value)));
 505                         }
 506                         else unsupported("type", cop.resultType());
 507                         addResult(cop.result(), new SpirvResult(type, null, result));
 508                     }
 509                     case SpirvOps.ConvertOp scop -> {
 510                         SPIRVId toType = spirvType(scop.resultType().toString());
 511                         SPIRVId to = nextId();
 512                         SpirvResult valueResult = getResult(scop.operands().get(0));
 513                         SPIRVId from = valueResult.value();
 514                         SPIRVId fromType = valueResult.type();
 515                         if (isIntegerType(fromType)) {
 516                             if (isIntegerType(toType)) {
 517                                 spirvBlock.add(new SPIRVOpSConvert(toType, to, from));
 518                             }
 519                             else if (isFloatType(toType)) {
 520                                 spirvBlock.add(new SPIRVOpConvertSToF(toType, to, from));
 521                             }
 522                             else unsupported("conversion type", scop.resultType());
 523                         }
 524                         else unsupported("conversion type", scop.operands().get(0));
 525                         addResult(scop.result(), new SpirvResult(toType, null, to));
 526                     }
 527                     case SpirvOps.InBoundAccessChainOp iacop -> {
 528                         SPIRVId type = spirvType(iacop.resultType().toString());
 529                         SPIRVId result = nextId();
 530                         SPIRVId object = getResult(iacop.operands().get(0)).value();
 531                         SPIRVId index = getResult(iacop.operands().get(1)).value();
 532                         spirvBlock.add(new SPIRVOpInBoundsPtrAccessChain(type, result, object, index, new SPIRVMultipleOperands<>()));
 533                         addResult(iacop.result(), new SpirvResult(type, result, null));
 534                     }
 535                     case SpirvOps.FieldLoadOp flo -> {
 536                         if (flo.operands().size() > 0 && flo.operands().get(0).type().equals(JavaType.type(ClassDesc.of("spirvdemo.GPU$Index")))) {
 537                             SpirvResult result;
 538                             int group = -1;
 539                             int index = -1;
 540                             String fieldName = flo.fieldDescriptor().name();
 541                             switch(fieldName) {
 542                                 case "x": group = 0; index = 0; break;
 543                                 case "y": group = 0; index = 1; break;
 544                                 case "z": group = 0; index = 2; break;
 545                                 case "w": group = 1; index = 0; break;
 546                                 case "h": group = 1; index = 1; break;
 547                                 case "d": group = 1; index = 2; break;
 548                             }
 549                             switch (group) {
 550                                 case 0: result = globalId(index, spirvBlock); break;
 551                                 case 1: result = globalSize(index, spirvBlock); break;
 552                                 default: throw new RuntimeException("Unknown Index field: " + fieldName);
 553                             }
 554                             addResult(flo.result(), result);
 555                         }
 556                         else if (((JavaType)flo.resultType()).equals(JavaType.type(VectorSpecies.class))) {
 557                             addResult(flo.result(), new SpirvResult(getType("int"), null, getConst("int_EIGHT")));
 558                         }
 559                         else if (flo.fieldDescriptor().refType().equals(JavaType.type(VectorOperators.class))) {
 560                             // currently ignored
 561                         }
 562                         else if (flo.fieldDescriptor().refType().equals(JavaType.type(ByteOrder.class))) {
 563                             // currently ignored
 564                         }
 565                         else unsupported("field load", ((ClassType)flo.fieldDescriptor().refType()).toClassName() + "." + flo.fieldDescriptor().name());
 566                     }
 567                     case SpirvOps.BranchOp bop -> {
 568                         SPIRVId trueLabel = symbols.getLabel(bop.branch()).getResultId();
 569                         spirvBlock.add(new SPIRVOpBranch(trueLabel));
 570                     }
 571                     case SpirvOps.ConditionalBranchOp cbop -> {
 572                         SPIRVId test = getResult(cbop.operands().get(0)).value();
 573                         SPIRVId trueLabel = symbols.getLabel(cbop.trueBranch()).getResultId();
 574                         SPIRVId falseLabel = symbols.getLabel(cbop.falseBranch()).getResultId();
 575                         spirvBlock.add(new SPIRVOpBranchConditional(test, trueLabel, falseLabel, new SPIRVMultipleOperands<SPIRVLiteralInteger>()));
 576                     }
 577                     case SpirvOps.LtOp ltop -> {
 578                         SPIRVId lhs = getResult(ltop.operands().get(0)).value();
 579                         SPIRVId rhs = getResult(ltop.operands().get(1)).value();
 580                         SPIRVId boolType = getType("bool");
 581                         SPIRVId result = nextId();
 582                         spirvBlock.add(new SPIRVOpSLessThan(boolType, result, lhs, rhs));
 583                         addResult(ltop.result(), new SpirvResult(boolType, null, result));
 584                     }
 585                     case SpirvOps.ReturnOp rop -> {
 586                         if (rop.operands().size() == 0) {
 587                             spirvBlock.add(new SPIRVOpReturn());
 588                         }
 589                         else {
 590                             SPIRVId returnValue = getResult(rop.operands().get(0)).value();
 591                             spirvBlock.add(new SPIRVOpReturnValue(returnValue));
 592                         }
 593                     }
 594                     default -> unsupported("op", op.getClass());
 595                 }
 596             }
 597         }
 598     }
 599 
 600     private void initModule() {
 601         module.add(new SPIRVOpCapability(SPIRVCapability.Addresses()));
 602         module.add(new SPIRVOpCapability(SPIRVCapability.Linkage()));
 603         module.add(new SPIRVOpCapability(SPIRVCapability.Kernel()));
 604         module.add(new SPIRVOpCapability(SPIRVCapability.Int8()));
 605         module.add(new SPIRVOpCapability(SPIRVCapability.Int16()));
 606         module.add(new SPIRVOpCapability(SPIRVCapability.Int64()));
 607         module.add(new SPIRVOpCapability(SPIRVCapability.Vector16()));
 608         module.add(new SPIRVOpCapability(SPIRVCapability.Float16()));
 609         module.add(new SPIRVOpMemoryModel(SPIRVAddressingModel.Physical64(), SPIRVMemoryModel.OpenCL()));
 610 
 611         // OpenCL extension provides built-in variables suitable for kernel programming
 612         // Import extention and declare fourn variables
 613         SPIRVId oclExtension = nextId("oclExtension");
 614         module.add(new SPIRVOpExtInstImport(oclExtension, new SPIRVLiteralString("OpenCL.std")));
 615 
 616         SPIRVId globalInvocationId = nextId("globalInvocationId");
 617         SPIRVId globalSize = nextId("globalSize");
 618         SPIRVId subgroupSize = nextId("subgroupSize");
 619         SPIRVId subgroupId = nextId("subgroupId");
 620 
 621         module.add(new SPIRVOpDecorate(globalInvocationId, SPIRVDecoration.BuiltIn(SPIRVBuiltIn.GlobalInvocationId())));
 622         module.add(new SPIRVOpDecorate(globalInvocationId, SPIRVDecoration.Constant()));
 623         module.add(new SPIRVOpDecorate(globalInvocationId, SPIRVDecoration.LinkageAttributes(new SPIRVLiteralString("spirv_BuiltInGlobalInvocationId"), SPIRVLinkageType.Import())));
 624         module.add(new SPIRVOpDecorate(globalSize, SPIRVDecoration.BuiltIn(SPIRVBuiltIn.GlobalSize())));
 625         module.add(new SPIRVOpDecorate(globalSize, SPIRVDecoration.Constant()));
 626         module.add(new SPIRVOpDecorate(globalSize, SPIRVDecoration.LinkageAttributes(new SPIRVLiteralString("spirv_BuiltInGlobalSize"), SPIRVLinkageType.Import())));
 627         module.add(new SPIRVOpDecorate(subgroupSize, SPIRVDecoration.BuiltIn(SPIRVBuiltIn.SubgroupSize())));
 628         module.add(new SPIRVOpDecorate(subgroupSize, SPIRVDecoration.Constant()));
 629         module.add(new SPIRVOpDecorate(subgroupSize, SPIRVDecoration.LinkageAttributes(new SPIRVLiteralString("spirv_BuiltInSubgroupSize"), SPIRVLinkageType.Import())));
 630         module.add(new SPIRVOpDecorate(subgroupId, SPIRVDecoration.BuiltIn(SPIRVBuiltIn.SubgroupId())));
 631         module.add(new SPIRVOpDecorate(subgroupId, SPIRVDecoration.Constant()));
 632         module.add(new SPIRVOpDecorate(subgroupId, SPIRVDecoration.LinkageAttributes(new SPIRVLiteralString("spirv_BuiltInSubgroupId"), SPIRVLinkageType.Import())));
 633 
 634         module.add(new SPIRVOpVariable(getType("ptrV3long"), globalInvocationId, SPIRVStorageClass.Input(), new SPIRVOptionalOperand<>()));
 635         module.add(new SPIRVOpVariable(getType("ptrV3long"), globalSize, SPIRVStorageClass.Input(), new SPIRVOptionalOperand<>()));
 636         module.add(new SPIRVOpVariable(getType("ptrV3long"), subgroupSize, SPIRVStorageClass.Input(), new SPIRVOptionalOperand<>()));
 637         module.add(new SPIRVOpVariable(getType("ptrV3long"), subgroupId, SPIRVStorageClass.Input(), new SPIRVOptionalOperand<>()));
 638     }
 639 
 640     private SPIRVId spirvType(String javaType) {
 641         SPIRVId ans = switch(javaType) {
 642             case "byte" -> getType("byte");
 643             case "short" -> getType("short");
 644             case "int" -> getType("int");
 645             case "long" -> getType("long");
 646             case "float" -> getType("float");
 647             case "double" -> getType("double");
 648             case "int[]" -> getType("int[]");
 649             case "float[]" -> getType("float[]");
 650             case "double[]" -> getType("double[]");
 651             case "long[]" -> getType("long[]");
 652             case "bool" -> getType("bool");
 653             case "spirvdemo.IntArray" -> getType("int[]");
 654             case "spirvdemo.FloatArray" -> getType("float[]");
 655             case "jdk.incubator.vector.IntVector" -> spirvVectorType("IntVector", 8);
 656             case "jdk.incubator.vector.FloatVector" -> spirvVectorType("FloatVector", 8);
 657             case "jdk.incubator.vector.VectorSpecies<java.lang.Integer>" -> getType("int");
 658             case "jdk.incubator.vector.VectorSpecies<java.lang.Long>" -> getType("long");
 659             case "jdk.incubator.vector.VectorSpecies<java.lang.Float>" -> getType("int");
 660             case "VectorSpecies" -> getType("int");
 661             case "void" -> getType("void");
 662             case "spirvdemo.GPU$Index" -> getType("ptrGPUIndex");
 663             case "java.lang.foreign.MemorySegment" -> getType("ptrByte");
 664             default -> null;
 665         };
 666         if (ans == null) unsupported("type", javaType);
 667         return ans;
 668     }
 669 
 670     private SPIRVId spirvElementType(String javaType) {
 671         SPIRVId ans = switch(javaType) {
 672             case "byte[]" -> getType("byte");
 673             case "short[]" -> getType("short");
 674             case "int[]" -> getType("int");
 675             case "long[]" -> getType("long");
 676             case "float[]" -> getType("float");
 677             case "double[]" -> getType("double");
 678             case "boolean[]" -> getType("bool");
 679             case "spirvdemo.IntArray" -> getType("int");
 680             case "spirvdemo.FloatArray" -> getType("float");
 681             case "jdk.incubator.vector.LongVector" -> getType("long");
 682             case "jdk.incubator.vector.FloatVector" -> getType("float");
 683             case "IntVector" -> getType("int");
 684             case "LongVector" -> getType("long");
 685             case "FloatVector" -> getType("float");
 686             case "java.lang.foreign.MemorySegment" -> getType("byte");
 687             default -> null;
 688         };
 689         if (ans == null) unsupported("type", javaType);
 690         return ans;
 691     }
 692 
 693     private SPIRVId vectorElementType(SPIRVId type) {
 694         SPIRVId ans = switch(type.getName()) {
 695             case "v8int" -> getType("int");
 696             case "v16int" -> getType("int");
 697             case "v8long" -> getType("long");
 698             case "v8float" -> getType("float");
 699             case "v16float" -> getType("float");
 700             default -> null;
 701         };
 702         if (ans == null) unsupported("type", type.getName());
 703         return ans;
 704     }
 705 
 706     private SPIRVId spirvVariableType(SPIRVId spirvType) {
 707         SPIRVId ans = switch(spirvType.getName()) {
 708             case "byte" -> getType("ptrByte");
 709             case "short" -> getType("ptrShort");
 710             case "int" -> getType("ptrInt");
 711             case "long" -> getType("ptrLong");
 712             case "float" -> getType("ptrFloat");
 713             case "double" -> getType("ptrDouble");
 714             case "boolean" -> getType("ptrBool");
 715             case "int[]" -> getType("ptrInt[]");
 716             case "long[]" -> getType("ptrLong[]");
 717             case "float[]" -> getType("ptrFloat[]");
 718             case "double[]" -> getType("ptrDouble[]");
 719             case "v8int" -> getType("ptrV8int");
 720             case "v16int" -> getType("ptrV16int");
 721             case "v8long" -> getType("ptrV8long");
 722             case "v8float" -> getType("ptrV8float");
 723             case "v16float" -> getType("ptrV16float");
 724             case "ptrGPUIndex" -> getType("ptrPtrGPUIndex");
 725             case "ptrByte" -> getType("ptrPtrByte");
 726             default -> null;
 727         };
 728         if (ans == null) unsupported("type", spirvType.getName());
 729         return ans;
 730     }
 731 
 732     private SPIRVId spirvVectorType(String javaVectorType, int vectorLength) {
 733         String prefix = "v" + vectorLength;
 734         String elementType = spirvElementType(javaVectorType).getName();
 735         return getType(prefix + elementType);
 736     }
 737 
 738     private int alignment(String spirvType) {
 739         int ans = switch(spirvType) {
 740             case "byte" -> 1;
 741             case "short" -> 2;
 742             case "int" -> 4;
 743             case "long" -> 8;
 744             case "float" -> 4;
 745             case "double" -> 8;
 746             case "boolean" -> 1;
 747             case "v8int" -> 32;
 748             case "v16int" -> 64;
 749             case "v8long" -> 64;
 750             case "v8float" -> 32;
 751             case "v16float" -> 64;
 752             case "ptrGPUIndex" -> 32;
 753             case "int[]" -> 8;
 754             case "long[]" -> 8;
 755             case "float[]" -> 8;
 756             case "double[]" -> 8;
 757             case "ptrByte" -> 8;
 758             case "ptrInt" -> 8;
 759             case "ptrInt[]" -> 8;
 760             case "ptrLong" -> 8;
 761             case "ptrLong[]" -> 8;
 762             case "ptrFloat" -> 8;
 763             case "ptrFloat[]" -> 8;
 764             case "ptrV8int" -> 8;
 765             case "ptrV8float" -> 8;
 766             case "ptrPtrGPUIndex" -> 8;
 767             default -> 0;
 768         };
 769         if (ans == 0) unsupported("type", spirvType);
 770         return ans;
 771     }
 772 
 773     private int laneCount(String vectorType) {
 774         int ans = switch(vectorType) {
 775             case "v8int" -> 8;
 776             case "v8long" -> 8;
 777             case "v8float" -> 8;
 778             case "v16int" -> 16;
 779             case "v16float" -> 16;
 780             default -> 0;
 781         };
 782         if (ans == 0) unsupported("type", vectorType);
 783         return ans;
 784     }
 785 
 786     private SPIRVId vectorExponent(String vectorType) {
 787         SPIRVId ans = null;
 788         switch(vectorType) {
 789             case "v8int" -> ans = getId("int_THREE");
 790             case "v8long" -> ans = getId("int_THREE");
 791             case "v8float" -> ans = getId("int_THREE");
 792             case "v16int" -> ans = getId("int_FOUR");
 793             case "v16float" -> ans = getId("int_FOUR");
 794             default -> unsupported("type", vectorType);
 795         };
 796         return ans;
 797     }
 798 
 799     private Set<String> moduleTypes = new HashSet<>();
 800 
 801     private SPIRVId getType(String name) {
 802         if (!moduleTypes.contains(name)) {
 803             switch (name) {
 804                 case "void" -> module.add(new SPIRVOpTypeVoid(nextId(name)));
 805                 case "bool" -> module.add(new SPIRVOpTypeBool(nextId(name)));
 806                 case "byte" -> module.add(new SPIRVOpTypeInt(nextId(name), new SPIRVLiteralInteger(8), new SPIRVLiteralInteger(0)));
 807                 case "short" -> module.add(new SPIRVOpTypeInt(nextId(name), new SPIRVLiteralInteger(16), new SPIRVLiteralInteger(0)));
 808                 case "int" -> module.add(new SPIRVOpTypeInt(nextId(name), new SPIRVLiteralInteger(32), new SPIRVLiteralInteger(0)));
 809                 case "long" -> module.add(new SPIRVOpTypeInt(nextId(name), new SPIRVLiteralInteger(64), new SPIRVLiteralInteger(0)));
 810                 case "float" -> module.add(new SPIRVOpTypeFloat(nextId(name), new SPIRVLiteralInteger(32)));
 811                 case "double" -> module.add(new SPIRVOpTypeFloat(nextId(name), new SPIRVLiteralInteger(64)));
 812                 case "ptrByte" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.CrossWorkgroup(), getType("byte")));
 813                 case "ptrInt" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("int")));
 814                 case "ptrLong" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("long")));
 815                 case "ptrFloat" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("float")));
 816                 case "short[]" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.CrossWorkgroup(), getType("short")));
 817                 case "int[]" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.CrossWorkgroup(), getType("int")));
 818                 case "long[]" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.CrossWorkgroup(), getType("long")));
 819                 case "float[]" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.CrossWorkgroup(), getType("float")));
 820                 case "double[]" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.CrossWorkgroup(), getType("double")));
 821                 case "boolean[]" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.CrossWorkgroup(), getType("boolean")));
 822                 case "ptrInt[]" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("int[]")));
 823                 case "ptrLong[]" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("long[]")));
 824                 case "ptrFloat[]" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("float[]")));
 825                 case "spirvdemo.GPUIndex" -> module.add(new SPIRVOpTypeStruct(nextId(name), new SPIRVMultipleOperands<>(getType("long"), getType("long"), getType("long"))));
 826                 case "ptrGPUIndex" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("spirvdemo.GPUIndex")));
 827                 case "ptrCrossGroupByte"-> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.CrossWorkgroup(), getType("byte")));
 828                 case "ptrPtrGPUIndex" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("ptrGPUIndex")));
 829                 case "ptrPtrByte" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("ptrByte")));
 830                 case "v3long" -> module.add(new SPIRVOpTypeVector(nextId(name), getType("long"), new SPIRVLiteralInteger(3)));
 831                 case "v8int" -> module.add(new SPIRVOpTypeVector(nextId(name), getType("int"), new SPIRVLiteralInteger(8)));
 832                 case "v8long" -> module.add(new SPIRVOpTypeVector(nextId(name), getType("long"), new SPIRVLiteralInteger(8)));
 833                 case "v16int" -> module.add(new SPIRVOpTypeVector(nextId(name), getType("int"), new SPIRVLiteralInteger(16)));
 834                 case "v8float" -> module.add(new SPIRVOpTypeVector(nextId(name), getType("float"), new SPIRVLiteralInteger(8)));
 835                 case "v16float" -> module.add(new SPIRVOpTypeVector(nextId(name), getType("float"), new SPIRVLiteralInteger(16)));
 836                 case "ptrV3long" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Input(), getType("v3long")));
 837                 case "ptrV8long" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("v8long")));
 838                 case "ptrV8int" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("v8int")));
 839                 case "ptrV16int" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("v16int")));
 840                 case "ptrV8float" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("v8float")));
 841                 case "ptrV16float" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("v16float")));
 842                 case "ptrPtrV8int" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("ptrV8int")));
 843                 case "ptrPtrV16int" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("ptrV16int")));
 844                 case "ptrPtrV8float" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("ptrV8float")));
 845                 case "ptrPtrV16float" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("ptrV16float")));
 846                 default -> unsupported("type", name);
 847             }
 848             moduleTypes.add(name);
 849         }
 850         return getId(name);
 851     }
 852 
 853     private Set<String> moduleConstants = new HashSet<>();
 854 
 855     private SPIRVId getConst(String name) {
 856         if (!moduleConstants.contains(name)) {
 857             switch (name) {
 858                 case "int_ZERO" -> module.add(new SPIRVOpConstant(getType("int"), nextId("int_ZERO"), new SPIRVContextDependentInt(new BigInteger("0"))));
 859                 case "int_ONE" -> module.add(new SPIRVOpConstant(getType("int"), nextId("int_ONE"), new SPIRVContextDependentInt(new BigInteger("1"))));
 860                 case "int_TWO" -> module.add(new SPIRVOpConstant(getType("int"), nextId("int_TWO"), new SPIRVContextDependentInt(new BigInteger("2"))));
 861                 case "int_EIGHT" -> module.add(new SPIRVOpConstant(getType("int"), nextId("int_EIGHT"), new SPIRVContextDependentInt(new BigInteger("8"))));
 862                 default -> unsupported("constant", name);
 863             }
 864             moduleConstants.add(name);
 865         }
 866         return getId(name);
 867     }
 868 
 869     private SPIRVOptionalOperand<SPIRVMemoryAccess> align(int align) {
 870         return new SPIRVOptionalOperand<>(SPIRVMemoryAccess.Aligned(new SPIRVLiteralInteger(align)));
 871     }
 872 
 873     private SPIRVOptionalOperand<SPIRVMemoryAccess> align(String type) {
 874         return align(alignment(type));
 875     }
 876 
 877     private SPIRVMultipleOperands<SPIRVId> spirvOperands(SPIRVId value, int count) {
 878         SPIRVId[] values = new SPIRVId[count];
 879         Arrays.fill(values, value);
 880         return new SPIRVMultipleOperands<>(values);
 881     }
 882 
 883     private SPIRVOptionalOperand<SPIRVMemoryAccess> none() {
 884         return new SPIRVOptionalOperand<>();
 885     }
 886 
 887     private SpirvResult globalSize(int index, SPIRVBlock spirvBlock) {
 888         SPIRVId longType = getType("long");
 889         SPIRVId v3long = getId("v3long");
 890         SPIRVId globalSizeId = getId("globalSize");
 891         SPIRVId globalSizes = nextId();
 892         spirvBlock.add(new SPIRVOpLoad(v3long, globalSizes, globalSizeId, align(32)));
 893         SPIRVId globalSize = nextId();
 894         spirvBlock.add(new SPIRVOpCompositeExtract(longType, globalSize, globalSizes, new SPIRVMultipleOperands<>(new SPIRVLiteralInteger(index))));
 895         return new SpirvResult(longType, null, globalSize);
 896     }
 897 
 898     private SpirvResult globalId(int index, SPIRVBlock spirvBlock) {
 899         SPIRVId longType = getType("long");
 900         SPIRVId v3long = getId("v3long");
 901         SPIRVId globalInvocationId = getId("globalInvocationId");
 902         SPIRVId globalIds = nextId();
 903         spirvBlock.add(new SPIRVOpLoad(v3long, globalIds, globalInvocationId, align(32)));
 904         SPIRVId globalIndex = nextId();
 905         spirvBlock.add(new SPIRVOpCompositeExtract(longType, globalIndex, globalIds, new SPIRVMultipleOperands<>(new SPIRVLiteralInteger(index))));
 906         return new SpirvResult(longType, null, globalIndex);
 907     }
 908 
 909     private SPIRVId nextId() {
 910         return module.getNextId();
 911     }
 912 
 913     private SPIRVId nextId(String name) {
 914         SPIRVId ans = nextId();
 915         ans.setName(name);
 916         symbols.putId(name, ans);
 917         module.add(new SPIRVOpName(ans, new SPIRVLiteralString(name)));
 918         return ans;
 919     }
 920 
 921     private static int counter = 0;
 922 
 923     private String nextTempTag() {
 924         counter++;
 925         return "temp_" + counter + "_";
 926     }
 927 
 928     private boolean isIntegerType(SPIRVId type) {
 929         String name = type.getName();
 930         return name.equals("short") || name.equals("int") || name.equals("long");
 931     }
 932 
 933     private boolean isFloatType(SPIRVId type) {
 934         String name = type.getName();
 935         return name.equals("float") || name.equals("double");
 936     }
 937 
 938     private boolean isVectorSpecies(String javaType) {
 939         return javaType.equals("VectorSpecies");
 940     }
 941 
 942     private boolean isVectorType(String javaType) {
 943         return javaType.equals("IntVector") || javaType.equals("FloatVector");
 944     }
 945 
 946     private void addId(String name, SPIRVId id) {
 947         symbols.putId(name, id);
 948     }
 949 
 950     private SPIRVId getId(String name) {
 951         SPIRVId ans = symbols.getId(name);
 952         assert ans != null : name + " not found";
 953         return ans;
 954     }
 955 
 956     private SPIRVId getIdOrNull(String name) {
 957         return symbols.getId(name);
 958     }
 959 
 960     private static Object map(Function<Object, Boolean> test, Object... args) {
 961         int len = args.length;
 962         assert len >= 2 && len % 2 == 0;
 963         int pairs = len / 2;
 964         for (int i = 0; i < pairs; i++) {
 965             if (test.apply(args[i])) return args[i + pairs];
 966         }
 967         throw new RuntimeException("No match: " + args[0]);
 968     }
 969 
 970     private void unsupported(String message, Object value) {
 971         throw new RuntimeException("Unsupported " + message + ": " + value);
 972     }
 973 
 974     private void addResult(Value value, SpirvResult result) {
 975         assert symbols.getResult(value) == null : "result already present";
 976         symbols.putResult(value, result);
 977     }
 978 
 979     private SpirvResult getResult(Value value) {
 980         return symbols.getResult(value);
 981     }
 982 
 983     private static class Symbols {
 984         private final HashMap<Value, SpirvResult> results;
 985         private final HashMap<String, SPIRVId> ids;
 986         private final HashMap<Block, SPIRVBlock> blocks;
 987         private final HashMap<Block, SPIRVOpLabel> labels;
 988 
 989         public Symbols() {
 990             this.results = new HashMap<>();
 991             this.ids = new HashMap<>();
 992             this.blocks = new HashMap<>();
 993             this.labels = new HashMap<>();
 994         }
 995 
 996         public void putId(String name, SPIRVId id) {
 997             ids.put(name, id);
 998         }
 999 
1000         public SPIRVId getId(String name) {
1001             return ids.get(name);
1002         }
1003 
1004         public void putBlock(Block block, SPIRVBlock spirvBlock) {
1005             blocks.put(block, spirvBlock);
1006         }
1007 
1008         public SPIRVBlock getBlock(Block block) {
1009             return blocks.get(block);
1010         }
1011 
1012         public void putLabel(Block block, SPIRVOpLabel spirvLabel) {
1013             labels.put(block, spirvLabel);
1014         }
1015 
1016         public SPIRVOpLabel getLabel(Block block) {
1017             return labels.get(block);
1018         }
1019 
1020         public void putResult(Value value, SpirvResult result) {
1021             results.put(value, result);
1022         }
1023 
1024         public SpirvResult getResult(Value value) {
1025             return results.get(value);
1026         }
1027 
1028         public String toString() {
1029             return String.format("results %s\n\nids %s\n\nblocks %s\nlabels %s\n", results.keySet(), ids.keySet(), blocks.keySet(), labels.keySet());
1030         }
1031     }
1032 }