1 /*
2 * Copyright (c) 2024 Intel Corporation. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation. Oracle designates this
8 * particular file as subject to the "Classpath" exception as provided
9 * by Oracle in the LICENSE file that accompanied this code.
10 *
11 * This code is distributed in the hope that it will be useful, but WITHOUT
12 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
13 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
14 * version 2 for more details (a copy is included in the LICENSE file that
15 * accompanied this code).
16 *
17 * You should have received a copy of the GNU General Public License version
18 * 2 along with this work; if not, write to the Free Software Foundation,
19 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
20 *
21 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
22 * or visit www.oracle.com if you need additional information or have any
23 * questions.
24 */
25
26 package intel.code.spirv;
27
28 import java.util.List;
29 import java.util.ArrayList;
30 import java.util.Arrays;
31 import java.util.Map;
32 import java.util.HashMap;
33 import java.util.Set;
34 import java.util.HashSet;
35 import java.util.Optional;
36 import java.util.function.Function;
37 import java.io.IOException;
38 import java.io.File;
39 import java.io.FileOutputStream;
40 import java.io.ByteArrayInputStream;
41 import java.io.ByteArrayOutputStream;
42 import java.io.PrintStream;
43 import java.nio.ByteBuffer;
44 import java.nio.ByteOrder;
45 import java.nio.channels.FileChannel;
46 import java.math.BigInteger;
47 import java.lang.constant.ClassDesc;
48 import java.lang.invoke.MethodHandles;
49 import java.lang.foreign.MemorySegment;
50 import java.lang.foreign.ValueLayout;
51 import jdk.incubator.code.Block;
52 import jdk.incubator.code.Body;
53 import jdk.incubator.code.Op;
54 import jdk.incubator.code.Value;
55 import jdk.incubator.code.TypeElement;
56 import jdk.incubator.code.type.MethodRef;
57 import jdk.incubator.code.type.ClassType;
58 import jdk.incubator.code.dialect.core.FunctionType;
59 import jdk.incubator.code.dialect.java.JavaType;
60 import jdk.incubator.code.dialect.java.CoreOp;
61 import hat.buffer.*;
62 import hat.callgraph.CallGraph;
63 import hat.callgraph.KernelCallGraph;
64 import hat.callgraph.KernelEntrypoint;
65 import uk.ac.manchester.beehivespirvtoolkit.lib.SPIRVHeader;
66 import uk.ac.manchester.beehivespirvtoolkit.lib.SPIRVModule;
67 import uk.ac.manchester.beehivespirvtoolkit.lib.SPIRVFunction;
68 import uk.ac.manchester.beehivespirvtoolkit.lib.SPIRVBlock;
69 import uk.ac.manchester.beehivespirvtoolkit.lib.instructions.*;
70 import uk.ac.manchester.beehivespirvtoolkit.lib.instructions.operands.*;
71 import uk.ac.manchester.beehivespirvtoolkit.lib.disassembler.Disassembler;
72 import uk.ac.manchester.beehivespirvtoolkit.lib.disassembler.SPIRVDisassemblerOptions;
73 import uk.ac.manchester.beehivespirvtoolkit.lib.disassembler.SPVByteStreamReader;
74 import intel.code.spirv.SpirvOp.PhiOp;
75
76 public class SpirvModuleGenerator {
77 private final String moduleName;
78 private final SPIRVModule module;
79 private final Symbols symbols;
80
81 public static SpirvModuleGenerator create(String moduleName) {
82 return new SpirvModuleGenerator(moduleName);
83 }
84
85 public static MemorySegment generateModule(String moduleName, KernelCallGraph callGraph) {
86 SpirvModuleGenerator generator = SpirvModuleGenerator.create(moduleName);
87 for (CallGraph.MethodCall call : callGraph.calls) {
88 if (call.targetMethodRef != null) {
89 try {
90 Optional<CoreOp.FuncOp> ofo = call.targetMethodRef.codeModel(MethodHandles.lookup());
91 if (ofo.isPresent()) {
92 CoreOp.FuncOp fo = ofo.get();
93 SpirvOp.FuncOp spirvFunc = TranslateToSpirvModel.translateFunction(fo);
94 }
95 } catch (Exception e) {
96 throw new RuntimeException(e);
97 }
98 }
99 }
100 KernelEntrypoint kernelEntrypoint = callGraph.entrypoint;
101 CoreOp.FuncOp funcOp = kernelEntrypoint.funcOpWrapper().op();
102 String kernelName = funcOp.funcName();
103 SpirvOp.FuncOp spirvFunc = TranslateToSpirvModel.translateFunction(funcOp);
104 generator.generateFunction(funcOp.funcName(), spirvFunc, true);
105 return generator.finalizeModule();
106 }
107
108 public static MemorySegment generateModule(String moduleName, CoreOp.FuncOp func) {
109 SpirvOp.FuncOp spirvFunc = TranslateToSpirvModel.translateFunction(func);
110 MemorySegment moduleSegment = SpirvModuleGenerator.generateModule(moduleName, spirvFunc);
111 return moduleSegment;
112 }
113
114 public static MemorySegment generateModule(String moduleName, SpirvOp.FuncOp func) {
115 SpirvModuleGenerator generator = new SpirvModuleGenerator(moduleName);
116 MemorySegment moduleSegment = generator.generateModuleInternal(func);
117 return moduleSegment;
118 }
119
120 public static void writeModuleToFile(MemorySegment module, String filepath) {
121 ByteBuffer buffer = module.asByteBuffer();
122 File out = new File(filepath);
123 try (FileChannel channel = new FileOutputStream(out, false).getChannel()) {
124 channel.write(buffer);
125 channel.close();
126 }
127 catch (IOException e) {
128 throw new RuntimeException(e);
129 }
130 }
131
132 private static void writeModuleToFile(SPIRVModule module, String filepath)
133 {
134 ByteBuffer buffer = ByteBuffer.allocate(module.getByteCount());
135 buffer.order(ByteOrder.LITTLE_ENDIAN);
136 module.close().write(buffer);
137 buffer.flip();
138 File out = new File(filepath);
139 try (FileChannel channel = new FileOutputStream(out, false).getChannel())
140 {
141 channel.write(buffer);
142 }
143 catch (IOException e)
144 {
145 throw new RuntimeException(e);
146 }
147 }
148
149 public static String disassembleModule(MemorySegment module) {
150 SPVByteStreamReader input = new SPVByteStreamReader(new ByteArrayInputStream(module.toArray(ValueLayout.JAVA_BYTE)));
151 ByteArrayOutputStream out = new ByteArrayOutputStream();
152 try (PrintStream ps = new PrintStream(out)) {
153 SPIRVDisassemblerOptions options = new SPIRVDisassemblerOptions(false, false, false, false, true);
154 Disassembler dis = new Disassembler(input, ps, options);
155 dis.run();
156 }
157 catch (Exception e) {
158 throw new RuntimeException(e);
159 }
160 return new String(out.toByteArray());
161 }
162
163 public MemorySegment finalizeModule() {
164 ByteBuffer buffer = ByteBuffer.allocateDirect(module.getByteCount());
165 buffer.order(ByteOrder.LITTLE_ENDIAN);
166 module.close().write(buffer);
167 buffer.flip();
168 return MemorySegment.ofBuffer(buffer);
169 }
170
171 private record SpirvResult(SPIRVId type, SPIRVId address, SPIRVId value) {}
172
173 private SpirvModuleGenerator(String moduleName) {
174 this.moduleName = moduleName;
175 this.module = new SPIRVModule(new SPIRVHeader(1, 2, 32, 0, 0));
176 this.symbols = new Symbols();
177 initModule();
178 }
179
180 private MemorySegment generateModuleInternal(SpirvOp.FuncOp func) {
181 generateFunction(moduleName, func, true);
182 return finalizeModule();
183 }
184
185 private SPIRVId generateFunction(String fnName, SpirvOp.FuncOp func, boolean isEntryPoint) {
186 TypeElement returnType = func.invokableType().returnType();
187 SPIRVId functionId = nextId(fnName);
188 String signature = func.invokableType().returnType().toString();
189 List<TypeElement> paramTypes = func.invokableType().parameterTypes();
190 // build signature string
191 for (int i = 0; i < paramTypes.size(); i++) {
192 signature += "_" + paramTypes.get(i).toString();
193 }
194 // declare function type if not already present
195 SPIRVId functionSig = getIdOrNull(signature);
196 if (functionSig == null) {
197 SPIRVId[] typeIdsArray = new SPIRVId[paramTypes.size()];
198 for (int i = 0; i < paramTypes.size(); i++) {
199 typeIdsArray[i] = spirvType(paramTypes.get(i).toString());
200 }
201 functionSig = nextId(fnName + "Signature");
202 module.add(new SPIRVOpTypeFunction(functionSig, spirvType(returnType.toString()), new SPIRVMultipleOperands<>(typeIdsArray)));
203 addId(signature, functionSig);
204 }
205 SPIRVId spirvReturnType = spirvType(returnType.toString());
206 SPIRVFunction function = (SPIRVFunction)module.add(new SPIRVOpFunction(spirvReturnType, functionId, SPIRVFunctionControl.DontInline(), functionSig));
207 SPIRVOpLabel entryLabel = new SPIRVOpLabel(nextId());
208 SPIRVBlock entryBlock = (SPIRVBlock)function.add(entryLabel);
209 SPIRVMultipleOperands<SPIRVId> operands = new SPIRVMultipleOperands<>(getId("globalInvocationId"), getId("globalSize"), getId("subgroupSize"), getId("subgroupId"));
210 if (isEntryPoint) {
211 module.add(new SPIRVOpEntryPoint(SPIRVExecutionModel.Kernel(), functionId, new SPIRVLiteralString(fnName), operands));
212 }
213 translateBody(func.body(), function, entryBlock);
214 function.add(new SPIRVOpFunctionEnd());
215 return functionId;
216 }
217
218 private void translateBody(Body body, SPIRVFunction function, SPIRVBlock entryBlock) {
219 int labelNumber = 0;
220 SPIRVBlock spirvBlock = entryBlock;
221 for (int bi = 1; bi < body.blocks().size(); bi++) {
222 Block block = body.blocks().get(bi);
223 SPIRVOpLabel blockLabel = new SPIRVOpLabel(nextId());
224 SPIRVBlock newBlock = (SPIRVBlock)function.add(blockLabel);
225 symbols.putBlock(block, newBlock);
226 symbols.putLabel(block, blockLabel);
227 for (int i = 0; i < block.parameters().size(); i++) {
228 Value param = block.parameters().get(i);
229 SPIRVId paramId = nextId();
230 addResult(param, new SpirvResult(spirvType(param.type().toString()), null, paramId));
231 }
232 }
233 for (int bi = 0; bi < body.blocks().size(); bi++) {
234 Block block = body.blocks().get(bi);
235 if (bi > 0) {
236 spirvBlock = symbols.getBlock(block);
237 }
238 for (Op op : block.ops()) {
239 switch (op) {
240 case SpirvOp.PhiOp phop -> {
241 List<PhiOp.Predecessor> inPredecessors = phop.predecessors();
242 SPIRVPairIdRefIdRef[] outPredecessors = new SPIRVPairIdRefIdRef[inPredecessors.size()];
243 for (int i = 0; i < inPredecessors.size(); i++) {
244 PhiOp.Predecessor predecessor = inPredecessors.get(i);
245 SPIRVId label = symbols.getLabel(predecessor.block().targetBlock()).getResultId();
246 SPIRVId value = getResult(predecessor.value()).value();
247 outPredecessors[i] = new SPIRVPairIdRefIdRef(value, label);
248 }
249 SPIRVId result = nextId();
250 SPIRVId type = spirvType(phop.resultType().toString());
251 SPIRVOpPhi phiOp = new SPIRVOpPhi(spirvType(phop.resultType().toString()), result, new SPIRVMultipleOperands<>(outPredecessors));
252 spirvBlock.add(phiOp);
253 addResult(phop.result(), new SpirvResult(type, null, result));
254 }
255 case SpirvOp.VariableOp vop -> {
256 String typeName = vop.varType().toString();
257 SPIRVId type = spirvType(typeName);
258 SPIRVId varType = spirvVariableType(type);
259 SPIRVId var = nextId(vop.varName());
260 spirvBlock.add(new SPIRVOpVariable(varType, var, SPIRVStorageClass.Function(), new SPIRVOptionalOperand<>()));
261 addResult(vop.result(), new SpirvResult(varType, var, null));
262 }
263 case SpirvOp.FunctionParameterOp fpo -> {
264 SPIRVId result = nextId();
265 SPIRVId type = spirvType(fpo.resultType().toString());
266 function.add(new SPIRVOpFunctionParameter(type, result));
267 module.add(new SPIRVOpDecorate(result, SPIRVDecoration.Alignment(new SPIRVLiteralInteger(8))));
268 addResult(fpo.result(), new SpirvResult(type, null, result));
269 }
270 case SpirvOp.LoadOp lo -> {
271 SPIRVId type = spirvType(lo.resultType().toString());
272 SpirvResult toLoad = getResult(lo.operands().get(0));
273 SPIRVId varAddr = toLoad.address();
274 SPIRVId result = nextId();
275 spirvBlock.add(new SPIRVOpLoad(type, result, varAddr, align(type.getName())));
276 addResult(lo.result(), new SpirvResult(type, varAddr, result));
277 }
278 case SpirvOp.StoreOp so -> {
279 SpirvResult var = getResult(so.operands().get(0));
280 SPIRVId varAddr = var.address();
281 SPIRVId value = getResult(so.operands().get(1)).value();
282 spirvBlock.add(new SPIRVOpStore(varAddr, value, align(var.type().getName())));
283 }
284 case SpirvOp.IAddOp _, SpirvOp.FAddOp _ -> {
285 SPIRVId intType = getType("int");
286 SPIRVId longType = getType("long");
287 SPIRVId floatType = getType("float");
288 SPIRVId lhs = getResult(op.operands().get(0)).value();
289 SPIRVId rhs = getResult(op.operands().get(1)).value();
290 SPIRVId lhsType = spirvType(op.resultType().toString());
291 SPIRVId ans = nextId();
292 if (lhsType == intType) spirvBlock.add(new SPIRVOpIAdd(intType, ans, lhs, rhs));
293 else if (lhsType == longType) spirvBlock.add(new SPIRVOpIAdd(longType, ans, lhs, rhs));
294 else if (lhsType == floatType) spirvBlock.add(new SPIRVOpFAdd(floatType, ans, lhs, rhs));
295 else unsupported("type", lhsType.getName());
296 addResult(op.result(), new SpirvResult(lhsType, null, ans));
297 }
298 case SpirvOp.ISubOp _, SpirvOp.FSubOp _ -> {
299 SPIRVId intType = getType("int");
300 SPIRVId longType = getType("long");
301 SPIRVId floatType = getType("float");
302 SPIRVId lhs = getResult(op.operands().get(0)).value();
303 SPIRVId rhs = getResult(op.operands().get(1)).value();
304 SPIRVId lhsType = spirvType(op.resultType().toString());
305 SPIRVId ans = nextId();
306 if (lhsType == intType) spirvBlock.add(new SPIRVOpISub(intType, ans, lhs, rhs));
307 else if (lhsType == longType) spirvBlock.add(new SPIRVOpISub(longType, ans, lhs, rhs));
308 else if (lhsType == floatType) spirvBlock.add(new SPIRVOpFSub(floatType, ans, lhs, rhs));
309 else unsupported("type", lhsType.getName());
310 addResult(op.result(), new SpirvResult(lhsType, null, ans));
311 }
312 case SpirvOp.IMulOp _, SpirvOp.FMulOp _, SpirvOp.IDivOp _, SpirvOp.FDivOp _ -> {
313 SPIRVId intType = getType("int");
314 SPIRVId longType = getType("long");
315 SPIRVId floatType = getType("float");
316 SPIRVId lhs = getResult(op.operands().get(0)).value();
317 SPIRVId rhs = getResult(op.operands().get(1)).value();
318 SPIRVId lhsType = spirvType(op.resultType().toString());
319 SPIRVId rhsType = getResult(op.operands().get(1)).type();
320 SPIRVId ans = nextId();
321 if (lhsType == intType) {
322 if (op instanceof SpirvOp.IMulOp) spirvBlock.add(new SPIRVOpIMul(intType, ans, lhs, rhs));
323 else if (op instanceof SpirvOp.IDivOp) spirvBlock.add(new SPIRVOpSDiv(intType, ans, lhs, rhs));
324 }
325 else if (lhsType == longType) {
326 SPIRVId rhsId = rhsType == intType ? nextId() : rhs;
327 if (rhsType == intType) spirvBlock.add(new SPIRVOpSConvert(longType, rhsId, rhs));
328 if (op instanceof SpirvOp.IMulOp) spirvBlock.add(new SPIRVOpIMul(longType, ans, lhs, rhsId));
329 else if (op instanceof SpirvOp.IDivOp) spirvBlock.add(new SPIRVOpSDiv(longType, ans, lhs, rhs));
330 }
331 else if (lhsType == floatType) {
332 if (op instanceof SpirvOp.FMulOp) spirvBlock.add(new SPIRVOpFMul(floatType, ans, lhs, rhs));
333 else if (op instanceof SpirvOp.FDivOp) spirvBlock.add(new SPIRVOpFDiv(floatType, ans, lhs, rhs));
334 }
335 else unsupported("type", lhsType);
336 addResult(op.result(), new SpirvResult(lhsType, null, ans));
337 }
338 case SpirvOp.ModOp mop -> {
339 SPIRVId type = getType(mop.operands().get(0).type().toString());
340 SPIRVId lhs = getResult(mop.operands().get(0)).value();
341 SPIRVId rhs = getResult(mop.operands().get(1)).value();
342 SPIRVId result = nextId();
343 spirvBlock.add(new SPIRVOpUMod(type, result, lhs, rhs));
344 addResult(mop.result(), new SpirvResult(type, null, result));
345 }
346 case SpirvOp.IEqualOp eqop -> {
347 SPIRVId boolType = getType("bool");
348 SPIRVId intType = getType("int");
349 SPIRVId longType = getType("long");
350 SPIRVId floatType = getType("float");
351 SPIRVId lhs = getResult(op.operands().get(0)).value();
352 SPIRVId rhs = getResult(op.operands().get(1)).value();
353 SPIRVId lhsType = spirvType(op.resultType().toString());
354 SPIRVId ans = nextId();
355 if (lhsType == intType) spirvBlock.add(new SPIRVOpIEqual(boolType, ans, lhs, rhs));
356 else if (lhsType == longType) spirvBlock.add(new SPIRVOpIEqual(boolType, ans, lhs, rhs));
357 else unsupported("type", lhsType.getName());
358 addResult(op.result(), new SpirvResult(lhsType, null, ans));
359 }
360 case SpirvOp.GeOp eqop -> {
361 SPIRVId boolType = getType("bool");
362 SPIRVId lhs = getResult(op.operands().get(0)).value();
363 SPIRVId rhs = getResult(op.operands().get(1)).value();
364 SPIRVId lhsType = spirvType(op.resultType().toString());
365 SPIRVId ans = nextId();
366 spirvBlock.add(new SPIRVOpSGreaterThanEqual(boolType, ans, lhs, rhs));
367 addResult(op.result(), new SpirvResult(lhsType, null, ans));
368 }
369 case SpirvOp.PtrNotEqualOp neqop -> {
370 SPIRVId boolType = getType("bool");
371 SPIRVId longType = getType("long");
372 SPIRVId lhs = getResult(neqop.operands().get(0)).value();
373 SPIRVId rhs = getResult(neqop.operands().get(1)).value();
374 SPIRVId ans = nextId();
375 SPIRVId lhsLong = nextId();
376 SPIRVId rhsLong = nextId();
377 spirvBlock.add(new SPIRVOpConvertPtrToU(longType, lhsLong, lhs));
378 spirvBlock.add(new SPIRVOpConvertPtrToU(longType, rhsLong, rhs));
379 spirvBlock.add(new SPIRVOpIEqual(boolType, ans, lhsLong, rhsLong));
380 addResult(op.result(), new SpirvResult(boolType, null, ans));
381 }
382 case SpirvOp.CallOp call -> {
383 MethodRef methodRef = call.callDescriptor();
384 if (methodRef.equals(MethodRef.method(S32Array.class, "array", int.class, long.class)))
385 {
386 SPIRVId longType = getType("long");
387 String arrayTypeName = call.operands().get(0).type().toString();
388 SpirvResult arrayResult = getResult(call.operands().get(0));
389 SPIRVId arrayAddr = arrayResult.address();
390 SPIRVId arrayType = spirvType(arrayTypeName);
391 SPIRVId elementType = spirvElementType(arrayTypeName);
392 int nIndexes = call.operands().size() - 1;
393 SPIRVId indexX = getResult(call.operands().get(1)).value();
394 SPIRVId array = nextId();
395 spirvBlock.add(new SPIRVOpLoad(arrayType, array, arrayAddr, align(arrayType.getName())));
396 SPIRVId temp1 = nextId();
397 SPIRVId temp2 = nextId();
398 spirvBlock.add(new SPIRVOpConvertPtrToU(longType, temp1, array));
399 spirvBlock.add(new SPIRVOpIAdd(longType, temp2, temp1, getConst("long", 4)));
400 SPIRVId elementBase = nextId();
401 spirvBlock.add(new SPIRVOpConvertUToPtr(arrayType, elementBase, temp2));
402 SPIRVId resultAddr = nextId();
403 spirvBlock.add(new SPIRVOpInBoundsPtrAccessChain(arrayType, resultAddr, elementBase, indexX, new SPIRVMultipleOperands<>()));
404 SPIRVId result = nextId();
405 spirvBlock.add(new SPIRVOpLoad(elementType, result, resultAddr, align(elementType.getName())));
406 addResult(call.result(), new SpirvResult(elementType, resultAddr, result));
407 }
408 else if (methodRef.equals(MethodRef.method(S32Array2D.class, "array", int.class, long.class) ||
409 methodRef.equals(MethodRef.method(F32Array2D.class, "array", float.class, long.class))
410 {
411 SPIRVId longType = getType("long");
412 String arrayTypeName = call.operands().get(0).type().toString();
413 SpirvResult arrayResult = getResult(call.operands().get(0));
414 SPIRVId arrayAddr = arrayResult.address();
415 SPIRVId arrayType = spirvType(arrayTypeName);
416 SPIRVId elementType = spirvElementType(arrayTypeName);
417 int nIndexes = call.operands().size() - 1;
418 SPIRVId indexX = getResult(call.operands().get(1)).value();
419 SPIRVId array = nextId();
420 SPIRVId temp1 = nextId();
421 SPIRVId temp2 = nextId();
422 spirvBlock.add(new SPIRVOpConvertPtrToU(longType, temp1, array));
423 spirvBlock.add(new SPIRVOpIAdd(longType, temp2, temp1, getConst("long", 4)));
424 SPIRVId elementBase = nextId();
425 spirvBlock.add(new SPIRVOpConvertUToPtr(arrayType, elementBase, temp2));
426 spirvBlock.add(new SPIRVOpLoad(arrayType, array, arrayAddr, align(arrayType.getName())));
427 SPIRVId resultAddr = nextId();
428 spirvBlock.add(new SPIRVOpInBoundsPtrAccessChain(arrayType, resultAddr, elementBase, indexX, new SPIRVMultipleOperands<>()));
429 SPIRVId result = nextId();
430 spirvBlock.add(new SPIRVOpLoad(elementType, result, resultAddr, align(elementType.getName())));
431 addResult(call.result(), new SpirvResult(elementType, resultAddr, result));
432 }
433 else if (methodRef.equals(MethodRef.method(S32Array.class, "array", void.class, long.class, int.class))) {
434 SPIRVId longType = getType("long");
435 String arrayTypeName = call.operands().get(0).type().toString();
436 SpirvResult arrayResult = getResult(call.operands().get(0));
437 SPIRVId arrayAddr = arrayResult.address();
438 SPIRVId arrayType = spirvType(arrayTypeName);
439 SPIRVId elementType = spirvElementType(arrayTypeName);
440 int nIndexes = call.operands().size() - 2;
441 int valueIndex = nIndexes + 1;
442 SPIRVId indexX = getResult(call.operands().get(1)).value();
443 SPIRVId array = nextId();
444 spirvBlock.add(new SPIRVOpLoad(arrayType, array, arrayAddr, align(arrayType.getName())));
445 SPIRVId temp1 = nextId();
446 SPIRVId temp2 = nextId();
447 spirvBlock.add(new SPIRVOpConvertPtrToU(longType, temp1, array));
448 spirvBlock.add(new SPIRVOpIAdd(longType, temp2, temp1, getConst("long", 4)));
449 SPIRVId elementBase = nextId();
450 spirvBlock.add(new SPIRVOpConvertUToPtr(arrayType, elementBase, temp2));
451 SPIRVId dest = nextId();
452 spirvBlock.add(new SPIRVOpInBoundsPtrAccessChain(arrayType, dest, elementBase, indexX, new SPIRVMultipleOperands<>()));
453 SPIRVId value = getResult(call.operands().get(valueIndex)).value();
454 spirvBlock.add(new SPIRVOpStore(dest, value, align(elementType.getName())));
455 }
456 else if (methodRef.equals(MethodRef.method(S32Array2D.class, "array", void.class, long.class, int.class)) ||
457 methodRef.equals(MethodRef.method(F32Array2D.class, "array", void.class, long.class, float.class))) {
458 SPIRVId longType = getType("long");
459 String arrayTypeName = call.operands().get(0).type().toString();
460 SpirvResult arrayResult = getResult(call.operands().get(0));
461 SPIRVId arrayAddr = arrayResult.address();
462 SPIRVId arrayType = spirvType(arrayTypeName);
463 SPIRVId elementType = spirvElementType(arrayTypeName);
464 int nIndexes = call.operands().size() - 2;
465 int valueIndex = nIndexes + 1;
466 SPIRVId indexX = getResult(call.operands().get(1)).value();
467 SPIRVId array = nextId();
468 spirvBlock.add(new SPIRVOpLoad(arrayType, array, arrayAddr, align(arrayType.getName())));
469 SPIRVId temp1 = nextId();
470 SPIRVId temp2 = nextId();
471 spirvBlock.add(new SPIRVOpConvertPtrToU(longType, temp1, array));
472 spirvBlock.add(new SPIRVOpIAdd(longType, temp2, temp1, getConst("long", 8)));
473 SPIRVId elementBase = nextId();
474 spirvBlock.add(new SPIRVOpConvertUToPtr(arrayType, elementBase, temp2));
475 SPIRVId dest = nextId();
476 spirvBlock.add(new SPIRVOpInBoundsPtrAccessChain(arrayType, dest, elementBase, indexX, new SPIRVMultipleOperands<>()));
477 SPIRVId value = getResult(call.operands().get(valueIndex)).value();
478 spirvBlock.add(new SPIRVOpStore(dest, value, align(elementType.getName())));
479 }
480 else if (methodRef.equals(MethodRef.method(S32Array.class, "length", int.class)) ||
481 methodRef.equals(MethodRef.method(S32Array2D.class, "width", int.class)) ||
482 methodRef.equals(MethodRef.method(F32Array2D.class, "width", int.class))) {
483 SPIRVId intType = getType("int");
484 String arrayTypeName = call.operands().get(0).type().toString();
485 SpirvResult arrayResult = getResult(call.operands().get(0));
486 SPIRVId arrayAddr = arrayResult.address();
487 SPIRVId arrayType = spirvType(arrayTypeName);
488 SPIRVId array = nextId();
489 spirvBlock.add(new SPIRVOpLoad(arrayType, array, arrayAddr, align(arrayType.getName())));
490 SPIRVId resultAddr = nextId();
491 spirvBlock.add(new SPIRVOpInBoundsPtrAccessChain(arrayType, resultAddr, array, getConst("int", 0), new SPIRVMultipleOperands<>()));
492 SPIRVId result = nextId();
493 spirvBlock.add(new SPIRVOpLoad(intType, result, resultAddr, align(arrayType.getName())));
494 addResult(call.result(), new SpirvResult(intType, resultAddr, result));
495 }
496 else if (methodRef.equals(MethodRef.method(S32Array2D.class, "height", int.class)) ||
497 methodRef.equals(MethodRef.method(F32Array2D.class, "height", int.class))) {
498 SPIRVId intType = getType("int");
499 String arrayTypeName = call.operands().get(0).type().toString();
500 SpirvResult arrayResult = getResult(call.operands().get(0));
501 SPIRVId arrayAddr = arrayResult.address();
502 SPIRVId arrayType = spirvType(arrayTypeName);
503 SPIRVId array = nextId();
504 spirvBlock.add(new SPIRVOpLoad(arrayType, array, arrayAddr, align(arrayType.getName())));
505 SPIRVId resultAddr = nextId();
506 spirvBlock.add(new SPIRVOpInBoundsPtrAccessChain(arrayType, resultAddr, array, getConst("int", 1), new SPIRVMultipleOperands<>()));
507 SPIRVId result = nextId();
508 spirvBlock.add(new SPIRVOpLoad(intType, result, resultAddr, align(arrayType.getName())));
509 addResult(call.result(), new SpirvResult(intType, resultAddr, result));
510 }
511 else if (methodRef.equals(MethodRef.method(Math.class, "sqrt", double.class, double.class))) {
512 SPIRVId floatType = getType("float");
513 SPIRVId result = nextId();
514 SPIRVId operand = getResult(call.operands().get(0)).value();
515 spirvBlock.add(new SPIRVOpExtInst(floatType, result, getId("oclExtension"), new SPIRVLiteralExtInstInteger(61, "sqrt"), new SPIRVMultipleOperands<>(operand)));
516 addResult(call.result(), new SpirvResult(floatType, null, result));
517 }
518 else {
519 SPIRVId fnId = getFunctionId(methodRef);
520 if (fnId == null) {
521 unsupported("method", methodRef);
522 }
523 else {
524 FunctionType fnType = methodRef.type();
525 SPIRVId[] args = new SPIRVId[call.operands().size()];
526 for (int i = 0; i < args.length; i++) {
527 SPIRVId argId = getResult(call.operands().get(i)).value();
528 args[i] = argId;
529 }
530 SPIRVId returnType = spirvType(fnType.returnType().toString());
531 SPIRVId callResult = nextId();
532 spirvBlock.add(new SPIRVOpFunctionCall(returnType, callResult, fnId, new SPIRVMultipleOperands<>(args)));
533 addResult(call.result(), new SpirvResult(returnType, null, callResult));
534 }
535 }
536 }
537 case SpirvOp.ConstantOp cop -> {
538 SPIRVId type = spirvType(cop.resultType().toString());
539 SPIRVId result = nextId();
540 Object value = cop.value();
541 if (type == getType("int")) {
542 module.add(new SPIRVOpConstant(type, result, new SPIRVContextDependentInt(new BigInteger(String.valueOf(value)))));
543 }
544 else if (type == getType("long")) {
545 module.add(new SPIRVOpConstant(type, result, new SPIRVContextDependentLong(new BigInteger(String.valueOf(value)))));
546 }
547 else if (type == getType("float")) {
548 module.add(new SPIRVOpConstant(type, result, new SPIRVContextDependentFloat((float)value)));
549 }
550 else if (type == getType("bool")) {
551 module.add(((boolean)value) ? new SPIRVOpConstantTrue(type, result) : new SPIRVOpConstantFalse(type, result));
552 }
553 else if (type == getType("java.lang.Object")) {
554 module.add(new SPIRVOpConstantNull(type, result));
555 }
556 else if (type == getType("int[]")) {
557 module.add(new SPIRVOpConstantNull(type, result));
558 }
559 else unsupported("type", cop.resultType());
560 addResult(cop.result(), new SpirvResult(type, null, result));
561 }
562 case SpirvOp.ConvertOp scop -> {
563 SPIRVId toType = spirvType(scop.resultType().toString());
564 SPIRVId to = nextId();
565 SpirvResult valueResult = getResult(scop.operands().get(0));
566 SPIRVId from = valueResult.value();
567 SPIRVId fromType = valueResult.type();
568 if (isIntegerType(fromType)) {
569 if (isIntegerType(toType)) {
570 spirvBlock.add(new SPIRVOpSConvert(toType, to, from));
571 }
572 else if (isFloatType(toType)) {
573 spirvBlock.add(new SPIRVOpConvertSToF(toType, to, from));
574 }
575 else unsupported("conversion type", scop.resultType());
576 }
577 else if (isFloatType(fromType)) {
578 if (isIntegerType(toType)) {
579 spirvBlock.add(new SPIRVOpConvertFToS(toType, to, from));
580 }
581 else if (isFloatType(toType)) {
582 spirvBlock.add(new SPIRVOpFConvert(toType, to, from));
583 }
584 else unsupported("conversion type", scop.resultType());
585 }
586 else unsupported("conversion type", scop.operands().get(0));
587 addResult(scop.result(), new SpirvResult(toType, null, to));
588 }
589 case SpirvOp.InBoundsAccessChainOp iacop -> {
590 SPIRVId type = spirvType(iacop.resultType().toString());
591 SPIRVId result = nextId();
592 SPIRVId object = getResult(iacop.operands().get(0)).value();
593 SPIRVId index = getResult(iacop.operands().get(1)).value();
594 spirvBlock.add(new SPIRVOpInBoundsPtrAccessChain(type, result, object, index, new SPIRVMultipleOperands<>()));
595 addResult(iacop.result(), new SpirvResult(type, result, null));
596 }
597 case SpirvOp.FieldLoadOp flo -> {
598 if (flo.operands().size() > 0 && (flo.operands().get(0).type().equals(KernelContext.class))) {
599 SpirvResult result;
600 int group = -1;
601 int index = -1;
602 String fieldName = flo.fieldDescriptor().name();
603 switch (fieldName) {
604 case "x": group = 0; index = 0; break;
605 case "y": group = 0; index = 1; break;
606 case "z": group = 0; index = 2; break;
607 case "maxX": group = 1; index = 0; break;
608 case "maxY": group = 1; index = 1; break;
609 case "maxZ": group = 1; index = 2; break;
610 }
611 switch (group) {
612 case 0: result = globalId(index, spirvBlock); break;
613 case 1: result = globalSize(index, spirvBlock); break;
614 default: throw new RuntimeException("Unknown Index field: " + fieldName);
615 }
616 addResult(flo.result(), result);
617 }
618 else if (flo.operands().get(0).type().equals(JavaType.type(KernelContext.class))) {
619 String fieldName = flo.fieldDescriptor().name();
620 SPIRVId fieldIndex = switch (fieldName) {
621 case "x" -> getConst("long", 0);
622 case "maxX" -> getConst("long", 1);
623 default -> throw new RuntimeException("Unknown field: " + fieldName);
624 };
625 SPIRVId intType = getType("int");
626 String contextTypeName = flo.operands().get(0).type().toString();
627 SpirvResult kernalContext = getResult(flo.operands().get(0));
628 SPIRVId contextAddr = kernalContext.address();
629 SPIRVId contextType = spirvType(contextTypeName);
630 SPIRVId context = nextId();
631 spirvBlock.add(new SPIRVOpLoad(contextType, context, contextAddr, align(contextType.getName())));
632 SPIRVId fieldType = intType;
633 SPIRVId resultAddr = nextId();
634 spirvBlock.add(new SPIRVOpInBoundsAccessChain(getType("ptrInt"), resultAddr, context, new SPIRVMultipleOperands<>(fieldIndex)));
635 SPIRVId result = nextId();
636 spirvBlock.add(new SPIRVOpLoad(intType, result, resultAddr, align("int")));
637 addResult(flo.result(), new SpirvResult(intType, resultAddr, result));
638 }
639 else if (flo.fieldDescriptor().refType().equals(JavaType.type(ByteOrder.class))) {
640 // currently ignored
641 }
642 else unsupported("field load", ((ClassType)flo.fieldDescriptor().refType()).toClassName() + "." + flo.fieldDescriptor().name());
643 }
644 case SpirvOp.BranchOp bop -> {
645 SPIRVId label = symbols.getLabel(bop.branch().targetBlock()).getResultId();
646 Block.Reference target = bop.branch();
647 spirvBlock.add(new SPIRVOpBranch(label));
648 }
649 case SpirvOp.ConditionalBranchOp cbop -> {
650 SPIRVId test = getResult(cbop.operands().get(0)).value();
651 SPIRVId trueLabel = symbols.getLabel(cbop.trueBranch().targetBlock()).getResultId();
652 SPIRVId falseLabel = symbols.getLabel(cbop.falseBranch().targetBlock()).getResultId();
653 spirvBlock.add(new SPIRVOpBranchConditional(test, trueLabel, falseLabel, new SPIRVMultipleOperands<SPIRVLiteralInteger>()));
654 }
655 case SpirvOp.LtOp ltop -> {
656 SpirvResult lhs = getResult(ltop.operands().get(0));
657 SpirvResult rhs = getResult(ltop.operands().get(1));
658 SPIRVId boolType = getType("bool");
659 SPIRVId result = nextId();
660 String operandType = lhs.type().getName();
661 SPIRVInstruction sop = switch (operandType) {
662 case "float" -> new SPIRVOpFUnordLessThan(boolType, result, lhs.value(), rhs.value());
663 case "int" -> new SPIRVOpSLessThan(boolType, result, lhs.value(), rhs.value());
664 case "long" -> new SPIRVOpSLessThan(boolType, result, lhs.value(), rhs.value());
665 default -> throw new RuntimeException("Unsupported type: " + lhs.type().getName());
666 };
667 spirvBlock.add(sop);
668 addResult(ltop.result(), new SpirvResult(boolType, null, result));
669 }
670 case SpirvOp.ReturnOp rop -> {
671 spirvBlock.add(new SPIRVOpReturn());
672 }
673 case SpirvOp.ReturnValueOp rop -> {
674 SPIRVId returnValue = getResult(rop.operands().get(0)).value();
675 spirvBlock.add(new SPIRVOpReturnValue(returnValue));
676 }
677 default -> unsupported("op", op.getClass());
678 }
679 }
680 } // end bi
681 }
682
683 private SPIRVId getFunctionId(MethodRef methodRef) {
684 SPIRVId fnId = symbols.getId(methodRef.toString());
685 if (fnId == null) {
686 try {
687 Optional<CoreOp.FuncOp> optJFuncOp = methodRef.codeModel(MethodHandles.lookup());
688 if (optJFuncOp.isPresent()) {
689 CoreOp.FuncOp jFuncOp = optJFuncOp.get();
690 SpirvOp.FuncOp sFuncOp = TranslateToSpirvModel.translateFunction(jFuncOp);
691 fnId = generateFunction(jFuncOp.funcName(), sFuncOp, false);
692 symbols.putId(methodRef.toString(), fnId);
693 }
694 }
695 catch (ReflectiveOperationException e) {
696 throw new RuntimeException(e);
697 }
698 }
699 return fnId;
700 }
701
702 private void initModule() {
703 module.add(new SPIRVOpCapability(SPIRVCapability.Addresses()));
704 module.add(new SPIRVOpCapability(SPIRVCapability.Linkage()));
705 module.add(new SPIRVOpCapability(SPIRVCapability.Kernel()));
706 module.add(new SPIRVOpCapability(SPIRVCapability.Int8()));
707 module.add(new SPIRVOpCapability(SPIRVCapability.Int64()));
708 module.add(new SPIRVOpMemoryModel(SPIRVAddressingModel.Physical64(), SPIRVMemoryModel.OpenCL()));
709
710 // OpenCL extension provides built-in variables suitable for kernel programming
711 // Import extension and declare four variables
712 SPIRVId oclExtension = nextId("oclExtension");
713 module.add(new SPIRVOpExtInstImport(oclExtension, new SPIRVLiteralString("OpenCL.std")));
714 symbols.putId("oclExtension", oclExtension);
715
716 SPIRVId globalInvocationId = nextId("globalInvocationId");
717 SPIRVId globalSize = nextId("globalSize");
718 SPIRVId subgroupSize = nextId("subgroupSize");
719 SPIRVId subgroupId = nextId("subgroupId");
720
721 module.add(new SPIRVOpDecorate(globalInvocationId, SPIRVDecoration.BuiltIn(SPIRVBuiltIn.GlobalInvocationId())));
722 module.add(new SPIRVOpDecorate(globalInvocationId, SPIRVDecoration.Constant()));
723 module.add(new SPIRVOpDecorate(globalInvocationId, SPIRVDecoration.LinkageAttributes(new SPIRVLiteralString("spirv_BuiltInGlobalInvocationId"), SPIRVLinkageType.Import())));
724 module.add(new SPIRVOpDecorate(globalSize, SPIRVDecoration.BuiltIn(SPIRVBuiltIn.GlobalSize())));
725 module.add(new SPIRVOpDecorate(globalSize, SPIRVDecoration.Constant()));
726 module.add(new SPIRVOpDecorate(globalSize, SPIRVDecoration.LinkageAttributes(new SPIRVLiteralString("spirv_BuiltInGlobalSize"), SPIRVLinkageType.Import())));
727 module.add(new SPIRVOpDecorate(subgroupSize, SPIRVDecoration.BuiltIn(SPIRVBuiltIn.SubgroupSize())));
728 module.add(new SPIRVOpDecorate(subgroupSize, SPIRVDecoration.Constant()));
729 module.add(new SPIRVOpDecorate(subgroupSize, SPIRVDecoration.LinkageAttributes(new SPIRVLiteralString("spirv_BuiltInSubgroupSize"), SPIRVLinkageType.Import())));
730 module.add(new SPIRVOpDecorate(subgroupId, SPIRVDecoration.BuiltIn(SPIRVBuiltIn.SubgroupId())));
731 module.add(new SPIRVOpDecorate(subgroupId, SPIRVDecoration.Constant()));
732 module.add(new SPIRVOpDecorate(subgroupId, SPIRVDecoration.LinkageAttributes(new SPIRVLiteralString("spirv_BuiltInSubgroupId"), SPIRVLinkageType.Import())));
733
734 module.add(new SPIRVOpVariable(getType("ptrV3long"), globalInvocationId, SPIRVStorageClass.Input(), new SPIRVOptionalOperand<>()));
735 module.add(new SPIRVOpVariable(getType("ptrV3long"), globalSize, SPIRVStorageClass.Input(), new SPIRVOptionalOperand<>()));
736 module.add(new SPIRVOpVariable(getType("ptrV3long"), subgroupSize, SPIRVStorageClass.Input(), new SPIRVOptionalOperand<>()));
737 module.add(new SPIRVOpVariable(getType("ptrV3long"), subgroupId, SPIRVStorageClass.Input(), new SPIRVOptionalOperand<>()));
738 }
739
740 private SPIRVId spirvType(String inputType) {
741 String javaType = inputType.replaceAll("\\$", ".");
742 SPIRVId ans = switch(javaType) {
743 case "byte" -> getType("byte");
744 case "short" -> getType("short");
745 case "int" -> getType("int");
746 case "long" -> getType("long");
747 case "float" -> getType("float");
748 case "double" -> getType("double");
749 case "byte[]" -> getType("byte[]");
750 case "int[]" -> getType("int[]");
751 case "float[]" -> getType("float[]");
752 case "double[]" -> getType("double[]");
753 case "long[]" -> getType("long[]");
754 case "bool" -> getType("bool");
755 case "boolean" -> getType("bool");
756 case "java.lang.Object" -> getType("java.lang.Object");
757 case "hat.buffer.S32Array" -> getType("int[]");
758 case "hat.buffer.S32Array2D" -> getType("int[]");
759 case "hat.buffer.F32Array2D" -> getType("float[]");
760 case "void" -> getType("void");
761 case "hat.KernelContext" -> getType("ptrKernelContext");
762 case "java.lang.foreign.MemorySegment" -> getType("ptrByte");
763 default -> null;
764 };
765 if (ans == null) unsupported("type", javaType);
766 return ans;
767 }
768
769 private SPIRVId spirvElementType(String inputType) {
770 String javaType = inputType.replaceAll("\\$", ".");
771 SPIRVId ans = switch(javaType) {
772 case "byte[]" -> getType("byte");
773 case "short[]" -> getType("short");
774 case "int[]" -> getType("int");
775 case "long[]" -> getType("long");
776 case "float[]" -> getType("float");
777 case "double[]" -> getType("double");
778 case "boolean[]" -> getType("bool");
779 case "hat.buffer.S32Array" -> getType("int");
780 case "hat.buffer.S32Array2D" -> getType("int");
781 case "hat.buffer.F32Array2D" -> getType("float");
782 case "java.lang.foreign.MemorySegment" -> getType("byte");
783 default -> null;
784 };
785 if (ans == null) unsupported("type", javaType);
786 return ans;
787 }
788
789 private SPIRVId vectorElementType(SPIRVId type) {
790 SPIRVId ans = switch(type.getName()) {
791 case "v8int" -> getType("int");
792 case "v16int" -> getType("int");
793 case "v8long" -> getType("long");
794 case "v8float" -> getType("float");
795 case "v16float" -> getType("float");
796 default -> null;
797 };
798 if (ans == null) unsupported("type", type.getName());
799 return ans;
800 }
801
802 private SPIRVId spirvVariableType(SPIRVId spirvType) {
803 SPIRVId ans = switch(spirvType.getName()) {
804 case "bool" -> getType("ptrBool");
805 case "byte" -> getType("ptrByte");
806 case "short" -> getType("ptrShort");
807 case "int" -> getType("ptrInt");
808 case "long" -> getType("ptrLong");
809 case "float" -> getType("ptrFloat");
810 case "double" -> getType("ptrDouble");
811 case "boolean" -> getType("ptrBool");
812 case "byte[]" -> getType("ptrByte[]");
813 case "int[]" -> getType("ptrInt[]");
814 case "long[]" -> getType("ptrLong[]");
815 case "float[]" -> getType("ptrFloat[]");
816 case "double[]" -> getType("ptrDouble[]");
817 case "v8int" -> getType("ptrV8int");
818 case "v16int" -> getType("ptrV16int");
819 case "v8long" -> getType("ptrV8long");
820 case "v8float" -> getType("ptrV8float");
821 case "v16float" -> getType("ptrV16float");
822 case "ptrKernelContext" -> getType("ptrPtrKernelContext");
823 case "hat.KernelContext" -> getType("ptrKernelContext");
824 case "ptrByte" -> getType("ptrPtrByte");
825 default -> null;
826 };
827 if (ans == null) unsupported("type", spirvType.getName());
828 return ans;
829 }
830
831 private SPIRVId spirvVectorType(String javaVectorType, int vectorLength) {
832 String prefix = "v" + vectorLength;
833 String elementType = spirvElementType(javaVectorType).getName();
834 return getType(prefix + elementType);
835 }
836
837 private int alignment(String inputType) {
838 String spirvType = inputType.replaceAll("\\$", ".");
839 int ans = switch(spirvType) {
840 case "bool" -> 1;
841 case "byte" -> 1;
842 case "short" -> 2;
843 case "int" -> 4;
844 case "long" -> 8;
845 case "float" -> 4;
846 case "double" -> 8;
847 case "boolean" -> 1;
848 case "v8int" -> 32;
849 case "v16int" -> 64;
850 case "v8long" -> 64;
851 case "v8float" -> 32;
852 case "v16float" -> 64;
853 case "hat.KernelContext" -> 32;
854 case "ptrKernelContext" -> 32;
855 case "byte[]" -> 8;
856 case "int[]" -> 8;
857 case "long[]" -> 8;
858 case "float[]" -> 8;
859 case "double[]" -> 8;
860 case "ptrBool" -> 8;
861 case "ptrByte" -> 8;
862 case "ptrInt" -> 8;
863 case "ptrByte[]" -> 8;
864 case "ptrInt[]" -> 8;
865 case "ptrLong" -> 8;
866 case "ptrLong[]" -> 8;
867 case "ptrFloat" -> 8;
868 case "ptrFloat[]" -> 8;
869 case "ptrV8int" -> 8;
870 case "ptrV8float" -> 8;
871 case "ptrPtrKernelContext" -> 8;
872 default -> 0;
873 };
874 if (ans == 0) unsupported("type", spirvType);
875 return ans;
876 }
877
878 private Set<String> moduleTypes = new HashSet<>();
879
880 private SPIRVId getType(String inputName) {
881 String name = inputName.replaceAll("\\$", ".");
882 if (!moduleTypes.contains(name)) {
883 switch (name) {
884 case "void" -> module.add(new SPIRVOpTypeVoid(nextId(name)));
885 case "bool" -> module.add(new SPIRVOpTypeBool(nextId(name)));
886 case "boolean" -> module.add(new SPIRVOpTypeBool(nextId(name)));
887 case "byte" -> module.add(new SPIRVOpTypeInt(nextId(name), new SPIRVLiteralInteger(8), new SPIRVLiteralInteger(0)));
888 case "short" -> module.add(new SPIRVOpTypeInt(nextId(name), new SPIRVLiteralInteger(16), new SPIRVLiteralInteger(0)));
889 case "int" -> module.add(new SPIRVOpTypeInt(nextId(name), new SPIRVLiteralInteger(32), new SPIRVLiteralInteger(0)));
890 case "long" -> module.add(new SPIRVOpTypeInt(nextId(name), new SPIRVLiteralInteger(64), new SPIRVLiteralInteger(0)));
891 case "float" -> module.add(new SPIRVOpTypeFloat(nextId(name), new SPIRVLiteralInteger(32)));
892 case "double" -> module.add(new SPIRVOpTypeFloat(nextId(name), new SPIRVLiteralInteger(64)));
893 case "ptrBool" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.CrossWorkgroup(), getType("bool")));
894 case "ptrByte" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.CrossWorkgroup(), getType("byte")));
895 case "ptrInt" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("int")));
896 case "ptrLong" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("long")));
897 case "ptrFloat" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("float")));
898 case "byte[]" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.CrossWorkgroup(), getType("byte")));
899 case "short[]" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.CrossWorkgroup(), getType("short")));
900 case "int[]" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.CrossWorkgroup(), getType("int")));
901 case "long[]" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.CrossWorkgroup(), getType("long")));
902 case "float[]" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.CrossWorkgroup(), getType("float")));
903 case "double[]" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.CrossWorkgroup(), getType("double")));
904 case "boolean[]" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.CrossWorkgroup(), getType("boolean")));
905 case "ptrByte[]" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("byte[]")));
906 case "ptrInt[]" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("int[]")));
907 case "ptrLong[]" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("long[]")));
908 case "ptrFloat[]" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("float[]")));
909 case "java.lang.Object" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("void")));
910 case "hat.KernelContext" -> module.add(new SPIRVOpTypeStruct(nextId(name), new SPIRVMultipleOperands<>(getType("int"), getType("int"))));
911 case "ptrKernelContext" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("hat.KernelContext")));
912 case "ptrCrossGroupByte"-> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.CrossWorkgroup(), getType("byte")));
913 case "ptrPtrKernelContext" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("ptrKernelContext")));
914 case "ptrPtrByte" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("ptrByte")));
915 case "ptrPtrInt" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("ptrInt")));
916 case "ptrPtrFloat" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("ptrFloat")));
917 case "v3long" -> module.add(new SPIRVOpTypeVector(nextId(name), getType("long"), new SPIRVLiteralInteger(3)));
918 case "v8int" -> module.add(new SPIRVOpTypeVector(nextId(name), getType("int"), new SPIRVLiteralInteger(8)));
919 case "v8long" -> module.add(new SPIRVOpTypeVector(nextId(name), getType("long"), new SPIRVLiteralInteger(8)));
920 case "v16int" -> module.add(new SPIRVOpTypeVector(nextId(name), getType("int"), new SPIRVLiteralInteger(16)));
921 case "v8float" -> module.add(new SPIRVOpTypeVector(nextId(name), getType("float"), new SPIRVLiteralInteger(8)));
922 case "v16float" -> module.add(new SPIRVOpTypeVector(nextId(name), getType("float"), new SPIRVLiteralInteger(16)));
923 case "ptrV3long" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Input(), getType("v3long")));
924 case "ptrV8long" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("v8long")));
925 case "ptrV8int" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("v8int")));
926 case "ptrV16int" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("v16int")));
927 case "ptrV8float" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("v8float")));
928 case "ptrV16float" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("v16float")));
929 case "ptrPtrV8int" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("ptrV8int")));
930 case "ptrPtrV16int" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("ptrV16int")));
931 case "ptrPtrV8float" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("ptrV8float")));
932 case "ptrPtrV16float" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("ptrV16float")));
933 default -> unsupported("type", name);
934 }
935 moduleTypes.add(name);
936 }
937 return getId(name);
938 }
939
940 private Set<String> moduleConstants = new HashSet<>();
941
942 private SPIRVId getConst(String typeName, long value) {
943 String name = typeName + "_" + value;
944 if (!moduleConstants.contains(name)) {
945 String valueStr = String.valueOf(value);
946 switch (typeName) {
947 case "int" -> module.add(new SPIRVOpConstant(getType(typeName), nextId(name), new SPIRVContextDependentInt(new BigInteger(valueStr))));
948 case "long" -> module.add(new SPIRVOpConstant(getType(typeName), nextId(name), new SPIRVContextDependentLong(new BigInteger(valueStr))));
949 case "boolean" -> module.add(value == 0 ? new SPIRVOpConstantFalse(getType(typeName), nextId(name)) : new SPIRVOpConstantTrue(getType(typeName), nextId(name)));
950 default -> unsupported("constant", name);
951 };
952 moduleConstants.add(name);
953 }
954 return getId(name);
955 }
956
957 private SPIRVOptionalOperand<SPIRVMemoryAccess> align(int align) {
958 return new SPIRVOptionalOperand<>(SPIRVMemoryAccess.Aligned(new SPIRVLiteralInteger(align)));
959 }
960
961 private SPIRVOptionalOperand<SPIRVMemoryAccess> align(String type) {
962 return align(alignment(type));
963 }
964
965 private SPIRVMultipleOperands<SPIRVId> spirvOperands(SPIRVId value, int count) {
966 SPIRVId[] values = new SPIRVId[count];
967 Arrays.fill(values, value);
968 return new SPIRVMultipleOperands<>(values);
969 }
970
971 private SPIRVOptionalOperand<SPIRVMemoryAccess> none() {
972 return new SPIRVOptionalOperand<>();
973 }
974
975 private SpirvResult globalSize(int index, SPIRVBlock spirvBlock) {
976 SPIRVId intType = getType("int");
977 SPIRVId longType = getType("long");
978 SPIRVId v3long = getId("v3long");
979 SPIRVId globalSizeId = getId("globalSize");
980 SPIRVId globalSizes = nextId();
981 spirvBlock.add(new SPIRVOpLoad(v3long, globalSizes, globalSizeId, align(32)));
982 SPIRVId longSize = nextId();
983 SPIRVId globalSize = nextId();
984 spirvBlock.add(new SPIRVOpCompositeExtract(longType, longSize, globalSizes, new SPIRVMultipleOperands<>(new SPIRVLiteralInteger(index))));
985 spirvBlock.add(new SPIRVOpSConvert(intType, globalSize, longSize));
986 return new SpirvResult(intType, null, globalSize);
987 }
988
989 private SpirvResult globalId(int index, SPIRVBlock spirvBlock) {
990 SPIRVId intType = getType("int");
991 SPIRVId longType = getType("long");
992 SPIRVId v3long = getId("v3long");
993 SPIRVId globalInvocationId = getId("globalInvocationId");
994 SPIRVId globalIds = nextId();
995 spirvBlock.add(new SPIRVOpLoad(v3long, globalIds, globalInvocationId, align(32)));
996 SPIRVId longIndex = nextId();
997 SPIRVId globalIndex = nextId();
998 spirvBlock.add(new SPIRVOpCompositeExtract(longType, longIndex, globalIds, new SPIRVMultipleOperands<>(new SPIRVLiteralInteger(index))));
999 spirvBlock.add(new SPIRVOpSConvert(intType, globalIndex, longIndex));
1000 return new SpirvResult(intType, null, globalIndex);
1001 }
1002
1003 private SpirvResult flatIndex(SPIRVId sizeX, SPIRVId sizeY, SPIRVId sizeZ, SPIRVId indexX, SPIRVId indexY, SPIRVId indexZ, SPIRVBlock spirvBlock)
1004 {
1005 SPIRVId longType = getType("long");
1006 SPIRVId xTerm0 = nextId();
1007 SPIRVId xTerm1 = nextId();
1008 SPIRVId yTerm = nextId();
1009 SPIRVId flat0 = nextId();
1010 SPIRVId flat1 = nextId();
1011 spirvBlock.add(new SPIRVOpIMul(longType, xTerm0, sizeY, sizeZ));
1012 spirvBlock.add(new SPIRVOpIMul(longType, xTerm1, xTerm0, indexX));
1013 spirvBlock.add(new SPIRVOpIMul(longType, yTerm, sizeZ, indexY));
1014 spirvBlock.add(new SPIRVOpIAdd(longType, flat0, xTerm1, yTerm));
1015 spirvBlock.add(new SPIRVOpIAdd(longType, flat1, flat0, indexZ));
1016 return new SpirvResult(longType, null, flat1);
1017 }
1018
1019 private SPIRVId nextId() {
1020 return module.getNextId();
1021 }
1022
1023 private SPIRVId nextId(String name) {
1024 SPIRVId ans = nextId();
1025 ans.setName(name);
1026 symbols.putId(name, ans);
1027 module.add(new SPIRVOpName(ans, new SPIRVLiteralString(name)));
1028 return ans;
1029 }
1030
1031 private static int counter = 0;
1032
1033 private String nextTempTag() {
1034 counter++;
1035 return "temp_" + counter + "_";
1036 }
1037
1038 private boolean isIntegerType(SPIRVId type) {
1039 String name = type.getName();
1040 return name.equals("byte") || name.equals("short") || name.equals("int") || name.equals("long");
1041 }
1042
1043 private boolean isFloatType(SPIRVId type) {
1044 String name = type.getName();
1045 return name.equals("float") || name.equals("double");
1046 }
1047
1048 private boolean isVectorSpecies(String javaType) {
1049 return javaType.equals("VectorSpecies");
1050 }
1051
1052 private boolean isVectorType(String javaType) {
1053 return javaType.equals("IntVector") || javaType.equals("FloatVector");
1054 }
1055
1056 private void addId(String name, SPIRVId id) {
1057 symbols.putId(name, id);
1058 }
1059
1060 private SPIRVId getId(String name) {
1061 SPIRVId ans = symbols.getId(name);
1062 assert ans != null : name + " not found";
1063 return ans;
1064 }
1065
1066 private SPIRVId getIdOrNull(String name) {
1067 return symbols.getId(name);
1068 }
1069
1070 private static Object map(Function<Object, Boolean> test, Object... args) {
1071 int len = args.length;
1072 assert len >= 2 && len % 2 == 0;
1073 int pairs = len / 2;
1074 for (int i = 0; i < pairs; i++) {
1075 if (test.apply(args[i])) return args[i + pairs];
1076 }
1077 throw new RuntimeException("No match: " + args[0]);
1078 }
1079
1080 private void unsupported(String message, Object value) {
1081 throw new RuntimeException("Unsupported " + message + ": " + value);
1082 }
1083
1084 private void addResult(Value value, SpirvResult result) {
1085 assert symbols.getResult(value) == null : "result already present";
1086 symbols.putResult(value, result);
1087 }
1088
1089 private SpirvResult getResult(Value value) {
1090 return symbols.getResult(value);
1091 }
1092
1093 private static class Symbols {
1094 private final HashMap<Value, SpirvResult> results;
1095 private final HashMap<String, SPIRVId> ids;
1096 private final HashMap<Block, SPIRVBlock> blocks;
1097 private final HashMap<Block, SPIRVOpLabel> labels;
1098
1099 public Symbols() {
1100 this.results = new HashMap<>();
1101 this.ids = new HashMap<>();
1102 this.blocks = new HashMap<>();
1103 this.labels = new HashMap<>();
1104 }
1105
1106 public void putId(String name, SPIRVId id) {
1107 ids.put(name, id);
1108 }
1109
1110 public SPIRVId getId(String name) {
1111 return ids.get(name);
1112 }
1113
1114 public void putBlock(Block block, SPIRVBlock spirvBlock) {
1115 blocks.put(block, spirvBlock);
1116 }
1117
1118 public SPIRVBlock getBlock(Block block) {
1119 return blocks.get(block);
1120 }
1121
1122 public void putLabel(Block block, SPIRVOpLabel spirvLabel) {
1123 labels.put(block, spirvLabel);
1124 }
1125
1126 public SPIRVOpLabel getLabel(Block block) {
1127 return labels.get(block);
1128 }
1129
1130 public void putResult(Value value, SpirvResult result) {
1131 results.put(value, result);
1132 }
1133
1134 public SpirvResult getResult(Value value) {
1135 return results.get(value);
1136 }
1137
1138 public String toString() {
1139 return String.format("results %s\n\nids %s\n\nblocks %s\nlabels %s\n", results.keySet(), ids.keySet(), blocks.keySet(), labels.keySet());
1140 }
1141 }
1142
1143 public static void debug(String message, Object... args) {
1144 System.out.println(String.format(message, args));
1145 }
1146 }