1 /*
   2  * Copyright (c) 2024 Intel Corporation. All rights reserved.
   3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
   4  *
   5  * This code is free software; you can redistribute it and/or modify it
   6  * under the terms of the GNU General Public License version 2 only, as
   7  * published by the Free Software Foundation.  Oracle designates this
   8  * particular file as subject to the "Classpath" exception as provided
   9  * by Oracle in the LICENSE file that accompanied this code.
  10  *
  11  * This code is distributed in the hope that it will be useful, but WITHOUT
  12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
  13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
  14  * version 2 for more details (a copy is included in the LICENSE file that
  15  * accompanied this code).
  16  *
  17  * You should have received a copy of the GNU General Public License version
  18  * 2 along with this work; if not, write to the Free Software Foundation,
  19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
  20  *
  21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
  22  * or visit www.oracle.com if you need additional information or have any
  23  * questions.
  24  */
  25 
  26 package intel.code.spirv;
  27 
  28 import java.util.List;
  29 import java.util.ArrayList;
  30 import java.util.Arrays;
  31 import java.util.Map;
  32 import java.util.HashMap;
  33 import java.util.Set;
  34 import java.util.HashSet;
  35 import java.util.Optional;
  36 import java.util.function.Function;
  37 import java.io.IOException;
  38 import java.io.File;
  39 import java.io.FileOutputStream;
  40 import java.io.ByteArrayInputStream;
  41 import java.io.ByteArrayOutputStream;
  42 import java.io.PrintStream;
  43 import java.nio.ByteBuffer;
  44 import java.nio.ByteOrder;
  45 import java.nio.channels.FileChannel;
  46 import java.math.BigInteger;
  47 import java.lang.invoke.MethodHandles;
  48 import java.lang.foreign.MemorySegment;
  49 import java.lang.foreign.ValueLayout;
  50 import jdk.incubator.code.Block;
  51 import jdk.incubator.code.Body;
  52 import jdk.incubator.code.Op;
  53 import jdk.incubator.code.Value;
  54 import jdk.incubator.code.op.CoreOp;
  55 import jdk.incubator.code.TypeElement;
  56 import jdk.incubator.code.type.MethodRef;
  57 import jdk.incubator.code.type.ClassType;
  58 import jdk.incubator.code.type.JavaType;
  59 import jdk.incubator.code.type.FunctionType;
  60 import hat.callgraph.CallGraph;
  61 import hat.callgraph.KernelCallGraph;
  62 import hat.callgraph.KernelEntrypoint;
  63 import uk.ac.manchester.beehivespirvtoolkit.lib.SPIRVHeader;
  64 import uk.ac.manchester.beehivespirvtoolkit.lib.SPIRVModule;
  65 import uk.ac.manchester.beehivespirvtoolkit.lib.SPIRVFunction;
  66 import uk.ac.manchester.beehivespirvtoolkit.lib.SPIRVBlock;
  67 import uk.ac.manchester.beehivespirvtoolkit.lib.instructions.*;
  68 import uk.ac.manchester.beehivespirvtoolkit.lib.instructions.operands.*;
  69 import uk.ac.manchester.beehivespirvtoolkit.lib.disassembler.Disassembler;
  70 import uk.ac.manchester.beehivespirvtoolkit.lib.disassembler.SPIRVDisassemblerOptions;
  71 import uk.ac.manchester.beehivespirvtoolkit.lib.disassembler.SPVByteStreamReader;
  72 import intel.code.spirv.SpirvOp.PhiOp;
  73 
  74 public class SpirvModuleGenerator {
  75     private final String moduleName;
  76     private final SPIRVModule module;
  77     private final Symbols symbols;
  78 
  79     public static SpirvModuleGenerator create(String moduleName) {
  80         return new SpirvModuleGenerator(moduleName);
  81     }
  82 
  83     public static MemorySegment generateModule(String moduleName, KernelCallGraph callGraph) {
  84         SpirvModuleGenerator generator = SpirvModuleGenerator.create(moduleName);
  85         for (CallGraph.MethodCall call : callGraph.calls) {
  86             if (call.targetMethodRef != null) {
  87                 try {
  88                     Optional<CoreOp.FuncOp> ofo = call.targetMethodRef.codeModel(MethodHandles.lookup());
  89                     if (ofo.isPresent()) {
  90                         CoreOp.FuncOp fo = ofo.get();
  91                         SpirvOp.FuncOp spirvFunc = TranslateToSpirvModel.translateFunction(fo);
  92                     }
  93                 } catch (Exception e) {
  94                     throw new RuntimeException(e);
  95                 }
  96             }
  97         }
  98         KernelEntrypoint kernelEntrypoint = callGraph.entrypoint;
  99         CoreOp.FuncOp funcOp = kernelEntrypoint.funcOpWrapper().op();
 100         String kernelName = funcOp.funcName();
 101         SpirvOp.FuncOp spirvFunc = TranslateToSpirvModel.translateFunction(funcOp);
 102         generator.generateFunction(funcOp.funcName(), spirvFunc, true);
 103         return generator.finalizeModule();
 104     }
 105 
 106     public static MemorySegment generateModule(String moduleName, CoreOp.FuncOp func) {
 107         SpirvOp.FuncOp spirvFunc = TranslateToSpirvModel.translateFunction(func);
 108         MemorySegment moduleSegment = SpirvModuleGenerator.generateModule(moduleName, spirvFunc);
 109         return moduleSegment;
 110     }
 111 
 112     public static MemorySegment generateModule(String moduleName, SpirvOp.FuncOp func) {
 113         SpirvModuleGenerator generator = new SpirvModuleGenerator(moduleName);
 114         MemorySegment moduleSegment = generator.generateModuleInternal(func);
 115         return moduleSegment;
 116     }
 117 
 118     public static void writeModuleToFile(MemorySegment module, String filepath)  {
 119         ByteBuffer buffer = module.asByteBuffer();
 120         File out = new File(filepath);
 121         try (FileChannel channel = new FileOutputStream(out, false).getChannel()) {
 122             channel.write(buffer);
 123             channel.close();
 124         }
 125         catch (IOException e)  {
 126             throw new RuntimeException(e);
 127         }
 128     }
 129 
 130     private static void writeModuleToFile(SPIRVModule module, String filepath)
 131     {
 132         ByteBuffer buffer = ByteBuffer.allocate(module.getByteCount());
 133         buffer.order(ByteOrder.LITTLE_ENDIAN);
 134         module.close().write(buffer);
 135         buffer.flip();
 136         File out = new File(filepath);
 137         try (FileChannel channel = new FileOutputStream(out, false).getChannel())
 138         {
 139             channel.write(buffer);
 140         }
 141         catch (IOException e)
 142         {
 143             throw new RuntimeException(e);
 144         }
 145     }
 146 
 147     public static String disassembleModule(MemorySegment module) {
 148         SPVByteStreamReader input = new SPVByteStreamReader(new ByteArrayInputStream(module.toArray(ValueLayout.JAVA_BYTE)));
 149         ByteArrayOutputStream out = new ByteArrayOutputStream();
 150         try (PrintStream ps = new PrintStream(out))  {
 151             SPIRVDisassemblerOptions options = new SPIRVDisassemblerOptions(false, false, false, false, true);
 152             Disassembler dis = new Disassembler(input, ps, options);
 153             dis.run();
 154         }
 155         catch (Exception e) {
 156             throw new RuntimeException(e);
 157         }
 158         return new String(out.toByteArray());
 159     }
 160 
 161     public MemorySegment finalizeModule() {
 162         ByteBuffer buffer = ByteBuffer.allocateDirect(module.getByteCount());
 163         buffer.order(ByteOrder.LITTLE_ENDIAN);
 164         module.close().write(buffer);
 165         buffer.flip();
 166         return MemorySegment.ofBuffer(buffer);
 167     }
 168 
 169     private record SpirvResult(SPIRVId type, SPIRVId address, SPIRVId value) {}
 170 
 171     private SpirvModuleGenerator(String moduleName) {
 172         this.moduleName = moduleName;
 173         this.module = new SPIRVModule(new SPIRVHeader(1, 2, 32, 0, 0));
 174         this.symbols = new Symbols();
 175         initModule();
 176     }
 177 
 178     private MemorySegment generateModuleInternal(SpirvOp.FuncOp func) {
 179         generateFunction(moduleName, func, true);
 180         return finalizeModule();
 181     }
 182 
 183     private SPIRVId generateFunction(String fnName, SpirvOp.FuncOp func, boolean isEntryPoint) {
 184         TypeElement returnType = func.invokableType().returnType();
 185         SPIRVId functionId = nextId(fnName);
 186         String signature = func.invokableType().returnType().toString();
 187         List<TypeElement> paramTypes = func.invokableType().parameterTypes();
 188         // build signature string
 189         for (int i = 0; i < paramTypes.size(); i++) {
 190             signature += "_" + paramTypes.get(i).toString();
 191         }
 192         // declare function type if not already present
 193         SPIRVId functionSig = getIdOrNull(signature);
 194         if (functionSig == null) {
 195             SPIRVId[] typeIdsArray = new SPIRVId[paramTypes.size()];
 196             for (int i = 0; i < paramTypes.size(); i++) {
 197                 typeIdsArray[i] = spirvType(paramTypes.get(i).toString());
 198             }
 199             functionSig = nextId(fnName + "Signature");
 200             module.add(new SPIRVOpTypeFunction(functionSig, spirvType(returnType.toString()), new SPIRVMultipleOperands<>(typeIdsArray)));
 201             addId(signature, functionSig);
 202         }
 203         SPIRVId spirvReturnType = spirvType(returnType.toString());
 204         SPIRVFunction function = (SPIRVFunction)module.add(new SPIRVOpFunction(spirvReturnType, functionId, SPIRVFunctionControl.DontInline(), functionSig));
 205         SPIRVOpLabel entryLabel = new SPIRVOpLabel(nextId());
 206         SPIRVBlock entryBlock = (SPIRVBlock)function.add(entryLabel);
 207         SPIRVMultipleOperands<SPIRVId> operands = new SPIRVMultipleOperands<>(getId("globalInvocationId"), getId("globalSize"), getId("subgroupSize"), getId("subgroupId"));
 208         if (isEntryPoint) {
 209             module.add(new SPIRVOpEntryPoint(SPIRVExecutionModel.Kernel(), functionId, new SPIRVLiteralString(fnName), operands));
 210         }
 211         translateBody(func.body(), function, entryBlock);
 212         function.add(new SPIRVOpFunctionEnd());
 213         return functionId;
 214     }
 215 
 216     private void translateBody(Body body, SPIRVFunction function, SPIRVBlock entryBlock) {
 217         int labelNumber = 0;
 218         SPIRVBlock spirvBlock = entryBlock;
 219         for (int bi = 1; bi < body.blocks().size(); bi++)  {
 220             Block block = body.blocks().get(bi);
 221             SPIRVOpLabel blockLabel = new SPIRVOpLabel(nextId());
 222             SPIRVBlock newBlock = (SPIRVBlock)function.add(blockLabel);
 223             symbols.putBlock(block, newBlock);
 224             symbols.putLabel(block, blockLabel);
 225             for (int i = 0; i < block.parameters().size(); i++) {
 226                 Value param = block.parameters().get(i);
 227                 SPIRVId paramId = nextId();
 228                 addResult(param, new SpirvResult(spirvType(param.type().toString()), null, paramId));
 229             }
 230         }
 231         for (int bi = 0; bi < body.blocks().size(); bi++)  {
 232             Block block = body.blocks().get(bi);
 233             if (bi > 0) {
 234                 spirvBlock = symbols.getBlock(block);
 235             }
 236             for (Op op : block.ops()) {
 237                 switch (op)  {
 238                     case SpirvOp.PhiOp phop -> {
 239                         List<PhiOp.Predecessor> inPredecessors = phop.predecessors();
 240                         SPIRVPairIdRefIdRef[] outPredecessors = new SPIRVPairIdRefIdRef[inPredecessors.size()];
 241                         for (int i = 0; i < inPredecessors.size(); i++) {
 242                             PhiOp.Predecessor predecessor = inPredecessors.get(i);
 243                             SPIRVId label = symbols.getLabel(predecessor.block().targetBlock()).getResultId();
 244                             SPIRVId value = getResult(predecessor.value()).value();
 245                             outPredecessors[i] = new SPIRVPairIdRefIdRef(value, label);
 246                         }
 247                         SPIRVId result = nextId();
 248                         SPIRVId type = spirvType(phop.resultType().toString());
 249                         SPIRVOpPhi phiOp = new SPIRVOpPhi(spirvType(phop.resultType().toString()), result, new SPIRVMultipleOperands<>(outPredecessors));
 250                         spirvBlock.add(phiOp);
 251                         addResult(phop.result(), new SpirvResult(type, null, result));
 252                     }
 253                     case SpirvOp.VariableOp vop -> {
 254                         String typeName = vop.varType().toString();
 255                         SPIRVId type = spirvType(typeName);
 256                         SPIRVId varType = spirvVariableType(type);
 257                         SPIRVId var = nextId(vop.varName());
 258                         spirvBlock.add(new SPIRVOpVariable(varType, var, SPIRVStorageClass.Function(), new SPIRVOptionalOperand<>()));
 259                         addResult(vop.result(), new SpirvResult(varType, var, null));
 260                     }
 261                     case SpirvOp.FunctionParameterOp fpo -> {
 262                         SPIRVId result = nextId();
 263                         SPIRVId type = spirvType(fpo.resultType().toString());
 264                         function.add(new SPIRVOpFunctionParameter(type, result));
 265                         module.add(new SPIRVOpDecorate(result, SPIRVDecoration.Alignment(new SPIRVLiteralInteger(8))));
 266                         addResult(fpo.result(), new SpirvResult(type, null, result));
 267                     }
 268                     case SpirvOp.LoadOp lo -> {
 269                         SPIRVId type = spirvType(lo.resultType().toString());
 270                         SpirvResult toLoad = getResult(lo.operands().get(0));
 271                         SPIRVId varAddr = toLoad.address();
 272                         SPIRVId result = nextId();
 273                         spirvBlock.add(new SPIRVOpLoad(type, result, varAddr, align(type.getName())));
 274                         addResult(lo.result(), new SpirvResult(type, varAddr, result));
 275                     }
 276                     case SpirvOp.StoreOp so -> {
 277                         SpirvResult var = getResult(so.operands().get(0));
 278                         SPIRVId varAddr = var.address();
 279                         SPIRVId value = getResult(so.operands().get(1)).value();
 280                         spirvBlock.add(new SPIRVOpStore(varAddr, value, align(var.type().getName())));
 281                     }
 282                     case SpirvOp.IAddOp _, SpirvOp.FAddOp _ -> {
 283                         SPIRVId intType = getType("int");
 284                         SPIRVId longType = getType("long");
 285                         SPIRVId floatType = getType("float");
 286                         SPIRVId lhs = getResult(op.operands().get(0)).value();
 287                         SPIRVId rhs = getResult(op.operands().get(1)).value();
 288                         SPIRVId lhsType = spirvType(op.resultType().toString());
 289                         SPIRVId ans = nextId();
 290                         if (lhsType == intType) spirvBlock.add(new SPIRVOpIAdd(intType, ans, lhs, rhs));
 291                         else if (lhsType == longType) spirvBlock.add(new SPIRVOpIAdd(longType, ans, lhs, rhs));
 292                         else if (lhsType == floatType) spirvBlock.add(new SPIRVOpFAdd(floatType, ans, lhs, rhs));
 293                         else unsupported("type", lhsType.getName());
 294                         addResult(op.result(), new SpirvResult(lhsType, null, ans));
 295                     }
 296                     case SpirvOp.ISubOp _, SpirvOp.FSubOp _ -> {
 297                         SPIRVId intType = getType("int");
 298                         SPIRVId longType = getType("long");
 299                         SPIRVId floatType = getType("float");
 300                         SPIRVId lhs = getResult(op.operands().get(0)).value();
 301                         SPIRVId rhs = getResult(op.operands().get(1)).value();
 302                         SPIRVId lhsType = spirvType(op.resultType().toString());
 303                         SPIRVId ans = nextId();
 304                         if (lhsType == intType) spirvBlock.add(new SPIRVOpISub(intType, ans, lhs, rhs));
 305                         else if (lhsType == longType) spirvBlock.add(new SPIRVOpISub(longType, ans, lhs, rhs));
 306                         else if (lhsType == floatType) spirvBlock.add(new SPIRVOpFSub(floatType, ans, lhs, rhs));
 307                         else unsupported("type", lhsType.getName());
 308                         addResult(op.result(), new SpirvResult(lhsType, null, ans));
 309                     }
 310                     case SpirvOp.IMulOp _, SpirvOp.FMulOp _, SpirvOp.IDivOp _, SpirvOp.FDivOp _ -> {
 311                         SPIRVId intType = getType("int");
 312                         SPIRVId longType = getType("long");
 313                         SPIRVId floatType = getType("float");
 314                         SPIRVId lhs = getResult(op.operands().get(0)).value();
 315                         SPIRVId rhs = getResult(op.operands().get(1)).value();
 316                         SPIRVId lhsType = spirvType(op.resultType().toString());
 317                         SPIRVId rhsType = getResult(op.operands().get(1)).type();
 318                         SPIRVId ans = nextId();
 319                         if (lhsType == intType) {
 320                             if (op instanceof SpirvOp.IMulOp) spirvBlock.add(new SPIRVOpIMul(intType, ans, lhs, rhs));
 321                             else if (op instanceof SpirvOp.IDivOp) spirvBlock.add(new SPIRVOpSDiv(intType, ans, lhs, rhs));
 322                         }
 323                         else if (lhsType == longType) {
 324                             SPIRVId rhsId = rhsType == intType ? nextId() : rhs;
 325                             if (rhsType == intType) spirvBlock.add(new SPIRVOpSConvert(longType, rhsId, rhs));
 326                             if (op instanceof SpirvOp.IMulOp) spirvBlock.add(new SPIRVOpIMul(longType, ans, lhs, rhsId));
 327                             else if (op instanceof SpirvOp.IDivOp) spirvBlock.add(new SPIRVOpSDiv(longType, ans, lhs, rhs));
 328                         }
 329                         else if (lhsType == floatType) {
 330                             if (op instanceof SpirvOp.FMulOp) spirvBlock.add(new SPIRVOpFMul(floatType, ans, lhs, rhs));
 331                             else if (op instanceof SpirvOp.FDivOp) spirvBlock.add(new SPIRVOpFDiv(floatType, ans, lhs, rhs));
 332                         }
 333                         else unsupported("type", lhsType);
 334                         addResult(op.result(), new SpirvResult(lhsType, null, ans));
 335                     }
 336                     case SpirvOp.ModOp mop -> {
 337                         SPIRVId type = getType(mop.operands().get(0).type().toString());
 338                         SPIRVId lhs = getResult(mop.operands().get(0)).value();
 339                         SPIRVId rhs = getResult(mop.operands().get(1)).value();
 340                         SPIRVId result = nextId();
 341                         spirvBlock.add(new SPIRVOpUMod(type, result, lhs, rhs));
 342                         addResult(mop.result(), new SpirvResult(type, null, result));
 343                     }
 344                     case SpirvOp.IEqualOp eqop -> {
 345                         SPIRVId boolType = getType("bool");
 346                         SPIRVId intType = getType("int");
 347                         SPIRVId longType = getType("long");
 348                         SPIRVId floatType = getType("float");
 349                         SPIRVId lhs = getResult(op.operands().get(0)).value();
 350                         SPIRVId rhs = getResult(op.operands().get(1)).value();
 351                         SPIRVId lhsType = spirvType(op.resultType().toString());
 352                         SPIRVId ans = nextId();
 353                         if (lhsType == intType) spirvBlock.add(new SPIRVOpIEqual(boolType, ans, lhs, rhs));
 354                         else if (lhsType == longType) spirvBlock.add(new SPIRVOpIEqual(boolType, ans, lhs, rhs));
 355                         else unsupported("type", lhsType.getName());
 356                         addResult(op.result(), new SpirvResult(lhsType, null, ans));
 357                     }
 358                     case SpirvOp.GeOp eqop -> {
 359                         SPIRVId boolType = getType("bool");
 360                         SPIRVId lhs = getResult(op.operands().get(0)).value();
 361                         SPIRVId rhs = getResult(op.operands().get(1)).value();
 362                         SPIRVId lhsType = spirvType(op.resultType().toString());
 363                         SPIRVId ans = nextId();
 364                         spirvBlock.add(new SPIRVOpSGreaterThanEqual(boolType, ans, lhs, rhs));
 365                         addResult(op.result(), new SpirvResult(lhsType, null, ans));
 366                     }
 367                     case SpirvOp.PtrNotEqualOp neqop -> {
 368                         SPIRVId boolType = getType("bool");
 369                         SPIRVId longType = getType("long");
 370                         SPIRVId lhs = getResult(neqop.operands().get(0)).value();
 371                         SPIRVId rhs = getResult(neqop.operands().get(1)).value();
 372                         SPIRVId ans = nextId();
 373                         SPIRVId lhsLong = nextId();
 374                         SPIRVId rhsLong = nextId();
 375                         spirvBlock.add(new SPIRVOpConvertPtrToU(longType, lhsLong, lhs));
 376                         spirvBlock.add(new SPIRVOpConvertPtrToU(longType, rhsLong, rhs));
 377                         spirvBlock.add(new SPIRVOpIEqual(boolType, ans, lhsLong, rhsLong));
 378                         addResult(op.result(), new SpirvResult(boolType, null, ans));
 379                     }
 380                     case SpirvOp.CallOp call -> {
 381                         MethodRef methodRef = call.callDescriptor();
 382                         if (methodRef.equals(MethodRef.ofString("hat.buffer.S32Array::array(long)int")))
 383                         {
 384                             SPIRVId longType = getType("long");
 385                             String arrayTypeName = call.operands().get(0).type().toString();
 386                             SpirvResult arrayResult = getResult(call.operands().get(0));
 387                             SPIRVId arrayAddr = arrayResult.address();
 388                             SPIRVId arrayType = spirvType(arrayTypeName);
 389                             SPIRVId elementType = spirvElementType(arrayTypeName);
 390                             int nIndexes = call.operands().size() - 1;
 391                             SPIRVId indexX = getResult(call.operands().get(1)).value();
 392                             SPIRVId array = nextId();
 393                             spirvBlock.add(new SPIRVOpLoad(arrayType, array, arrayAddr, align(arrayType.getName())));
 394                             SPIRVId temp1 = nextId();
 395                             SPIRVId temp2 = nextId();
 396                             spirvBlock.add(new SPIRVOpConvertPtrToU(longType, temp1, array));
 397                             spirvBlock.add(new SPIRVOpIAdd(longType, temp2, temp1, getConst("long", 4)));
 398                             SPIRVId elementBase = nextId();
 399                             spirvBlock.add(new SPIRVOpConvertUToPtr(arrayType, elementBase, temp2));
 400                             SPIRVId resultAddr = nextId();
 401                             spirvBlock.add(new SPIRVOpInBoundsPtrAccessChain(arrayType, resultAddr, elementBase, indexX, new SPIRVMultipleOperands<>()));
 402                             SPIRVId result = nextId();
 403                             spirvBlock.add(new SPIRVOpLoad(elementType, result, resultAddr, align(elementType.getName())));
 404                             addResult(call.result(), new SpirvResult(elementType, resultAddr, result));
 405                         }
 406                         else if (methodRef.equals(MethodRef.ofString("hat.buffer.S32Array2D::array(long)int")) ||
 407                                  methodRef.equals(MethodRef.ofString("hat.buffer.F32Array2D::array(long)float")))
 408                         {
 409                             SPIRVId longType = getType("long");
 410                             String arrayTypeName = call.operands().get(0).type().toString();
 411                             SpirvResult arrayResult = getResult(call.operands().get(0));
 412                             SPIRVId arrayAddr = arrayResult.address();
 413                             SPIRVId arrayType = spirvType(arrayTypeName);
 414                             SPIRVId elementType = spirvElementType(arrayTypeName);
 415                             int nIndexes = call.operands().size() - 1;
 416                             SPIRVId indexX = getResult(call.operands().get(1)).value();
 417                             SPIRVId array = nextId();
 418                             SPIRVId temp1 = nextId();
 419                             SPIRVId temp2 = nextId();
 420                             spirvBlock.add(new SPIRVOpConvertPtrToU(longType, temp1, array));
 421                             spirvBlock.add(new SPIRVOpIAdd(longType, temp2, temp1, getConst("long", 4)));
 422                             SPIRVId elementBase = nextId();
 423                             spirvBlock.add(new SPIRVOpConvertUToPtr(arrayType, elementBase, temp2));
 424                             spirvBlock.add(new SPIRVOpLoad(arrayType, array, arrayAddr, align(arrayType.getName())));
 425                             SPIRVId resultAddr = nextId();
 426                             spirvBlock.add(new SPIRVOpInBoundsPtrAccessChain(arrayType, resultAddr, elementBase, indexX, new SPIRVMultipleOperands<>()));
 427                             SPIRVId result = nextId();
 428                             spirvBlock.add(new SPIRVOpLoad(elementType, result, resultAddr, align(elementType.getName())));
 429                             addResult(call.result(), new SpirvResult(elementType, resultAddr, result));
 430                         }
 431                         else if (methodRef.equals(MethodRef.ofString("hat.buffer.S32Array::array(long, int)void"))) {
 432                             SPIRVId longType = getType("long");
 433                             String arrayTypeName = call.operands().get(0).type().toString();
 434                             SpirvResult arrayResult = getResult(call.operands().get(0));
 435                             SPIRVId arrayAddr = arrayResult.address();
 436                             SPIRVId arrayType = spirvType(arrayTypeName);
 437                             SPIRVId elementType = spirvElementType(arrayTypeName);
 438                             int nIndexes = call.operands().size() - 2;
 439                             int valueIndex = nIndexes + 1;
 440                             SPIRVId indexX = getResult(call.operands().get(1)).value();
 441                             SPIRVId array = nextId();
 442                             spirvBlock.add(new SPIRVOpLoad(arrayType, array, arrayAddr, align(arrayType.getName())));
 443                             SPIRVId temp1 = nextId();
 444                             SPIRVId temp2 = nextId();
 445                             spirvBlock.add(new SPIRVOpConvertPtrToU(longType, temp1, array));
 446                             spirvBlock.add(new SPIRVOpIAdd(longType, temp2, temp1, getConst("long", 4)));
 447                             SPIRVId elementBase = nextId();
 448                             spirvBlock.add(new SPIRVOpConvertUToPtr(arrayType, elementBase, temp2));
 449                             SPIRVId dest = nextId();
 450                             spirvBlock.add(new SPIRVOpInBoundsPtrAccessChain(arrayType, dest, elementBase, indexX, new SPIRVMultipleOperands<>()));
 451                             SPIRVId value = getResult(call.operands().get(valueIndex)).value();
 452                             spirvBlock.add(new SPIRVOpStore(dest, value, align(elementType.getName())));
 453                         }
 454                         else if (methodRef.equals(MethodRef.ofString("hat.buffer.S32Array2D::array(long, int)void")) ||
 455                                  methodRef.equals(MethodRef.ofString("hat.buffer.F32Array2D::array(long, float)void"))) {
 456                             SPIRVId longType = getType("long");
 457                             String arrayTypeName = call.operands().get(0).type().toString();
 458                             SpirvResult arrayResult = getResult(call.operands().get(0));
 459                             SPIRVId arrayAddr = arrayResult.address();
 460                             SPIRVId arrayType = spirvType(arrayTypeName);
 461                             SPIRVId elementType = spirvElementType(arrayTypeName);
 462                             int nIndexes = call.operands().size() - 2;
 463                             int valueIndex = nIndexes + 1;
 464                             SPIRVId indexX = getResult(call.operands().get(1)).value();
 465                             SPIRVId array = nextId();
 466                             spirvBlock.add(new SPIRVOpLoad(arrayType, array, arrayAddr, align(arrayType.getName())));
 467                             SPIRVId temp1 = nextId();
 468                             SPIRVId temp2 = nextId();
 469                             spirvBlock.add(new SPIRVOpConvertPtrToU(longType, temp1, array));
 470                             spirvBlock.add(new SPIRVOpIAdd(longType, temp2, temp1, getConst("long", 8)));
 471                             SPIRVId elementBase = nextId();
 472                             spirvBlock.add(new SPIRVOpConvertUToPtr(arrayType, elementBase, temp2));
 473                             SPIRVId dest = nextId();
 474                             spirvBlock.add(new SPIRVOpInBoundsPtrAccessChain(arrayType, dest, elementBase, indexX, new SPIRVMultipleOperands<>()));
 475                             SPIRVId value = getResult(call.operands().get(valueIndex)).value();
 476                             spirvBlock.add(new SPIRVOpStore(dest, value, align(elementType.getName())));
 477                         }
 478                         else if (methodRef.equals(MethodRef.ofString("hat.buffer.S32Array::length()int")) ||
 479                                  methodRef.equals(MethodRef.ofString("hat.buffer.S32Array2D::width()int"))||
 480                                  methodRef.equals(MethodRef.ofString("hat.buffer.F32Array2D::width()int"))) {
 481                             SPIRVId intType = getType("int");
 482                             String arrayTypeName = call.operands().get(0).type().toString();
 483                             SpirvResult arrayResult = getResult(call.operands().get(0));
 484                             SPIRVId arrayAddr = arrayResult.address();
 485                             SPIRVId arrayType = spirvType(arrayTypeName);
 486                             SPIRVId array = nextId();
 487                             spirvBlock.add(new SPIRVOpLoad(arrayType, array, arrayAddr, align(arrayType.getName())));
 488                             SPIRVId resultAddr = nextId();
 489                             spirvBlock.add(new SPIRVOpInBoundsPtrAccessChain(arrayType, resultAddr, array, getConst("int", 0), new SPIRVMultipleOperands<>()));
 490                             SPIRVId result = nextId();
 491                             spirvBlock.add(new SPIRVOpLoad(intType, result, resultAddr, align(arrayType.getName())));
 492                             addResult(call.result(), new SpirvResult(intType, resultAddr, result));
 493                         }
 494                         else if (methodRef.equals(MethodRef.ofString("hat.buffer.S32Array2D::height()int")) ||
 495                                  methodRef.equals(MethodRef.ofString("hat.buffer.F32Array2D::height()int"))) {
 496                             SPIRVId intType = getType("int");
 497                             String arrayTypeName = call.operands().get(0).type().toString();
 498                             SpirvResult arrayResult = getResult(call.operands().get(0));
 499                             SPIRVId arrayAddr = arrayResult.address();
 500                             SPIRVId arrayType = spirvType(arrayTypeName);
 501                             SPIRVId array = nextId();
 502                             spirvBlock.add(new SPIRVOpLoad(arrayType, array, arrayAddr, align(arrayType.getName())));
 503                             SPIRVId resultAddr = nextId();
 504                             spirvBlock.add(new SPIRVOpInBoundsPtrAccessChain(arrayType, resultAddr, array, getConst("int", 1), new SPIRVMultipleOperands<>()));
 505                             SPIRVId result = nextId();
 506                             spirvBlock.add(new SPIRVOpLoad(intType, result, resultAddr, align(arrayType.getName())));
 507                             addResult(call.result(), new SpirvResult(intType, resultAddr, result));
 508                         }
 509                         else if (methodRef.equals(MethodRef.ofString("java.lang.Math::sqrt(double)double"))) {
 510                             SPIRVId floatType = getType("float");
 511                             SPIRVId result = nextId();
 512                             SPIRVId operand = getResult(call.operands().get(0)).value();
 513                             spirvBlock.add(new SPIRVOpExtInst(floatType, result, getId("oclExtension"), new SPIRVLiteralExtInstInteger(61, "sqrt"), new SPIRVMultipleOperands<>(operand)));
 514                             addResult(call.result(), new SpirvResult(floatType, null, result));
 515                         }
 516                         else {
 517                             SPIRVId fnId = getFunctionId(methodRef);
 518                             if (fnId == null) {
 519                                 unsupported("method", methodRef);
 520                             }
 521                             else {
 522                                 FunctionType fnType = methodRef.type();
 523                                 SPIRVId[] args = new SPIRVId[call.operands().size()];
 524                                 for (int i = 0; i < args.length; i++) {
 525                                     SPIRVId argId = getResult(call.operands().get(i)).value();
 526                                     args[i] = argId;
 527                                 }
 528                                 SPIRVId returnType = spirvType(fnType.returnType().toString());
 529                                 SPIRVId callResult = nextId();
 530                                 spirvBlock.add(new SPIRVOpFunctionCall(returnType, callResult, fnId, new SPIRVMultipleOperands<>(args)));
 531                                 addResult(call.result(), new SpirvResult(returnType, null, callResult));
 532                             }
 533                         }
 534                     }
 535                     case SpirvOp.ConstantOp cop -> {
 536                         SPIRVId type = spirvType(cop.resultType().toString());
 537                         SPIRVId result = nextId();
 538                         Object value = cop.value();
 539                         if (type == getType("int")) {
 540                             module.add(new SPIRVOpConstant(type, result, new SPIRVContextDependentInt(new BigInteger(String.valueOf(value)))));
 541                         }
 542                         else if (type == getType("long")) {
 543                             module.add(new SPIRVOpConstant(type, result, new SPIRVContextDependentLong(new BigInteger(String.valueOf(value)))));
 544                         }
 545                         else if (type == getType("float")) {
 546                             module.add(new SPIRVOpConstant(type, result, new SPIRVContextDependentFloat((float)value)));
 547                         }
 548                         else if (type == getType("bool")) {
 549                             module.add(((boolean)value) ? new SPIRVOpConstantTrue(type, result) : new SPIRVOpConstantFalse(type, result));
 550                         }
 551                         else if (type == getType("java.lang.Object")) {
 552                             module.add(new SPIRVOpConstantNull(type, result));
 553                         }
 554                         else if (type == getType("int[]")) {
 555                             module.add(new SPIRVOpConstantNull(type, result));
 556                         }
 557                         else unsupported("type", cop.resultType());
 558                         addResult(cop.result(), new SpirvResult(type, null, result));
 559                     }
 560                     case SpirvOp.ConvertOp scop -> {
 561                         SPIRVId toType = spirvType(scop.resultType().toString());
 562                         SPIRVId to = nextId();
 563                         SpirvResult valueResult = getResult(scop.operands().get(0));
 564                         SPIRVId from = valueResult.value();
 565                         SPIRVId fromType = valueResult.type();
 566                         if (isIntegerType(fromType)) {
 567                             if (isIntegerType(toType)) {
 568                                 spirvBlock.add(new SPIRVOpSConvert(toType, to, from));
 569                             }
 570                             else if (isFloatType(toType)) {
 571                                 spirvBlock.add(new SPIRVOpConvertSToF(toType, to, from));
 572                             }
 573                             else unsupported("conversion type", scop.resultType());
 574                         }
 575                         else if (isFloatType(fromType)) {
 576                             if (isIntegerType(toType)) {
 577                                 spirvBlock.add(new SPIRVOpConvertFToS(toType, to, from));
 578                             }
 579                             else if (isFloatType(toType)) {
 580                                 spirvBlock.add(new SPIRVOpFConvert(toType, to, from));
 581                             }
 582                             else unsupported("conversion type", scop.resultType());
 583                         }
 584                         else unsupported("conversion type", scop.operands().get(0));
 585                         addResult(scop.result(), new SpirvResult(toType, null, to));
 586                     }
 587                     case SpirvOp.InBoundsAccessChainOp iacop -> {
 588                         SPIRVId type = spirvType(iacop.resultType().toString());
 589                         SPIRVId result = nextId();
 590                         SPIRVId object = getResult(iacop.operands().get(0)).value();
 591                         SPIRVId index = getResult(iacop.operands().get(1)).value();
 592                         spirvBlock.add(new SPIRVOpInBoundsPtrAccessChain(type, result, object, index, new SPIRVMultipleOperands<>()));
 593                         addResult(iacop.result(), new SpirvResult(type, result, null));
 594                     }
 595                     case SpirvOp.FieldLoadOp flo -> {
 596                         if (flo.operands().size() > 0 && (flo.operands().get(0).type().equals(JavaType.ofString("hat.KernelContext")))) {
 597                             SpirvResult result;
 598                             int group = -1;
 599                             int index = -1;
 600                             String fieldName = flo.fieldDescriptor().name();
 601                             switch (fieldName) {
 602                                 case "x": group = 0; index = 0; break;
 603                                 case "y": group = 0; index = 1; break;
 604                                 case "z": group = 0; index = 2; break;
 605                                 case "maxX": group = 1; index = 0; break;
 606                                 case "maxY": group = 1; index = 1; break;
 607                                 case "maxZ": group = 1; index = 2; break;
 608                             }
 609                             switch (group) {
 610                                 case 0: result = globalId(index, spirvBlock); break;
 611                                 case 1: result = globalSize(index, spirvBlock); break;
 612                                 default: throw new RuntimeException("Unknown Index field: " + fieldName);
 613                             }
 614                             addResult(flo.result(), result);
 615                         }
 616                         else if (flo.operands().get(0).type().equals(JavaType.ofString("hat.KernelContext"))) {
 617                             String fieldName = flo.fieldDescriptor().name();
 618                             SPIRVId fieldIndex = switch (fieldName) {
 619                                 case "x" -> getConst("long", 0);
 620                                 case "maxX" -> getConst("long", 1);
 621                                 default -> throw new RuntimeException("Unknown field: " + fieldName);
 622                             };
 623                             SPIRVId intType = getType("int");
 624                             String contextTypeName = flo.operands().get(0).type().toString();
 625                             SpirvResult kernalContext = getResult(flo.operands().get(0));
 626                             SPIRVId contextAddr = kernalContext.address();
 627                             SPIRVId contextType = spirvType(contextTypeName);
 628                             SPIRVId context = nextId();
 629                             spirvBlock.add(new SPIRVOpLoad(contextType, context, contextAddr, align(contextType.getName())));
 630                             SPIRVId fieldType = intType;
 631                             SPIRVId resultAddr = nextId();
 632                             spirvBlock.add(new SPIRVOpInBoundsAccessChain(getType("ptrInt"), resultAddr, context, new SPIRVMultipleOperands<>(fieldIndex)));
 633                             SPIRVId result = nextId();
 634                             spirvBlock.add(new SPIRVOpLoad(intType, result, resultAddr, align("int")));
 635                             addResult(flo.result(), new SpirvResult(intType, resultAddr, result));
 636                         }
 637                         else if (flo.fieldDescriptor().refType().equals(JavaType.type(ByteOrder.class))) {
 638                             // currently ignored
 639                         }
 640                         else unsupported("field load", ((ClassType)flo.fieldDescriptor().refType()).toClassName() + "." + flo.fieldDescriptor().name());
 641                     }
 642                     case SpirvOp.BranchOp bop -> {
 643                         SPIRVId label = symbols.getLabel(bop.branch().targetBlock()).getResultId();
 644                         Block.Reference target = bop.branch();
 645                         spirvBlock.add(new SPIRVOpBranch(label));
 646                     }
 647                     case SpirvOp.ConditionalBranchOp cbop -> {
 648                         SPIRVId test = getResult(cbop.operands().get(0)).value();
 649                         SPIRVId trueLabel = symbols.getLabel(cbop.trueBranch().targetBlock()).getResultId();
 650                         SPIRVId falseLabel = symbols.getLabel(cbop.falseBranch().targetBlock()).getResultId();
 651                         spirvBlock.add(new SPIRVOpBranchConditional(test, trueLabel, falseLabel, new SPIRVMultipleOperands<SPIRVLiteralInteger>()));
 652                     }
 653                     case SpirvOp.LtOp ltop -> {
 654                         SpirvResult lhs = getResult(ltop.operands().get(0));
 655                         SpirvResult rhs = getResult(ltop.operands().get(1));
 656                         SPIRVId boolType = getType("bool");
 657                         SPIRVId result = nextId();
 658                         String operandType = lhs.type().getName();
 659                         SPIRVInstruction sop = switch (operandType) {
 660                             case "float" -> new SPIRVOpFUnordLessThan(boolType, result, lhs.value(), rhs.value());
 661                             case "int" -> new SPIRVOpSLessThan(boolType, result, lhs.value(), rhs.value());
 662                             case "long" -> new SPIRVOpSLessThan(boolType, result, lhs.value(), rhs.value());
 663                             default -> throw new RuntimeException("Unsupported type: " + lhs.type().getName());
 664                         };
 665                         spirvBlock.add(sop);
 666                         addResult(ltop.result(), new SpirvResult(boolType, null, result));
 667                     }
 668                     case SpirvOp.ReturnOp rop -> {
 669                         spirvBlock.add(new SPIRVOpReturn());
 670                     }
 671                     case SpirvOp.ReturnValueOp rop -> {
 672                         SPIRVId returnValue = getResult(rop.operands().get(0)).value();
 673                         spirvBlock.add(new SPIRVOpReturnValue(returnValue));
 674                     }
 675                     default -> unsupported("op", op.getClass());
 676                 }
 677             }
 678         } // end bi
 679     }
 680 
 681     private SPIRVId getFunctionId(MethodRef methodRef) {
 682         SPIRVId fnId = symbols.getId(methodRef.toString());
 683         if (fnId == null) {
 684             try {
 685                 Optional<CoreOp.FuncOp> optJFuncOp = methodRef.codeModel(MethodHandles.lookup());
 686                 if (optJFuncOp.isPresent()) {
 687                     CoreOp.FuncOp jFuncOp = optJFuncOp.get();
 688                     SpirvOp.FuncOp sFuncOp = TranslateToSpirvModel.translateFunction(jFuncOp);
 689                     fnId = generateFunction(jFuncOp.funcName(), sFuncOp, false);
 690                     symbols.putId(methodRef.toString(), fnId);
 691                 }
 692             }
 693             catch (ReflectiveOperationException e) {
 694                 throw new RuntimeException(e);
 695             }
 696         }
 697         return fnId;
 698     }
 699 
 700     private void initModule() {
 701         module.add(new SPIRVOpCapability(SPIRVCapability.Addresses()));
 702         module.add(new SPIRVOpCapability(SPIRVCapability.Linkage()));
 703         module.add(new SPIRVOpCapability(SPIRVCapability.Kernel()));
 704         module.add(new SPIRVOpCapability(SPIRVCapability.Int8()));
 705         module.add(new SPIRVOpCapability(SPIRVCapability.Int64()));
 706         module.add(new SPIRVOpMemoryModel(SPIRVAddressingModel.Physical64(), SPIRVMemoryModel.OpenCL()));
 707 
 708         // OpenCL extension provides built-in variables suitable for kernel programming
 709         // Import extension and declare four variables
 710         SPIRVId oclExtension = nextId("oclExtension");
 711         module.add(new SPIRVOpExtInstImport(oclExtension, new SPIRVLiteralString("OpenCL.std")));
 712         symbols.putId("oclExtension", oclExtension);
 713 
 714         SPIRVId globalInvocationId = nextId("globalInvocationId");
 715         SPIRVId globalSize = nextId("globalSize");
 716         SPIRVId subgroupSize = nextId("subgroupSize");
 717         SPIRVId subgroupId = nextId("subgroupId");
 718 
 719         module.add(new SPIRVOpDecorate(globalInvocationId, SPIRVDecoration.BuiltIn(SPIRVBuiltIn.GlobalInvocationId())));
 720         module.add(new SPIRVOpDecorate(globalInvocationId, SPIRVDecoration.Constant()));
 721         module.add(new SPIRVOpDecorate(globalInvocationId, SPIRVDecoration.LinkageAttributes(new SPIRVLiteralString("spirv_BuiltInGlobalInvocationId"), SPIRVLinkageType.Import())));
 722         module.add(new SPIRVOpDecorate(globalSize, SPIRVDecoration.BuiltIn(SPIRVBuiltIn.GlobalSize())));
 723         module.add(new SPIRVOpDecorate(globalSize, SPIRVDecoration.Constant()));
 724         module.add(new SPIRVOpDecorate(globalSize, SPIRVDecoration.LinkageAttributes(new SPIRVLiteralString("spirv_BuiltInGlobalSize"), SPIRVLinkageType.Import())));
 725         module.add(new SPIRVOpDecorate(subgroupSize, SPIRVDecoration.BuiltIn(SPIRVBuiltIn.SubgroupSize())));
 726         module.add(new SPIRVOpDecorate(subgroupSize, SPIRVDecoration.Constant()));
 727         module.add(new SPIRVOpDecorate(subgroupSize, SPIRVDecoration.LinkageAttributes(new SPIRVLiteralString("spirv_BuiltInSubgroupSize"), SPIRVLinkageType.Import())));
 728         module.add(new SPIRVOpDecorate(subgroupId, SPIRVDecoration.BuiltIn(SPIRVBuiltIn.SubgroupId())));
 729         module.add(new SPIRVOpDecorate(subgroupId, SPIRVDecoration.Constant()));
 730         module.add(new SPIRVOpDecorate(subgroupId, SPIRVDecoration.LinkageAttributes(new SPIRVLiteralString("spirv_BuiltInSubgroupId"), SPIRVLinkageType.Import())));
 731 
 732         module.add(new SPIRVOpVariable(getType("ptrV3long"), globalInvocationId, SPIRVStorageClass.Input(), new SPIRVOptionalOperand<>()));
 733         module.add(new SPIRVOpVariable(getType("ptrV3long"), globalSize, SPIRVStorageClass.Input(), new SPIRVOptionalOperand<>()));
 734         module.add(new SPIRVOpVariable(getType("ptrV3long"), subgroupSize, SPIRVStorageClass.Input(), new SPIRVOptionalOperand<>()));
 735         module.add(new SPIRVOpVariable(getType("ptrV3long"), subgroupId, SPIRVStorageClass.Input(), new SPIRVOptionalOperand<>()));
 736     }
 737 
 738     private SPIRVId spirvType(String inputType) {
 739         String javaType = inputType.replaceAll("\\$", ".");
 740         SPIRVId ans = switch(javaType) {
 741             case "byte" -> getType("byte");
 742             case "short" -> getType("short");
 743             case "int" -> getType("int");
 744             case "long" -> getType("long");
 745             case "float" -> getType("float");
 746             case "double" -> getType("double");
 747             case "byte[]" -> getType("byte[]");
 748             case "int[]" -> getType("int[]");
 749             case "float[]" -> getType("float[]");
 750             case "double[]" -> getType("double[]");
 751             case "long[]" -> getType("long[]");
 752             case "bool" -> getType("bool");
 753             case "boolean" -> getType("bool");
 754             case "java.lang.Object" -> getType("java.lang.Object");
 755             case "hat.buffer.S32Array" -> getType("int[]");
 756             case "hat.buffer.S32Array2D" -> getType("int[]");
 757             case "hat.buffer.F32Array2D" -> getType("float[]");
 758             case "void" -> getType("void");
 759             case "hat.KernelContext" -> getType("ptrKernelContext");
 760             case "java.lang.foreign.MemorySegment" -> getType("ptrByte");
 761             default -> null;
 762         };
 763         if (ans == null) unsupported("type", javaType);
 764         return ans;
 765     }
 766 
 767     private SPIRVId spirvElementType(String inputType) {
 768         String javaType = inputType.replaceAll("\\$", ".");
 769         SPIRVId ans = switch(javaType) {
 770             case "byte[]" -> getType("byte");
 771             case "short[]" -> getType("short");
 772             case "int[]" -> getType("int");
 773             case "long[]" -> getType("long");
 774             case "float[]" -> getType("float");
 775             case "double[]" -> getType("double");
 776             case "boolean[]" -> getType("bool");
 777             case "hat.buffer.S32Array" -> getType("int");
 778             case "hat.buffer.S32Array2D" -> getType("int");
 779             case "hat.buffer.F32Array2D" -> getType("float");
 780             case "java.lang.foreign.MemorySegment" -> getType("byte");
 781             default -> null;
 782         };
 783         if (ans == null) unsupported("type", javaType);
 784         return ans;
 785     }
 786 
 787     private SPIRVId vectorElementType(SPIRVId type) {
 788         SPIRVId ans = switch(type.getName()) {
 789             case "v8int" -> getType("int");
 790             case "v16int" -> getType("int");
 791             case "v8long" -> getType("long");
 792             case "v8float" -> getType("float");
 793             case "v16float" -> getType("float");
 794             default -> null;
 795         };
 796         if (ans == null) unsupported("type", type.getName());
 797         return ans;
 798     }
 799 
 800     private SPIRVId spirvVariableType(SPIRVId spirvType) {
 801         SPIRVId ans = switch(spirvType.getName()) {
 802             case "bool" -> getType("ptrBool");
 803             case "byte" -> getType("ptrByte");
 804             case "short" -> getType("ptrShort");
 805             case "int" -> getType("ptrInt");
 806             case "long" -> getType("ptrLong");
 807             case "float" -> getType("ptrFloat");
 808             case "double" -> getType("ptrDouble");
 809             case "boolean" -> getType("ptrBool");
 810             case "byte[]" -> getType("ptrByte[]");
 811             case "int[]" -> getType("ptrInt[]");
 812             case "long[]" -> getType("ptrLong[]");
 813             case "float[]" -> getType("ptrFloat[]");
 814             case "double[]" -> getType("ptrDouble[]");
 815             case "v8int" -> getType("ptrV8int");
 816             case "v16int" -> getType("ptrV16int");
 817             case "v8long" -> getType("ptrV8long");
 818             case "v8float" -> getType("ptrV8float");
 819             case "v16float" -> getType("ptrV16float");
 820             case "ptrKernelContext" -> getType("ptrPtrKernelContext");
 821             case "hat.KernelContext" -> getType("ptrKernelContext");
 822             case "ptrByte" -> getType("ptrPtrByte");
 823             default -> null;
 824         };
 825         if (ans == null) unsupported("type", spirvType.getName());
 826         return ans;
 827     }
 828 
 829     private SPIRVId spirvVectorType(String javaVectorType, int vectorLength) {
 830         String prefix = "v" + vectorLength;
 831         String elementType = spirvElementType(javaVectorType).getName();
 832         return getType(prefix + elementType);
 833     }
 834 
 835     private int alignment(String inputType) {
 836         String spirvType = inputType.replaceAll("\\$", ".");
 837         int ans = switch(spirvType) {
 838             case "bool" -> 1;
 839             case "byte" -> 1;
 840             case "short" -> 2;
 841             case "int" -> 4;
 842             case "long" -> 8;
 843             case "float" -> 4;
 844             case "double" -> 8;
 845             case "boolean" -> 1;
 846             case "v8int" -> 32;
 847             case "v16int" -> 64;
 848             case "v8long" -> 64;
 849             case "v8float" -> 32;
 850             case "v16float" -> 64;
 851             case "hat.KernelContext" -> 32;
 852             case "ptrKernelContext" -> 32;
 853             case "byte[]" -> 8;
 854             case "int[]" -> 8;
 855             case "long[]" -> 8;
 856             case "float[]" -> 8;
 857             case "double[]" -> 8;
 858             case "ptrBool" -> 8;
 859             case "ptrByte" -> 8;
 860             case "ptrInt" -> 8;
 861             case "ptrByte[]" -> 8;
 862             case "ptrInt[]" -> 8;
 863             case "ptrLong" -> 8;
 864             case "ptrLong[]" -> 8;
 865             case "ptrFloat" -> 8;
 866             case "ptrFloat[]" -> 8;
 867             case "ptrV8int" -> 8;
 868             case "ptrV8float" -> 8;
 869             case "ptrPtrKernelContext" -> 8;
 870             default -> 0;
 871         };
 872         if (ans == 0) unsupported("type", spirvType);
 873         return ans;
 874     }
 875 
 876     private Set<String> moduleTypes = new HashSet<>();
 877 
 878     private SPIRVId getType(String inputName) {
 879         String name = inputName.replaceAll("\\$", ".");
 880         if (!moduleTypes.contains(name)) {
 881             switch (name) {
 882                 case "void" -> module.add(new SPIRVOpTypeVoid(nextId(name)));
 883                 case "bool" -> module.add(new SPIRVOpTypeBool(nextId(name)));
 884                 case "boolean" -> module.add(new SPIRVOpTypeBool(nextId(name)));
 885                 case "byte" -> module.add(new SPIRVOpTypeInt(nextId(name), new SPIRVLiteralInteger(8), new SPIRVLiteralInteger(0)));
 886                 case "short" -> module.add(new SPIRVOpTypeInt(nextId(name), new SPIRVLiteralInteger(16), new SPIRVLiteralInteger(0)));
 887                 case "int" -> module.add(new SPIRVOpTypeInt(nextId(name), new SPIRVLiteralInteger(32), new SPIRVLiteralInteger(0)));
 888                 case "long" -> module.add(new SPIRVOpTypeInt(nextId(name), new SPIRVLiteralInteger(64), new SPIRVLiteralInteger(0)));
 889                 case "float" -> module.add(new SPIRVOpTypeFloat(nextId(name), new SPIRVLiteralInteger(32)));
 890                 case "double" -> module.add(new SPIRVOpTypeFloat(nextId(name), new SPIRVLiteralInteger(64)));
 891                 case "ptrBool" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.CrossWorkgroup(), getType("bool")));
 892                 case "ptrByte" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.CrossWorkgroup(), getType("byte")));
 893                 case "ptrInt" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("int")));
 894                 case "ptrLong" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("long")));
 895                 case "ptrFloat" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("float")));
 896                 case "byte[]" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.CrossWorkgroup(), getType("byte")));
 897                 case "short[]" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.CrossWorkgroup(), getType("short")));
 898                 case "int[]" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.CrossWorkgroup(), getType("int")));
 899                 case "long[]" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.CrossWorkgroup(), getType("long")));
 900                 case "float[]" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.CrossWorkgroup(), getType("float")));
 901                 case "double[]" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.CrossWorkgroup(), getType("double")));
 902                 case "boolean[]" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.CrossWorkgroup(), getType("boolean")));
 903                 case "ptrByte[]" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("byte[]")));
 904                 case "ptrInt[]" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("int[]")));
 905                 case "ptrLong[]" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("long[]")));
 906                 case "ptrFloat[]" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("float[]")));
 907                 case "java.lang.Object" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("void")));
 908                 case "hat.KernelContext" -> module.add(new SPIRVOpTypeStruct(nextId(name), new SPIRVMultipleOperands<>(getType("int"), getType("int"))));
 909                 case "ptrKernelContext" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("hat.KernelContext")));
 910                 case "ptrCrossGroupByte"-> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.CrossWorkgroup(), getType("byte")));
 911                 case "ptrPtrKernelContext" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("ptrKernelContext")));
 912                 case "ptrPtrByte" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("ptrByte")));
 913                 case "ptrPtrInt" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("ptrInt")));
 914                 case "ptrPtrFloat" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("ptrFloat")));
 915                 case "v3long" -> module.add(new SPIRVOpTypeVector(nextId(name), getType("long"), new SPIRVLiteralInteger(3)));
 916                 case "v8int" -> module.add(new SPIRVOpTypeVector(nextId(name), getType("int"), new SPIRVLiteralInteger(8)));
 917                 case "v8long" -> module.add(new SPIRVOpTypeVector(nextId(name), getType("long"), new SPIRVLiteralInteger(8)));
 918                 case "v16int" -> module.add(new SPIRVOpTypeVector(nextId(name), getType("int"), new SPIRVLiteralInteger(16)));
 919                 case "v8float" -> module.add(new SPIRVOpTypeVector(nextId(name), getType("float"), new SPIRVLiteralInteger(8)));
 920                 case "v16float" -> module.add(new SPIRVOpTypeVector(nextId(name), getType("float"), new SPIRVLiteralInteger(16)));
 921                 case "ptrV3long" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Input(), getType("v3long")));
 922                 case "ptrV8long" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("v8long")));
 923                 case "ptrV8int" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("v8int")));
 924                 case "ptrV16int" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("v16int")));
 925                 case "ptrV8float" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("v8float")));
 926                 case "ptrV16float" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("v16float")));
 927                 case "ptrPtrV8int" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("ptrV8int")));
 928                 case "ptrPtrV16int" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("ptrV16int")));
 929                 case "ptrPtrV8float" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("ptrV8float")));
 930                 case "ptrPtrV16float" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("ptrV16float")));
 931                 default -> unsupported("type", name);
 932             }
 933             moduleTypes.add(name);
 934         }
 935         return getId(name);
 936     }
 937 
 938     private Set<String> moduleConstants = new HashSet<>();
 939 
 940     private SPIRVId getConst(String typeName, long value) {
 941         String name = typeName + "_" + value;
 942         if (!moduleConstants.contains(name)) {
 943             String valueStr = String.valueOf(value);
 944             switch (typeName) {
 945                 case "int" -> module.add(new SPIRVOpConstant(getType(typeName), nextId(name), new SPIRVContextDependentInt(new BigInteger(valueStr))));
 946                 case "long" -> module.add(new SPIRVOpConstant(getType(typeName), nextId(name), new SPIRVContextDependentLong(new BigInteger(valueStr))));
 947                 case "boolean" -> module.add(value == 0 ? new SPIRVOpConstantFalse(getType(typeName), nextId(name)) : new SPIRVOpConstantTrue(getType(typeName), nextId(name)));
 948                 default -> unsupported("constant", name);
 949             };
 950             moduleConstants.add(name);
 951         }
 952         return getId(name);
 953     }
 954 
 955     private SPIRVOptionalOperand<SPIRVMemoryAccess> align(int align) {
 956         return new SPIRVOptionalOperand<>(SPIRVMemoryAccess.Aligned(new SPIRVLiteralInteger(align)));
 957     }
 958 
 959     private SPIRVOptionalOperand<SPIRVMemoryAccess> align(String type) {
 960         return align(alignment(type));
 961     }
 962 
 963     private SPIRVMultipleOperands<SPIRVId> spirvOperands(SPIRVId value, int count) {
 964         SPIRVId[] values = new SPIRVId[count];
 965         Arrays.fill(values, value);
 966         return new SPIRVMultipleOperands<>(values);
 967     }
 968 
 969     private SPIRVOptionalOperand<SPIRVMemoryAccess> none() {
 970         return new SPIRVOptionalOperand<>();
 971     }
 972 
 973     private SpirvResult globalSize(int index, SPIRVBlock spirvBlock) {
 974         SPIRVId intType = getType("int");
 975         SPIRVId longType = getType("long");
 976         SPIRVId v3long = getId("v3long");
 977         SPIRVId globalSizeId = getId("globalSize");
 978         SPIRVId globalSizes = nextId();
 979         spirvBlock.add(new SPIRVOpLoad(v3long, globalSizes, globalSizeId, align(32)));
 980         SPIRVId longSize = nextId();
 981         SPIRVId globalSize = nextId();
 982         spirvBlock.add(new SPIRVOpCompositeExtract(longType, longSize, globalSizes, new SPIRVMultipleOperands<>(new SPIRVLiteralInteger(index))));
 983         spirvBlock.add(new SPIRVOpSConvert(intType, globalSize, longSize));
 984         return new SpirvResult(intType, null, globalSize);
 985     }
 986 
 987     private SpirvResult globalId(int index, SPIRVBlock spirvBlock) {
 988         SPIRVId intType = getType("int");
 989         SPIRVId longType = getType("long");
 990         SPIRVId v3long = getId("v3long");
 991         SPIRVId globalInvocationId = getId("globalInvocationId");
 992         SPIRVId globalIds = nextId();
 993         spirvBlock.add(new SPIRVOpLoad(v3long, globalIds, globalInvocationId, align(32)));
 994         SPIRVId longIndex = nextId();
 995         SPIRVId globalIndex = nextId();
 996         spirvBlock.add(new SPIRVOpCompositeExtract(longType, longIndex, globalIds, new SPIRVMultipleOperands<>(new SPIRVLiteralInteger(index))));
 997         spirvBlock.add(new SPIRVOpSConvert(intType, globalIndex, longIndex));
 998         return new SpirvResult(intType, null, globalIndex);
 999     }
1000 
1001    private SpirvResult flatIndex(SPIRVId sizeX, SPIRVId sizeY, SPIRVId sizeZ, SPIRVId indexX, SPIRVId indexY, SPIRVId indexZ, SPIRVBlock spirvBlock)
1002     {
1003         SPIRVId longType = getType("long");
1004         SPIRVId xTerm0 = nextId();
1005         SPIRVId xTerm1 = nextId();
1006         SPIRVId yTerm = nextId();
1007         SPIRVId flat0 = nextId();
1008         SPIRVId flat1 = nextId();
1009         spirvBlock.add(new SPIRVOpIMul(longType, xTerm0, sizeY, sizeZ));
1010         spirvBlock.add(new SPIRVOpIMul(longType, xTerm1, xTerm0, indexX));
1011         spirvBlock.add(new SPIRVOpIMul(longType, yTerm, sizeZ, indexY));
1012         spirvBlock.add(new SPIRVOpIAdd(longType, flat0, xTerm1, yTerm));
1013         spirvBlock.add(new SPIRVOpIAdd(longType, flat1, flat0, indexZ));
1014         return new SpirvResult(longType, null, flat1);
1015     }
1016 
1017     private SPIRVId nextId() {
1018         return module.getNextId();
1019     }
1020 
1021     private SPIRVId nextId(String name) {
1022         SPIRVId ans = nextId();
1023         ans.setName(name);
1024         symbols.putId(name, ans);
1025         module.add(new SPIRVOpName(ans, new SPIRVLiteralString(name)));
1026         return ans;
1027     }
1028 
1029     private static int counter = 0;
1030 
1031     private String nextTempTag() {
1032         counter++;
1033         return "temp_" + counter + "_";
1034     }
1035 
1036     private boolean isIntegerType(SPIRVId type) {
1037         String name = type.getName();
1038         return name.equals("byte") || name.equals("short") || name.equals("int") || name.equals("long");
1039     }
1040 
1041     private boolean isFloatType(SPIRVId type) {
1042         String name = type.getName();
1043         return name.equals("float") || name.equals("double");
1044     }
1045 
1046     private boolean isVectorSpecies(String javaType) {
1047         return javaType.equals("VectorSpecies");
1048     }
1049 
1050     private boolean isVectorType(String javaType) {
1051         return javaType.equals("IntVector") || javaType.equals("FloatVector");
1052     }
1053 
1054     private void addId(String name, SPIRVId id) {
1055         symbols.putId(name, id);
1056     }
1057 
1058     private SPIRVId getId(String name) {
1059         SPIRVId ans = symbols.getId(name);
1060         assert ans != null : name + " not found";
1061         return ans;
1062     }
1063 
1064     private SPIRVId getIdOrNull(String name) {
1065         return symbols.getId(name);
1066     }
1067 
1068     private static Object map(Function<Object, Boolean> test, Object... args) {
1069         int len = args.length;
1070         assert len >= 2 && len % 2 == 0;
1071         int pairs = len / 2;
1072         for (int i = 0; i < pairs; i++) {
1073             if (test.apply(args[i])) return args[i + pairs];
1074         }
1075         throw new RuntimeException("No match: " + args[0]);
1076     }
1077 
1078     private void unsupported(String message, Object value) {
1079         throw new RuntimeException("Unsupported " + message + ": " + value);
1080     }
1081 
1082     private void addResult(Value value, SpirvResult result) {
1083         assert symbols.getResult(value) == null : "result already present";
1084         symbols.putResult(value, result);
1085     }
1086 
1087     private SpirvResult getResult(Value value) {
1088         return symbols.getResult(value);
1089     }
1090 
1091     private static class Symbols {
1092         private final HashMap<Value, SpirvResult> results;
1093         private final HashMap<String, SPIRVId> ids;
1094         private final HashMap<Block, SPIRVBlock> blocks;
1095         private final HashMap<Block, SPIRVOpLabel> labels;
1096 
1097         public Symbols() {
1098             this.results = new HashMap<>();
1099             this.ids = new HashMap<>();
1100             this.blocks = new HashMap<>();
1101             this.labels = new HashMap<>();
1102         }
1103 
1104         public void putId(String name, SPIRVId id) {
1105             ids.put(name, id);
1106         }
1107 
1108         public SPIRVId getId(String name) {
1109             return ids.get(name);
1110         }
1111 
1112         public void putBlock(Block block, SPIRVBlock spirvBlock) {
1113             blocks.put(block, spirvBlock);
1114         }
1115 
1116         public SPIRVBlock getBlock(Block block) {
1117             return blocks.get(block);
1118         }
1119 
1120         public void putLabel(Block block, SPIRVOpLabel spirvLabel) {
1121             labels.put(block, spirvLabel);
1122         }
1123 
1124         public SPIRVOpLabel getLabel(Block block) {
1125             return labels.get(block);
1126         }
1127 
1128         public void putResult(Value value, SpirvResult result) {
1129             results.put(value, result);
1130         }
1131 
1132         public SpirvResult getResult(Value value) {
1133             return results.get(value);
1134         }
1135 
1136         public String toString() {
1137             return String.format("results %s\n\nids %s\n\nblocks %s\nlabels %s\n", results.keySet(), ids.keySet(), blocks.keySet(), labels.keySet());
1138         }
1139     }
1140 
1141     public static void debug(String message, Object... args) {
1142         System.out.println(String.format(message, args));
1143     }
1144 }