1 /*
   2  * Copyright (c) 2024 Intel Corporation. All rights reserved.
   3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
   4  *
   5  * This code is free software; you can redistribute it and/or modify it
   6  * under the terms of the GNU General Public License version 2 only, as
   7  * published by the Free Software Foundation.  Oracle designates this
   8  * particular file as subject to the "Classpath" exception as provided
   9  * by Oracle in the LICENSE file that accompanied this code.
  10  *
  11  * This code is distributed in the hope that it will be useful, but WITHOUT
  12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
  13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
  14  * version 2 for more details (a copy is included in the LICENSE file that
  15  * accompanied this code).
  16  *
  17  * You should have received a copy of the GNU General Public License version
  18  * 2 along with this work; if not, write to the Free Software Foundation,
  19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
  20  *
  21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
  22  * or visit www.oracle.com if you need additional information or have any
  23  * questions.
  24  */
  25 
  26 package intel.code.spirv;
  27 
  28 import java.util.List;
  29 import java.util.ArrayList;
  30 import java.util.Arrays;
  31 import java.util.Map;
  32 import java.util.HashMap;
  33 import java.util.Set;
  34 import java.util.HashSet;
  35 import java.util.Optional;
  36 import java.util.function.Function;
  37 import java.io.IOException;
  38 import java.io.File;
  39 import java.io.FileOutputStream;
  40 import java.io.ByteArrayInputStream;
  41 import java.io.ByteArrayOutputStream;
  42 import java.io.PrintStream;
  43 import java.nio.ByteBuffer;
  44 import java.nio.ByteOrder;
  45 import java.nio.channels.FileChannel;
  46 import java.math.BigInteger;
  47 import java.lang.constant.ClassDesc;
  48 import java.lang.invoke.MethodHandles;
  49 import java.lang.foreign.MemorySegment;
  50 import java.lang.foreign.ValueLayout;
  51 import jdk.incubator.code.Block;
  52 import jdk.incubator.code.Body;
  53 import jdk.incubator.code.Op;
  54 import jdk.incubator.code.Value;
  55 import jdk.incubator.code.TypeElement;
  56 import jdk.incubator.code.type.MethodRef;
  57 import jdk.incubator.code.type.ClassType;
  58 import jdk.incubator.code.dialect.core.FunctionType;
  59 import jdk.incubator.code.dialect.java.JavaType;
  60 import jdk.incubator.code.dialect.java.CoreOp;
  61 import hat.buffer.*;
  62 import hat.callgraph.CallGraph;
  63 import hat.callgraph.KernelCallGraph;
  64 import hat.callgraph.KernelEntrypoint;
  65 import uk.ac.manchester.beehivespirvtoolkit.lib.SPIRVHeader;
  66 import uk.ac.manchester.beehivespirvtoolkit.lib.SPIRVModule;
  67 import uk.ac.manchester.beehivespirvtoolkit.lib.SPIRVFunction;
  68 import uk.ac.manchester.beehivespirvtoolkit.lib.SPIRVBlock;
  69 import uk.ac.manchester.beehivespirvtoolkit.lib.instructions.*;
  70 import uk.ac.manchester.beehivespirvtoolkit.lib.instructions.operands.*;
  71 import uk.ac.manchester.beehivespirvtoolkit.lib.disassembler.Disassembler;
  72 import uk.ac.manchester.beehivespirvtoolkit.lib.disassembler.SPIRVDisassemblerOptions;
  73 import uk.ac.manchester.beehivespirvtoolkit.lib.disassembler.SPVByteStreamReader;
  74 import intel.code.spirv.SpirvOp.PhiOp;
  75 
  76 public class SpirvModuleGenerator {
  77     private final String moduleName;
  78     private final SPIRVModule module;
  79     private final Symbols symbols;
  80 
  81     public static SpirvModuleGenerator create(String moduleName) {
  82         return new SpirvModuleGenerator(moduleName);
  83     }
  84 
  85     public static MemorySegment generateModule(String moduleName, KernelCallGraph callGraph) {
  86         SpirvModuleGenerator generator = SpirvModuleGenerator.create(moduleName);
  87         for (CallGraph.MethodCall call : callGraph.calls) {
  88             if (call.targetMethodRef != null) {
  89                 try {
  90                     Optional<CoreOp.FuncOp> ofo = call.targetMethodRef.codeModel(MethodHandles.lookup());
  91                     if (ofo.isPresent()) {
  92                         CoreOp.FuncOp fo = ofo.get();
  93                         SpirvOp.FuncOp spirvFunc = TranslateToSpirvModel.translateFunction(fo);
  94                     }
  95                 } catch (Exception e) {
  96                     throw new RuntimeException(e);
  97                 }
  98             }
  99         }
 100         KernelEntrypoint kernelEntrypoint = callGraph.entrypoint;
 101         CoreOp.FuncOp funcOp = kernelEntrypoint.funcOpWrapper().op();
 102         String kernelName = funcOp.funcName();
 103         SpirvOp.FuncOp spirvFunc = TranslateToSpirvModel.translateFunction(funcOp);
 104         generator.generateFunction(funcOp.funcName(), spirvFunc, true);
 105         return generator.finalizeModule();
 106     }
 107 
 108     public static MemorySegment generateModule(String moduleName, CoreOp.FuncOp func) {
 109         SpirvOp.FuncOp spirvFunc = TranslateToSpirvModel.translateFunction(func);
 110         MemorySegment moduleSegment = SpirvModuleGenerator.generateModule(moduleName, spirvFunc);
 111         return moduleSegment;
 112     }
 113 
 114     public static MemorySegment generateModule(String moduleName, SpirvOp.FuncOp func) {
 115         SpirvModuleGenerator generator = new SpirvModuleGenerator(moduleName);
 116         MemorySegment moduleSegment = generator.generateModuleInternal(func);
 117         return moduleSegment;
 118     }
 119 
 120     public static void writeModuleToFile(MemorySegment module, String filepath)  {
 121         ByteBuffer buffer = module.asByteBuffer();
 122         File out = new File(filepath);
 123         try (FileChannel channel = new FileOutputStream(out, false).getChannel()) {
 124             channel.write(buffer);
 125             channel.close();
 126         }
 127         catch (IOException e)  {
 128             throw new RuntimeException(e);
 129         }
 130     }
 131 
 132     private static void writeModuleToFile(SPIRVModule module, String filepath)
 133     {
 134         ByteBuffer buffer = ByteBuffer.allocate(module.getByteCount());
 135         buffer.order(ByteOrder.LITTLE_ENDIAN);
 136         module.close().write(buffer);
 137         buffer.flip();
 138         File out = new File(filepath);
 139         try (FileChannel channel = new FileOutputStream(out, false).getChannel())
 140         {
 141             channel.write(buffer);
 142         }
 143         catch (IOException e)
 144         {
 145             throw new RuntimeException(e);
 146         }
 147     }
 148 
 149     public static String disassembleModule(MemorySegment module) {
 150         SPVByteStreamReader input = new SPVByteStreamReader(new ByteArrayInputStream(module.toArray(ValueLayout.JAVA_BYTE)));
 151         ByteArrayOutputStream out = new ByteArrayOutputStream();
 152         try (PrintStream ps = new PrintStream(out))  {
 153             SPIRVDisassemblerOptions options = new SPIRVDisassemblerOptions(false, false, false, false, true);
 154             Disassembler dis = new Disassembler(input, ps, options);
 155             dis.run();
 156         }
 157         catch (Exception e) {
 158             throw new RuntimeException(e);
 159         }
 160         return new String(out.toByteArray());
 161     }
 162 
 163     public MemorySegment finalizeModule() {
 164         ByteBuffer buffer = ByteBuffer.allocateDirect(module.getByteCount());
 165         buffer.order(ByteOrder.LITTLE_ENDIAN);
 166         module.close().write(buffer);
 167         buffer.flip();
 168         return MemorySegment.ofBuffer(buffer);
 169     }
 170 
 171     private record SpirvResult(SPIRVId type, SPIRVId address, SPIRVId value) {}
 172 
 173     private SpirvModuleGenerator(String moduleName) {
 174         this.moduleName = moduleName;
 175         this.module = new SPIRVModule(new SPIRVHeader(1, 2, 32, 0, 0));
 176         this.symbols = new Symbols();
 177         initModule();
 178     }
 179 
 180     private MemorySegment generateModuleInternal(SpirvOp.FuncOp func) {
 181         generateFunction(moduleName, func, true);
 182         return finalizeModule();
 183     }
 184 
 185     private SPIRVId generateFunction(String fnName, SpirvOp.FuncOp func, boolean isEntryPoint) {
 186         TypeElement returnType = func.invokableType().returnType();
 187         SPIRVId functionId = nextId(fnName);
 188         String signature = func.invokableType().returnType().toString();
 189         List<TypeElement> paramTypes = func.invokableType().parameterTypes();
 190         // build signature string
 191         for (int i = 0; i < paramTypes.size(); i++) {
 192             signature += "_" + paramTypes.get(i).toString();
 193         }
 194         // declare function type if not already present
 195         SPIRVId functionSig = getIdOrNull(signature);
 196         if (functionSig == null) {
 197             SPIRVId[] typeIdsArray = new SPIRVId[paramTypes.size()];
 198             for (int i = 0; i < paramTypes.size(); i++) {
 199                 typeIdsArray[i] = spirvType(paramTypes.get(i).toString());
 200             }
 201             functionSig = nextId(fnName + "Signature");
 202             module.add(new SPIRVOpTypeFunction(functionSig, spirvType(returnType.toString()), new SPIRVMultipleOperands<>(typeIdsArray)));
 203             addId(signature, functionSig);
 204         }
 205         SPIRVId spirvReturnType = spirvType(returnType.toString());
 206         SPIRVFunction function = (SPIRVFunction)module.add(new SPIRVOpFunction(spirvReturnType, functionId, SPIRVFunctionControl.DontInline(), functionSig));
 207         SPIRVOpLabel entryLabel = new SPIRVOpLabel(nextId());
 208         SPIRVBlock entryBlock = (SPIRVBlock)function.add(entryLabel);
 209         SPIRVMultipleOperands<SPIRVId> operands = new SPIRVMultipleOperands<>(getId("globalInvocationId"), getId("globalSize"), getId("subgroupSize"), getId("subgroupId"));
 210         if (isEntryPoint) {
 211             module.add(new SPIRVOpEntryPoint(SPIRVExecutionModel.Kernel(), functionId, new SPIRVLiteralString(fnName), operands));
 212         }
 213         translateBody(func.body(), function, entryBlock);
 214         function.add(new SPIRVOpFunctionEnd());
 215         return functionId;
 216     }
 217 
 218     private void translateBody(Body body, SPIRVFunction function, SPIRVBlock entryBlock) {
 219         int labelNumber = 0;
 220         SPIRVBlock spirvBlock = entryBlock;
 221         for (int bi = 1; bi < body.blocks().size(); bi++)  {
 222             Block block = body.blocks().get(bi);
 223             SPIRVOpLabel blockLabel = new SPIRVOpLabel(nextId());
 224             SPIRVBlock newBlock = (SPIRVBlock)function.add(blockLabel);
 225             symbols.putBlock(block, newBlock);
 226             symbols.putLabel(block, blockLabel);
 227             for (int i = 0; i < block.parameters().size(); i++) {
 228                 Value param = block.parameters().get(i);
 229                 SPIRVId paramId = nextId();
 230                 addResult(param, new SpirvResult(spirvType(param.type().toString()), null, paramId));
 231             }
 232         }
 233         for (int bi = 0; bi < body.blocks().size(); bi++)  {
 234             Block block = body.blocks().get(bi);
 235             if (bi > 0) {
 236                 spirvBlock = symbols.getBlock(block);
 237             }
 238             for (Op op : block.ops()) {
 239                 switch (op)  {
 240                     case SpirvOp.PhiOp phop -> {
 241                         List<PhiOp.Predecessor> inPredecessors = phop.predecessors();
 242                         SPIRVPairIdRefIdRef[] outPredecessors = new SPIRVPairIdRefIdRef[inPredecessors.size()];
 243                         for (int i = 0; i < inPredecessors.size(); i++) {
 244                             PhiOp.Predecessor predecessor = inPredecessors.get(i);
 245                             SPIRVId label = symbols.getLabel(predecessor.block().targetBlock()).getResultId();
 246                             SPIRVId value = getResult(predecessor.value()).value();
 247                             outPredecessors[i] = new SPIRVPairIdRefIdRef(value, label);
 248                         }
 249                         SPIRVId result = nextId();
 250                         SPIRVId type = spirvType(phop.resultType().toString());
 251                         SPIRVOpPhi phiOp = new SPIRVOpPhi(spirvType(phop.resultType().toString()), result, new SPIRVMultipleOperands<>(outPredecessors));
 252                         spirvBlock.add(phiOp);
 253                         addResult(phop.result(), new SpirvResult(type, null, result));
 254                     }
 255                     case SpirvOp.VariableOp vop -> {
 256                         String typeName = vop.varType().toString();
 257                         SPIRVId type = spirvType(typeName);
 258                         SPIRVId varType = spirvVariableType(type);
 259                         SPIRVId var = nextId(vop.varName());
 260                         spirvBlock.add(new SPIRVOpVariable(varType, var, SPIRVStorageClass.Function(), new SPIRVOptionalOperand<>()));
 261                         addResult(vop.result(), new SpirvResult(varType, var, null));
 262                     }
 263                     case SpirvOp.FunctionParameterOp fpo -> {
 264                         SPIRVId result = nextId();
 265                         SPIRVId type = spirvType(fpo.resultType().toString());
 266                         function.add(new SPIRVOpFunctionParameter(type, result));
 267                         module.add(new SPIRVOpDecorate(result, SPIRVDecoration.Alignment(new SPIRVLiteralInteger(8))));
 268                         addResult(fpo.result(), new SpirvResult(type, null, result));
 269                     }
 270                     case SpirvOp.LoadOp lo -> {
 271                         SPIRVId type = spirvType(lo.resultType().toString());
 272                         SpirvResult toLoad = getResult(lo.operands().get(0));
 273                         SPIRVId varAddr = toLoad.address();
 274                         SPIRVId result = nextId();
 275                         spirvBlock.add(new SPIRVOpLoad(type, result, varAddr, align(type.getName())));
 276                         addResult(lo.result(), new SpirvResult(type, varAddr, result));
 277                     }
 278                     case SpirvOp.StoreOp so -> {
 279                         SpirvResult var = getResult(so.operands().get(0));
 280                         SPIRVId varAddr = var.address();
 281                         SPIRVId value = getResult(so.operands().get(1)).value();
 282                         spirvBlock.add(new SPIRVOpStore(varAddr, value, align(var.type().getName())));
 283                     }
 284                     case SpirvOp.IAddOp _, SpirvOp.FAddOp _ -> {
 285                         SPIRVId intType = getType("int");
 286                         SPIRVId longType = getType("long");
 287                         SPIRVId floatType = getType("float");
 288                         SPIRVId lhs = getResult(op.operands().get(0)).value();
 289                         SPIRVId rhs = getResult(op.operands().get(1)).value();
 290                         SPIRVId lhsType = spirvType(op.resultType().toString());
 291                         SPIRVId ans = nextId();
 292                         if (lhsType == intType) spirvBlock.add(new SPIRVOpIAdd(intType, ans, lhs, rhs));
 293                         else if (lhsType == longType) spirvBlock.add(new SPIRVOpIAdd(longType, ans, lhs, rhs));
 294                         else if (lhsType == floatType) spirvBlock.add(new SPIRVOpFAdd(floatType, ans, lhs, rhs));
 295                         else unsupported("type", lhsType.getName());
 296                         addResult(op.result(), new SpirvResult(lhsType, null, ans));
 297                     }
 298                     case SpirvOp.ISubOp _, SpirvOp.FSubOp _ -> {
 299                         SPIRVId intType = getType("int");
 300                         SPIRVId longType = getType("long");
 301                         SPIRVId floatType = getType("float");
 302                         SPIRVId lhs = getResult(op.operands().get(0)).value();
 303                         SPIRVId rhs = getResult(op.operands().get(1)).value();
 304                         SPIRVId lhsType = spirvType(op.resultType().toString());
 305                         SPIRVId ans = nextId();
 306                         if (lhsType == intType) spirvBlock.add(new SPIRVOpISub(intType, ans, lhs, rhs));
 307                         else if (lhsType == longType) spirvBlock.add(new SPIRVOpISub(longType, ans, lhs, rhs));
 308                         else if (lhsType == floatType) spirvBlock.add(new SPIRVOpFSub(floatType, ans, lhs, rhs));
 309                         else unsupported("type", lhsType.getName());
 310                         addResult(op.result(), new SpirvResult(lhsType, null, ans));
 311                     }
 312                     case SpirvOp.IMulOp _, SpirvOp.FMulOp _, SpirvOp.IDivOp _, SpirvOp.FDivOp _ -> {
 313                         SPIRVId intType = getType("int");
 314                         SPIRVId longType = getType("long");
 315                         SPIRVId floatType = getType("float");
 316                         SPIRVId lhs = getResult(op.operands().get(0)).value();
 317                         SPIRVId rhs = getResult(op.operands().get(1)).value();
 318                         SPIRVId lhsType = spirvType(op.resultType().toString());
 319                         SPIRVId rhsType = getResult(op.operands().get(1)).type();
 320                         SPIRVId ans = nextId();
 321                         if (lhsType == intType) {
 322                             if (op instanceof SpirvOp.IMulOp) spirvBlock.add(new SPIRVOpIMul(intType, ans, lhs, rhs));
 323                             else if (op instanceof SpirvOp.IDivOp) spirvBlock.add(new SPIRVOpSDiv(intType, ans, lhs, rhs));
 324                         }
 325                         else if (lhsType == longType) {
 326                             SPIRVId rhsId = rhsType == intType ? nextId() : rhs;
 327                             if (rhsType == intType) spirvBlock.add(new SPIRVOpSConvert(longType, rhsId, rhs));
 328                             if (op instanceof SpirvOp.IMulOp) spirvBlock.add(new SPIRVOpIMul(longType, ans, lhs, rhsId));
 329                             else if (op instanceof SpirvOp.IDivOp) spirvBlock.add(new SPIRVOpSDiv(longType, ans, lhs, rhs));
 330                         }
 331                         else if (lhsType == floatType) {
 332                             if (op instanceof SpirvOp.FMulOp) spirvBlock.add(new SPIRVOpFMul(floatType, ans, lhs, rhs));
 333                             else if (op instanceof SpirvOp.FDivOp) spirvBlock.add(new SPIRVOpFDiv(floatType, ans, lhs, rhs));
 334                         }
 335                         else unsupported("type", lhsType);
 336                         addResult(op.result(), new SpirvResult(lhsType, null, ans));
 337                     }
 338                     case SpirvOp.ModOp mop -> {
 339                         SPIRVId type = getType(mop.operands().get(0).type().toString());
 340                         SPIRVId lhs = getResult(mop.operands().get(0)).value();
 341                         SPIRVId rhs = getResult(mop.operands().get(1)).value();
 342                         SPIRVId result = nextId();
 343                         spirvBlock.add(new SPIRVOpUMod(type, result, lhs, rhs));
 344                         addResult(mop.result(), new SpirvResult(type, null, result));
 345                     }
 346                     case SpirvOp.IEqualOp eqop -> {
 347                         SPIRVId boolType = getType("bool");
 348                         SPIRVId intType = getType("int");
 349                         SPIRVId longType = getType("long");
 350                         SPIRVId floatType = getType("float");
 351                         SPIRVId lhs = getResult(op.operands().get(0)).value();
 352                         SPIRVId rhs = getResult(op.operands().get(1)).value();
 353                         SPIRVId lhsType = spirvType(op.resultType().toString());
 354                         SPIRVId ans = nextId();
 355                         if (lhsType == intType) spirvBlock.add(new SPIRVOpIEqual(boolType, ans, lhs, rhs));
 356                         else if (lhsType == longType) spirvBlock.add(new SPIRVOpIEqual(boolType, ans, lhs, rhs));
 357                         else unsupported("type", lhsType.getName());
 358                         addResult(op.result(), new SpirvResult(lhsType, null, ans));
 359                     }
 360                     case SpirvOp.GeOp eqop -> {
 361                         SPIRVId boolType = getType("bool");
 362                         SPIRVId lhs = getResult(op.operands().get(0)).value();
 363                         SPIRVId rhs = getResult(op.operands().get(1)).value();
 364                         SPIRVId lhsType = spirvType(op.resultType().toString());
 365                         SPIRVId ans = nextId();
 366                         spirvBlock.add(new SPIRVOpSGreaterThanEqual(boolType, ans, lhs, rhs));
 367                         addResult(op.result(), new SpirvResult(lhsType, null, ans));
 368                     }
 369                     case SpirvOp.PtrNotEqualOp neqop -> {
 370                         SPIRVId boolType = getType("bool");
 371                         SPIRVId longType = getType("long");
 372                         SPIRVId lhs = getResult(neqop.operands().get(0)).value();
 373                         SPIRVId rhs = getResult(neqop.operands().get(1)).value();
 374                         SPIRVId ans = nextId();
 375                         SPIRVId lhsLong = nextId();
 376                         SPIRVId rhsLong = nextId();
 377                         spirvBlock.add(new SPIRVOpConvertPtrToU(longType, lhsLong, lhs));
 378                         spirvBlock.add(new SPIRVOpConvertPtrToU(longType, rhsLong, rhs));
 379                         spirvBlock.add(new SPIRVOpIEqual(boolType, ans, lhsLong, rhsLong));
 380                         addResult(op.result(), new SpirvResult(boolType, null, ans));
 381                     }
 382                     case SpirvOp.CallOp call -> {
 383                         MethodRef methodRef = call.callDescriptor();
 384                         if (methodRef.equals(MethodRef.method(S32Array.class, "array", int.class, long.class)))
 385                         {
 386                             SPIRVId longType = getType("long");
 387                             String arrayTypeName = call.operands().get(0).type().toString();
 388                             SpirvResult arrayResult = getResult(call.operands().get(0));
 389                             SPIRVId arrayAddr = arrayResult.address();
 390                             SPIRVId arrayType = spirvType(arrayTypeName);
 391                             SPIRVId elementType = spirvElementType(arrayTypeName);
 392                             int nIndexes = call.operands().size() - 1;
 393                             SPIRVId indexX = getResult(call.operands().get(1)).value();
 394                             SPIRVId array = nextId();
 395                             spirvBlock.add(new SPIRVOpLoad(arrayType, array, arrayAddr, align(arrayType.getName())));
 396                             SPIRVId temp1 = nextId();
 397                             SPIRVId temp2 = nextId();
 398                             spirvBlock.add(new SPIRVOpConvertPtrToU(longType, temp1, array));
 399                             spirvBlock.add(new SPIRVOpIAdd(longType, temp2, temp1, getConst("long", 4)));
 400                             SPIRVId elementBase = nextId();
 401                             spirvBlock.add(new SPIRVOpConvertUToPtr(arrayType, elementBase, temp2));
 402                             SPIRVId resultAddr = nextId();
 403                             spirvBlock.add(new SPIRVOpInBoundsPtrAccessChain(arrayType, resultAddr, elementBase, indexX, new SPIRVMultipleOperands<>()));
 404                             SPIRVId result = nextId();
 405                             spirvBlock.add(new SPIRVOpLoad(elementType, result, resultAddr, align(elementType.getName())));
 406                             addResult(call.result(), new SpirvResult(elementType, resultAddr, result));
 407                         }
 408                         else if (methodRef.equals(MethodRef.method(S32Array2D.class, "array", int.class, long.class) ||
 409                                  methodRef.equals(MethodRef.method(F32Array2D.class, "array", float.class, long.class))
 410                         {
 411                             SPIRVId longType = getType("long");
 412                             String arrayTypeName = call.operands().get(0).type().toString();
 413                             SpirvResult arrayResult = getResult(call.operands().get(0));
 414                             SPIRVId arrayAddr = arrayResult.address();
 415                             SPIRVId arrayType = spirvType(arrayTypeName);
 416                             SPIRVId elementType = spirvElementType(arrayTypeName);
 417                             int nIndexes = call.operands().size() - 1;
 418                             SPIRVId indexX = getResult(call.operands().get(1)).value();
 419                             SPIRVId array = nextId();
 420                             SPIRVId temp1 = nextId();
 421                             SPIRVId temp2 = nextId();
 422                             spirvBlock.add(new SPIRVOpConvertPtrToU(longType, temp1, array));
 423                             spirvBlock.add(new SPIRVOpIAdd(longType, temp2, temp1, getConst("long", 4)));
 424                             SPIRVId elementBase = nextId();
 425                             spirvBlock.add(new SPIRVOpConvertUToPtr(arrayType, elementBase, temp2));
 426                             spirvBlock.add(new SPIRVOpLoad(arrayType, array, arrayAddr, align(arrayType.getName())));
 427                             SPIRVId resultAddr = nextId();
 428                             spirvBlock.add(new SPIRVOpInBoundsPtrAccessChain(arrayType, resultAddr, elementBase, indexX, new SPIRVMultipleOperands<>()));
 429                             SPIRVId result = nextId();
 430                             spirvBlock.add(new SPIRVOpLoad(elementType, result, resultAddr, align(elementType.getName())));
 431                             addResult(call.result(), new SpirvResult(elementType, resultAddr, result));
 432                         }
 433                         else if (methodRef.equals(MethodRef.method(S32Array.class, "array", void.class, long.class, int.class))) {
 434                             SPIRVId longType = getType("long");
 435                             String arrayTypeName = call.operands().get(0).type().toString();
 436                             SpirvResult arrayResult = getResult(call.operands().get(0));
 437                             SPIRVId arrayAddr = arrayResult.address();
 438                             SPIRVId arrayType = spirvType(arrayTypeName);
 439                             SPIRVId elementType = spirvElementType(arrayTypeName);
 440                             int nIndexes = call.operands().size() - 2;
 441                             int valueIndex = nIndexes + 1;
 442                             SPIRVId indexX = getResult(call.operands().get(1)).value();
 443                             SPIRVId array = nextId();
 444                             spirvBlock.add(new SPIRVOpLoad(arrayType, array, arrayAddr, align(arrayType.getName())));
 445                             SPIRVId temp1 = nextId();
 446                             SPIRVId temp2 = nextId();
 447                             spirvBlock.add(new SPIRVOpConvertPtrToU(longType, temp1, array));
 448                             spirvBlock.add(new SPIRVOpIAdd(longType, temp2, temp1, getConst("long", 4)));
 449                             SPIRVId elementBase = nextId();
 450                             spirvBlock.add(new SPIRVOpConvertUToPtr(arrayType, elementBase, temp2));
 451                             SPIRVId dest = nextId();
 452                             spirvBlock.add(new SPIRVOpInBoundsPtrAccessChain(arrayType, dest, elementBase, indexX, new SPIRVMultipleOperands<>()));
 453                             SPIRVId value = getResult(call.operands().get(valueIndex)).value();
 454                             spirvBlock.add(new SPIRVOpStore(dest, value, align(elementType.getName())));
 455                         }
 456                         else if (methodRef.equals(MethodRef.method(S32Array2D.class, "array", void.class, long.class, int.class)) ||
 457                                  methodRef.equals(MethodRef.method(F32Array2D.class, "array", void.class, long.class, float.class))) {
 458                             SPIRVId longType = getType("long");
 459                             String arrayTypeName = call.operands().get(0).type().toString();
 460                             SpirvResult arrayResult = getResult(call.operands().get(0));
 461                             SPIRVId arrayAddr = arrayResult.address();
 462                             SPIRVId arrayType = spirvType(arrayTypeName);
 463                             SPIRVId elementType = spirvElementType(arrayTypeName);
 464                             int nIndexes = call.operands().size() - 2;
 465                             int valueIndex = nIndexes + 1;
 466                             SPIRVId indexX = getResult(call.operands().get(1)).value();
 467                             SPIRVId array = nextId();
 468                             spirvBlock.add(new SPIRVOpLoad(arrayType, array, arrayAddr, align(arrayType.getName())));
 469                             SPIRVId temp1 = nextId();
 470                             SPIRVId temp2 = nextId();
 471                             spirvBlock.add(new SPIRVOpConvertPtrToU(longType, temp1, array));
 472                             spirvBlock.add(new SPIRVOpIAdd(longType, temp2, temp1, getConst("long", 8)));
 473                             SPIRVId elementBase = nextId();
 474                             spirvBlock.add(new SPIRVOpConvertUToPtr(arrayType, elementBase, temp2));
 475                             SPIRVId dest = nextId();
 476                             spirvBlock.add(new SPIRVOpInBoundsPtrAccessChain(arrayType, dest, elementBase, indexX, new SPIRVMultipleOperands<>()));
 477                             SPIRVId value = getResult(call.operands().get(valueIndex)).value();
 478                             spirvBlock.add(new SPIRVOpStore(dest, value, align(elementType.getName())));
 479                         }
 480                         else if (methodRef.equals(MethodRef.method(S32Array.class, "length", int.class)) ||
 481                                  methodRef.equals(MethodRef.method(S32Array2D.class, "width", int.class)) ||
 482                                  methodRef.equals(MethodRef.method(F32Array2D.class, "width", int.class))) {
 483                             SPIRVId intType = getType("int");
 484                             String arrayTypeName = call.operands().get(0).type().toString();
 485                             SpirvResult arrayResult = getResult(call.operands().get(0));
 486                             SPIRVId arrayAddr = arrayResult.address();
 487                             SPIRVId arrayType = spirvType(arrayTypeName);
 488                             SPIRVId array = nextId();
 489                             spirvBlock.add(new SPIRVOpLoad(arrayType, array, arrayAddr, align(arrayType.getName())));
 490                             SPIRVId resultAddr = nextId();
 491                             spirvBlock.add(new SPIRVOpInBoundsPtrAccessChain(arrayType, resultAddr, array, getConst("int", 0), new SPIRVMultipleOperands<>()));
 492                             SPIRVId result = nextId();
 493                             spirvBlock.add(new SPIRVOpLoad(intType, result, resultAddr, align(arrayType.getName())));
 494                             addResult(call.result(), new SpirvResult(intType, resultAddr, result));
 495                         }
 496                         else if (methodRef.equals(MethodRef.method(S32Array2D.class, "height", int.class)) ||
 497                                  methodRef.equals(MethodRef.method(F32Array2D.class, "height", int.class))) {
 498                             SPIRVId intType = getType("int");
 499                             String arrayTypeName = call.operands().get(0).type().toString();
 500                             SpirvResult arrayResult = getResult(call.operands().get(0));
 501                             SPIRVId arrayAddr = arrayResult.address();
 502                             SPIRVId arrayType = spirvType(arrayTypeName);
 503                             SPIRVId array = nextId();
 504                             spirvBlock.add(new SPIRVOpLoad(arrayType, array, arrayAddr, align(arrayType.getName())));
 505                             SPIRVId resultAddr = nextId();
 506                             spirvBlock.add(new SPIRVOpInBoundsPtrAccessChain(arrayType, resultAddr, array, getConst("int", 1), new SPIRVMultipleOperands<>()));
 507                             SPIRVId result = nextId();
 508                             spirvBlock.add(new SPIRVOpLoad(intType, result, resultAddr, align(arrayType.getName())));
 509                             addResult(call.result(), new SpirvResult(intType, resultAddr, result));
 510                         }
 511                         else if (methodRef.equals(MethodRef.method(Math.class, "sqrt", double.class, double.class))) {
 512                             SPIRVId floatType = getType("float");
 513                             SPIRVId result = nextId();
 514                             SPIRVId operand = getResult(call.operands().get(0)).value();
 515                             spirvBlock.add(new SPIRVOpExtInst(floatType, result, getId("oclExtension"), new SPIRVLiteralExtInstInteger(61, "sqrt"), new SPIRVMultipleOperands<>(operand)));
 516                             addResult(call.result(), new SpirvResult(floatType, null, result));
 517                         }
 518                         else {
 519                             SPIRVId fnId = getFunctionId(methodRef);
 520                             if (fnId == null) {
 521                                 unsupported("method", methodRef);
 522                             }
 523                             else {
 524                                 FunctionType fnType = methodRef.type();
 525                                 SPIRVId[] args = new SPIRVId[call.operands().size()];
 526                                 for (int i = 0; i < args.length; i++) {
 527                                     SPIRVId argId = getResult(call.operands().get(i)).value();
 528                                     args[i] = argId;
 529                                 }
 530                                 SPIRVId returnType = spirvType(fnType.returnType().toString());
 531                                 SPIRVId callResult = nextId();
 532                                 spirvBlock.add(new SPIRVOpFunctionCall(returnType, callResult, fnId, new SPIRVMultipleOperands<>(args)));
 533                                 addResult(call.result(), new SpirvResult(returnType, null, callResult));
 534                             }
 535                         }
 536                     }
 537                     case SpirvOp.ConstantOp cop -> {
 538                         SPIRVId type = spirvType(cop.resultType().toString());
 539                         SPIRVId result = nextId();
 540                         Object value = cop.value();
 541                         if (type == getType("int")) {
 542                             module.add(new SPIRVOpConstant(type, result, new SPIRVContextDependentInt(new BigInteger(String.valueOf(value)))));
 543                         }
 544                         else if (type == getType("long")) {
 545                             module.add(new SPIRVOpConstant(type, result, new SPIRVContextDependentLong(new BigInteger(String.valueOf(value)))));
 546                         }
 547                         else if (type == getType("float")) {
 548                             module.add(new SPIRVOpConstant(type, result, new SPIRVContextDependentFloat((float)value)));
 549                         }
 550                         else if (type == getType("bool")) {
 551                             module.add(((boolean)value) ? new SPIRVOpConstantTrue(type, result) : new SPIRVOpConstantFalse(type, result));
 552                         }
 553                         else if (type == getType("java.lang.Object")) {
 554                             module.add(new SPIRVOpConstantNull(type, result));
 555                         }
 556                         else if (type == getType("int[]")) {
 557                             module.add(new SPIRVOpConstantNull(type, result));
 558                         }
 559                         else unsupported("type", cop.resultType());
 560                         addResult(cop.result(), new SpirvResult(type, null, result));
 561                     }
 562                     case SpirvOp.ConvertOp scop -> {
 563                         SPIRVId toType = spirvType(scop.resultType().toString());
 564                         SPIRVId to = nextId();
 565                         SpirvResult valueResult = getResult(scop.operands().get(0));
 566                         SPIRVId from = valueResult.value();
 567                         SPIRVId fromType = valueResult.type();
 568                         if (isIntegerType(fromType)) {
 569                             if (isIntegerType(toType)) {
 570                                 spirvBlock.add(new SPIRVOpSConvert(toType, to, from));
 571                             }
 572                             else if (isFloatType(toType)) {
 573                                 spirvBlock.add(new SPIRVOpConvertSToF(toType, to, from));
 574                             }
 575                             else unsupported("conversion type", scop.resultType());
 576                         }
 577                         else if (isFloatType(fromType)) {
 578                             if (isIntegerType(toType)) {
 579                                 spirvBlock.add(new SPIRVOpConvertFToS(toType, to, from));
 580                             }
 581                             else if (isFloatType(toType)) {
 582                                 spirvBlock.add(new SPIRVOpFConvert(toType, to, from));
 583                             }
 584                             else unsupported("conversion type", scop.resultType());
 585                         }
 586                         else unsupported("conversion type", scop.operands().get(0));
 587                         addResult(scop.result(), new SpirvResult(toType, null, to));
 588                     }
 589                     case SpirvOp.InBoundsAccessChainOp iacop -> {
 590                         SPIRVId type = spirvType(iacop.resultType().toString());
 591                         SPIRVId result = nextId();
 592                         SPIRVId object = getResult(iacop.operands().get(0)).value();
 593                         SPIRVId index = getResult(iacop.operands().get(1)).value();
 594                         spirvBlock.add(new SPIRVOpInBoundsPtrAccessChain(type, result, object, index, new SPIRVMultipleOperands<>()));
 595                         addResult(iacop.result(), new SpirvResult(type, result, null));
 596                     }
 597                     case SpirvOp.FieldLoadOp flo -> {
 598                         if (flo.operands().size() > 0 && (flo.operands().get(0).type().equals(KernelContext.class))) {
 599                             SpirvResult result;
 600                             int group = -1;
 601                             int index = -1;
 602                             String fieldName = flo.fieldDescriptor().name();
 603                             switch (fieldName) {
 604                                 case "x": group = 0; index = 0; break;
 605                                 case "y": group = 0; index = 1; break;
 606                                 case "z": group = 0; index = 2; break;
 607                                 case "maxX": group = 1; index = 0; break;
 608                                 case "maxY": group = 1; index = 1; break;
 609                                 case "maxZ": group = 1; index = 2; break;
 610                             }
 611                             switch (group) {
 612                                 case 0: result = globalId(index, spirvBlock); break;
 613                                 case 1: result = globalSize(index, spirvBlock); break;
 614                                 default: throw new RuntimeException("Unknown Index field: " + fieldName);
 615                             }
 616                             addResult(flo.result(), result);
 617                         }
 618                         else if (flo.operands().get(0).type().equals(JavaType.type(KernelContext.class))) {
 619                             String fieldName = flo.fieldDescriptor().name();
 620                             SPIRVId fieldIndex = switch (fieldName) {
 621                                 case "x" -> getConst("long", 0);
 622                                 case "maxX" -> getConst("long", 1);
 623                                 default -> throw new RuntimeException("Unknown field: " + fieldName);
 624                             };
 625                             SPIRVId intType = getType("int");
 626                             String contextTypeName = flo.operands().get(0).type().toString();
 627                             SpirvResult kernalContext = getResult(flo.operands().get(0));
 628                             SPIRVId contextAddr = kernalContext.address();
 629                             SPIRVId contextType = spirvType(contextTypeName);
 630                             SPIRVId context = nextId();
 631                             spirvBlock.add(new SPIRVOpLoad(contextType, context, contextAddr, align(contextType.getName())));
 632                             SPIRVId fieldType = intType;
 633                             SPIRVId resultAddr = nextId();
 634                             spirvBlock.add(new SPIRVOpInBoundsAccessChain(getType("ptrInt"), resultAddr, context, new SPIRVMultipleOperands<>(fieldIndex)));
 635                             SPIRVId result = nextId();
 636                             spirvBlock.add(new SPIRVOpLoad(intType, result, resultAddr, align("int")));
 637                             addResult(flo.result(), new SpirvResult(intType, resultAddr, result));
 638                         }
 639                         else if (flo.fieldDescriptor().refType().equals(JavaType.type(ByteOrder.class))) {
 640                             // currently ignored
 641                         }
 642                         else unsupported("field load", ((ClassType)flo.fieldDescriptor().refType()).toClassName() + "." + flo.fieldDescriptor().name());
 643                     }
 644                     case SpirvOp.BranchOp bop -> {
 645                         SPIRVId label = symbols.getLabel(bop.branch().targetBlock()).getResultId();
 646                         Block.Reference target = bop.branch();
 647                         spirvBlock.add(new SPIRVOpBranch(label));
 648                     }
 649                     case SpirvOp.ConditionalBranchOp cbop -> {
 650                         SPIRVId test = getResult(cbop.operands().get(0)).value();
 651                         SPIRVId trueLabel = symbols.getLabel(cbop.trueBranch().targetBlock()).getResultId();
 652                         SPIRVId falseLabel = symbols.getLabel(cbop.falseBranch().targetBlock()).getResultId();
 653                         spirvBlock.add(new SPIRVOpBranchConditional(test, trueLabel, falseLabel, new SPIRVMultipleOperands<SPIRVLiteralInteger>()));
 654                     }
 655                     case SpirvOp.LtOp ltop -> {
 656                         SpirvResult lhs = getResult(ltop.operands().get(0));
 657                         SpirvResult rhs = getResult(ltop.operands().get(1));
 658                         SPIRVId boolType = getType("bool");
 659                         SPIRVId result = nextId();
 660                         String operandType = lhs.type().getName();
 661                         SPIRVInstruction sop = switch (operandType) {
 662                             case "float" -> new SPIRVOpFUnordLessThan(boolType, result, lhs.value(), rhs.value());
 663                             case "int" -> new SPIRVOpSLessThan(boolType, result, lhs.value(), rhs.value());
 664                             case "long" -> new SPIRVOpSLessThan(boolType, result, lhs.value(), rhs.value());
 665                             default -> throw new RuntimeException("Unsupported type: " + lhs.type().getName());
 666                         };
 667                         spirvBlock.add(sop);
 668                         addResult(ltop.result(), new SpirvResult(boolType, null, result));
 669                     }
 670                     case SpirvOp.ReturnOp rop -> {
 671                         spirvBlock.add(new SPIRVOpReturn());
 672                     }
 673                     case SpirvOp.ReturnValueOp rop -> {
 674                         SPIRVId returnValue = getResult(rop.operands().get(0)).value();
 675                         spirvBlock.add(new SPIRVOpReturnValue(returnValue));
 676                     }
 677                     default -> unsupported("op", op.getClass());
 678                 }
 679             }
 680         } // end bi
 681     }
 682 
 683     private SPIRVId getFunctionId(MethodRef methodRef) {
 684         SPIRVId fnId = symbols.getId(methodRef.toString());
 685         if (fnId == null) {
 686             try {
 687                 Optional<CoreOp.FuncOp> optJFuncOp = methodRef.codeModel(MethodHandles.lookup());
 688                 if (optJFuncOp.isPresent()) {
 689                     CoreOp.FuncOp jFuncOp = optJFuncOp.get();
 690                     SpirvOp.FuncOp sFuncOp = TranslateToSpirvModel.translateFunction(jFuncOp);
 691                     fnId = generateFunction(jFuncOp.funcName(), sFuncOp, false);
 692                     symbols.putId(methodRef.toString(), fnId);
 693                 }
 694             }
 695             catch (ReflectiveOperationException e) {
 696                 throw new RuntimeException(e);
 697             }
 698         }
 699         return fnId;
 700     }
 701 
 702     private void initModule() {
 703         module.add(new SPIRVOpCapability(SPIRVCapability.Addresses()));
 704         module.add(new SPIRVOpCapability(SPIRVCapability.Linkage()));
 705         module.add(new SPIRVOpCapability(SPIRVCapability.Kernel()));
 706         module.add(new SPIRVOpCapability(SPIRVCapability.Int8()));
 707         module.add(new SPIRVOpCapability(SPIRVCapability.Int64()));
 708         module.add(new SPIRVOpMemoryModel(SPIRVAddressingModel.Physical64(), SPIRVMemoryModel.OpenCL()));
 709 
 710         // OpenCL extension provides built-in variables suitable for kernel programming
 711         // Import extension and declare four variables
 712         SPIRVId oclExtension = nextId("oclExtension");
 713         module.add(new SPIRVOpExtInstImport(oclExtension, new SPIRVLiteralString("OpenCL.std")));
 714         symbols.putId("oclExtension", oclExtension);
 715 
 716         SPIRVId globalInvocationId = nextId("globalInvocationId");
 717         SPIRVId globalSize = nextId("globalSize");
 718         SPIRVId subgroupSize = nextId("subgroupSize");
 719         SPIRVId subgroupId = nextId("subgroupId");
 720 
 721         module.add(new SPIRVOpDecorate(globalInvocationId, SPIRVDecoration.BuiltIn(SPIRVBuiltIn.GlobalInvocationId())));
 722         module.add(new SPIRVOpDecorate(globalInvocationId, SPIRVDecoration.Constant()));
 723         module.add(new SPIRVOpDecorate(globalInvocationId, SPIRVDecoration.LinkageAttributes(new SPIRVLiteralString("spirv_BuiltInGlobalInvocationId"), SPIRVLinkageType.Import())));
 724         module.add(new SPIRVOpDecorate(globalSize, SPIRVDecoration.BuiltIn(SPIRVBuiltIn.GlobalSize())));
 725         module.add(new SPIRVOpDecorate(globalSize, SPIRVDecoration.Constant()));
 726         module.add(new SPIRVOpDecorate(globalSize, SPIRVDecoration.LinkageAttributes(new SPIRVLiteralString("spirv_BuiltInGlobalSize"), SPIRVLinkageType.Import())));
 727         module.add(new SPIRVOpDecorate(subgroupSize, SPIRVDecoration.BuiltIn(SPIRVBuiltIn.SubgroupSize())));
 728         module.add(new SPIRVOpDecorate(subgroupSize, SPIRVDecoration.Constant()));
 729         module.add(new SPIRVOpDecorate(subgroupSize, SPIRVDecoration.LinkageAttributes(new SPIRVLiteralString("spirv_BuiltInSubgroupSize"), SPIRVLinkageType.Import())));
 730         module.add(new SPIRVOpDecorate(subgroupId, SPIRVDecoration.BuiltIn(SPIRVBuiltIn.SubgroupId())));
 731         module.add(new SPIRVOpDecorate(subgroupId, SPIRVDecoration.Constant()));
 732         module.add(new SPIRVOpDecorate(subgroupId, SPIRVDecoration.LinkageAttributes(new SPIRVLiteralString("spirv_BuiltInSubgroupId"), SPIRVLinkageType.Import())));
 733 
 734         module.add(new SPIRVOpVariable(getType("ptrV3long"), globalInvocationId, SPIRVStorageClass.Input(), new SPIRVOptionalOperand<>()));
 735         module.add(new SPIRVOpVariable(getType("ptrV3long"), globalSize, SPIRVStorageClass.Input(), new SPIRVOptionalOperand<>()));
 736         module.add(new SPIRVOpVariable(getType("ptrV3long"), subgroupSize, SPIRVStorageClass.Input(), new SPIRVOptionalOperand<>()));
 737         module.add(new SPIRVOpVariable(getType("ptrV3long"), subgroupId, SPIRVStorageClass.Input(), new SPIRVOptionalOperand<>()));
 738     }
 739 
 740     private SPIRVId spirvType(String inputType) {
 741         String javaType = inputType.replaceAll("\\$", ".");
 742         SPIRVId ans = switch(javaType) {
 743             case "byte" -> getType("byte");
 744             case "short" -> getType("short");
 745             case "int" -> getType("int");
 746             case "long" -> getType("long");
 747             case "float" -> getType("float");
 748             case "double" -> getType("double");
 749             case "byte[]" -> getType("byte[]");
 750             case "int[]" -> getType("int[]");
 751             case "float[]" -> getType("float[]");
 752             case "double[]" -> getType("double[]");
 753             case "long[]" -> getType("long[]");
 754             case "bool" -> getType("bool");
 755             case "boolean" -> getType("bool");
 756             case "java.lang.Object" -> getType("java.lang.Object");
 757             case "hat.buffer.S32Array" -> getType("int[]");
 758             case "hat.buffer.S32Array2D" -> getType("int[]");
 759             case "hat.buffer.F32Array2D" -> getType("float[]");
 760             case "void" -> getType("void");
 761             case "hat.KernelContext" -> getType("ptrKernelContext");
 762             case "java.lang.foreign.MemorySegment" -> getType("ptrByte");
 763             default -> null;
 764         };
 765         if (ans == null) unsupported("type", javaType);
 766         return ans;
 767     }
 768 
 769     private SPIRVId spirvElementType(String inputType) {
 770         String javaType = inputType.replaceAll("\\$", ".");
 771         SPIRVId ans = switch(javaType) {
 772             case "byte[]" -> getType("byte");
 773             case "short[]" -> getType("short");
 774             case "int[]" -> getType("int");
 775             case "long[]" -> getType("long");
 776             case "float[]" -> getType("float");
 777             case "double[]" -> getType("double");
 778             case "boolean[]" -> getType("bool");
 779             case "hat.buffer.S32Array" -> getType("int");
 780             case "hat.buffer.S32Array2D" -> getType("int");
 781             case "hat.buffer.F32Array2D" -> getType("float");
 782             case "java.lang.foreign.MemorySegment" -> getType("byte");
 783             default -> null;
 784         };
 785         if (ans == null) unsupported("type", javaType);
 786         return ans;
 787     }
 788 
 789     private SPIRVId vectorElementType(SPIRVId type) {
 790         SPIRVId ans = switch(type.getName()) {
 791             case "v8int" -> getType("int");
 792             case "v16int" -> getType("int");
 793             case "v8long" -> getType("long");
 794             case "v8float" -> getType("float");
 795             case "v16float" -> getType("float");
 796             default -> null;
 797         };
 798         if (ans == null) unsupported("type", type.getName());
 799         return ans;
 800     }
 801 
 802     private SPIRVId spirvVariableType(SPIRVId spirvType) {
 803         SPIRVId ans = switch(spirvType.getName()) {
 804             case "bool" -> getType("ptrBool");
 805             case "byte" -> getType("ptrByte");
 806             case "short" -> getType("ptrShort");
 807             case "int" -> getType("ptrInt");
 808             case "long" -> getType("ptrLong");
 809             case "float" -> getType("ptrFloat");
 810             case "double" -> getType("ptrDouble");
 811             case "boolean" -> getType("ptrBool");
 812             case "byte[]" -> getType("ptrByte[]");
 813             case "int[]" -> getType("ptrInt[]");
 814             case "long[]" -> getType("ptrLong[]");
 815             case "float[]" -> getType("ptrFloat[]");
 816             case "double[]" -> getType("ptrDouble[]");
 817             case "v8int" -> getType("ptrV8int");
 818             case "v16int" -> getType("ptrV16int");
 819             case "v8long" -> getType("ptrV8long");
 820             case "v8float" -> getType("ptrV8float");
 821             case "v16float" -> getType("ptrV16float");
 822             case "ptrKernelContext" -> getType("ptrPtrKernelContext");
 823             case "hat.KernelContext" -> getType("ptrKernelContext");
 824             case "ptrByte" -> getType("ptrPtrByte");
 825             default -> null;
 826         };
 827         if (ans == null) unsupported("type", spirvType.getName());
 828         return ans;
 829     }
 830 
 831     private SPIRVId spirvVectorType(String javaVectorType, int vectorLength) {
 832         String prefix = "v" + vectorLength;
 833         String elementType = spirvElementType(javaVectorType).getName();
 834         return getType(prefix + elementType);
 835     }
 836 
 837     private int alignment(String inputType) {
 838         String spirvType = inputType.replaceAll("\\$", ".");
 839         int ans = switch(spirvType) {
 840             case "bool" -> 1;
 841             case "byte" -> 1;
 842             case "short" -> 2;
 843             case "int" -> 4;
 844             case "long" -> 8;
 845             case "float" -> 4;
 846             case "double" -> 8;
 847             case "boolean" -> 1;
 848             case "v8int" -> 32;
 849             case "v16int" -> 64;
 850             case "v8long" -> 64;
 851             case "v8float" -> 32;
 852             case "v16float" -> 64;
 853             case "hat.KernelContext" -> 32;
 854             case "ptrKernelContext" -> 32;
 855             case "byte[]" -> 8;
 856             case "int[]" -> 8;
 857             case "long[]" -> 8;
 858             case "float[]" -> 8;
 859             case "double[]" -> 8;
 860             case "ptrBool" -> 8;
 861             case "ptrByte" -> 8;
 862             case "ptrInt" -> 8;
 863             case "ptrByte[]" -> 8;
 864             case "ptrInt[]" -> 8;
 865             case "ptrLong" -> 8;
 866             case "ptrLong[]" -> 8;
 867             case "ptrFloat" -> 8;
 868             case "ptrFloat[]" -> 8;
 869             case "ptrV8int" -> 8;
 870             case "ptrV8float" -> 8;
 871             case "ptrPtrKernelContext" -> 8;
 872             default -> 0;
 873         };
 874         if (ans == 0) unsupported("type", spirvType);
 875         return ans;
 876     }
 877 
 878     private Set<String> moduleTypes = new HashSet<>();
 879 
 880     private SPIRVId getType(String inputName) {
 881         String name = inputName.replaceAll("\\$", ".");
 882         if (!moduleTypes.contains(name)) {
 883             switch (name) {
 884                 case "void" -> module.add(new SPIRVOpTypeVoid(nextId(name)));
 885                 case "bool" -> module.add(new SPIRVOpTypeBool(nextId(name)));
 886                 case "boolean" -> module.add(new SPIRVOpTypeBool(nextId(name)));
 887                 case "byte" -> module.add(new SPIRVOpTypeInt(nextId(name), new SPIRVLiteralInteger(8), new SPIRVLiteralInteger(0)));
 888                 case "short" -> module.add(new SPIRVOpTypeInt(nextId(name), new SPIRVLiteralInteger(16), new SPIRVLiteralInteger(0)));
 889                 case "int" -> module.add(new SPIRVOpTypeInt(nextId(name), new SPIRVLiteralInteger(32), new SPIRVLiteralInteger(0)));
 890                 case "long" -> module.add(new SPIRVOpTypeInt(nextId(name), new SPIRVLiteralInteger(64), new SPIRVLiteralInteger(0)));
 891                 case "float" -> module.add(new SPIRVOpTypeFloat(nextId(name), new SPIRVLiteralInteger(32)));
 892                 case "double" -> module.add(new SPIRVOpTypeFloat(nextId(name), new SPIRVLiteralInteger(64)));
 893                 case "ptrBool" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.CrossWorkgroup(), getType("bool")));
 894                 case "ptrByte" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.CrossWorkgroup(), getType("byte")));
 895                 case "ptrInt" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("int")));
 896                 case "ptrLong" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("long")));
 897                 case "ptrFloat" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("float")));
 898                 case "byte[]" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.CrossWorkgroup(), getType("byte")));
 899                 case "short[]" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.CrossWorkgroup(), getType("short")));
 900                 case "int[]" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.CrossWorkgroup(), getType("int")));
 901                 case "long[]" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.CrossWorkgroup(), getType("long")));
 902                 case "float[]" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.CrossWorkgroup(), getType("float")));
 903                 case "double[]" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.CrossWorkgroup(), getType("double")));
 904                 case "boolean[]" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.CrossWorkgroup(), getType("boolean")));
 905                 case "ptrByte[]" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("byte[]")));
 906                 case "ptrInt[]" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("int[]")));
 907                 case "ptrLong[]" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("long[]")));
 908                 case "ptrFloat[]" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("float[]")));
 909                 case "java.lang.Object" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("void")));
 910                 case "hat.KernelContext" -> module.add(new SPIRVOpTypeStruct(nextId(name), new SPIRVMultipleOperands<>(getType("int"), getType("int"))));
 911                 case "ptrKernelContext" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("hat.KernelContext")));
 912                 case "ptrCrossGroupByte"-> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.CrossWorkgroup(), getType("byte")));
 913                 case "ptrPtrKernelContext" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("ptrKernelContext")));
 914                 case "ptrPtrByte" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("ptrByte")));
 915                 case "ptrPtrInt" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("ptrInt")));
 916                 case "ptrPtrFloat" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("ptrFloat")));
 917                 case "v3long" -> module.add(new SPIRVOpTypeVector(nextId(name), getType("long"), new SPIRVLiteralInteger(3)));
 918                 case "v8int" -> module.add(new SPIRVOpTypeVector(nextId(name), getType("int"), new SPIRVLiteralInteger(8)));
 919                 case "v8long" -> module.add(new SPIRVOpTypeVector(nextId(name), getType("long"), new SPIRVLiteralInteger(8)));
 920                 case "v16int" -> module.add(new SPIRVOpTypeVector(nextId(name), getType("int"), new SPIRVLiteralInteger(16)));
 921                 case "v8float" -> module.add(new SPIRVOpTypeVector(nextId(name), getType("float"), new SPIRVLiteralInteger(8)));
 922                 case "v16float" -> module.add(new SPIRVOpTypeVector(nextId(name), getType("float"), new SPIRVLiteralInteger(16)));
 923                 case "ptrV3long" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Input(), getType("v3long")));
 924                 case "ptrV8long" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("v8long")));
 925                 case "ptrV8int" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("v8int")));
 926                 case "ptrV16int" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("v16int")));
 927                 case "ptrV8float" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("v8float")));
 928                 case "ptrV16float" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("v16float")));
 929                 case "ptrPtrV8int" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("ptrV8int")));
 930                 case "ptrPtrV16int" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("ptrV16int")));
 931                 case "ptrPtrV8float" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("ptrV8float")));
 932                 case "ptrPtrV16float" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("ptrV16float")));
 933                 default -> unsupported("type", name);
 934             }
 935             moduleTypes.add(name);
 936         }
 937         return getId(name);
 938     }
 939 
 940     private Set<String> moduleConstants = new HashSet<>();
 941 
 942     private SPIRVId getConst(String typeName, long value) {
 943         String name = typeName + "_" + value;
 944         if (!moduleConstants.contains(name)) {
 945             String valueStr = String.valueOf(value);
 946             switch (typeName) {
 947                 case "int" -> module.add(new SPIRVOpConstant(getType(typeName), nextId(name), new SPIRVContextDependentInt(new BigInteger(valueStr))));
 948                 case "long" -> module.add(new SPIRVOpConstant(getType(typeName), nextId(name), new SPIRVContextDependentLong(new BigInteger(valueStr))));
 949                 case "boolean" -> module.add(value == 0 ? new SPIRVOpConstantFalse(getType(typeName), nextId(name)) : new SPIRVOpConstantTrue(getType(typeName), nextId(name)));
 950                 default -> unsupported("constant", name);
 951             };
 952             moduleConstants.add(name);
 953         }
 954         return getId(name);
 955     }
 956 
 957     private SPIRVOptionalOperand<SPIRVMemoryAccess> align(int align) {
 958         return new SPIRVOptionalOperand<>(SPIRVMemoryAccess.Aligned(new SPIRVLiteralInteger(align)));
 959     }
 960 
 961     private SPIRVOptionalOperand<SPIRVMemoryAccess> align(String type) {
 962         return align(alignment(type));
 963     }
 964 
 965     private SPIRVMultipleOperands<SPIRVId> spirvOperands(SPIRVId value, int count) {
 966         SPIRVId[] values = new SPIRVId[count];
 967         Arrays.fill(values, value);
 968         return new SPIRVMultipleOperands<>(values);
 969     }
 970 
 971     private SPIRVOptionalOperand<SPIRVMemoryAccess> none() {
 972         return new SPIRVOptionalOperand<>();
 973     }
 974 
 975     private SpirvResult globalSize(int index, SPIRVBlock spirvBlock) {
 976         SPIRVId intType = getType("int");
 977         SPIRVId longType = getType("long");
 978         SPIRVId v3long = getId("v3long");
 979         SPIRVId globalSizeId = getId("globalSize");
 980         SPIRVId globalSizes = nextId();
 981         spirvBlock.add(new SPIRVOpLoad(v3long, globalSizes, globalSizeId, align(32)));
 982         SPIRVId longSize = nextId();
 983         SPIRVId globalSize = nextId();
 984         spirvBlock.add(new SPIRVOpCompositeExtract(longType, longSize, globalSizes, new SPIRVMultipleOperands<>(new SPIRVLiteralInteger(index))));
 985         spirvBlock.add(new SPIRVOpSConvert(intType, globalSize, longSize));
 986         return new SpirvResult(intType, null, globalSize);
 987     }
 988 
 989     private SpirvResult globalId(int index, SPIRVBlock spirvBlock) {
 990         SPIRVId intType = getType("int");
 991         SPIRVId longType = getType("long");
 992         SPIRVId v3long = getId("v3long");
 993         SPIRVId globalInvocationId = getId("globalInvocationId");
 994         SPIRVId globalIds = nextId();
 995         spirvBlock.add(new SPIRVOpLoad(v3long, globalIds, globalInvocationId, align(32)));
 996         SPIRVId longIndex = nextId();
 997         SPIRVId globalIndex = nextId();
 998         spirvBlock.add(new SPIRVOpCompositeExtract(longType, longIndex, globalIds, new SPIRVMultipleOperands<>(new SPIRVLiteralInteger(index))));
 999         spirvBlock.add(new SPIRVOpSConvert(intType, globalIndex, longIndex));
1000         return new SpirvResult(intType, null, globalIndex);
1001     }
1002 
1003    private SpirvResult flatIndex(SPIRVId sizeX, SPIRVId sizeY, SPIRVId sizeZ, SPIRVId indexX, SPIRVId indexY, SPIRVId indexZ, SPIRVBlock spirvBlock)
1004     {
1005         SPIRVId longType = getType("long");
1006         SPIRVId xTerm0 = nextId();
1007         SPIRVId xTerm1 = nextId();
1008         SPIRVId yTerm = nextId();
1009         SPIRVId flat0 = nextId();
1010         SPIRVId flat1 = nextId();
1011         spirvBlock.add(new SPIRVOpIMul(longType, xTerm0, sizeY, sizeZ));
1012         spirvBlock.add(new SPIRVOpIMul(longType, xTerm1, xTerm0, indexX));
1013         spirvBlock.add(new SPIRVOpIMul(longType, yTerm, sizeZ, indexY));
1014         spirvBlock.add(new SPIRVOpIAdd(longType, flat0, xTerm1, yTerm));
1015         spirvBlock.add(new SPIRVOpIAdd(longType, flat1, flat0, indexZ));
1016         return new SpirvResult(longType, null, flat1);
1017     }
1018 
1019     private SPIRVId nextId() {
1020         return module.getNextId();
1021     }
1022 
1023     private SPIRVId nextId(String name) {
1024         SPIRVId ans = nextId();
1025         ans.setName(name);
1026         symbols.putId(name, ans);
1027         module.add(new SPIRVOpName(ans, new SPIRVLiteralString(name)));
1028         return ans;
1029     }
1030 
1031     private static int counter = 0;
1032 
1033     private String nextTempTag() {
1034         counter++;
1035         return "temp_" + counter + "_";
1036     }
1037 
1038     private boolean isIntegerType(SPIRVId type) {
1039         String name = type.getName();
1040         return name.equals("byte") || name.equals("short") || name.equals("int") || name.equals("long");
1041     }
1042 
1043     private boolean isFloatType(SPIRVId type) {
1044         String name = type.getName();
1045         return name.equals("float") || name.equals("double");
1046     }
1047 
1048     private boolean isVectorSpecies(String javaType) {
1049         return javaType.equals("VectorSpecies");
1050     }
1051 
1052     private boolean isVectorType(String javaType) {
1053         return javaType.equals("IntVector") || javaType.equals("FloatVector");
1054     }
1055 
1056     private void addId(String name, SPIRVId id) {
1057         symbols.putId(name, id);
1058     }
1059 
1060     private SPIRVId getId(String name) {
1061         SPIRVId ans = symbols.getId(name);
1062         assert ans != null : name + " not found";
1063         return ans;
1064     }
1065 
1066     private SPIRVId getIdOrNull(String name) {
1067         return symbols.getId(name);
1068     }
1069 
1070     private static Object map(Function<Object, Boolean> test, Object... args) {
1071         int len = args.length;
1072         assert len >= 2 && len % 2 == 0;
1073         int pairs = len / 2;
1074         for (int i = 0; i < pairs; i++) {
1075             if (test.apply(args[i])) return args[i + pairs];
1076         }
1077         throw new RuntimeException("No match: " + args[0]);
1078     }
1079 
1080     private void unsupported(String message, Object value) {
1081         throw new RuntimeException("Unsupported " + message + ": " + value);
1082     }
1083 
1084     private void addResult(Value value, SpirvResult result) {
1085         assert symbols.getResult(value) == null : "result already present";
1086         symbols.putResult(value, result);
1087     }
1088 
1089     private SpirvResult getResult(Value value) {
1090         return symbols.getResult(value);
1091     }
1092 
1093     private static class Symbols {
1094         private final HashMap<Value, SpirvResult> results;
1095         private final HashMap<String, SPIRVId> ids;
1096         private final HashMap<Block, SPIRVBlock> blocks;
1097         private final HashMap<Block, SPIRVOpLabel> labels;
1098 
1099         public Symbols() {
1100             this.results = new HashMap<>();
1101             this.ids = new HashMap<>();
1102             this.blocks = new HashMap<>();
1103             this.labels = new HashMap<>();
1104         }
1105 
1106         public void putId(String name, SPIRVId id) {
1107             ids.put(name, id);
1108         }
1109 
1110         public SPIRVId getId(String name) {
1111             return ids.get(name);
1112         }
1113 
1114         public void putBlock(Block block, SPIRVBlock spirvBlock) {
1115             blocks.put(block, spirvBlock);
1116         }
1117 
1118         public SPIRVBlock getBlock(Block block) {
1119             return blocks.get(block);
1120         }
1121 
1122         public void putLabel(Block block, SPIRVOpLabel spirvLabel) {
1123             labels.put(block, spirvLabel);
1124         }
1125 
1126         public SPIRVOpLabel getLabel(Block block) {
1127             return labels.get(block);
1128         }
1129 
1130         public void putResult(Value value, SpirvResult result) {
1131             results.put(value, result);
1132         }
1133 
1134         public SpirvResult getResult(Value value) {
1135             return results.get(value);
1136         }
1137 
1138         public String toString() {
1139             return String.format("results %s\n\nids %s\n\nblocks %s\nlabels %s\n", results.keySet(), ids.keySet(), blocks.keySet(), labels.keySet());
1140         }
1141     }
1142 
1143     public static void debug(String message, Object... args) {
1144         System.out.println(String.format(message, args));
1145     }
1146 }