1 /*
   2  * Copyright (c) 2024 Intel Corporation. All rights reserved.
   3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
   4  *
   5  * This code is free software; you can redistribute it and/or modify it
   6  * under the terms of the GNU General Public License version 2 only, as
   7  * published by the Free Software Foundation.  Oracle designates this
   8  * particular file as subject to the "Classpath" exception as provided
   9  * by Oracle in the LICENSE file that accompanied this code.
  10  *
  11  * This code is distributed in the hope that it will be useful, but WITHOUT
  12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
  13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
  14  * version 2 for more details (a copy is included in the LICENSE file that
  15  * accompanied this code).
  16  *
  17  * You should have received a copy of the GNU General Public License version
  18  * 2 along with this work; if not, write to the Free Software Foundation,
  19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
  20  *
  21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
  22  * or visit www.oracle.com if you need additional information or have any
  23  * questions.
  24  */
  25 
  26 package intel.code.spirv;
  27 
  28 import java.util.List;
  29 import java.util.Arrays;
  30 import java.util.HashMap;
  31 import java.util.Set;
  32 import java.util.HashSet;
  33 import java.util.function.Function;
  34 import java.io.IOException;
  35 import java.io.File;
  36 import java.io.FileOutputStream;
  37 import java.io.ByteArrayInputStream;
  38 import java.io.ByteArrayOutputStream;
  39 import java.io.PrintStream;
  40 import java.nio.ByteBuffer;
  41 import java.nio.ByteOrder;
  42 import java.nio.channels.FileChannel;
  43 import java.math.BigInteger;
  44 import jdk.incubator.vector.VectorSpecies;
  45 import jdk.incubator.vector.VectorOperators;
  46 import jdk.incubator.vector.Vector;
  47 import jdk.incubator.vector.IntVector;
  48 import jdk.incubator.vector.FloatVector;
  49 import java.lang.foreign.MemorySegment;
  50 import java.lang.foreign.ValueLayout;
  51 import java.lang.reflect.code.Block;
  52 import java.lang.reflect.code.Body;
  53 import java.lang.reflect.code.Op;
  54 import java.lang.reflect.code.Value;
  55 import java.lang.reflect.code.op.CoreOp;
  56 import java.lang.reflect.code.TypeElement;
  57 import java.lang.reflect.code.type.MethodRef;
  58 import java.lang.reflect.code.type.ClassType;
  59 import java.lang.reflect.code.type.JavaType;
  60 import uk.ac.manchester.beehivespirvtoolkit.lib.SPIRVHeader;
  61 import uk.ac.manchester.beehivespirvtoolkit.lib.SPIRVModule;
  62 import uk.ac.manchester.beehivespirvtoolkit.lib.SPIRVFunction;
  63 import uk.ac.manchester.beehivespirvtoolkit.lib.SPIRVBlock;
  64 import uk.ac.manchester.beehivespirvtoolkit.lib.instructions.*;
  65 import uk.ac.manchester.beehivespirvtoolkit.lib.instructions.operands.*;
  66 import uk.ac.manchester.beehivespirvtoolkit.lib.disassembler.Disassembler;
  67 import uk.ac.manchester.beehivespirvtoolkit.lib.disassembler.SPIRVDisassemblerOptions;
  68 import uk.ac.manchester.beehivespirvtoolkit.lib.disassembler.SPVByteStreamReader;
  69 
  70 public class SpirvModuleGenerator {
  71     public static MemorySegment generateModule(String moduleName, CoreOp.FuncOp func) {
  72         SpirvOps.FuncOp spirvFunc = TranslateToSpirvModel.translateFunction(func);
  73         MemorySegment module = SpirvModuleGenerator.generateModule(moduleName, spirvFunc);
  74         return module;
  75     }
  76 
  77     public static MemorySegment generateModule(String moduleName, SpirvOps.FuncOp func) {
  78         return new SpirvModuleGenerator().generateModuleInternal(moduleName, func);
  79     }
  80 
  81     public static void writeModuleToFile(MemorySegment module, String filepath)  {
  82         ByteBuffer buffer = module.asByteBuffer();
  83         File out = new File(filepath);
  84         try (FileChannel channel = new FileOutputStream(out, false).getChannel()) {
  85             channel.write(buffer);
  86         }
  87         catch (IOException e)  {
  88             throw new RuntimeException(e);
  89         }
  90     }
  91 
  92     public static String disassembleModule(MemorySegment module) {
  93         SPVByteStreamReader input = new SPVByteStreamReader(new ByteArrayInputStream(module.toArray(ValueLayout.JAVA_BYTE)));
  94         ByteArrayOutputStream out = new ByteArrayOutputStream();
  95         try (PrintStream ps = new PrintStream(out))  {
  96             SPIRVDisassemblerOptions options = new SPIRVDisassemblerOptions(false, false, false, false, true);
  97             Disassembler dis = new Disassembler(input, ps, options);
  98             dis.run();
  99         }
 100         catch (Exception e) {
 101             throw new RuntimeException(e);
 102         }
 103         return new String(out.toByteArray());
 104     }
 105 
 106     private record SpirvResult(SPIRVId type, SPIRVId address, SPIRVId value) {}
 107 
 108     private final SPIRVModule module;
 109     private final Symbols symbols;
 110 
 111     private SpirvModuleGenerator() {
 112         this.module = new SPIRVModule(new SPIRVHeader(1, 2, 32, 0, 0));
 113         this.symbols = new Symbols();
 114     }
 115 
 116     private MemorySegment generateModuleInternal(String moduleName, SpirvOps.FuncOp func) {
 117         initModule();
 118         generateFunction(moduleName, moduleName, func);
 119         ByteBuffer buffer = ByteBuffer.allocateDirect(module.getByteCount());
 120         buffer.order(ByteOrder.LITTLE_ENDIAN);
 121         module.close().write(buffer);
 122         buffer.flip();
 123         return MemorySegment.ofBuffer(buffer);
 124     }
 125 
 126     private void generateFunction(String moduleName, String fnName, SpirvOps.FuncOp func) {
 127         TypeElement returnType = func.invokableType().returnType();
 128         SPIRVId functionID = nextId(fnName);
 129         String signature = func.invokableType().returnType().toString();
 130         List<TypeElement> paramTypes = func.invokableType().parameterTypes();
 131         // build signature string
 132         for (int i = 0; i < paramTypes.size(); i++) {
 133             signature += "_" + paramTypes.get(i).toString();
 134         }
 135         // declare function type if not already present
 136         SPIRVId functionSig = getIdOrNull(signature);
 137         if (functionSig == null) {
 138             SPIRVId[] typeIdsArray = new SPIRVId[paramTypes.size()];
 139             for (int i = 0; i < paramTypes.size(); i++) {
 140                 typeIdsArray[i] = spirvType(paramTypes.get(i).toString());
 141             }
 142             functionSig = nextId(fnName + "Signature");
 143             module.add(new SPIRVOpTypeFunction(functionSig, spirvType(returnType.toString()), new SPIRVMultipleOperands<>(typeIdsArray)));
 144             addId(signature, functionSig);
 145         }
 146         // declare function as modeule entry point
 147         SPIRVId spirvReturnType = spirvType(returnType.toString());
 148         SPIRVFunction function = (SPIRVFunction)module.add(new SPIRVOpFunction(spirvReturnType, functionID, SPIRVFunctionControl.DontInline(), functionSig));
 149         SPIRVOpLabel entryPoint = new SPIRVOpLabel(nextId());
 150         SPIRVBlock entryBlock = (SPIRVBlock)function.add(entryPoint);
 151         SPIRVMultipleOperands<SPIRVId> operands = new SPIRVMultipleOperands<>(getId("globalInvocationId"), getId("globalSize"), getId("subgroupSize"), getId("subgroupId"));
 152         module.add(new SPIRVOpEntryPoint(SPIRVExecutionModel.Kernel(), functionID, new SPIRVLiteralString(fnName), operands));
 153 
 154         translateBody(func.body(), function, entryBlock);
 155         function.add(new SPIRVOpFunctionEnd());
 156     }
 157 
 158     private void translateBody(Body body, SPIRVFunction function, SPIRVBlock entryBlock) {
 159         int labelNumber = 0;
 160         SPIRVBlock spirvBlock = entryBlock;
 161         for (int bi = 1; bi < body.blocks().size(); bi++)  {
 162             Block block = body.blocks().get(bi);
 163             String blockName = String.valueOf(block.hashCode());
 164             SPIRVOpLabel blockLabel = new SPIRVOpLabel(nextId());
 165             SPIRVBlock newBlock = (SPIRVBlock)function.add(blockLabel);
 166             symbols.putBlock(block, newBlock);
 167             symbols.putLabel(block, blockLabel);
 168         }
 169         for (Value param : body.entryBlock().parameters()) {
 170             SPIRVId paramId = nextId();
 171             addResult(param, new SpirvResult(spirvType(param.type().toString()), null, paramId));
 172         }
 173         for (int bi = 0; bi < body.blocks().size(); bi++)  {
 174             Block block = body.blocks().get(bi);
 175             if (bi > 0) {
 176                 spirvBlock = symbols.getBlock(block);
 177             }
 178             List<Op> ops = block.ops();
 179             for (Op op : block.ops()) {
 180                 // debug("---------- spirv op = %s", op.toText());
 181                 switch (op)  {
 182                     case SpirvOps.VariableOp vop -> {
 183                         String typeName = vop.varType().toString();
 184                         SPIRVId type = spirvType(typeName);
 185                         SPIRVId varType = spirvVariableType(type);
 186                         SPIRVId var = nextId(vop.varName());
 187                         spirvBlock.add(new SPIRVOpVariable(varType, var, SPIRVStorageClass.Function(), new SPIRVOptionalOperand<>()));
 188                         addResult(vop.result(), new SpirvResult(varType, var, null));
 189                     }
 190                     case SpirvOps.FunctionParameterOp fpo -> {
 191                         SPIRVId result = nextId();
 192                         SPIRVId type = spirvType(fpo.resultType().toString());
 193                         function.add(new SPIRVOpFunctionParameter(type, result));
 194                         addResult(fpo.result(), new SpirvResult(type, null, result));
 195                     }
 196                     case SpirvOps.LoadOp lo -> {
 197                         if (((JavaType)lo.resultType()).equals(JavaType.type(VectorSpecies.class))) {
 198                             addResult(lo.result(), new SpirvResult(getType("int"), null, getConst("int_EIGHT")));
 199                         }
 200                         else {
 201                             SPIRVId type = spirvType(lo.resultType().toString());
 202                             SpirvResult toLoad = getResult(lo.operands().get(0));
 203                             SPIRVId varAddr = toLoad.address();
 204                             SPIRVId result = nextId();
 205                             spirvBlock.add(new SPIRVOpLoad(type, result, varAddr, align(type.getName())));
 206                             addResult(lo.result(), new SpirvResult(type, varAddr, result));
 207                         }
 208                     }
 209                     case SpirvOps.StoreOp so -> {
 210                         SpirvResult var = getResult(so.operands().get(0));
 211                         SPIRVId varAddr = var.address();
 212                         SPIRVId value = getResult(so.operands().get(1)).value();
 213                         spirvBlock.add(new SPIRVOpStore(varAddr, value, align(var.type().getName())));
 214                     }
 215                     case SpirvOps.IAddOp _, SpirvOps.FAddOp _ -> {
 216                         SPIRVId intType = getType("int");
 217                         SPIRVId longType = getType("long");
 218                         SPIRVId floatType = getType("float");
 219                         SPIRVId lhs = getResult(op.operands().get(0)).value();
 220                         SPIRVId rhs = getResult(op.operands().get(1)).value();
 221                         SPIRVId lhsType = spirvType(op.resultType().toString());
 222                         SPIRVId ans = nextId();
 223                         if (lhsType == intType) spirvBlock.add(new SPIRVOpIAdd(intType, ans, lhs, rhs));
 224                         else if (lhsType == longType) spirvBlock.add(new SPIRVOpIAdd(longType, ans, lhs, rhs));
 225                         else if (lhsType == floatType) spirvBlock.add(new SPIRVOpFAdd(floatType, ans, lhs, rhs));
 226                         else unsupported("type", lhsType.getName());
 227                         addResult(op.result(), new SpirvResult(lhsType, null, ans));
 228                     }
 229                     case SpirvOps.IMulOp _, SpirvOps.FMulOp _, SpirvOps.IDivOp _, SpirvOps.FDivOp _ -> {
 230                         SPIRVId intType = getType("int");
 231                         SPIRVId longType = getType("long");
 232                         SPIRVId floatType = getType("float");
 233                         SPIRVId lhs = getResult(op.operands().get(0)).value();
 234                         SPIRVId rhs = getResult(op.operands().get(1)).value();
 235                         SPIRVId lhsType = spirvType(op.resultType().toString());
 236                         SPIRVId rhsType = getResult(op.operands().get(1)).type();
 237                         SPIRVId ans = nextId();
 238                         if (lhsType == intType) {
 239                             if (op instanceof SpirvOps.IMulOp) spirvBlock.add(new SPIRVOpIMul(intType, ans, lhs, rhs));
 240                             else if (op instanceof SpirvOps.IDivOp) spirvBlock.add(new SPIRVOpSDiv(intType, ans, lhs, rhs));
 241                         }
 242                         else if (lhsType == longType) {
 243                             SPIRVId rhsId = rhsType == intType ? nextId() : rhs;
 244                             if (rhsType == intType) spirvBlock.add(new SPIRVOpSConvert(longType, rhsId, rhs));
 245                             if (op instanceof SpirvOps.IMulOp) spirvBlock.add(new SPIRVOpIMul(longType, ans, lhs, rhsId));
 246                             else if (op instanceof SpirvOps.IDivOp) spirvBlock.add(new SPIRVOpSDiv(longType, ans, lhs, rhs));
 247                         }
 248                         else if (lhsType == floatType) {
 249                             if (op instanceof SpirvOps.FMulOp) spirvBlock.add(new SPIRVOpFMul(floatType, ans, lhs, rhs));
 250                             else if (op instanceof SpirvOps.FDivOp) spirvBlock.add(new SPIRVOpFDiv(floatType, ans, lhs, rhs));
 251                         }
 252                         else unsupported("type", lhsType);
 253                         addResult(op.result(), new SpirvResult(lhsType, null, ans));
 254                     }
 255                     case SpirvOps.ModOp mop -> {
 256                         SPIRVId type = getType(mop.operands().get(0).type().toString());
 257                         SPIRVId lhs = getResult(mop.operands().get(0)).value();
 258                         SPIRVId rhs = getResult(mop.operands().get(1)).value();
 259                         SPIRVId result = nextId();
 260                         spirvBlock.add(new SPIRVOpUMod(type, result, lhs, rhs));
 261                         addResult(mop.result(), new SpirvResult(type, null, result));
 262                     }
 263                     case SpirvOps.IEqualOp eqop -> {
 264                         SPIRVId boolType = getType("bool");
 265                         SPIRVId intType = getType("int");
 266                         SPIRVId longType = getType("long");
 267                         SPIRVId floatType = getType("float");
 268                         SPIRVId lhs = getResult(op.operands().get(0)).value();
 269                         SPIRVId rhs = getResult(op.operands().get(1)).value();
 270                         SPIRVId lhsType = spirvType(op.resultType().toString());
 271                         SPIRVId ans = nextId();
 272                         if (lhsType == intType) spirvBlock.add(new SPIRVOpIEqual(boolType, ans, lhs, rhs));
 273                         else if (lhsType == longType) spirvBlock.add(new SPIRVOpIEqual(boolType, ans, lhs, rhs));
 274                         else unsupported("type", lhsType.getName());
 275                         addResult(op.result(), new SpirvResult(lhsType, null, ans));
 276                     }
 277                     case SpirvOps.CallOp call -> {
 278                         if (call.callDescriptor().equals(MethodRef.ofString("spirvdemo.IntArray::get(long)int")) ||
 279                             call.callDescriptor().equals(MethodRef.ofString("spirvdemo.FloatArray::get(long)float"))) {
 280                             SPIRVId longType = getType("long");
 281                             String arrayTypeName = call.operands().get(0).type().toString();
 282                             SpirvResult arrayResult = getResult(call.operands().get(0));
 283                             SPIRVId arrayAddr = arrayResult.address();
 284                             SPIRVId arrayType = spirvType(arrayTypeName);
 285                             SPIRVId elementType = spirvElementType(arrayTypeName);
 286                             int nIndexes = call.operands().size() - 1;
 287                             SPIRVId index = getResult(call.operands().get(1)).value();
 288                             SPIRVId array = nextId();
 289                             spirvBlock.add(new SPIRVOpLoad(arrayType, array, arrayAddr, align(arrayType.getName())));
 290                             SPIRVId resultAddr = nextId();
 291                             spirvBlock.add(new SPIRVOpInBoundsPtrAccessChain(arrayType, resultAddr, array, index, new SPIRVMultipleOperands<>()));
 292                             SPIRVId result = nextId();
 293                             spirvBlock.add(new SPIRVOpLoad(elementType, result, resultAddr, align(elementType.getName())));
 294                             addResult(call.result(), new SpirvResult(elementType, resultAddr, result));
 295                         }
 296                         else if (call.callDescriptor().equals(MethodRef.ofString("spirvdemo.IntArray::set(long, int)void")) ||
 297                                 call.callDescriptor().equals(MethodRef.ofString("spirvdemo.FloatArray::set(long, float)void"))) {
 298                             SPIRVId longType = getType("long");
 299                             String arrayTypeName = call.operands().get(0).type().toString();
 300                             SpirvResult arrayResult = getResult(call.operands().get(0));
 301                             SPIRVId arrayAddr = arrayResult.address();
 302                             SPIRVId arrayType = spirvType(arrayTypeName);
 303                             SPIRVId elementType = spirvElementType(arrayTypeName);
 304                             int nIndexes = call.operands().size() - 2;
 305                             int valueIndex = nIndexes + 1;
 306                             SPIRVId index = getResult(call.operands().get(1)).value();
 307                             SPIRVId array = nextId();
 308                             spirvBlock.add(new SPIRVOpLoad(arrayType, array, arrayAddr, align(arrayType.getName())));
 309                             SPIRVId dest = nextId();
 310                             spirvBlock.add(new SPIRVOpInBoundsPtrAccessChain(arrayType, dest, array, index, new SPIRVMultipleOperands<>()));
 311                             SPIRVId value = getResult(call.operands().get(valueIndex)).value();
 312                             spirvBlock.add(new SPIRVOpStore(dest, value, align(elementType.getName())));
 313                         }
 314                         else if (call.callDescriptor().equals(MethodRef.method(IntVector.class, "fromArray", IntVector.class, VectorSpecies.class, int[].class, int.class))
 315                               || call.callDescriptor().equals(MethodRef.method(FloatVector.class, "fromArray", FloatVector.class, VectorSpecies.class, float[].class, int.class))) {
 316                             SPIRVId oclExtension = getId("oclExtension");
 317                             SpirvResult speciesResult = getResult(call.operands().get(0));
 318                             SpirvResult arrayResult = getResult(call.operands().get(1));
 319                             String arrayType = arrayResult.type().getName();
 320                             int laneCount = 8;  //TODO: remove hard code, instruction below needs a literal
 321                             String vTypeName = ((ClassType)call.callDescriptor().refType()).toClassName();
 322                             SPIRVId vType = spirvVectorType(vTypeName, laneCount);
 323                             SPIRVId array = arrayResult.value();
 324                             SPIRVId index = getResult(call.operands().get(2)).value();
 325                             SPIRVId vectorIndex = nextId();
 326                             spirvBlock.add(new SPIRVOpSDiv(getType("int"), vectorIndex, index, speciesResult.value()));
 327                             SPIRVId longIndex = nextId();
 328                             spirvBlock.add(new SPIRVOpSConvert(getType("long"), longIndex, vectorIndex));
 329                             SPIRVId vector = nextId();
 330                             SPIRVMultipleOperands<SPIRVId> operands = new SPIRVMultipleOperands<>(longIndex, array, new SPIRVId(laneCount)); // TODO: lanes must be a literal
 331                             spirvBlock.add(new SPIRVOpExtInst(vType, vector, oclExtension, new SPIRVLiteralExtInstInteger(171, "vloadn"), operands));
 332                             addResult(call.result(), new SpirvResult(vType, null, vector));
 333                         }
 334                         else if (call.callDescriptor().equals(MethodRef.method(IntVector.class, "fromMemorySegment", IntVector.class, VectorSpecies.class, MemorySegment.class, long.class, ByteOrder.class))
 335                               || call.callDescriptor().equals(MethodRef.method(FloatVector.class, "fromMemorySegment", FloatVector.class, VectorSpecies.class, MemorySegment.class, long.class, ByteOrder.class))) {
 336                             SPIRVId oclExtension = getId("oclExtension");
 337                             SPIRVId species = getResult(call.operands().get(0)).value();
 338                             SPIRVId lanesLong = nextId();
 339                             spirvBlock.add(new SPIRVOpSConvert(getType("long"), lanesLong, species));
 340                             int laneCount = 8; //TODO: remove hard code, vloadn instruction below needs a literal lane count, get value from env
 341                             SPIRVId segment = getResult(call.operands().get(1)).value();
 342                             String vTypeName = ((ClassType)call.callDescriptor().refType()).toClassName();
 343                             SPIRVId vType = spirvVectorType(vTypeName, laneCount);
 344                             SPIRVId temp = nextId();
 345                             spirvBlock.add(new SPIRVOpConvertPtrToU(getType("long"), temp, segment));
 346                             SPIRVId typedSegment = nextId();
 347                             SPIRVId pointerType = (SPIRVId)map(x -> x.equals(vTypeName), "jdk.incubator.vector.IntVector", "jdk.incubator.vector.FloatVector", getType("ptrInt"), getType("ptrFloat"));
 348                             spirvBlock.add(new SPIRVOpConvertUToPtr(pointerType, typedSegment, temp));
 349                             SPIRVId offset = getResult(call.operands().get(2)).value();
 350                             SPIRVId vectorIndex = nextId();
 351                             spirvBlock.add(new SPIRVOpSDiv(getType("long"), vectorIndex, offset, lanesLong)); // divide by lane count
 352                             SPIRVId finalIndex = nextId();
 353                             SPIRVId vector = nextId();
 354                             SPIRVMultipleOperands<SPIRVId> operands = new SPIRVMultipleOperands<>(vectorIndex, typedSegment, new SPIRVId(laneCount)); // TODO: lanes must be a literal
 355                             spirvBlock.add(new SPIRVOpExtInst(vType, vector, oclExtension, new SPIRVLiteralExtInstInteger(171, "vloadn"), operands));
 356                             addResult(call.result(), new SpirvResult(vType, null, vector));
 357                         }
 358                         else if (call.callDescriptor().equals(MethodRef.method(IntVector.class, "intoArray", void.class, int[].class, int.class))
 359                               || call.callDescriptor().equals(MethodRef.method(FloatVector.class, "intoArray", void.class, float[].class, int.class))) {
 360                             SPIRVId oclExtension = getId("oclExtension");
 361                             SpirvResult vectorResult = getResult(call.operands().get(0));
 362                             SPIRVId vector = vectorResult.value();
 363                             SPIRVId vectorType = vectorResult.type();
 364                             SpirvResult arrayResult = getResult(call.operands().get(1));
 365                             SPIRVId array = arrayResult.value();
 366                             SPIRVId index = getResult(call.operands().get(2)).value();
 367                             SPIRVId vectorIndex = nextId();
 368                             spirvBlock.add(new SPIRVOpShiftRightArithmetic(getType("int"), vectorIndex, index, vectorExponent(vectorType.getName())));
 369                             SPIRVId longIndex = nextId();
 370                             spirvBlock.add(new SPIRVOpSConvert(getType("long"), longIndex, vectorIndex));
 371                             SPIRVMultipleOperands<SPIRVId> operandsR = new SPIRVMultipleOperands<>(vector, longIndex, array);
 372                             spirvBlock.add(new SPIRVOpExtInst(getType("void"), nextId(), oclExtension, new SPIRVLiteralExtInstInteger(172, "vstoren"), operandsR));
 373                         }
 374                         else if (call.callDescriptor().equals(MethodRef.method(IntVector.class, "intoMemorySegment", void.class, MemorySegment.class, long.class, ByteOrder.class))
 375                               || call.callDescriptor().equals(MethodRef.method(FloatVector.class, "intoMemorySegment", void.class, MemorySegment.class, long.class, ByteOrder.class))) {
 376                             SPIRVId oclExtension = getId("oclExtension");
 377                             SpirvResult vectorResult = getResult(call.operands().get(0));
 378                             SPIRVId vector = vectorResult.value();
 379                             SPIRVId vectorType = vectorResult.type();
 380                             SpirvResult segmentResult = getResult(call.operands().get(1));;
 381                             SPIRVId segment = segmentResult.value();
 382                             SPIRVId temp = nextId();
 383                             spirvBlock.add(new SPIRVOpConvertPtrToU(getType("long"), temp, segment));
 384                             SPIRVId typedSegment = nextId();
 385                             String vectorElementType = vectorElementType(vectorType).getName();
 386                             SPIRVId pointerType = (SPIRVId)map(x -> x.equals(vectorElementType), "int", "float", getType("ptrInt"), getType("ptrFloat"));
 387                             spirvBlock.add(new SPIRVOpConvertUToPtr(pointerType, typedSegment, temp));
 388                             SPIRVId offset = getResult(call.operands().get(2)).value();
 389                             SPIRVId vectorIndex = nextId();
 390                             int laneCount = laneCount(vectorType.getName());
 391                             spirvBlock.add(new SPIRVOpShiftRightArithmetic(getType("long"), vectorIndex, offset, vectorExponent(vectorType.getName())));
 392                             SPIRVMultipleOperands<SPIRVId> operandsR = new SPIRVMultipleOperands<>(vector, vectorIndex, typedSegment);
 393                             spirvBlock.add(new SPIRVOpExtInst(getId("void"), nextId(), oclExtension, new SPIRVLiteralExtInstInteger(172, "vstoren"), operandsR));
 394                         }
 395                         else if (call.callDescriptor().equals(MethodRef.method(IntVector.class, "reduceLanes", int.class, VectorOperators.Associative.class))
 396                               || call.callDescriptor().equals(MethodRef.method(FloatVector.class, "reduceLanes", float.class, VectorOperators.Associative.class))) {
 397                             SpirvResult vectorResult = getResult(call.operands().get(0));
 398                             SPIRVId vectorType = vectorResult.type();
 399                             SPIRVId vector = vectorResult.value();
 400                             String vTypeName = vectorType.getName();
 401                             SPIRVId elementType = vectorElementType(vectorType);
 402                             Op reduceOp = ((Op.Result)call.operands().get(1)).op();
 403                             if (reduceOp instanceof SpirvOps.FieldLoadOp flo) {
 404                                 assert flo.fieldDescriptor().refType().equals(JavaType.type(VectorOperators.class));
 405                                 assert flo.fieldDescriptor().name().equals("ADD");
 406                                 String operation = flo.fieldDescriptor().name();
 407                             }
 408                             else unsupported("operation expression", reduceOp.toText());
 409                             String tempTag = nextTempTag();
 410                             SPIRVId temp_0 = nextId(tempTag + 0);
 411                             spirvBlock.add(new SPIRVOpCompositeExtract(elementType, temp_0, vector, new SPIRVMultipleOperands<>(new SPIRVLiteralInteger(0))));
 412                             for (int lane = 1; lane < laneCount(vectorType.getName()); lane++) {
 413                                 SPIRVId temp = nextId(tempTag + lane);
 414                                 SPIRVId element = nextId();
 415                                 spirvBlock.add(new SPIRVOpCompositeExtract(elementType, element, vector, new SPIRVMultipleOperands<>(new SPIRVLiteralInteger(lane))));
 416                                 if (elementType == getType("int")) {
 417                                     spirvBlock.add(new SPIRVOpIAdd(elementType, temp, getId(tempTag + (lane - 1)), element));
 418                                 }
 419                                 else if (elementType == getType("float")) {
 420                                     spirvBlock.add(new SPIRVOpFAdd(elementType, temp, getId(tempTag + (lane - 1)), element));
 421                                 }
 422                                 else unsupported("type", elementType.getName());
 423                             }
 424                             addResult(call.result(), new SpirvResult(elementType, null, getId(tempTag + (laneCount(vectorType.getName()) - 1))));
 425                         }
 426                         else if (call.callDescriptor().equals(MethodRef.method(IntVector.class, "add", IntVector.class, Vector.class))
 427                               || call.callDescriptor().equals(MethodRef.method(IntVector.class, "mul", IntVector.class, Vector.class))
 428                               || call.callDescriptor().equals(MethodRef.method(FloatVector.class, "add", FloatVector.class, Vector.class))
 429                               || call.callDescriptor().equals(MethodRef.method(FloatVector.class, "mul", FloatVector.class, Vector.class))) {
 430                             SPIRVId oclExtension = getId("oclExtension");
 431                             SpirvResult lhsResult = getResult(call.operands().get(0));
 432                             SPIRVId lhsType = lhsResult.type();
 433                             SPIRVId lhs = lhsResult.value();
 434                             SPIRVId rhs = getResult(call.operands().get(1)).value();
 435                             SPIRVId add = nextId();
 436                             if (call.callDescriptor().name().equals("add")) {
 437                                 spirvBlock.add(lhsType.getName().endsWith("int") ? new SPIRVOpIAdd(lhsType, add, lhs, rhs) : new SPIRVOpFAdd(lhsType, add, lhs, rhs));
 438                             }
 439                             else if (call.callDescriptor().name().equals("mul")) {
 440                                 spirvBlock.add(lhsType.getName().endsWith("int") ? new SPIRVOpIMul(lhsType, add, lhs, rhs) : new SPIRVOpFMul(lhsType, add, lhs, rhs));
 441                             }
 442                             addResult(call.result(), new SpirvResult(lhsType, null, add));
 443                         }
 444                         else if (call.callDescriptor().equals(MethodRef.method(FloatVector.class, "fma", FloatVector.class, Vector.class, Vector.class))) {
 445                             SPIRVId oclExtension = getId("oclExtension");
 446                             SpirvResult aResult = getResult(call.operands().get(0));
 447                             SPIRVId vType = aResult.type();
 448                             SPIRVId a = aResult.value();
 449                             SPIRVId b = getResult(call.operands().get(1)).value();
 450                             SPIRVId c = getResult(call.operands().get(2)).value();
 451                             String vTypeStr = vType.getName();
 452                             assert vTypeStr.endsWith("float");
 453                             SPIRVId result  = nextId();
 454                             SPIRVMultipleOperands<SPIRVId> operands = new SPIRVMultipleOperands<>(a, b, c);
 455                             spirvBlock.add(new SPIRVOpExtInst(vType, result, oclExtension, new SPIRVLiteralExtInstInteger(26, "fma"), operands));
 456                             addResult(call.result(), new SpirvResult(vType, null, result));
 457                         }
 458                         else if (call.callDescriptor().equals(MethodRef.method(IntVector.class, "zero", IntVector.class, VectorSpecies.class))
 459                              || call.callDescriptor().equals(MethodRef.method(FloatVector.class, "zero", FloatVector.class, VectorSpecies.class))) {
 460                             SpirvResult speciesResult = getResult(call.operands().get(0));
 461                             SPIRVId vType = spirvType(((ClassType)call.callDescriptor().refType()).toClassName());
 462                             String elementType = vectorElementType(vType).getName();
 463                             SPIRVId value = getId(elementType + "_ZERO");
 464                             int laneCount = laneCount(vType.getName());
 465                             assert laneCount == 8 || laneCount == 16;
 466                             SPIRVId vector = nextId();
 467                             SPIRVMultipleOperands<SPIRVId> operands = spirvOperands(value, laneCount);
 468                             spirvBlock.add(new SPIRVOpCompositeConstruct(vType, vector, operands));
 469                             addResult(call.result(), new SpirvResult(vType, null, vector));
 470                         }
 471                         else if (call.callDescriptor().equals(MethodRef.method(IntVector.class, "lane", int.class, int.class))
 472                               || call.callDescriptor().equals(MethodRef.method(FloatVector.class, "lane", float.class, int.class)))  {
 473                             SpirvResult lhsResult = getResult(call.operands().get(0));
 474                             SPIRVId lhsType = lhsResult.type();
 475                             SPIRVId lhs = lhsResult.value();
 476                             String vTypeStr = lhsType.getName();
 477                             SPIRVId vType = lhsResult.type();
 478                             SPIRVId elementType = vectorElementType(vType);
 479                             SPIRVId result = nextId();
 480                             Op laneOp = ((Op.Result)call.operands().get(1)).op();
 481                             assert laneOp instanceof SpirvOps.ConstantOp;
 482                             int lane = (int)((SpirvOps.ConstantOp)laneOp).value();
 483                             spirvBlock.add(new SPIRVOpCompositeExtract(elementType, result, lhsResult.value(), new SPIRVMultipleOperands<>(new SPIRVLiteralInteger(lane))));
 484                             addResult(call.result(), new SpirvResult(elementType, null, result));
 485                         }
 486                         else if (call.callDescriptor().equals(MethodRef.method(VectorSpecies.class, "length", int.class))) {
 487                             addResult(call.result(), new SpirvResult(getType("int"), null, getConst("int_EIGHT"))); // TODO: remove hardcode
 488                         }
 489                         else unsupported("method", call.callDescriptor());
 490                     }
 491                     case SpirvOps.ConstantOp cop -> {
 492                         SPIRVId type = spirvType(cop.resultType().toString());
 493                         SPIRVId result = nextId();
 494                         Object value = cop.value();
 495                         if (type == getType("int")) {
 496                             module.add(new SPIRVOpConstant(type, result, new SPIRVContextDependentInt(new BigInteger(String.valueOf(value)))));
 497                         }
 498                         else if (type == getType("long")) {
 499                             module.add(new SPIRVOpConstant(type, result, new SPIRVContextDependentLong(new BigInteger(String.valueOf(value)))));
 500                         }
 501                         else if (type == getType("float")) {
 502                             module.add(new SPIRVOpConstant(type, result, new SPIRVContextDependentFloat((float)value)));
 503                         }
 504                         else unsupported("type", cop.resultType());
 505                         addResult(cop.result(), new SpirvResult(type, null, result));
 506                     }
 507                     case SpirvOps.ConvertOp scop -> {
 508                         SPIRVId toType = spirvType(scop.resultType().toString());
 509                         SPIRVId to = nextId();
 510                         SpirvResult valueResult = getResult(scop.operands().get(0));
 511                         SPIRVId from = valueResult.value();
 512                         SPIRVId fromType = valueResult.type();
 513                         if (isIntegerType(fromType)) {
 514                             if (isIntegerType(toType)) {
 515                                 spirvBlock.add(new SPIRVOpSConvert(toType, to, from));
 516                             }
 517                             else if (isFloatType(toType)) {
 518                                 spirvBlock.add(new SPIRVOpConvertSToF(toType, to, from));
 519                             }
 520                             else unsupported("conversion type", scop.resultType());
 521                         }
 522                         else unsupported("conversion type", scop.operands().get(0));
 523                         addResult(scop.result(), new SpirvResult(toType, null, to));
 524                     }
 525                     case SpirvOps.InBoundAccessChainOp iacop -> {
 526                         SPIRVId type = spirvType(iacop.resultType().toString());
 527                         SPIRVId result = nextId();
 528                         SPIRVId object = getResult(iacop.operands().get(0)).value();
 529                         SPIRVId index = getResult(iacop.operands().get(1)).value();
 530                         spirvBlock.add(new SPIRVOpInBoundsPtrAccessChain(type, result, object, index, new SPIRVMultipleOperands<>()));
 531                         addResult(iacop.result(), new SpirvResult(type, result, null));
 532                     }
 533                     case SpirvOps.FieldLoadOp flo -> {
 534                         if (flo.operands().size() > 0 && flo.operands().get(0).type().equals(JavaType.ofString("spirvdemo.GPU$Index"))) {
 535                             SpirvResult result;
 536                             int group = -1;
 537                             int index = -1;
 538                             String fieldName = flo.fieldDescriptor().name();
 539                             switch(fieldName) {
 540                                 case "x": group = 0; index = 0; break;
 541                                 case "y": group = 0; index = 1; break;
 542                                 case "z": group = 0; index = 2; break;
 543                                 case "w": group = 1; index = 0; break;
 544                                 case "h": group = 1; index = 1; break;
 545                                 case "d": group = 1; index = 2; break;
 546                             }
 547                             switch (group) {
 548                                 case 0: result = globalId(index, spirvBlock); break;
 549                                 case 1: result = globalSize(index, spirvBlock); break;
 550                                 default: throw new RuntimeException("Unknown Index field: " + fieldName);
 551                             }
 552                             addResult(flo.result(), result);
 553                         }
 554                         else if (((JavaType)flo.resultType()).equals(JavaType.type(VectorSpecies.class))) {
 555                             addResult(flo.result(), new SpirvResult(getType("int"), null, getConst("int_EIGHT")));
 556                         }
 557                         else if (flo.fieldDescriptor().refType().equals(JavaType.type(VectorOperators.class))) {
 558                             // currently ignored
 559                         }
 560                         else if (flo.fieldDescriptor().refType().equals(JavaType.type(ByteOrder.class))) {
 561                             // currently ignored
 562                         }
 563                         else unsupported("field load", ((ClassType)flo.fieldDescriptor().refType()).toClassName() + "." + flo.fieldDescriptor().name());
 564                     }
 565                     case SpirvOps.BranchOp bop -> {
 566                         SPIRVId trueLabel = symbols.getLabel(bop.branch()).getResultId();
 567                         spirvBlock.add(new SPIRVOpBranch(trueLabel));
 568                     }
 569                     case SpirvOps.ConditionalBranchOp cbop -> {
 570                         SPIRVId test = getResult(cbop.operands().get(0)).value();
 571                         SPIRVId trueLabel = symbols.getLabel(cbop.trueBranch()).getResultId();
 572                         SPIRVId falseLabel = symbols.getLabel(cbop.falseBranch()).getResultId();
 573                         spirvBlock.add(new SPIRVOpBranchConditional(test, trueLabel, falseLabel, new SPIRVMultipleOperands<SPIRVLiteralInteger>()));
 574                     }
 575                     case SpirvOps.LtOp ltop -> {
 576                         SPIRVId lhs = getResult(ltop.operands().get(0)).value();
 577                         SPIRVId rhs = getResult(ltop.operands().get(1)).value();
 578                         SPIRVId boolType = getType("bool");
 579                         SPIRVId result = nextId();
 580                         spirvBlock.add(new SPIRVOpSLessThan(boolType, result, lhs, rhs));
 581                         addResult(ltop.result(), new SpirvResult(boolType, null, result));
 582                     }
 583                     case SpirvOps.ReturnOp rop -> {
 584                         if (rop.operands().size() == 0) {
 585                             spirvBlock.add(new SPIRVOpReturn());
 586                         }
 587                         else {
 588                             SPIRVId returnValue = getResult(rop.operands().get(0)).value();
 589                             spirvBlock.add(new SPIRVOpReturnValue(returnValue));
 590                         }
 591                     }
 592                     default -> unsupported("op", op.getClass());
 593                 }
 594             }
 595         }
 596     }
 597 
 598     private void initModule() {
 599         module.add(new SPIRVOpCapability(SPIRVCapability.Addresses()));
 600         module.add(new SPIRVOpCapability(SPIRVCapability.Linkage()));
 601         module.add(new SPIRVOpCapability(SPIRVCapability.Kernel()));
 602         module.add(new SPIRVOpCapability(SPIRVCapability.Int8()));
 603         module.add(new SPIRVOpCapability(SPIRVCapability.Int16()));
 604         module.add(new SPIRVOpCapability(SPIRVCapability.Int64()));
 605         module.add(new SPIRVOpCapability(SPIRVCapability.Vector16()));
 606         module.add(new SPIRVOpCapability(SPIRVCapability.Float16()));
 607         module.add(new SPIRVOpMemoryModel(SPIRVAddressingModel.Physical64(), SPIRVMemoryModel.OpenCL()));
 608 
 609         // OpenCL extension provides built-in variables suitable for kernel programming
 610         // Import extention and declare fourn variables
 611         SPIRVId oclExtension = nextId("oclExtension");
 612         module.add(new SPIRVOpExtInstImport(oclExtension, new SPIRVLiteralString("OpenCL.std")));
 613 
 614         SPIRVId globalInvocationId = nextId("globalInvocationId");
 615         SPIRVId globalSize = nextId("globalSize");
 616         SPIRVId subgroupSize = nextId("subgroupSize");
 617         SPIRVId subgroupId = nextId("subgroupId");
 618 
 619         module.add(new SPIRVOpDecorate(globalInvocationId, SPIRVDecoration.BuiltIn(SPIRVBuiltIn.GlobalInvocationId())));
 620         module.add(new SPIRVOpDecorate(globalInvocationId, SPIRVDecoration.Constant()));
 621         module.add(new SPIRVOpDecorate(globalInvocationId, SPIRVDecoration.LinkageAttributes(new SPIRVLiteralString("spirv_BuiltInGlobalInvocationId"), SPIRVLinkageType.Import())));
 622         module.add(new SPIRVOpDecorate(globalSize, SPIRVDecoration.BuiltIn(SPIRVBuiltIn.GlobalSize())));
 623         module.add(new SPIRVOpDecorate(globalSize, SPIRVDecoration.Constant()));
 624         module.add(new SPIRVOpDecorate(globalSize, SPIRVDecoration.LinkageAttributes(new SPIRVLiteralString("spirv_BuiltInGlobalSize"), SPIRVLinkageType.Import())));
 625         module.add(new SPIRVOpDecorate(subgroupSize, SPIRVDecoration.BuiltIn(SPIRVBuiltIn.SubgroupSize())));
 626         module.add(new SPIRVOpDecorate(subgroupSize, SPIRVDecoration.Constant()));
 627         module.add(new SPIRVOpDecorate(subgroupSize, SPIRVDecoration.LinkageAttributes(new SPIRVLiteralString("spirv_BuiltInSubgroupSize"), SPIRVLinkageType.Import())));
 628         module.add(new SPIRVOpDecorate(subgroupId, SPIRVDecoration.BuiltIn(SPIRVBuiltIn.SubgroupId())));
 629         module.add(new SPIRVOpDecorate(subgroupId, SPIRVDecoration.Constant()));
 630         module.add(new SPIRVOpDecorate(subgroupId, SPIRVDecoration.LinkageAttributes(new SPIRVLiteralString("spirv_BuiltInSubgroupId"), SPIRVLinkageType.Import())));
 631 
 632         module.add(new SPIRVOpVariable(getType("ptrV3long"), globalInvocationId, SPIRVStorageClass.Input(), new SPIRVOptionalOperand<>()));
 633         module.add(new SPIRVOpVariable(getType("ptrV3long"), globalSize, SPIRVStorageClass.Input(), new SPIRVOptionalOperand<>()));
 634         module.add(new SPIRVOpVariable(getType("ptrV3long"), subgroupSize, SPIRVStorageClass.Input(), new SPIRVOptionalOperand<>()));
 635         module.add(new SPIRVOpVariable(getType("ptrV3long"), subgroupId, SPIRVStorageClass.Input(), new SPIRVOptionalOperand<>()));
 636     }
 637 
 638     private SPIRVId spirvType(String javaType) {
 639         SPIRVId ans = switch(javaType) {
 640             case "byte" -> getType("byte");
 641             case "short" -> getType("short");
 642             case "int" -> getType("int");
 643             case "long" -> getType("long");
 644             case "float" -> getType("float");
 645             case "double" -> getType("double");
 646             case "int[]" -> getType("int[]");
 647             case "float[]" -> getType("float[]");
 648             case "double[]" -> getType("double[]");
 649             case "long[]" -> getType("long[]");
 650             case "bool" -> getType("bool");
 651             case "spirvdemo.IntArray" -> getType("int[]");
 652             case "spirvdemo.FloatArray" -> getType("float[]");
 653             case "jdk.incubator.vector.IntVector" -> spirvVectorType("IntVector", 8);
 654             case "jdk.incubator.vector.FloatVector" -> spirvVectorType("FloatVector", 8);
 655             case "jdk.incubator.vector.VectorSpecies<java.lang.Integer>" -> getType("int");
 656             case "jdk.incubator.vector.VectorSpecies<java.lang.Long>" -> getType("long");
 657             case "jdk.incubator.vector.VectorSpecies<java.lang.Float>" -> getType("int");
 658             case "VectorSpecies" -> getType("int");
 659             case "void" -> getType("void");
 660             case "spirvdemo.GPU$Index" -> getType("ptrGPUIndex");
 661             case "java.lang.foreign.MemorySegment" -> getType("ptrByte");
 662             default -> null;
 663         };
 664         if (ans == null) unsupported("type", javaType);
 665         return ans;
 666     }
 667 
 668     private SPIRVId spirvElementType(String javaType) {
 669         SPIRVId ans = switch(javaType) {
 670             case "byte[]" -> getType("byte");
 671             case "short[]" -> getType("short");
 672             case "int[]" -> getType("int");
 673             case "long[]" -> getType("long");
 674             case "float[]" -> getType("float");
 675             case "double[]" -> getType("double");
 676             case "boolean[]" -> getType("bool");
 677             case "spirvdemo.IntArray" -> getType("int");
 678             case "spirvdemo.FloatArray" -> getType("float");
 679             case "jdk.incubator.vector.LongVector" -> getType("long");
 680             case "jdk.incubator.vector.FloatVector" -> getType("float");
 681             case "IntVector" -> getType("int");
 682             case "LongVector" -> getType("long");
 683             case "FloatVector" -> getType("float");
 684             case "java.lang.foreign.MemorySegment" -> getType("byte");
 685             default -> null;
 686         };
 687         if (ans == null) unsupported("type", javaType);
 688         return ans;
 689     }
 690 
 691     private SPIRVId vectorElementType(SPIRVId type) {
 692         SPIRVId ans = switch(type.getName()) {
 693             case "v8int" -> getType("int");
 694             case "v16int" -> getType("int");
 695             case "v8long" -> getType("long");
 696             case "v8float" -> getType("float");
 697             case "v16float" -> getType("float");
 698             default -> null;
 699         };
 700         if (ans == null) unsupported("type", type.getName());
 701         return ans;
 702     }
 703 
 704     private SPIRVId spirvVariableType(SPIRVId spirvType) {
 705         SPIRVId ans = switch(spirvType.getName()) {
 706             case "byte" -> getType("ptrByte");
 707             case "short" -> getType("ptrShort");
 708             case "int" -> getType("ptrInt");
 709             case "long" -> getType("ptrLong");
 710             case "float" -> getType("ptrFloat");
 711             case "double" -> getType("ptrDouble");
 712             case "boolean" -> getType("ptrBool");
 713             case "int[]" -> getType("ptrInt[]");
 714             case "long[]" -> getType("ptrLong[]");
 715             case "float[]" -> getType("ptrFloat[]");
 716             case "double[]" -> getType("ptrDouble[]");
 717             case "v8int" -> getType("ptrV8int");
 718             case "v16int" -> getType("ptrV16int");
 719             case "v8long" -> getType("ptrV8long");
 720             case "v8float" -> getType("ptrV8float");
 721             case "v16float" -> getType("ptrV16float");
 722             case "ptrGPUIndex" -> getType("ptrPtrGPUIndex");
 723             case "ptrByte" -> getType("ptrPtrByte");
 724             default -> null;
 725         };
 726         if (ans == null) unsupported("type", spirvType.getName());
 727         return ans;
 728     }
 729 
 730     private SPIRVId spirvVectorType(String javaVectorType, int vectorLength) {
 731         String prefix = "v" + vectorLength;
 732         String elementType = spirvElementType(javaVectorType).getName();
 733         return getType(prefix + elementType);
 734     }
 735 
 736     private int alignment(String spirvType) {
 737         int ans = switch(spirvType) {
 738             case "byte" -> 1;
 739             case "short" -> 2;
 740             case "int" -> 4;
 741             case "long" -> 8;
 742             case "float" -> 4;
 743             case "double" -> 8;
 744             case "boolean" -> 1;
 745             case "v8int" -> 32;
 746             case "v16int" -> 64;
 747             case "v8long" -> 64;
 748             case "v8float" -> 32;
 749             case "v16float" -> 64;
 750             case "ptrGPUIndex" -> 32;
 751             case "int[]" -> 8;
 752             case "long[]" -> 8;
 753             case "float[]" -> 8;
 754             case "double[]" -> 8;
 755             case "ptrByte" -> 8;
 756             case "ptrInt" -> 8;
 757             case "ptrInt[]" -> 8;
 758             case "ptrLong" -> 8;
 759             case "ptrLong[]" -> 8;
 760             case "ptrFloat" -> 8;
 761             case "ptrFloat[]" -> 8;
 762             case "ptrV8int" -> 8;
 763             case "ptrV8float" -> 8;
 764             case "ptrPtrGPUIndex" -> 8;
 765             default -> 0;
 766         };
 767         if (ans == 0) unsupported("type", spirvType);
 768         return ans;
 769     }
 770 
 771     private int laneCount(String vectorType) {
 772         int ans = switch(vectorType) {
 773             case "v8int" -> 8;
 774             case "v8long" -> 8;
 775             case "v8float" -> 8;
 776             case "v16int" -> 16;
 777             case "v16float" -> 16;
 778             default -> 0;
 779         };
 780         if (ans == 0) unsupported("type", vectorType);
 781         return ans;
 782     }
 783 
 784     private SPIRVId vectorExponent(String vectorType) {
 785         SPIRVId ans = null;
 786         switch(vectorType) {
 787             case "v8int" -> ans = getId("int_THREE");
 788             case "v8long" -> ans = getId("int_THREE");
 789             case "v8float" -> ans = getId("int_THREE");
 790             case "v16int" -> ans = getId("int_FOUR");
 791             case "v16float" -> ans = getId("int_FOUR");
 792             default -> unsupported("type", vectorType);
 793         };
 794         return ans;
 795     }
 796 
 797     private Set<String> moduleTypes = new HashSet<>();
 798 
 799     private SPIRVId getType(String name) {
 800         if (!moduleTypes.contains(name)) {
 801             switch (name) {
 802                 case "void" -> module.add(new SPIRVOpTypeVoid(nextId(name)));
 803                 case "bool" -> module.add(new SPIRVOpTypeBool(nextId(name)));
 804                 case "byte" -> module.add(new SPIRVOpTypeInt(nextId(name), new SPIRVLiteralInteger(8), new SPIRVLiteralInteger(0)));
 805                 case "short" -> module.add(new SPIRVOpTypeInt(nextId(name), new SPIRVLiteralInteger(16), new SPIRVLiteralInteger(0)));
 806                 case "int" -> module.add(new SPIRVOpTypeInt(nextId(name), new SPIRVLiteralInteger(32), new SPIRVLiteralInteger(0)));
 807                 case "long" -> module.add(new SPIRVOpTypeInt(nextId(name), new SPIRVLiteralInteger(64), new SPIRVLiteralInteger(0)));
 808                 case "float" -> module.add(new SPIRVOpTypeFloat(nextId(name), new SPIRVLiteralInteger(32)));
 809                 case "double" -> module.add(new SPIRVOpTypeFloat(nextId(name), new SPIRVLiteralInteger(64)));
 810                 case "ptrByte" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.CrossWorkgroup(), getType("byte")));
 811                 case "ptrInt" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("int")));
 812                 case "ptrLong" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("long")));
 813                 case "ptrFloat" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("float")));
 814                 case "short[]" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.CrossWorkgroup(), getType("short")));
 815                 case "int[]" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.CrossWorkgroup(), getType("int")));
 816                 case "long[]" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.CrossWorkgroup(), getType("long")));
 817                 case "float[]" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.CrossWorkgroup(), getType("float")));
 818                 case "double[]" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.CrossWorkgroup(), getType("double")));
 819                 case "boolean[]" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.CrossWorkgroup(), getType("boolean")));
 820                 case "ptrInt[]" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("int[]")));
 821                 case "ptrLong[]" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("long[]")));
 822                 case "ptrFloat[]" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("float[]")));
 823                 case "spirvdemo.GPUIndex" -> module.add(new SPIRVOpTypeStruct(nextId(name), new SPIRVMultipleOperands<>(getType("long"), getType("long"), getType("long"))));
 824                 case "ptrGPUIndex" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("spirvdemo.GPUIndex")));
 825                 case "ptrCrossGroupByte"-> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.CrossWorkgroup(), getType("byte")));
 826                 case "ptrPtrGPUIndex" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("ptrGPUIndex")));
 827                 case "ptrPtrByte" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("ptrByte")));
 828                 case "v3long" -> module.add(new SPIRVOpTypeVector(nextId(name), getType("long"), new SPIRVLiteralInteger(3)));
 829                 case "v8int" -> module.add(new SPIRVOpTypeVector(nextId(name), getType("int"), new SPIRVLiteralInteger(8)));
 830                 case "v8long" -> module.add(new SPIRVOpTypeVector(nextId(name), getType("long"), new SPIRVLiteralInteger(8)));
 831                 case "v16int" -> module.add(new SPIRVOpTypeVector(nextId(name), getType("int"), new SPIRVLiteralInteger(16)));
 832                 case "v8float" -> module.add(new SPIRVOpTypeVector(nextId(name), getType("float"), new SPIRVLiteralInteger(8)));
 833                 case "v16float" -> module.add(new SPIRVOpTypeVector(nextId(name), getType("float"), new SPIRVLiteralInteger(16)));
 834                 case "ptrV3long" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Input(), getType("v3long")));
 835                 case "ptrV8long" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("v8long")));
 836                 case "ptrV8int" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("v8int")));
 837                 case "ptrV16int" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("v16int")));
 838                 case "ptrV8float" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("v8float")));
 839                 case "ptrV16float" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("v16float")));
 840                 case "ptrPtrV8int" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("ptrV8int")));
 841                 case "ptrPtrV16int" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("ptrV16int")));
 842                 case "ptrPtrV8float" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("ptrV8float")));
 843                 case "ptrPtrV16float" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("ptrV16float")));
 844                 default -> unsupported("type", name);
 845             }
 846             moduleTypes.add(name);
 847         }
 848         return getId(name);
 849     }
 850 
 851     private Set<String> moduleConstants = new HashSet<>();
 852 
 853     private SPIRVId getConst(String name) {
 854         if (!moduleConstants.contains(name)) {
 855             switch (name) {
 856                 case "int_ZERO" -> module.add(new SPIRVOpConstant(getType("int"), nextId("int_ZERO"), new SPIRVContextDependentInt(new BigInteger("0"))));
 857                 case "int_ONE" -> module.add(new SPIRVOpConstant(getType("int"), nextId("int_ONE"), new SPIRVContextDependentInt(new BigInteger("1"))));
 858                 case "int_TWO" -> module.add(new SPIRVOpConstant(getType("int"), nextId("int_TWO"), new SPIRVContextDependentInt(new BigInteger("2"))));
 859                 case "int_EIGHT" -> module.add(new SPIRVOpConstant(getType("int"), nextId("int_EIGHT"), new SPIRVContextDependentInt(new BigInteger("8"))));
 860                 default -> unsupported("constant", name);
 861             }
 862             moduleConstants.add(name);
 863         }
 864         return getId(name);
 865     }
 866 
 867     private SPIRVOptionalOperand<SPIRVMemoryAccess> align(int align) {
 868         return new SPIRVOptionalOperand<>(SPIRVMemoryAccess.Aligned(new SPIRVLiteralInteger(align)));
 869     }
 870 
 871     private SPIRVOptionalOperand<SPIRVMemoryAccess> align(String type) {
 872         return align(alignment(type));
 873     }
 874 
 875     private SPIRVMultipleOperands<SPIRVId> spirvOperands(SPIRVId value, int count) {
 876         SPIRVId[] values = new SPIRVId[count];
 877         Arrays.fill(values, value);
 878         return new SPIRVMultipleOperands<>(values);
 879     }
 880 
 881     private SPIRVOptionalOperand<SPIRVMemoryAccess> none() {
 882         return new SPIRVOptionalOperand<>();
 883     }
 884 
 885     private SpirvResult globalSize(int index, SPIRVBlock spirvBlock) {
 886         SPIRVId longType = getType("long");
 887         SPIRVId v3long = getId("v3long");
 888         SPIRVId globalSizeId = getId("globalSize");
 889         SPIRVId globalSizes = nextId();
 890         spirvBlock.add(new SPIRVOpLoad(v3long, globalSizes, globalSizeId, align(32)));
 891         SPIRVId globalSize = nextId();
 892         spirvBlock.add(new SPIRVOpCompositeExtract(longType, globalSize, globalSizes, new SPIRVMultipleOperands<>(new SPIRVLiteralInteger(index))));
 893         return new SpirvResult(longType, null, globalSize);
 894     }
 895 
 896     private SpirvResult globalId(int index, SPIRVBlock spirvBlock) {
 897         SPIRVId longType = getType("long");
 898         SPIRVId v3long = getId("v3long");
 899         SPIRVId globalInvocationId = getId("globalInvocationId");
 900         SPIRVId globalIds = nextId();
 901         spirvBlock.add(new SPIRVOpLoad(v3long, globalIds, globalInvocationId, align(32)));
 902         SPIRVId globalIndex = nextId();
 903         spirvBlock.add(new SPIRVOpCompositeExtract(longType, globalIndex, globalIds, new SPIRVMultipleOperands<>(new SPIRVLiteralInteger(index))));
 904         return new SpirvResult(longType, null, globalIndex);
 905     }
 906 
 907     private SPIRVId nextId() {
 908         return module.getNextId();
 909     }
 910 
 911     private SPIRVId nextId(String name) {
 912         SPIRVId ans = nextId();
 913         ans.setName(name);
 914         symbols.putId(name, ans);
 915         module.add(new SPIRVOpName(ans, new SPIRVLiteralString(name)));
 916         return ans;
 917     }
 918 
 919     private static int counter = 0;
 920 
 921     private String nextTempTag() {
 922         counter++;
 923         return "temp_" + counter + "_";
 924     }
 925 
 926     private boolean isIntegerType(SPIRVId type) {
 927         String name = type.getName();
 928         return name.equals("short") || name.equals("int") || name.equals("long");
 929     }
 930 
 931     private boolean isFloatType(SPIRVId type) {
 932         String name = type.getName();
 933         return name.equals("float") || name.equals("double");
 934     }
 935 
 936     private boolean isVectorSpecies(String javaType) {
 937         return javaType.equals("VectorSpecies");
 938     }
 939 
 940     private boolean isVectorType(String javaType) {
 941         return javaType.equals("IntVector") || javaType.equals("FloatVector");
 942     }
 943 
 944     private void addId(String name, SPIRVId id) {
 945         symbols.putId(name, id);
 946     }
 947 
 948     private SPIRVId getId(String name) {
 949         SPIRVId ans = symbols.getId(name);
 950         assert ans != null : name + " not found";
 951         return ans;
 952     }
 953 
 954     private SPIRVId getIdOrNull(String name) {
 955         return symbols.getId(name);
 956     }
 957 
 958     private static Object map(Function<Object, Boolean> test, Object... args) {
 959         int len = args.length;
 960         assert len >= 2 && len % 2 == 0;
 961         int pairs = len / 2;
 962         for (int i = 0; i < pairs; i++) {
 963             if (test.apply(args[i])) return args[i + pairs];
 964         }
 965         throw new RuntimeException("No match: " + args[0]);
 966     }
 967 
 968     private void unsupported(String message, Object value) {
 969         throw new RuntimeException("Unsupported " + message + ": " + value);
 970     }
 971 
 972     private void addResult(Value value, SpirvResult result) {
 973         assert symbols.getResult(value) == null : "result already present";
 974         symbols.putResult(value, result);
 975     }
 976 
 977     private SpirvResult getResult(Value value) {
 978         return symbols.getResult(value);
 979     }
 980 
 981     private static class Symbols {
 982         private final HashMap<Value, SpirvResult> results;
 983         private final HashMap<String, SPIRVId> ids;
 984         private final HashMap<Block, SPIRVBlock> blocks;
 985         private final HashMap<Block, SPIRVOpLabel> labels;
 986 
 987         public Symbols() {
 988             this.results = new HashMap<>();
 989             this.ids = new HashMap<>();
 990             this.blocks = new HashMap<>();
 991             this.labels = new HashMap<>();
 992         }
 993 
 994         public void putId(String name, SPIRVId id) {
 995             ids.put(name, id);
 996         }
 997 
 998         public SPIRVId getId(String name) {
 999             return ids.get(name);
1000         }
1001 
1002         public void putBlock(Block block, SPIRVBlock spirvBlock) {
1003             blocks.put(block, spirvBlock);
1004         }
1005 
1006         public SPIRVBlock getBlock(Block block) {
1007             return blocks.get(block);
1008         }
1009 
1010         public void putLabel(Block block, SPIRVOpLabel spirvLabel) {
1011             labels.put(block, spirvLabel);
1012         }
1013 
1014         public SPIRVOpLabel getLabel(Block block) {
1015             return labels.get(block);
1016         }
1017 
1018         public void putResult(Value value, SpirvResult result) {
1019             results.put(value, result);
1020         }
1021 
1022         public SpirvResult getResult(Value value) {
1023             return results.get(value);
1024         }
1025 
1026         public String toString() {
1027             return String.format("results %s\n\nids %s\n\nblocks %s\nlabels %s\n", results.keySet(), ids.keySet(), blocks.keySet(), labels.keySet());
1028         }
1029     }
1030 }