1 /*
2 * Copyright (c) 2024 Intel Corporation. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation. Oracle designates this
8 * particular file as subject to the "Classpath" exception as provided
9 * by Oracle in the LICENSE file that accompanied this code.
10 *
11 * This code is distributed in the hope that it will be useful, but WITHOUT
12 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
13 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
14 * version 2 for more details (a copy is included in the LICENSE file that
15 * accompanied this code).
16 *
17 * You should have received a copy of the GNU General Public License version
18 * 2 along with this work; if not, write to the Free Software Foundation,
19 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
20 *
21 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
22 * or visit www.oracle.com if you need additional information or have any
23 * questions.
24 */
25
26 package intel.code.spirv;
27
28 import java.util.List;
29 import java.util.Arrays;
30 import java.util.HashMap;
31 import java.util.Set;
32 import java.util.HashSet;
33 import java.util.function.Function;
34 import java.io.IOException;
35 import java.io.File;
36 import java.io.FileOutputStream;
37 import java.io.ByteArrayInputStream;
38 import java.io.ByteArrayOutputStream;
39 import java.io.PrintStream;
40 import java.lang.constant.ClassDesc;
41 import java.nio.ByteBuffer;
42 import java.nio.ByteOrder;
43 import java.nio.channels.FileChannel;
44 import java.math.BigInteger;
45
46 import jdk.incubator.code.dialect.core.CoreOp;
47 import jdk.incubator.code.dialect.java.ClassType;
48 import jdk.incubator.code.dialect.java.JavaType;
49 import jdk.incubator.code.dialect.java.MethodRef;
50 import jdk.incubator.vector.VectorSpecies;
51 import jdk.incubator.vector.VectorOperators;
52 import jdk.incubator.vector.Vector;
53 import jdk.incubator.vector.IntVector;
54 import jdk.incubator.vector.FloatVector;
55 import java.lang.foreign.MemorySegment;
56 import java.lang.foreign.ValueLayout;
57 import jdk.incubator.code.Block;
58 import jdk.incubator.code.Body;
59 import jdk.incubator.code.Op;
60 import jdk.incubator.code.Value;
61 import jdk.incubator.code.TypeElement;
62 import uk.ac.manchester.beehivespirvtoolkit.lib.SPIRVHeader;
63 import uk.ac.manchester.beehivespirvtoolkit.lib.SPIRVModule;
64 import uk.ac.manchester.beehivespirvtoolkit.lib.SPIRVFunction;
65 import uk.ac.manchester.beehivespirvtoolkit.lib.SPIRVBlock;
66 import uk.ac.manchester.beehivespirvtoolkit.lib.instructions.*;
67 import uk.ac.manchester.beehivespirvtoolkit.lib.instructions.operands.*;
68 import uk.ac.manchester.beehivespirvtoolkit.lib.disassembler.Disassembler;
69 import uk.ac.manchester.beehivespirvtoolkit.lib.disassembler.SPIRVDisassemblerOptions;
70 import uk.ac.manchester.beehivespirvtoolkit.lib.disassembler.SPVByteStreamReader;
71
72 public class SpirvModuleGenerator {
73 public static MemorySegment generateModule(String moduleName, CoreOp.FuncOp func) {
74 SpirvOps.FuncOp spirvFunc = TranslateToSpirvModel.translateFunction(func);
75 MemorySegment module = SpirvModuleGenerator.generateModule(moduleName, spirvFunc);
76 return module;
77 }
78
79 public static MemorySegment generateModule(String moduleName, SpirvOps.FuncOp func) {
80 return new SpirvModuleGenerator().generateModuleInternal(moduleName, func);
81 }
82
83 public static void writeModuleToFile(MemorySegment module, String filepath) {
84 ByteBuffer buffer = module.asByteBuffer();
85 File out = new File(filepath);
86 try (FileChannel channel = new FileOutputStream(out, false).getChannel()) {
87 channel.write(buffer);
88 }
89 catch (IOException e) {
90 throw new RuntimeException(e);
91 }
92 }
93
94 public static String disassembleModule(MemorySegment module) {
95 SPVByteStreamReader input = new SPVByteStreamReader(new ByteArrayInputStream(module.toArray(ValueLayout.JAVA_BYTE)));
96 ByteArrayOutputStream out = new ByteArrayOutputStream();
97 try (PrintStream ps = new PrintStream(out)) {
98 SPIRVDisassemblerOptions options = new SPIRVDisassemblerOptions(false, false, false, false, true);
99 Disassembler dis = new Disassembler(input, ps, options);
100 dis.run();
101 }
102 catch (Exception e) {
103 throw new RuntimeException(e);
104 }
105 return new String(out.toByteArray());
106 }
107
108 private record SpirvResult(SPIRVId type, SPIRVId address, SPIRVId value) {}
109
110 private final SPIRVModule module;
111 private final Symbols symbols;
112
113 private SpirvModuleGenerator() {
114 this.module = new SPIRVModule(new SPIRVHeader(1, 2, 32, 0, 0));
115 this.symbols = new Symbols();
116 }
117
118 private MemorySegment generateModuleInternal(String moduleName, SpirvOps.FuncOp func) {
119 initModule();
120 generateFunction(moduleName, moduleName, func);
121 ByteBuffer buffer = ByteBuffer.allocateDirect(module.getByteCount());
122 buffer.order(ByteOrder.LITTLE_ENDIAN);
123 module.close().write(buffer);
124 buffer.flip();
125 return MemorySegment.ofBuffer(buffer);
126 }
127
128 private void generateFunction(String moduleName, String fnName, SpirvOps.FuncOp func) {
129 TypeElement returnType = func.invokableType().returnType();
130 SPIRVId functionID = nextId(fnName);
131 String signature = func.invokableType().returnType().toString();
132 List<TypeElement> paramTypes = func.invokableType().parameterTypes();
133 // build signature string
134 for (int i = 0; i < paramTypes.size(); i++) {
135 signature += "_" + paramTypes.get(i).toString();
136 }
137 // declare function type if not already present
138 SPIRVId functionSig = getIdOrNull(signature);
139 if (functionSig == null) {
140 SPIRVId[] typeIdsArray = new SPIRVId[paramTypes.size()];
141 for (int i = 0; i < paramTypes.size(); i++) {
142 typeIdsArray[i] = spirvType(paramTypes.get(i).toString());
143 }
144 functionSig = nextId(fnName + "Signature");
145 module.add(new SPIRVOpTypeFunction(functionSig, spirvType(returnType.toString()), new SPIRVMultipleOperands<>(typeIdsArray)));
146 addId(signature, functionSig);
147 }
148 // declare function as modeule entry point
149 SPIRVId spirvReturnType = spirvType(returnType.toString());
150 SPIRVFunction function = (SPIRVFunction)module.add(new SPIRVOpFunction(spirvReturnType, functionID, SPIRVFunctionControl.DontInline(), functionSig));
151 SPIRVOpLabel entryPoint = new SPIRVOpLabel(nextId());
152 SPIRVBlock entryBlock = (SPIRVBlock)function.add(entryPoint);
153 SPIRVMultipleOperands<SPIRVId> operands = new SPIRVMultipleOperands<>(getId("globalInvocationId"), getId("globalSize"), getId("subgroupSize"), getId("subgroupId"));
154 module.add(new SPIRVOpEntryPoint(SPIRVExecutionModel.Kernel(), functionID, new SPIRVLiteralString(fnName), operands));
155
156 translateBody(func.body(), function, entryBlock);
157 function.add(new SPIRVOpFunctionEnd());
158 }
159
160 private void translateBody(Body body, SPIRVFunction function, SPIRVBlock entryBlock) {
161 int labelNumber = 0;
162 SPIRVBlock spirvBlock = entryBlock;
163 for (int bi = 1; bi < body.blocks().size(); bi++) {
164 Block block = body.blocks().get(bi);
165 String blockName = String.valueOf(block.hashCode());
166 SPIRVOpLabel blockLabel = new SPIRVOpLabel(nextId());
167 SPIRVBlock newBlock = (SPIRVBlock)function.add(blockLabel);
168 symbols.putBlock(block, newBlock);
169 symbols.putLabel(block, blockLabel);
170 }
171 for (Value param : body.entryBlock().parameters()) {
172 SPIRVId paramId = nextId();
173 addResult(param, new SpirvResult(spirvType(param.type().toString()), null, paramId));
174 }
175 for (int bi = 0; bi < body.blocks().size(); bi++) {
176 Block block = body.blocks().get(bi);
177 if (bi > 0) {
178 spirvBlock = symbols.getBlock(block);
179 }
180 List<Op> ops = block.ops();
181 for (Op op : block.ops()) {
182 // debug("---------- spirv op = %s", op.toText());
183 switch (op) {
184 case SpirvOps.VariableOp vop -> {
185 String typeName = vop.varType().toString();
186 SPIRVId type = spirvType(typeName);
187 SPIRVId varType = spirvVariableType(type);
188 SPIRVId var = nextId(vop.varName());
189 spirvBlock.add(new SPIRVOpVariable(varType, var, SPIRVStorageClass.Function(), new SPIRVOptionalOperand<>()));
190 addResult(vop.result(), new SpirvResult(varType, var, null));
191 }
192 case SpirvOps.FunctionParameterOp fpo -> {
193 SPIRVId result = nextId();
194 SPIRVId type = spirvType(fpo.resultType().toString());
195 function.add(new SPIRVOpFunctionParameter(type, result));
196 addResult(fpo.result(), new SpirvResult(type, null, result));
197 }
198 case SpirvOps.LoadOp lo -> {
199 if (((JavaType)lo.resultType()).equals(JavaType.type(VectorSpecies.class))) {
200 addResult(lo.result(), new SpirvResult(getType("int"), null, getConst("int_EIGHT")));
201 }
202 else {
203 SPIRVId type = spirvType(lo.resultType().toString());
204 SpirvResult toLoad = getResult(lo.operands().get(0));
205 SPIRVId varAddr = toLoad.address();
206 SPIRVId result = nextId();
207 spirvBlock.add(new SPIRVOpLoad(type, result, varAddr, align(type.getName())));
208 addResult(lo.result(), new SpirvResult(type, varAddr, result));
209 }
210 }
211 case SpirvOps.StoreOp so -> {
212 SpirvResult var = getResult(so.operands().get(0));
213 SPIRVId varAddr = var.address();
214 SPIRVId value = getResult(so.operands().get(1)).value();
215 spirvBlock.add(new SPIRVOpStore(varAddr, value, align(var.type().getName())));
216 }
217 case SpirvOps.IAddOp _, SpirvOps.FAddOp _ -> {
218 SPIRVId intType = getType("int");
219 SPIRVId longType = getType("long");
220 SPIRVId floatType = getType("float");
221 SPIRVId lhs = getResult(op.operands().get(0)).value();
222 SPIRVId rhs = getResult(op.operands().get(1)).value();
223 SPIRVId lhsType = spirvType(op.resultType().toString());
224 SPIRVId ans = nextId();
225 if (lhsType == intType) spirvBlock.add(new SPIRVOpIAdd(intType, ans, lhs, rhs));
226 else if (lhsType == longType) spirvBlock.add(new SPIRVOpIAdd(longType, ans, lhs, rhs));
227 else if (lhsType == floatType) spirvBlock.add(new SPIRVOpFAdd(floatType, ans, lhs, rhs));
228 else unsupported("type", lhsType.getName());
229 addResult(op.result(), new SpirvResult(lhsType, null, ans));
230 }
231 case SpirvOps.IMulOp _, SpirvOps.FMulOp _, SpirvOps.IDivOp _, SpirvOps.FDivOp _ -> {
232 SPIRVId intType = getType("int");
233 SPIRVId longType = getType("long");
234 SPIRVId floatType = getType("float");
235 SPIRVId lhs = getResult(op.operands().get(0)).value();
236 SPIRVId rhs = getResult(op.operands().get(1)).value();
237 SPIRVId lhsType = spirvType(op.resultType().toString());
238 SPIRVId rhsType = getResult(op.operands().get(1)).type();
239 SPIRVId ans = nextId();
240 if (lhsType == intType) {
241 if (op instanceof SpirvOps.IMulOp) spirvBlock.add(new SPIRVOpIMul(intType, ans, lhs, rhs));
242 else if (op instanceof SpirvOps.IDivOp) spirvBlock.add(new SPIRVOpSDiv(intType, ans, lhs, rhs));
243 }
244 else if (lhsType == longType) {
245 SPIRVId rhsId = rhsType == intType ? nextId() : rhs;
246 if (rhsType == intType) spirvBlock.add(new SPIRVOpSConvert(longType, rhsId, rhs));
247 if (op instanceof SpirvOps.IMulOp) spirvBlock.add(new SPIRVOpIMul(longType, ans, lhs, rhsId));
248 else if (op instanceof SpirvOps.IDivOp) spirvBlock.add(new SPIRVOpSDiv(longType, ans, lhs, rhs));
249 }
250 else if (lhsType == floatType) {
251 if (op instanceof SpirvOps.FMulOp) spirvBlock.add(new SPIRVOpFMul(floatType, ans, lhs, rhs));
252 else if (op instanceof SpirvOps.FDivOp) spirvBlock.add(new SPIRVOpFDiv(floatType, ans, lhs, rhs));
253 }
254 else unsupported("type", lhsType);
255 addResult(op.result(), new SpirvResult(lhsType, null, ans));
256 }
257 case SpirvOps.ModOp mop -> {
258 SPIRVId type = getType(mop.operands().get(0).type().toString());
259 SPIRVId lhs = getResult(mop.operands().get(0)).value();
260 SPIRVId rhs = getResult(mop.operands().get(1)).value();
261 SPIRVId result = nextId();
262 spirvBlock.add(new SPIRVOpUMod(type, result, lhs, rhs));
263 addResult(mop.result(), new SpirvResult(type, null, result));
264 }
265 case SpirvOps.IEqualOp eqop -> {
266 SPIRVId boolType = getType("bool");
267 SPIRVId intType = getType("int");
268 SPIRVId longType = getType("long");
269 SPIRVId floatType = getType("float");
270 SPIRVId lhs = getResult(op.operands().get(0)).value();
271 SPIRVId rhs = getResult(op.operands().get(1)).value();
272 SPIRVId lhsType = spirvType(op.resultType().toString());
273 SPIRVId ans = nextId();
274 if (lhsType == intType) spirvBlock.add(new SPIRVOpIEqual(boolType, ans, lhs, rhs));
275 else if (lhsType == longType) spirvBlock.add(new SPIRVOpIEqual(boolType, ans, lhs, rhs));
276 else unsupported("type", lhsType.getName());
277 addResult(op.result(), new SpirvResult(lhsType, null, ans));
278 }
279 case SpirvOps.CallOp call -> {
280 if (call.callDescriptor().equals(MethodRef.method(JavaType.type(ClassDesc.of("spirvdemo.IntArray")), "get", JavaType.INT, JavaType.LONG)) ||
281 call.callDescriptor().equals(MethodRef.method(JavaType.type(ClassDesc.of("spirvdemo.FloatArray")), "get", JavaType.FLOAT, JavaType.LONG))) {
282 SPIRVId longType = getType("long");
283 String arrayTypeName = call.operands().get(0).type().toString();
284 SpirvResult arrayResult = getResult(call.operands().get(0));
285 SPIRVId arrayAddr = arrayResult.address();
286 SPIRVId arrayType = spirvType(arrayTypeName);
287 SPIRVId elementType = spirvElementType(arrayTypeName);
288 int nIndexes = call.operands().size() - 1;
289 SPIRVId index = getResult(call.operands().get(1)).value();
290 SPIRVId array = nextId();
291 spirvBlock.add(new SPIRVOpLoad(arrayType, array, arrayAddr, align(arrayType.getName())));
292 SPIRVId resultAddr = nextId();
293 spirvBlock.add(new SPIRVOpInBoundsPtrAccessChain(arrayType, resultAddr, array, index, new SPIRVMultipleOperands<>()));
294 SPIRVId result = nextId();
295 spirvBlock.add(new SPIRVOpLoad(elementType, result, resultAddr, align(elementType.getName())));
296 addResult(call.result(), new SpirvResult(elementType, resultAddr, result));
297 }
298 else if (call.callDescriptor().equals(MethodRef.method(JavaType.type(ClassDesc.of("spirvdemo.IntArray")), "set", JavaType.VOID, JavaType.LONG, JavaType.INT)) ||
299 call.callDescriptor().equals(MethodRef.method(JavaType.type(ClassDesc.of("spirvdemo.FloatArray")), "set", JavaType.VOID, JavaType.LONG, JavaType.FLOAT))) {
300 SPIRVId longType = getType("long");
301 String arrayTypeName = call.operands().get(0).type().toString();
302 SpirvResult arrayResult = getResult(call.operands().get(0));
303 SPIRVId arrayAddr = arrayResult.address();
304 SPIRVId arrayType = spirvType(arrayTypeName);
305 SPIRVId elementType = spirvElementType(arrayTypeName);
306 int nIndexes = call.operands().size() - 2;
307 int valueIndex = nIndexes + 1;
308 SPIRVId index = getResult(call.operands().get(1)).value();
309 SPIRVId array = nextId();
310 spirvBlock.add(new SPIRVOpLoad(arrayType, array, arrayAddr, align(arrayType.getName())));
311 SPIRVId dest = nextId();
312 spirvBlock.add(new SPIRVOpInBoundsPtrAccessChain(arrayType, dest, array, index, new SPIRVMultipleOperands<>()));
313 SPIRVId value = getResult(call.operands().get(valueIndex)).value();
314 spirvBlock.add(new SPIRVOpStore(dest, value, align(elementType.getName())));
315 }
316 else if (call.callDescriptor().equals(MethodRef.method(IntVector.class, "fromArray", IntVector.class, VectorSpecies.class, int[].class, int.class))
317 || call.callDescriptor().equals(MethodRef.method(FloatVector.class, "fromArray", FloatVector.class, VectorSpecies.class, float[].class, int.class))) {
318 SPIRVId oclExtension = getId("oclExtension");
319 SpirvResult speciesResult = getResult(call.operands().get(0));
320 SpirvResult arrayResult = getResult(call.operands().get(1));
321 String arrayType = arrayResult.type().getName();
322 int laneCount = 8; //TODO: remove hard code, instruction below needs a literal
323 String vTypeName = ((ClassType)call.callDescriptor().refType()).toClassName();
324 SPIRVId vType = spirvVectorType(vTypeName, laneCount);
325 SPIRVId array = arrayResult.value();
326 SPIRVId index = getResult(call.operands().get(2)).value();
327 SPIRVId vectorIndex = nextId();
328 spirvBlock.add(new SPIRVOpSDiv(getType("int"), vectorIndex, index, speciesResult.value()));
329 SPIRVId longIndex = nextId();
330 spirvBlock.add(new SPIRVOpSConvert(getType("long"), longIndex, vectorIndex));
331 SPIRVId vector = nextId();
332 SPIRVMultipleOperands<SPIRVId> operands = new SPIRVMultipleOperands<>(longIndex, array, new SPIRVId(laneCount)); // TODO: lanes must be a literal
333 spirvBlock.add(new SPIRVOpExtInst(vType, vector, oclExtension, new SPIRVLiteralExtInstInteger(171, "vloadn"), operands));
334 addResult(call.result(), new SpirvResult(vType, null, vector));
335 }
336 else if (call.callDescriptor().equals(MethodRef.method(IntVector.class, "fromMemorySegment", IntVector.class, VectorSpecies.class, MemorySegment.class, long.class, ByteOrder.class))
337 || call.callDescriptor().equals(MethodRef.method(FloatVector.class, "fromMemorySegment", FloatVector.class, VectorSpecies.class, MemorySegment.class, long.class, ByteOrder.class))) {
338 SPIRVId oclExtension = getId("oclExtension");
339 SPIRVId species = getResult(call.operands().get(0)).value();
340 SPIRVId lanesLong = nextId();
341 spirvBlock.add(new SPIRVOpSConvert(getType("long"), lanesLong, species));
342 int laneCount = 8; //TODO: remove hard code, vloadn instruction below needs a literal lane count, get value from env
343 SPIRVId segment = getResult(call.operands().get(1)).value();
344 String vTypeName = ((ClassType)call.callDescriptor().refType()).toClassName();
345 SPIRVId vType = spirvVectorType(vTypeName, laneCount);
346 SPIRVId temp = nextId();
347 spirvBlock.add(new SPIRVOpConvertPtrToU(getType("long"), temp, segment));
348 SPIRVId typedSegment = nextId();
349 SPIRVId pointerType = (SPIRVId)map(x -> x.equals(vTypeName), "jdk.incubator.vector.IntVector", "jdk.incubator.vector.FloatVector", getType("ptrInt"), getType("ptrFloat"));
350 spirvBlock.add(new SPIRVOpConvertUToPtr(pointerType, typedSegment, temp));
351 SPIRVId offset = getResult(call.operands().get(2)).value();
352 SPIRVId vectorIndex = nextId();
353 spirvBlock.add(new SPIRVOpSDiv(getType("long"), vectorIndex, offset, lanesLong)); // divide by lane count
354 SPIRVId finalIndex = nextId();
355 SPIRVId vector = nextId();
356 SPIRVMultipleOperands<SPIRVId> operands = new SPIRVMultipleOperands<>(vectorIndex, typedSegment, new SPIRVId(laneCount)); // TODO: lanes must be a literal
357 spirvBlock.add(new SPIRVOpExtInst(vType, vector, oclExtension, new SPIRVLiteralExtInstInteger(171, "vloadn"), operands));
358 addResult(call.result(), new SpirvResult(vType, null, vector));
359 }
360 else if (call.callDescriptor().equals(MethodRef.method(IntVector.class, "intoArray", void.class, int[].class, int.class))
361 || call.callDescriptor().equals(MethodRef.method(FloatVector.class, "intoArray", void.class, float[].class, int.class))) {
362 SPIRVId oclExtension = getId("oclExtension");
363 SpirvResult vectorResult = getResult(call.operands().get(0));
364 SPIRVId vector = vectorResult.value();
365 SPIRVId vectorType = vectorResult.type();
366 SpirvResult arrayResult = getResult(call.operands().get(1));
367 SPIRVId array = arrayResult.value();
368 SPIRVId index = getResult(call.operands().get(2)).value();
369 SPIRVId vectorIndex = nextId();
370 spirvBlock.add(new SPIRVOpShiftRightArithmetic(getType("int"), vectorIndex, index, vectorExponent(vectorType.getName())));
371 SPIRVId longIndex = nextId();
372 spirvBlock.add(new SPIRVOpSConvert(getType("long"), longIndex, vectorIndex));
373 SPIRVMultipleOperands<SPIRVId> operandsR = new SPIRVMultipleOperands<>(vector, longIndex, array);
374 spirvBlock.add(new SPIRVOpExtInst(getType("void"), nextId(), oclExtension, new SPIRVLiteralExtInstInteger(172, "vstoren"), operandsR));
375 }
376 else if (call.callDescriptor().equals(MethodRef.method(IntVector.class, "intoMemorySegment", void.class, MemorySegment.class, long.class, ByteOrder.class))
377 || call.callDescriptor().equals(MethodRef.method(FloatVector.class, "intoMemorySegment", void.class, MemorySegment.class, long.class, ByteOrder.class))) {
378 SPIRVId oclExtension = getId("oclExtension");
379 SpirvResult vectorResult = getResult(call.operands().get(0));
380 SPIRVId vector = vectorResult.value();
381 SPIRVId vectorType = vectorResult.type();
382 SpirvResult segmentResult = getResult(call.operands().get(1));;
383 SPIRVId segment = segmentResult.value();
384 SPIRVId temp = nextId();
385 spirvBlock.add(new SPIRVOpConvertPtrToU(getType("long"), temp, segment));
386 SPIRVId typedSegment = nextId();
387 String vectorElementType = vectorElementType(vectorType).getName();
388 SPIRVId pointerType = (SPIRVId)map(x -> x.equals(vectorElementType), "int", "float", getType("ptrInt"), getType("ptrFloat"));
389 spirvBlock.add(new SPIRVOpConvertUToPtr(pointerType, typedSegment, temp));
390 SPIRVId offset = getResult(call.operands().get(2)).value();
391 SPIRVId vectorIndex = nextId();
392 int laneCount = laneCount(vectorType.getName());
393 spirvBlock.add(new SPIRVOpShiftRightArithmetic(getType("long"), vectorIndex, offset, vectorExponent(vectorType.getName())));
394 SPIRVMultipleOperands<SPIRVId> operandsR = new SPIRVMultipleOperands<>(vector, vectorIndex, typedSegment);
395 spirvBlock.add(new SPIRVOpExtInst(getId("void"), nextId(), oclExtension, new SPIRVLiteralExtInstInteger(172, "vstoren"), operandsR));
396 }
397 else if (call.callDescriptor().equals(MethodRef.method(IntVector.class, "reduceLanes", int.class, VectorOperators.Associative.class))
398 || call.callDescriptor().equals(MethodRef.method(FloatVector.class, "reduceLanes", float.class, VectorOperators.Associative.class))) {
399 SpirvResult vectorResult = getResult(call.operands().get(0));
400 SPIRVId vectorType = vectorResult.type();
401 SPIRVId vector = vectorResult.value();
402 String vTypeName = vectorType.getName();
403 SPIRVId elementType = vectorElementType(vectorType);
404 Op reduceOp = ((Op.Result)call.operands().get(1)).op();
405 if (reduceOp instanceof SpirvOps.FieldLoadOp flo) {
406 assert flo.fieldDescriptor().refType().equals(JavaType.type(VectorOperators.class));
407 assert flo.fieldDescriptor().name().equals("ADD");
408 String operation = flo.fieldDescriptor().name();
409 }
410 else unsupported("operation expression", reduceOp.toText());
411 String tempTag = nextTempTag();
412 SPIRVId temp_0 = nextId(tempTag + 0);
413 spirvBlock.add(new SPIRVOpCompositeExtract(elementType, temp_0, vector, new SPIRVMultipleOperands<>(new SPIRVLiteralInteger(0))));
414 for (int lane = 1; lane < laneCount(vectorType.getName()); lane++) {
415 SPIRVId temp = nextId(tempTag + lane);
416 SPIRVId element = nextId();
417 spirvBlock.add(new SPIRVOpCompositeExtract(elementType, element, vector, new SPIRVMultipleOperands<>(new SPIRVLiteralInteger(lane))));
418 if (elementType == getType("int")) {
419 spirvBlock.add(new SPIRVOpIAdd(elementType, temp, getId(tempTag + (lane - 1)), element));
420 }
421 else if (elementType == getType("float")) {
422 spirvBlock.add(new SPIRVOpFAdd(elementType, temp, getId(tempTag + (lane - 1)), element));
423 }
424 else unsupported("type", elementType.getName());
425 }
426 addResult(call.result(), new SpirvResult(elementType, null, getId(tempTag + (laneCount(vectorType.getName()) - 1))));
427 }
428 else if (call.callDescriptor().equals(MethodRef.method(IntVector.class, "add", IntVector.class, Vector.class))
429 || call.callDescriptor().equals(MethodRef.method(IntVector.class, "mul", IntVector.class, Vector.class))
430 || call.callDescriptor().equals(MethodRef.method(FloatVector.class, "add", FloatVector.class, Vector.class))
431 || call.callDescriptor().equals(MethodRef.method(FloatVector.class, "mul", FloatVector.class, Vector.class))) {
432 SPIRVId oclExtension = getId("oclExtension");
433 SpirvResult lhsResult = getResult(call.operands().get(0));
434 SPIRVId lhsType = lhsResult.type();
435 SPIRVId lhs = lhsResult.value();
436 SPIRVId rhs = getResult(call.operands().get(1)).value();
437 SPIRVId add = nextId();
438 if (call.callDescriptor().name().equals("add")) {
439 spirvBlock.add(lhsType.getName().endsWith("int") ? new SPIRVOpIAdd(lhsType, add, lhs, rhs) : new SPIRVOpFAdd(lhsType, add, lhs, rhs));
440 }
441 else if (call.callDescriptor().name().equals("mul")) {
442 spirvBlock.add(lhsType.getName().endsWith("int") ? new SPIRVOpIMul(lhsType, add, lhs, rhs) : new SPIRVOpFMul(lhsType, add, lhs, rhs));
443 }
444 addResult(call.result(), new SpirvResult(lhsType, null, add));
445 }
446 else if (call.callDescriptor().equals(MethodRef.method(FloatVector.class, "fma", FloatVector.class, Vector.class, Vector.class))) {
447 SPIRVId oclExtension = getId("oclExtension");
448 SpirvResult aResult = getResult(call.operands().get(0));
449 SPIRVId vType = aResult.type();
450 SPIRVId a = aResult.value();
451 SPIRVId b = getResult(call.operands().get(1)).value();
452 SPIRVId c = getResult(call.operands().get(2)).value();
453 String vTypeStr = vType.getName();
454 assert vTypeStr.endsWith("float");
455 SPIRVId result = nextId();
456 SPIRVMultipleOperands<SPIRVId> operands = new SPIRVMultipleOperands<>(a, b, c);
457 spirvBlock.add(new SPIRVOpExtInst(vType, result, oclExtension, new SPIRVLiteralExtInstInteger(26, "fma"), operands));
458 addResult(call.result(), new SpirvResult(vType, null, result));
459 }
460 else if (call.callDescriptor().equals(MethodRef.method(IntVector.class, "zero", IntVector.class, VectorSpecies.class))
461 || call.callDescriptor().equals(MethodRef.method(FloatVector.class, "zero", FloatVector.class, VectorSpecies.class))) {
462 SpirvResult speciesResult = getResult(call.operands().get(0));
463 SPIRVId vType = spirvType(((ClassType)call.callDescriptor().refType()).toClassName());
464 String elementType = vectorElementType(vType).getName();
465 SPIRVId value = getId(elementType + "_ZERO");
466 int laneCount = laneCount(vType.getName());
467 assert laneCount == 8 || laneCount == 16;
468 SPIRVId vector = nextId();
469 SPIRVMultipleOperands<SPIRVId> operands = spirvOperands(value, laneCount);
470 spirvBlock.add(new SPIRVOpCompositeConstruct(vType, vector, operands));
471 addResult(call.result(), new SpirvResult(vType, null, vector));
472 }
473 else if (call.callDescriptor().equals(MethodRef.method(IntVector.class, "lane", int.class, int.class))
474 || call.callDescriptor().equals(MethodRef.method(FloatVector.class, "lane", float.class, int.class))) {
475 SpirvResult lhsResult = getResult(call.operands().get(0));
476 SPIRVId lhsType = lhsResult.type();
477 SPIRVId lhs = lhsResult.value();
478 String vTypeStr = lhsType.getName();
479 SPIRVId vType = lhsResult.type();
480 SPIRVId elementType = vectorElementType(vType);
481 SPIRVId result = nextId();
482 Op laneOp = ((Op.Result)call.operands().get(1)).op();
483 assert laneOp instanceof SpirvOps.ConstantOp;
484 int lane = (int)((SpirvOps.ConstantOp)laneOp).value();
485 spirvBlock.add(new SPIRVOpCompositeExtract(elementType, result, lhsResult.value(), new SPIRVMultipleOperands<>(new SPIRVLiteralInteger(lane))));
486 addResult(call.result(), new SpirvResult(elementType, null, result));
487 }
488 else if (call.callDescriptor().equals(MethodRef.method(VectorSpecies.class, "length", int.class))) {
489 addResult(call.result(), new SpirvResult(getType("int"), null, getConst("int_EIGHT"))); // TODO: remove hardcode
490 }
491 else unsupported("method", call.callDescriptor());
492 }
493 case SpirvOps.ConstantOp cop -> {
494 SPIRVId type = spirvType(cop.resultType().toString());
495 SPIRVId result = nextId();
496 Object value = cop.value();
497 if (type == getType("int")) {
498 module.add(new SPIRVOpConstant(type, result, new SPIRVContextDependentInt(new BigInteger(String.valueOf(value)))));
499 }
500 else if (type == getType("long")) {
501 module.add(new SPIRVOpConstant(type, result, new SPIRVContextDependentLong(new BigInteger(String.valueOf(value)))));
502 }
503 else if (type == getType("float")) {
504 module.add(new SPIRVOpConstant(type, result, new SPIRVContextDependentFloat((float)value)));
505 }
506 else unsupported("type", cop.resultType());
507 addResult(cop.result(), new SpirvResult(type, null, result));
508 }
509 case SpirvOps.ConvertOp scop -> {
510 SPIRVId toType = spirvType(scop.resultType().toString());
511 SPIRVId to = nextId();
512 SpirvResult valueResult = getResult(scop.operands().get(0));
513 SPIRVId from = valueResult.value();
514 SPIRVId fromType = valueResult.type();
515 if (isIntegerType(fromType)) {
516 if (isIntegerType(toType)) {
517 spirvBlock.add(new SPIRVOpSConvert(toType, to, from));
518 }
519 else if (isFloatType(toType)) {
520 spirvBlock.add(new SPIRVOpConvertSToF(toType, to, from));
521 }
522 else unsupported("conversion type", scop.resultType());
523 }
524 else unsupported("conversion type", scop.operands().get(0));
525 addResult(scop.result(), new SpirvResult(toType, null, to));
526 }
527 case SpirvOps.InBoundAccessChainOp iacop -> {
528 SPIRVId type = spirvType(iacop.resultType().toString());
529 SPIRVId result = nextId();
530 SPIRVId object = getResult(iacop.operands().get(0)).value();
531 SPIRVId index = getResult(iacop.operands().get(1)).value();
532 spirvBlock.add(new SPIRVOpInBoundsPtrAccessChain(type, result, object, index, new SPIRVMultipleOperands<>()));
533 addResult(iacop.result(), new SpirvResult(type, result, null));
534 }
535 case SpirvOps.FieldLoadOp flo -> {
536 if (flo.operands().size() > 0 && flo.operands().get(0).type().equals(JavaType.type(ClassDesc.of("spirvdemo.GPU$Index")))) {
537 SpirvResult result;
538 int group = -1;
539 int index = -1;
540 String fieldName = flo.fieldDescriptor().name();
541 switch(fieldName) {
542 case "x": group = 0; index = 0; break;
543 case "y": group = 0; index = 1; break;
544 case "z": group = 0; index = 2; break;
545 case "w": group = 1; index = 0; break;
546 case "h": group = 1; index = 1; break;
547 case "d": group = 1; index = 2; break;
548 }
549 switch (group) {
550 case 0: result = globalId(index, spirvBlock); break;
551 case 1: result = globalSize(index, spirvBlock); break;
552 default: throw new RuntimeException("Unknown Index field: " + fieldName);
553 }
554 addResult(flo.result(), result);
555 }
556 else if (((JavaType)flo.resultType()).equals(JavaType.type(VectorSpecies.class))) {
557 addResult(flo.result(), new SpirvResult(getType("int"), null, getConst("int_EIGHT")));
558 }
559 else if (flo.fieldDescriptor().refType().equals(JavaType.type(VectorOperators.class))) {
560 // currently ignored
561 }
562 else if (flo.fieldDescriptor().refType().equals(JavaType.type(ByteOrder.class))) {
563 // currently ignored
564 }
565 else unsupported("field load", ((ClassType)flo.fieldDescriptor().refType()).toClassName() + "." + flo.fieldDescriptor().name());
566 }
567 case SpirvOps.BranchOp bop -> {
568 SPIRVId trueLabel = symbols.getLabel(bop.branch()).getResultId();
569 spirvBlock.add(new SPIRVOpBranch(trueLabel));
570 }
571 case SpirvOps.ConditionalBranchOp cbop -> {
572 SPIRVId test = getResult(cbop.operands().get(0)).value();
573 SPIRVId trueLabel = symbols.getLabel(cbop.trueBranch()).getResultId();
574 SPIRVId falseLabel = symbols.getLabel(cbop.falseBranch()).getResultId();
575 spirvBlock.add(new SPIRVOpBranchConditional(test, trueLabel, falseLabel, new SPIRVMultipleOperands<SPIRVLiteralInteger>()));
576 }
577 case SpirvOps.LtOp ltop -> {
578 SPIRVId lhs = getResult(ltop.operands().get(0)).value();
579 SPIRVId rhs = getResult(ltop.operands().get(1)).value();
580 SPIRVId boolType = getType("bool");
581 SPIRVId result = nextId();
582 spirvBlock.add(new SPIRVOpSLessThan(boolType, result, lhs, rhs));
583 addResult(ltop.result(), new SpirvResult(boolType, null, result));
584 }
585 case SpirvOps.ReturnOp rop -> {
586 if (rop.operands().size() == 0) {
587 spirvBlock.add(new SPIRVOpReturn());
588 }
589 else {
590 SPIRVId returnValue = getResult(rop.operands().get(0)).value();
591 spirvBlock.add(new SPIRVOpReturnValue(returnValue));
592 }
593 }
594 default -> unsupported("op", op.getClass());
595 }
596 }
597 }
598 }
599
600 private void initModule() {
601 module.add(new SPIRVOpCapability(SPIRVCapability.Addresses()));
602 module.add(new SPIRVOpCapability(SPIRVCapability.Linkage()));
603 module.add(new SPIRVOpCapability(SPIRVCapability.Kernel()));
604 module.add(new SPIRVOpCapability(SPIRVCapability.Int8()));
605 module.add(new SPIRVOpCapability(SPIRVCapability.Int16()));
606 module.add(new SPIRVOpCapability(SPIRVCapability.Int64()));
607 module.add(new SPIRVOpCapability(SPIRVCapability.Vector16()));
608 module.add(new SPIRVOpCapability(SPIRVCapability.Float16()));
609 module.add(new SPIRVOpMemoryModel(SPIRVAddressingModel.Physical64(), SPIRVMemoryModel.OpenCL()));
610
611 // OpenCL extension provides built-in variables suitable for kernel programming
612 // Import extention and declare fourn variables
613 SPIRVId oclExtension = nextId("oclExtension");
614 module.add(new SPIRVOpExtInstImport(oclExtension, new SPIRVLiteralString("OpenCL.std")));
615
616 SPIRVId globalInvocationId = nextId("globalInvocationId");
617 SPIRVId globalSize = nextId("globalSize");
618 SPIRVId subgroupSize = nextId("subgroupSize");
619 SPIRVId subgroupId = nextId("subgroupId");
620
621 module.add(new SPIRVOpDecorate(globalInvocationId, SPIRVDecoration.BuiltIn(SPIRVBuiltIn.GlobalInvocationId())));
622 module.add(new SPIRVOpDecorate(globalInvocationId, SPIRVDecoration.Constant()));
623 module.add(new SPIRVOpDecorate(globalInvocationId, SPIRVDecoration.LinkageAttributes(new SPIRVLiteralString("spirv_BuiltInGlobalInvocationId"), SPIRVLinkageType.Import())));
624 module.add(new SPIRVOpDecorate(globalSize, SPIRVDecoration.BuiltIn(SPIRVBuiltIn.GlobalSize())));
625 module.add(new SPIRVOpDecorate(globalSize, SPIRVDecoration.Constant()));
626 module.add(new SPIRVOpDecorate(globalSize, SPIRVDecoration.LinkageAttributes(new SPIRVLiteralString("spirv_BuiltInGlobalSize"), SPIRVLinkageType.Import())));
627 module.add(new SPIRVOpDecorate(subgroupSize, SPIRVDecoration.BuiltIn(SPIRVBuiltIn.SubgroupSize())));
628 module.add(new SPIRVOpDecorate(subgroupSize, SPIRVDecoration.Constant()));
629 module.add(new SPIRVOpDecorate(subgroupSize, SPIRVDecoration.LinkageAttributes(new SPIRVLiteralString("spirv_BuiltInSubgroupSize"), SPIRVLinkageType.Import())));
630 module.add(new SPIRVOpDecorate(subgroupId, SPIRVDecoration.BuiltIn(SPIRVBuiltIn.SubgroupId())));
631 module.add(new SPIRVOpDecorate(subgroupId, SPIRVDecoration.Constant()));
632 module.add(new SPIRVOpDecorate(subgroupId, SPIRVDecoration.LinkageAttributes(new SPIRVLiteralString("spirv_BuiltInSubgroupId"), SPIRVLinkageType.Import())));
633
634 module.add(new SPIRVOpVariable(getType("ptrV3long"), globalInvocationId, SPIRVStorageClass.Input(), new SPIRVOptionalOperand<>()));
635 module.add(new SPIRVOpVariable(getType("ptrV3long"), globalSize, SPIRVStorageClass.Input(), new SPIRVOptionalOperand<>()));
636 module.add(new SPIRVOpVariable(getType("ptrV3long"), subgroupSize, SPIRVStorageClass.Input(), new SPIRVOptionalOperand<>()));
637 module.add(new SPIRVOpVariable(getType("ptrV3long"), subgroupId, SPIRVStorageClass.Input(), new SPIRVOptionalOperand<>()));
638 }
639
640 private SPIRVId spirvType(String javaType) {
641 SPIRVId ans = switch(javaType) {
642 case "byte" -> getType("byte");
643 case "short" -> getType("short");
644 case "int" -> getType("int");
645 case "long" -> getType("long");
646 case "float" -> getType("float");
647 case "double" -> getType("double");
648 case "int[]" -> getType("int[]");
649 case "float[]" -> getType("float[]");
650 case "double[]" -> getType("double[]");
651 case "long[]" -> getType("long[]");
652 case "bool" -> getType("bool");
653 case "spirvdemo.IntArray" -> getType("int[]");
654 case "spirvdemo.FloatArray" -> getType("float[]");
655 case "jdk.incubator.vector.IntVector" -> spirvVectorType("IntVector", 8);
656 case "jdk.incubator.vector.FloatVector" -> spirvVectorType("FloatVector", 8);
657 case "jdk.incubator.vector.VectorSpecies<java.lang.Integer>" -> getType("int");
658 case "jdk.incubator.vector.VectorSpecies<java.lang.Long>" -> getType("long");
659 case "jdk.incubator.vector.VectorSpecies<java.lang.Float>" -> getType("int");
660 case "VectorSpecies" -> getType("int");
661 case "void" -> getType("void");
662 case "spirvdemo.GPU$Index" -> getType("ptrGPUIndex");
663 case "java.lang.foreign.MemorySegment" -> getType("ptrByte");
664 default -> null;
665 };
666 if (ans == null) unsupported("type", javaType);
667 return ans;
668 }
669
670 private SPIRVId spirvElementType(String javaType) {
671 SPIRVId ans = switch(javaType) {
672 case "byte[]" -> getType("byte");
673 case "short[]" -> getType("short");
674 case "int[]" -> getType("int");
675 case "long[]" -> getType("long");
676 case "float[]" -> getType("float");
677 case "double[]" -> getType("double");
678 case "boolean[]" -> getType("bool");
679 case "spirvdemo.IntArray" -> getType("int");
680 case "spirvdemo.FloatArray" -> getType("float");
681 case "jdk.incubator.vector.LongVector" -> getType("long");
682 case "jdk.incubator.vector.FloatVector" -> getType("float");
683 case "IntVector" -> getType("int");
684 case "LongVector" -> getType("long");
685 case "FloatVector" -> getType("float");
686 case "java.lang.foreign.MemorySegment" -> getType("byte");
687 default -> null;
688 };
689 if (ans == null) unsupported("type", javaType);
690 return ans;
691 }
692
693 private SPIRVId vectorElementType(SPIRVId type) {
694 SPIRVId ans = switch(type.getName()) {
695 case "v8int" -> getType("int");
696 case "v16int" -> getType("int");
697 case "v8long" -> getType("long");
698 case "v8float" -> getType("float");
699 case "v16float" -> getType("float");
700 default -> null;
701 };
702 if (ans == null) unsupported("type", type.getName());
703 return ans;
704 }
705
706 private SPIRVId spirvVariableType(SPIRVId spirvType) {
707 SPIRVId ans = switch(spirvType.getName()) {
708 case "byte" -> getType("ptrByte");
709 case "short" -> getType("ptrShort");
710 case "int" -> getType("ptrInt");
711 case "long" -> getType("ptrLong");
712 case "float" -> getType("ptrFloat");
713 case "double" -> getType("ptrDouble");
714 case "boolean" -> getType("ptrBool");
715 case "int[]" -> getType("ptrInt[]");
716 case "long[]" -> getType("ptrLong[]");
717 case "float[]" -> getType("ptrFloat[]");
718 case "double[]" -> getType("ptrDouble[]");
719 case "v8int" -> getType("ptrV8int");
720 case "v16int" -> getType("ptrV16int");
721 case "v8long" -> getType("ptrV8long");
722 case "v8float" -> getType("ptrV8float");
723 case "v16float" -> getType("ptrV16float");
724 case "ptrGPUIndex" -> getType("ptrPtrGPUIndex");
725 case "ptrByte" -> getType("ptrPtrByte");
726 default -> null;
727 };
728 if (ans == null) unsupported("type", spirvType.getName());
729 return ans;
730 }
731
732 private SPIRVId spirvVectorType(String javaVectorType, int vectorLength) {
733 String prefix = "v" + vectorLength;
734 String elementType = spirvElementType(javaVectorType).getName();
735 return getType(prefix + elementType);
736 }
737
738 private int alignment(String spirvType) {
739 int ans = switch(spirvType) {
740 case "byte" -> 1;
741 case "short" -> 2;
742 case "int" -> 4;
743 case "long" -> 8;
744 case "float" -> 4;
745 case "double" -> 8;
746 case "boolean" -> 1;
747 case "v8int" -> 32;
748 case "v16int" -> 64;
749 case "v8long" -> 64;
750 case "v8float" -> 32;
751 case "v16float" -> 64;
752 case "ptrGPUIndex" -> 32;
753 case "int[]" -> 8;
754 case "long[]" -> 8;
755 case "float[]" -> 8;
756 case "double[]" -> 8;
757 case "ptrByte" -> 8;
758 case "ptrInt" -> 8;
759 case "ptrInt[]" -> 8;
760 case "ptrLong" -> 8;
761 case "ptrLong[]" -> 8;
762 case "ptrFloat" -> 8;
763 case "ptrFloat[]" -> 8;
764 case "ptrV8int" -> 8;
765 case "ptrV8float" -> 8;
766 case "ptrPtrGPUIndex" -> 8;
767 default -> 0;
768 };
769 if (ans == 0) unsupported("type", spirvType);
770 return ans;
771 }
772
773 private int laneCount(String vectorType) {
774 int ans = switch(vectorType) {
775 case "v8int" -> 8;
776 case "v8long" -> 8;
777 case "v8float" -> 8;
778 case "v16int" -> 16;
779 case "v16float" -> 16;
780 default -> 0;
781 };
782 if (ans == 0) unsupported("type", vectorType);
783 return ans;
784 }
785
786 private SPIRVId vectorExponent(String vectorType) {
787 SPIRVId ans = null;
788 switch(vectorType) {
789 case "v8int" -> ans = getId("int_THREE");
790 case "v8long" -> ans = getId("int_THREE");
791 case "v8float" -> ans = getId("int_THREE");
792 case "v16int" -> ans = getId("int_FOUR");
793 case "v16float" -> ans = getId("int_FOUR");
794 default -> unsupported("type", vectorType);
795 };
796 return ans;
797 }
798
799 private Set<String> moduleTypes = new HashSet<>();
800
801 private SPIRVId getType(String name) {
802 if (!moduleTypes.contains(name)) {
803 switch (name) {
804 case "void" -> module.add(new SPIRVOpTypeVoid(nextId(name)));
805 case "bool" -> module.add(new SPIRVOpTypeBool(nextId(name)));
806 case "byte" -> module.add(new SPIRVOpTypeInt(nextId(name), new SPIRVLiteralInteger(8), new SPIRVLiteralInteger(0)));
807 case "short" -> module.add(new SPIRVOpTypeInt(nextId(name), new SPIRVLiteralInteger(16), new SPIRVLiteralInteger(0)));
808 case "int" -> module.add(new SPIRVOpTypeInt(nextId(name), new SPIRVLiteralInteger(32), new SPIRVLiteralInteger(0)));
809 case "long" -> module.add(new SPIRVOpTypeInt(nextId(name), new SPIRVLiteralInteger(64), new SPIRVLiteralInteger(0)));
810 case "float" -> module.add(new SPIRVOpTypeFloat(nextId(name), new SPIRVLiteralInteger(32)));
811 case "double" -> module.add(new SPIRVOpTypeFloat(nextId(name), new SPIRVLiteralInteger(64)));
812 case "ptrByte" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.CrossWorkgroup(), getType("byte")));
813 case "ptrInt" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("int")));
814 case "ptrLong" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("long")));
815 case "ptrFloat" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("float")));
816 case "short[]" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.CrossWorkgroup(), getType("short")));
817 case "int[]" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.CrossWorkgroup(), getType("int")));
818 case "long[]" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.CrossWorkgroup(), getType("long")));
819 case "float[]" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.CrossWorkgroup(), getType("float")));
820 case "double[]" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.CrossWorkgroup(), getType("double")));
821 case "boolean[]" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.CrossWorkgroup(), getType("boolean")));
822 case "ptrInt[]" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("int[]")));
823 case "ptrLong[]" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("long[]")));
824 case "ptrFloat[]" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("float[]")));
825 case "spirvdemo.GPUIndex" -> module.add(new SPIRVOpTypeStruct(nextId(name), new SPIRVMultipleOperands<>(getType("long"), getType("long"), getType("long"))));
826 case "ptrGPUIndex" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("spirvdemo.GPUIndex")));
827 case "ptrCrossGroupByte"-> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.CrossWorkgroup(), getType("byte")));
828 case "ptrPtrGPUIndex" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("ptrGPUIndex")));
829 case "ptrPtrByte" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("ptrByte")));
830 case "v3long" -> module.add(new SPIRVOpTypeVector(nextId(name), getType("long"), new SPIRVLiteralInteger(3)));
831 case "v8int" -> module.add(new SPIRVOpTypeVector(nextId(name), getType("int"), new SPIRVLiteralInteger(8)));
832 case "v8long" -> module.add(new SPIRVOpTypeVector(nextId(name), getType("long"), new SPIRVLiteralInteger(8)));
833 case "v16int" -> module.add(new SPIRVOpTypeVector(nextId(name), getType("int"), new SPIRVLiteralInteger(16)));
834 case "v8float" -> module.add(new SPIRVOpTypeVector(nextId(name), getType("float"), new SPIRVLiteralInteger(8)));
835 case "v16float" -> module.add(new SPIRVOpTypeVector(nextId(name), getType("float"), new SPIRVLiteralInteger(16)));
836 case "ptrV3long" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Input(), getType("v3long")));
837 case "ptrV8long" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("v8long")));
838 case "ptrV8int" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("v8int")));
839 case "ptrV16int" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("v16int")));
840 case "ptrV8float" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("v8float")));
841 case "ptrV16float" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("v16float")));
842 case "ptrPtrV8int" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("ptrV8int")));
843 case "ptrPtrV16int" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("ptrV16int")));
844 case "ptrPtrV8float" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("ptrV8float")));
845 case "ptrPtrV16float" -> module.add(new SPIRVOpTypePointer(nextId(name), SPIRVStorageClass.Function(), getType("ptrV16float")));
846 default -> unsupported("type", name);
847 }
848 moduleTypes.add(name);
849 }
850 return getId(name);
851 }
852
853 private Set<String> moduleConstants = new HashSet<>();
854
855 private SPIRVId getConst(String name) {
856 if (!moduleConstants.contains(name)) {
857 switch (name) {
858 case "int_ZERO" -> module.add(new SPIRVOpConstant(getType("int"), nextId("int_ZERO"), new SPIRVContextDependentInt(new BigInteger("0"))));
859 case "int_ONE" -> module.add(new SPIRVOpConstant(getType("int"), nextId("int_ONE"), new SPIRVContextDependentInt(new BigInteger("1"))));
860 case "int_TWO" -> module.add(new SPIRVOpConstant(getType("int"), nextId("int_TWO"), new SPIRVContextDependentInt(new BigInteger("2"))));
861 case "int_EIGHT" -> module.add(new SPIRVOpConstant(getType("int"), nextId("int_EIGHT"), new SPIRVContextDependentInt(new BigInteger("8"))));
862 default -> unsupported("constant", name);
863 }
864 moduleConstants.add(name);
865 }
866 return getId(name);
867 }
868
869 private SPIRVOptionalOperand<SPIRVMemoryAccess> align(int align) {
870 return new SPIRVOptionalOperand<>(SPIRVMemoryAccess.Aligned(new SPIRVLiteralInteger(align)));
871 }
872
873 private SPIRVOptionalOperand<SPIRVMemoryAccess> align(String type) {
874 return align(alignment(type));
875 }
876
877 private SPIRVMultipleOperands<SPIRVId> spirvOperands(SPIRVId value, int count) {
878 SPIRVId[] values = new SPIRVId[count];
879 Arrays.fill(values, value);
880 return new SPIRVMultipleOperands<>(values);
881 }
882
883 private SPIRVOptionalOperand<SPIRVMemoryAccess> none() {
884 return new SPIRVOptionalOperand<>();
885 }
886
887 private SpirvResult globalSize(int index, SPIRVBlock spirvBlock) {
888 SPIRVId longType = getType("long");
889 SPIRVId v3long = getId("v3long");
890 SPIRVId globalSizeId = getId("globalSize");
891 SPIRVId globalSizes = nextId();
892 spirvBlock.add(new SPIRVOpLoad(v3long, globalSizes, globalSizeId, align(32)));
893 SPIRVId globalSize = nextId();
894 spirvBlock.add(new SPIRVOpCompositeExtract(longType, globalSize, globalSizes, new SPIRVMultipleOperands<>(new SPIRVLiteralInteger(index))));
895 return new SpirvResult(longType, null, globalSize);
896 }
897
898 private SpirvResult globalId(int index, SPIRVBlock spirvBlock) {
899 SPIRVId longType = getType("long");
900 SPIRVId v3long = getId("v3long");
901 SPIRVId globalInvocationId = getId("globalInvocationId");
902 SPIRVId globalIds = nextId();
903 spirvBlock.add(new SPIRVOpLoad(v3long, globalIds, globalInvocationId, align(32)));
904 SPIRVId globalIndex = nextId();
905 spirvBlock.add(new SPIRVOpCompositeExtract(longType, globalIndex, globalIds, new SPIRVMultipleOperands<>(new SPIRVLiteralInteger(index))));
906 return new SpirvResult(longType, null, globalIndex);
907 }
908
909 private SPIRVId nextId() {
910 return module.getNextId();
911 }
912
913 private SPIRVId nextId(String name) {
914 SPIRVId ans = nextId();
915 ans.setName(name);
916 symbols.putId(name, ans);
917 module.add(new SPIRVOpName(ans, new SPIRVLiteralString(name)));
918 return ans;
919 }
920
921 private static int counter = 0;
922
923 private String nextTempTag() {
924 counter++;
925 return "temp_" + counter + "_";
926 }
927
928 private boolean isIntegerType(SPIRVId type) {
929 String name = type.getName();
930 return name.equals("short") || name.equals("int") || name.equals("long");
931 }
932
933 private boolean isFloatType(SPIRVId type) {
934 String name = type.getName();
935 return name.equals("float") || name.equals("double");
936 }
937
938 private boolean isVectorSpecies(String javaType) {
939 return javaType.equals("VectorSpecies");
940 }
941
942 private boolean isVectorType(String javaType) {
943 return javaType.equals("IntVector") || javaType.equals("FloatVector");
944 }
945
946 private void addId(String name, SPIRVId id) {
947 symbols.putId(name, id);
948 }
949
950 private SPIRVId getId(String name) {
951 SPIRVId ans = symbols.getId(name);
952 assert ans != null : name + " not found";
953 return ans;
954 }
955
956 private SPIRVId getIdOrNull(String name) {
957 return symbols.getId(name);
958 }
959
960 private static Object map(Function<Object, Boolean> test, Object... args) {
961 int len = args.length;
962 assert len >= 2 && len % 2 == 0;
963 int pairs = len / 2;
964 for (int i = 0; i < pairs; i++) {
965 if (test.apply(args[i])) return args[i + pairs];
966 }
967 throw new RuntimeException("No match: " + args[0]);
968 }
969
970 private void unsupported(String message, Object value) {
971 throw new RuntimeException("Unsupported " + message + ": " + value);
972 }
973
974 private void addResult(Value value, SpirvResult result) {
975 assert symbols.getResult(value) == null : "result already present";
976 symbols.putResult(value, result);
977 }
978
979 private SpirvResult getResult(Value value) {
980 return symbols.getResult(value);
981 }
982
983 private static class Symbols {
984 private final HashMap<Value, SpirvResult> results;
985 private final HashMap<String, SPIRVId> ids;
986 private final HashMap<Block, SPIRVBlock> blocks;
987 private final HashMap<Block, SPIRVOpLabel> labels;
988
989 public Symbols() {
990 this.results = new HashMap<>();
991 this.ids = new HashMap<>();
992 this.blocks = new HashMap<>();
993 this.labels = new HashMap<>();
994 }
995
996 public void putId(String name, SPIRVId id) {
997 ids.put(name, id);
998 }
999
1000 public SPIRVId getId(String name) {
1001 return ids.get(name);
1002 }
1003
1004 public void putBlock(Block block, SPIRVBlock spirvBlock) {
1005 blocks.put(block, spirvBlock);
1006 }
1007
1008 public SPIRVBlock getBlock(Block block) {
1009 return blocks.get(block);
1010 }
1011
1012 public void putLabel(Block block, SPIRVOpLabel spirvLabel) {
1013 labels.put(block, spirvLabel);
1014 }
1015
1016 public SPIRVOpLabel getLabel(Block block) {
1017 return labels.get(block);
1018 }
1019
1020 public void putResult(Value value, SpirvResult result) {
1021 results.put(value, result);
1022 }
1023
1024 public SpirvResult getResult(Value value) {
1025 return results.get(value);
1026 }
1027
1028 public String toString() {
1029 return String.format("results %s\n\nids %s\n\nblocks %s\nlabels %s\n", results.keySet(), ids.keySet(), blocks.keySet(), labels.keySet());
1030 }
1031 }
1032 }