1 /*
2 * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation. Oracle designates this
8 * particular file as subject to the "Classpath" exception as provided
9 * by Oracle in the LICENSE file that accompanied this code.
10 *
11 * This code is distributed in the hope that it will be useful, but WITHOUT
12 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
13 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
14 * version 2 for more details (a copy is included in the LICENSE file that
15 * accompanied this code).
16 *
17 * You should have received a copy of the GNU General Public License version
18 * 2 along with this work; if not, write to the Free Software Foundation,
19 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
20 *
21 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
22 * or visit www.oracle.com if you need additional information or have any
23 * questions.
24 */
25 package hat.backend.ffi;
26
27 import hat.dialect.HATOp;
28 import hat.ifacemapper.BoundSchema;
29 import hat.optools.*;
30 import hat.codebuilders.CodeBuilder;
31
32 import jdk.incubator.code.*;
33 import jdk.incubator.code.dialect.core.CoreOp;
34 import jdk.incubator.code.dialect.java.JavaOp;
35 import jdk.incubator.code.dialect.java.JavaType;
36
37 import java.lang.foreign.MemoryLayout;
38 import java.lang.invoke.MethodHandles;
39 import java.util.ArrayList;
40 import java.util.HashMap;
41 import java.util.List;
42 import java.util.Map;
43 import java.util.stream.Stream;
44
45 public class PTXHATKernelBuilder extends CodeBuilder<PTXHATKernelBuilder> {
46
47 Map<Value, PTXRegister> varToRegMap;
48 List<String> paramNames;
49 List<Block.Parameter> paramObjects;
50 Map<Field, PTXRegister> fieldToRegMap;
51
52 HashMap<PTXRegister.Type, Integer> ordinalMap;
53
54 PTXRegister returnReg;
55 private int addressSize;
56
57 public enum Field {
58 NTID_X ("ntid.x", false),
59 CTAID_X ("ctaid.x", false),
60 TID_X ("tid.x", false),
61 KC_X ("x", false),
62 KC_ADDR("kc", true),
63 KC_MAXX ("maxX", false);
64
65 private final String name;
66 private final boolean destination;
67
68 Field(String name, boolean destination) {
69 this.name = name;
70 this.destination = destination;
71 }
72 public String toString() {
73 return this.name;
74 }
75 public boolean isDestination() {return this.destination;}
76 }
77
78 public PTXHATKernelBuilder(int addressSize) {
79 varToRegMap = new HashMap<>();
80 paramNames = new ArrayList<>();
81 fieldToRegMap = new HashMap<>();
82 paramObjects = new ArrayList<>();
83 ordinalMap = new HashMap<>();
84 this.addressSize = addressSize;
85 }
86
87 public PTXHATKernelBuilder() {
88 this(32);
89 }
90
91 public void ptxHeader(int major, int minor, String target, int addressSize) {
92 this.addressSize = addressSize;
93 version().space().major(major).dot().minor(minor).nl();
94 target().space().target(target).nl();
95 addressSize().space().size(addressSize);
96 }
97
98 public void functionHeader(String funcName, boolean entry, TypeElement yieldType) {
99 if (entry) {
100 visible().space().entry().space();
101 } else {
102 func().space();
103 }
104 if (!yieldType.toString().equals("void")) {
105 returnReg = new PTXRegister(getOrdinal(getResultType(yieldType)), getResultType(yieldType));
106 returnReg.name("%retReg");
107 oparen().dot().param().space().paramType(yieldType);
108 space().regName(returnReg).cparen().space();
109 }
110 funcName(funcName);
111 }
112
113 public PTXHATKernelBuilder parameters(List<FuncOpParams.Info> infoList) {
114 paren(_ ->
115 nl().separated(infoList,(t)->t.comma().nl(), (info) -> {
116 ptxIndent().dot().param().space().paramType(info.javaType);
117 space().regName(info.varOp.varName());
118 paramNames.add(info.varOp.varName());
119 }).nl()).nl();
120 return this;
121 }
122
123 public void blockBody(Block block, Stream<Op> ops) {
124 if (block.index() == 0) {
125 for (Block.Parameter p : block.parameters()) {
126 ptxIndent().ld().dot().param();
127 resultType(p.type(), false).ptxIndent().space();
128 reg(p, getResultType(p.type())).commaSpace().osbrace().regName(paramNames.get(p.index())).csbrace().semicolon().nl();
129 paramObjects.add(p);
130 }
131 }
132 nl();
133 block(block);
134 colon().nl();
135 ops.forEach(op -> {
136 if (op instanceof JavaOp.InvokeOp invoke && !OpTk.isIfaceBufferMethod(MethodHandles.lookup(),invoke)) { // We should pass lookup down
137 ptxIndent().convert(op).nl();
138 } else {
139 ptxIndent().convert(op).semicolon().nl();
140 }
141 });
142 }
143
144 public void ptxRegisterDecl() {
145 for (PTXRegister.Type t : ordinalMap.keySet()) {
146 ptxIndent().reg().space();
147 if (t.equals(PTXRegister.Type.U32)) {
148 b32();
149 } else if (t.equals(PTXRegister.Type.U64)) {
150 b64();
151 } else {
152 dot().regType(t);
153 }
154 ptxIndent().regTypePrefix(t).oabrace().intVal(ordinalMap.get(t)).cabrace().semicolon().nl();
155 }
156 nl();
157 }
158
159 public void functionPrologue() {
160 obrace().nl();
161 }
162
163 public void functionEpilogue() {
164 cbrace();
165 }
166
167 public static class PTXPtrOp extends HATOp {
168 public String fieldName;
169 public static final String NAME = "ptxPtr";
170 final TypeElement resultType;
171 public BoundSchema<?> boundSchema;
172
173 PTXPtrOp(TypeElement resultType, String fieldName, List<Value> operands, BoundSchema<?> boundSchema) {
174 super(operands);
175 this.resultType = resultType;
176 this.fieldName = fieldName;
177 this.boundSchema = boundSchema;
178 }
179
180 PTXPtrOp(PTXPtrOp that, CopyContext cc) {
181 super(that, cc);
182 this.resultType = that.resultType;
183 this.fieldName = that.fieldName;
184 this.boundSchema = that.boundSchema;
185 }
186
187 @Override
188 public PTXPtrOp transform(CopyContext cc, OpTransformer ot) {
189 return new PTXPtrOp(this, cc);
190 }
191
192 @Override
193 public TypeElement resultType() {
194 return resultType;
195 }
196
197 @Override
198 public String externalizeOpName() {
199 return NAME;
200 }
201 }
202
203
204 public PTXHATKernelBuilder convert(Op op) {
205 switch (op) {
206 case JavaOp.FieldAccessOp.FieldLoadOp $ -> fieldLoad($);
207 case JavaOp.FieldAccessOp.FieldStoreOp $ -> fieldStore($);
208 case JavaOp.BinaryOp $ -> binaryOperation($);
209 case JavaOp.BinaryTestOp $ -> binaryTest($);
210 case JavaOp.ConvOp $ -> conv($);
211 case CoreOp.ConstantOp $ -> constant($);
212 case CoreOp.YieldOp $ -> javaYield($);
213 case JavaOp.InvokeOp $ -> methodCall($);
214 case CoreOp.VarOp $ when OpTk.paramVar($) != null -> varFuncDeclaration($);
215 case CoreOp.VarOp $ -> varDeclaration($);
216 case CoreOp.ReturnOp $ -> ret($);
217 case JavaOp.BreakOp $ -> javaBreak($);
218 default -> { // Why are these switch ops not just inlined above?
219 switch (op){
220 case CoreOp.BranchOp $ -> branch($);
221 case CoreOp.ConditionalBranchOp $ -> condBranch($);
222 case JavaOp.NegOp $ -> neg($);
223 case PTXPtrOp $ -> ptxPtr($);
224 default -> throw new IllegalStateException("op translation doesn't exist");
225 }
226 }
227 }
228 return this;
229 }
230
231 public void ptxPtr(PTXPtrOp op) {
232 PTXRegister source;
233 int offset = (int) op.boundSchema.groupLayout().byteOffset(MemoryLayout.PathElement.groupElement(op.fieldName));
234
235 if (op.fieldName.equals("array")) {
236 source = new PTXRegister(incrOrdinal(addressType()), addressType());
237 add().s64().space().regName(source).commaSpace().reg(op.operands().get(0)).commaSpace().reg(op.operands().get(1)).ptxNl();
238 } else {
239 source = getReg(op.operands().getFirst());
240 }
241
242 if (op.resultType.toString().equals("void")) {
243 st().global().dot().regType(op.operands().getLast()).space().address(source.name(), offset).commaSpace().reg(op.operands().getLast());
244 } else {
245 ld().global().resultType(op.resultType(), true).space().reg(op.result(), getResultType(op.resultType())).commaSpace().address(source.name(), offset);
246 }
247 }
248
249 public void fieldLoad(JavaOp.FieldAccessOp.FieldLoadOp fieldLoadOp) {
250 if (OpTk.fieldName(fieldLoadOp).equals(Field.KC_X.toString())) {
251 if (!fieldToRegMap.containsKey(Field.KC_X)) {
252 loadKcX(fieldLoadOp.result());
253 } else {
254 mov().u32().space().resultReg(fieldLoadOp, PTXRegister.Type.U32).commaSpace().fieldReg(Field.KC_X);
255 }
256 } else if (OpTk.fieldName(fieldLoadOp).equals(Field.KC_MAXX.toString())) {
257 if (!fieldToRegMap.containsKey(Field.KC_X)) {
258 loadKcX(fieldLoadOp.operands().getFirst());
259 }
260 ld().global().u32().space().fieldReg(Field.KC_MAXX, fieldLoadOp.result()).commaSpace()
261 .address(fieldToRegMap.get(Field.KC_ADDR).name(), 4);
262 } else {
263 ld().global().u32().space().resultReg(fieldLoadOp, PTXRegister.Type.U64).commaSpace().reg(fieldLoadOp.operands().getFirst());
264 }
265 }
266
267 public void loadKcX(Value value) {
268 cvta().to().global().size().space().fieldReg(Field.KC_ADDR).commaSpace()
269 .reg(paramObjects.get(paramNames.indexOf(Field.KC_ADDR.toString())), addressType()).ptxNl();
270 mov().u32().space().fieldReg(Field.NTID_X).commaSpace().percent().regName(Field.NTID_X.toString()).ptxNl();
271 mov().u32().space().fieldReg(Field.CTAID_X).commaSpace().percent().regName(Field.CTAID_X.toString()).ptxNl();
272 mov().u32().space().fieldReg(Field.TID_X).commaSpace().percent().regName(Field.TID_X.toString()).ptxNl();
273 mad().lo().s32().space().fieldReg(Field.KC_X, value).commaSpace().fieldReg(Field.CTAID_X)
274 .commaSpace().fieldReg(Field.NTID_X).commaSpace().fieldReg(Field.TID_X).ptxNl();
275 st().global().u32().space().address(fieldToRegMap.get(Field.KC_ADDR).name()).commaSpace().fieldReg(Field.KC_X);
276 }
277
278 public void fieldStore(JavaOp.FieldAccessOp.FieldStoreOp op) {
279 // TODO: fix
280 st().global().u64().space().resultReg(op, PTXRegister.Type.U64).commaSpace().reg(op.operands().getFirst());
281 }
282
283 PTXHATKernelBuilder symbol(Op op) {
284 return switch (op) {
285 case JavaOp.ModOp _ -> rem();
286 case JavaOp.MulOp _ -> mul();
287 case JavaOp.DivOp _ -> div();
288 case JavaOp.AddOp _ -> add();
289 case JavaOp.SubOp _ -> sub();
290 case JavaOp.LtOp _ -> lt();
291 case JavaOp.GtOp _ -> gt();
292 case JavaOp.LeOp _ -> le();
293 case JavaOp.GeOp _ -> ge();
294 case JavaOp.NeqOp _ -> ne();
295 case JavaOp.EqOp _ -> eq();
296 case JavaOp.OrOp _ -> or();
297 case JavaOp.AndOp _ -> and();
298 case JavaOp.XorOp _ -> xor();
299 case JavaOp.LshlOp _ -> shl();
300 case JavaOp.AshrOp _, JavaOp.LshrOp _ -> shr();
301 default -> throw new IllegalStateException("Unexpected value");
302 };
303 }
304
305 public void binaryOperation(JavaOp.BinaryOp op) {
306 symbol(op);
307 if (getResultType(op.resultType()).getBasicType().equals(PTXRegister.Type.BasicType.FLOATING)
308 && (op instanceof JavaOp.DivOp || op instanceof JavaOp.MulOp)) {
309 rn();
310 } else if (!getResultType(op.resultType()).getBasicType().equals(PTXRegister.Type.BasicType.FLOATING)
311 && op instanceof JavaOp.MulOp) {
312 lo();
313 }
314 resultType(op.resultType(), true).space();
315 resultReg(op, getResultType(op.resultType()));
316 commaSpace();
317 reg(op.operands().getFirst());
318 commaSpace();
319 reg(op.operands().get(1));
320 }
321
322 public void binaryTest(JavaOp.BinaryTestOp op) {
323 setp().dot();
324 symbol(op).resultType(op.operands().getFirst().type(), true).space();
325 resultReg(op, PTXRegister.Type.PREDICATE);
326 commaSpace();
327 reg(op.operands().getFirst());
328 commaSpace();
329 reg(op.operands().get(1));
330 }
331
332 public void conv(JavaOp.ConvOp op) {
333 if (op.resultType().equals(JavaType.LONG)) {
334 if (isIndex(op)) {
335 mul().wide().s32().space().resultReg(op, PTXRegister.Type.U64).commaSpace()
336 .reg(op.operands().getFirst()).commaSpace().intVal(4);
337 } else {
338 cvt().u64().dot().regType(op.operands().getFirst()).space()
339 .resultReg(op, PTXRegister.Type.U64).commaSpace().reg(op.operands().getFirst()).ptxNl();
340 }
341 } else if (op.resultType().equals(JavaType.FLOAT)) {
342 cvt().rn().f32().dot().regType(op.operands().getFirst()).space()
343 .resultReg(op, PTXRegister.Type.F32).commaSpace().reg(op.operands().getFirst());
344 } else if (op.resultType().equals(JavaType.DOUBLE)) {
345 cvt();
346 if (op.operands().getFirst().type().equals(JavaType.INT)) {
347 rn();
348 }
349 f64().dot().regType(op.operands().getFirst()).space()
350 .resultReg(op, PTXRegister.Type.F64).commaSpace().reg(op.operands().getFirst());
351 } else if (op.resultType().equals(JavaType.INT)) {
352 cvt();
353 if (op.operands().getFirst().type().equals(JavaType.DOUBLE) || op.operands().getFirst().type().equals(JavaType.FLOAT)) {
354 rzi();
355 } else {
356 rn();
357 }
358 s32().dot().regType(op.operands().getFirst()).space()
359 .resultReg(op, PTXRegister.Type.S32).commaSpace().reg(op.operands().getFirst());
360 } else {
361 cvt().rn().s32().dot().regType(op.operands().getFirst()).space()
362 .resultReg(op, PTXRegister.Type.S32).commaSpace().reg(op.operands().getFirst());
363 }
364 }
365
366
367
368
369
370 public static class PTXRegister {
371 private String name;
372 private final Type type;
373
374 public enum Type {
375 S8 (8, BasicType.SIGNED, "s8", "%s"),
376 S16 (16, BasicType.SIGNED, "s16", "%s"),
377 S32 (32, BasicType.SIGNED, "s32", "%s"),
378 S64 (64, BasicType.SIGNED, "s64", "%sd"),
379 U8 (8, BasicType.UNSIGNED, "u8", "%r"),
380 U16 (16, BasicType.UNSIGNED, "u16", "%r"),
381 U32 (32, BasicType.UNSIGNED, "u32", "%r"),
382 U64 (64, BasicType.UNSIGNED, "u64", "%rd"),
383 F16 (16, BasicType.FLOATING, "f16", "%f"),
384 F16X2 (16, BasicType.FLOATING, "f16", "%f"),
385 F32 (32, BasicType.FLOATING, "f32", "%f"),
386 F64 (64, BasicType.FLOATING, "f64", "%fd"),
387 B8 (8, BasicType.BIT, "b8", "%b"),
388 B16 (16, BasicType.BIT, "b16", "%b"),
389 B32 (32, BasicType.BIT, "b32", "%b"),
390 B64 (64, BasicType.BIT, "b64", "%bd"),
391 B128 (128, BasicType.BIT, "b128", "%b"),
392 PREDICATE (1, BasicType.PREDICATE, "pred", "%p");
393
394 public enum BasicType {
395 SIGNED,
396 UNSIGNED,
397 FLOATING,
398 BIT,
399 PREDICATE
400 }
401
402 private final int size;
403 private final BasicType basicType;
404 private final String name;
405 private final String regPrefix;
406
407 Type(int size, BasicType type, String name, String regPrefix) {
408 this.size = size;
409 this.basicType = type;
410 this.name = name;
411 this.regPrefix = regPrefix;
412 }
413
414 public int getSize() {
415 return this.size;
416 }
417
418 public BasicType getBasicType() {
419 return this.basicType;
420 }
421
422 public String getName() {
423 return this.name;
424 }
425
426 public String getRegPrefix() {
427 return this.regPrefix;
428 }
429 }
430
431 public PTXRegister(int num, Type type) {
432 this.type = type;
433 this.name = type.regPrefix + num;
434 }
435
436 public String name() {
437 return this.name;
438 }
439
440 public void name(String name) {
441 this.name = name;
442 }
443
444 public Type type() {
445 return this.type;
446 }
447 }
448
449
450 private boolean isIndex(JavaOp.ConvOp op) {
451 for (Op.Result r : op.result().uses()) {
452 if (r.op() instanceof PTXPtrOp) return true;
453 }
454 return false;
455 }
456
457 public void constant(CoreOp.ConstantOp op) {
458 mov().resultType(op.resultType(), false).space().resultReg(op, getResultType(op.resultType())).commaSpace();
459 if (op.resultType().toString().equals("float")) {
460 if (op.value().toString().equals("0.0")) {
461 floatVal("00000000");
462 } else {
463 floatVal(Integer.toHexString(Float.floatToIntBits(Float.parseFloat(op.value().toString()))).toUpperCase());
464 }
465 } else {
466 constant(op.value().toString());
467 }
468 }
469
470 public void javaYield(CoreOp.YieldOp op) {
471 exit();
472 }
473
474 // S32Array and S32Array2D functions can be deleted after schema is done
475 public void methodCall(JavaOp.InvokeOp op) {
476 switch (op.invokeDescriptor().toString()) {
477 // S32Array functions
478 case "hat.buffer.S32Array::array(long)int" -> {
479 PTXRegister temp = new PTXRegister(incrOrdinal(addressType()), addressType());
480 add().s64().space().regName(temp).commaSpace().reg(op.operands().getFirst()).commaSpace().reg(op.operands().get(1)).ptxNl();
481 ld().global().u32().space().resultReg(op, PTXRegister.Type.U32).commaSpace().address(temp.name(), 4);
482 }
483 case "hat.buffer.S32Array::array(long, int)void" -> {
484 PTXRegister temp = new PTXRegister(incrOrdinal(addressType()), addressType());
485 add().s64().space().regName(temp).commaSpace().reg(op.operands().getFirst()).commaSpace().reg(op.operands().get(1)).ptxNl();
486 st().global().u32().space().address(temp.name(), 4).commaSpace().reg(op.operands().get(2));
487 }
488 case "hat.buffer.S32Array::length()int" -> {
489 ld().global().u32().space().resultReg(op, PTXRegister.Type.U32).commaSpace().address(getReg(op.operands().getFirst()).name());
490 }
491 // S32Array2D functions
492 case "hat.buffer.S32Array2D::array(long, int)void" -> {
493 PTXRegister temp = new PTXRegister(incrOrdinal(addressType()), addressType());
494 add().s64().space().regName(temp).commaSpace().reg(op.operands().getFirst()).commaSpace().reg(op.operands().get(1)).ptxNl();
495 st().global().u32().space().address(temp.name(), 8).commaSpace().reg(op.operands().get(2));
496 }
497 case "hat.buffer.S32Array2D::width()int" -> {
498 ld().global().u32().space().resultReg(op, PTXRegister.Type.U32).commaSpace().address(getReg(op.operands().getFirst()).name());
499 }
500 case "hat.buffer.S32Array2D::height()int" -> {
501 ld().global().u32().space().resultReg(op, PTXRegister.Type.U32).commaSpace().address(getReg(op.operands().getFirst()).name(), 4);
502 }
503 // Java Math function
504 case "java.lang.Math::sqrt(double)double" -> {
505 sqrt().rn().f64().space().resultReg(op, PTXRegister.Type.F64).commaSpace().reg(op.operands().getFirst()).semicolon();
506 }
507 default -> {
508 obrace().nl().ptxIndent();
509 for (int i = 0; i < op.operands().size(); i++) {
510 dot().param().space().paramType(op.operands().get(i).type()).space().param().intVal(i).ptxNl();
511 st().dot().param().paramType(op.operands().get(i).type()).space().osbrace().param().intVal(i).csbrace().commaSpace().reg(op.operands().get(i)).ptxNl();
512 }
513 dot().param().space().paramType(op.resultType()).space().retVal().ptxNl();
514 call().uni().space().oparen().retVal().cparen().commaSpace().identifier(OpTk.methodOrThrow(MethodHandles.lookup(),op).getName()).commaSpace();
515 final int[] counter = {0};
516 paren(_ ->
517 separated(op.operands(),(_)->commaSpace(),
518 //commaSeparated(op.operands(),
519 _ -> param().intVal(counter[0]++))).ptxNl();
520 ld().dot().param().paramType(op.resultType()).space().resultReg(op, getResultType(op.resultType())).commaSpace().osbrace().retVal().csbrace();
521 ptxNl().cbrace();
522 }
523 }
524 }
525
526 public void varDeclaration(CoreOp.VarOp op) {
527 ld().dot().param().resultType(op.resultType(), false).space().resultReg(op, addressType()).commaSpace().reg(op.operands().getFirst());
528 }
529
530 public void varFuncDeclaration(CoreOp.VarOp op) {
531 ld().dot().param().resultType(op.resultType(), false).space().resultReg(op, addressType()).commaSpace().reg(op.operands().getFirst());
532 }
533
534 public void ret(CoreOp.ReturnOp op) {
535 if (!op.operands().isEmpty()) {
536 st().dot().param();
537 if (returnReg.type().equals(PTXRegister.Type.U32)) {
538 b32();
539 } else if (returnReg.type().equals(PTXRegister.Type.U64)) {
540 b64();
541 } else {
542 dot().regType(returnReg.type());
543 }
544 space().osbrace().regName(returnReg).csbrace().commaSpace().reg(op.operands().getFirst()).ptxNl();
545 }
546 ret();
547 }
548
549 public void javaBreak(JavaOp.BreakOp op) {
550 brkpt();
551 }
552
553 public void branch(CoreOp.BranchOp op) {
554 loadBlockParams(op.successors().getFirst());
555 bra().space().block(op.successors().getFirst().targetBlock());
556 }
557
558 public void condBranch(CoreOp.ConditionalBranchOp op) {
559 loadBlockParams(op.successors().getFirst());
560 loadBlockParams(op.successors().getLast());
561 at().reg(op.operands().getFirst()).space()
562 .bra().space().block(op.successors().getFirst().targetBlock()).ptxNl();
563 bra().space().block(op.successors().getLast().targetBlock());
564 }
565
566 public void neg(JavaOp.NegOp op) {
567 neg().resultType(op.resultType(), true).space().reg(op.result(), getResultType(op.resultType())).commaSpace().reg(op.operands().getFirst());
568 }
569
570 /*
571 * Helper functions for printing blocks and variables
572 */
573
574 public void loadBlockParams(Block.Reference block) {
575 for (int i = 0; i < block.arguments().size(); i++) {
576 Block.Parameter p = block.targetBlock().parameters().get(i);
577 mov().resultType(p.type(), false).space().reg(p, getResultType(p.type()))
578 .commaSpace().reg(block.arguments().get(i)).ptxNl();
579 }
580 }
581
582 public PTXHATKernelBuilder block(Block block) {
583 return typeName("block_").intVal(block.index());
584 }
585
586 public PTXHATKernelBuilder fieldReg(Field ref) {
587 if (fieldToRegMap.containsKey(ref)) {
588 return regName(fieldToRegMap.get(ref));
589 }
590 if (ref.isDestination()) {
591 fieldToRegMap.putIfAbsent(ref, new PTXRegister(incrOrdinal(addressType()), addressType()));
592 } else {
593 fieldToRegMap.putIfAbsent(ref, new PTXRegister(incrOrdinal(PTXRegister.Type.U32), PTXRegister.Type.U32));
594 }
595 return regName(fieldToRegMap.get(ref));
596 }
597
598 public PTXHATKernelBuilder fieldReg(Field ref, Value value) {
599 if (fieldToRegMap.containsKey(ref)) {
600 return regName(fieldToRegMap.get(ref));
601 }
602 if (ref.isDestination()) {
603 fieldToRegMap.putIfAbsent(ref, new PTXRegister(getOrdinal(addressType()), addressType()));
604 return reg(value, addressType());
605 } else {
606 fieldToRegMap.putIfAbsent(ref, new PTXRegister(getOrdinal(PTXRegister.Type.U32), PTXRegister.Type.U32));
607 return reg(value, PTXRegister.Type.U32);
608 }
609 }
610
611 public Field getFieldObj(String fieldName) {
612 for (Field f : fieldToRegMap.keySet()) {
613 if (f.toString().equals(fieldName)) return f;
614 }
615 throw new IllegalStateException("no existing field");
616 }
617
618 public PTXHATKernelBuilder resultReg(Op op, PTXRegister.Type type) {
619 return identifier(addReg(op.result(), type));
620 }
621
622 public PTXHATKernelBuilder reg(Value val, PTXRegister.Type type) {
623 if (varToRegMap.containsKey(val)) {
624 return regName(getReg(val));
625 } else {
626 return identifier(addReg(val, type));
627 }
628 }
629
630 public PTXHATKernelBuilder reg(Value val) {
631 return regName(getReg(val));
632 }
633
634 public PTXRegister getReg(Value val) {
635 if (varToRegMap.get(val) == null && val instanceof Op.Result result && result.op() instanceof JavaOp.FieldAccessOp.FieldLoadOp fieldLoadOp) {
636 return fieldToRegMap.get(getFieldObj(fieldLoadOp.fieldDescriptor().name()));
637 }
638 if (varToRegMap.containsKey(val)) {
639 return varToRegMap.get(val);
640 } else {
641 throw new IllegalStateException("var to reg mapping doesn't exist");
642 }
643 }
644
645 public String addReg(Value val, PTXRegister.Type type) {
646 if (varToRegMap.containsKey(val)) {
647 return varToRegMap.get(val).name();
648 }
649 varToRegMap.put(val, new PTXRegister(incrOrdinal(type), type));
650 return varToRegMap.get(val).name();
651 }
652
653 public Integer getOrdinal(PTXRegister.Type type) {
654 ordinalMap.putIfAbsent(type, 1);
655 return ordinalMap.get(type);
656 }
657
658 public Integer incrOrdinal(PTXRegister.Type type) {
659 ordinalMap.putIfAbsent(type, 1);
660 int out = ordinalMap.get(type);
661 ordinalMap.put(type, out + 1);
662 return out;
663 }
664
665 public PTXHATKernelBuilder size() {
666 return (addressSize == 32) ? u32() : u64();
667 }
668
669 public PTXRegister.Type addressType() {
670 return (addressSize == 32) ? PTXRegister.Type.U32 : PTXRegister.Type.U64;
671 }
672
673 public PTXHATKernelBuilder resultType(TypeElement type, boolean signedResult) {
674 PTXRegister.Type res = getResultType(type);
675 if (signedResult && (res == PTXRegister.Type.U32)) return s32();
676 return dot().typeName(getResultType(type).getName());
677 }
678
679 public PTXHATKernelBuilder paramType(TypeElement type) {
680 PTXRegister.Type res = getResultType(type);
681 if (res == PTXRegister.Type.U32) return b32();
682 if (res == PTXRegister.Type.U64) return b64();
683 return dot().typeName(getResultType(type).getName());
684 }
685
686 public PTXRegister.Type getResultType(TypeElement type) {
687 switch (type.toString()) {
688 case "float" -> {
689 return PTXRegister.Type.F32;
690 }
691 case "double" -> {
692 return PTXRegister.Type.F64;
693 }
694 case "int" -> {
695 return PTXRegister.Type.U32;
696 }
697 case "boolean" -> {
698 return PTXRegister.Type.PREDICATE;
699 }
700 default -> {
701 return PTXRegister.Type.U64;
702 }
703 }
704 }
705
706 /*
707 * Basic CodeBuilder functions
708 */
709
710 // used for parameter list
711 // prints out items separated by a comma then new line
712 // Don't know why this was overriding with the same code grf.
713 /* @Override
714 public <I> PTXHATKernelBuilder commaNlSeparated(Iterable<I> iterable, Consumer<I> c) {
715 StreamCounter.of(iterable, (counter, t) -> {
716 if (counter.isNotFirst()) {
717 comma().nl();
718 }
719 c.accept(t);
720 });
721 return self();
722 }
723 */
724 public PTXHATKernelBuilder address(String address) {
725 return osbrace().constant(address).csbrace();
726 }
727
728 public PTXHATKernelBuilder address(String address, int offset) {
729 osbrace().constant(address);
730 if (offset == 0) {
731 return csbrace();
732 } else if (offset > 0) {
733 plus();
734 }
735 return intVal(offset).csbrace();
736 }
737
738 public PTXHATKernelBuilder ptxNl() {
739 return semicolon().nl().ptxIndent();
740 }
741
742
743 public PTXHATKernelBuilder param() {
744 return keyword("param");
745 }
746
747 public PTXHATKernelBuilder global() {
748 return dot().keyword("global");
749 }
750
751 public PTXHATKernelBuilder rn() {
752 return dot().keyword("rn");
753 }
754
755 public PTXHATKernelBuilder rm() {
756 return dot().keyword("rm");
757 }
758
759 public PTXHATKernelBuilder rzi() {
760 return dot().keyword("rzi");
761 }
762
763 public PTXHATKernelBuilder to() {
764 return dot().keyword("to");
765 }
766
767 public PTXHATKernelBuilder lo() {
768 return dot().keyword("lo");
769 }
770
771 public PTXHATKernelBuilder wide() {
772 return dot().keyword("wide");
773 }
774
775 public PTXHATKernelBuilder uni() {
776 return dot().keyword("uni");
777 }
778
779 public PTXHATKernelBuilder sat() {
780 return dot().keyword("sat");
781 }
782
783 public PTXHATKernelBuilder ftz() {
784 return dot().keyword("ftz");
785 }
786
787 public PTXHATKernelBuilder approx() {
788 return dot().keyword("approx");
789 }
790
791 public PTXHATKernelBuilder mov() {
792 return keyword("mov");
793 }
794
795 public PTXHATKernelBuilder setp() {
796 return keyword("setp");
797 }
798
799 public PTXHATKernelBuilder selp() {
800 return keyword("selp");
801 }
802
803 public PTXHATKernelBuilder ld() {
804 return keyword("ld");
805 }
806
807 public PTXHATKernelBuilder st() {
808 return keyword("st");
809 }
810
811 public PTXHATKernelBuilder cvt() {
812 return keyword("cvt");
813 }
814
815 public PTXHATKernelBuilder bra() {
816 return keyword("bra");
817 }
818
819 public PTXHATKernelBuilder ret() {
820 return keyword("ret");
821 }
822
823 public PTXHATKernelBuilder rem() {
824 return keyword("rem");
825 }
826
827 public PTXHATKernelBuilder mul() {
828 return keyword("mul");
829 }
830
831 public PTXHATKernelBuilder div() {
832 return keyword("div");
833 }
834
835 public PTXHATKernelBuilder rcp() {
836 return keyword("rcp");
837 }
838
839 public PTXHATKernelBuilder add() {
840 return keyword("add");
841 }
842
843 public PTXHATKernelBuilder sub() {
844 return keyword("sub");
845 }
846
847 public PTXHATKernelBuilder lt() {
848 return keyword("lt");
849 }
850
851 public PTXHATKernelBuilder gt() {
852 return keyword("gt");
853 }
854
855 public PTXHATKernelBuilder le() {
856 return keyword("le");
857 }
858
859 public PTXHATKernelBuilder ge() {
860 return keyword("ge");
861 }
862
863 public PTXHATKernelBuilder geu() {
864 return keyword("geu");
865 }
866
867 public PTXHATKernelBuilder ne() {
868 return keyword("ne");
869 }
870
871 public PTXHATKernelBuilder eq() {
872 return keyword("eq");
873 }
874
875 public PTXHATKernelBuilder xor() {
876 return keyword("xor");
877 }
878
879 public PTXHATKernelBuilder or() {
880 return keyword("or");
881 }
882
883 public PTXHATKernelBuilder and() {
884 return keyword("and");
885 }
886
887 public PTXHATKernelBuilder cvta() {
888 return keyword("cvta");
889 }
890
891 public PTXHATKernelBuilder mad() {
892 return keyword("mad");
893 }
894
895 public PTXHATKernelBuilder fma() {
896 return keyword("fma");
897 }
898
899 public PTXHATKernelBuilder sqrt() {
900 return keyword("sqrt");
901 }
902
903 public PTXHATKernelBuilder abs() {
904 return keyword("abs");
905 }
906
907 public PTXHATKernelBuilder ex2() {
908 return keyword("ex2");
909 }
910
911 public PTXHATKernelBuilder shl() {
912 return keyword("shl");
913 }
914
915 public PTXHATKernelBuilder shr() {
916 return keyword("shr");
917 }
918
919 public PTXHATKernelBuilder neg() {
920 return keyword("neg");
921 }
922
923 public PTXHATKernelBuilder call() {
924 return keyword("call");
925 }
926
927 public PTXHATKernelBuilder exit() {
928 return keyword("exit");
929 }
930
931 public PTXHATKernelBuilder brkpt() {
932 return keyword("brkpt");
933 }
934
935 public PTXHATKernelBuilder ptxIndent() {
936 return space().space().space().space();
937 }
938
939 public PTXHATKernelBuilder u32() {
940 return dot().typeName(PTXRegister.Type.U32.getName());
941 }
942
943 public PTXHATKernelBuilder s32() {
944 return dot().typeName(PTXRegister.Type.S32.getName());
945 }
946
947 public PTXHATKernelBuilder f32() {
948 return dot().typeName(PTXRegister.Type.F32.getName());
949 }
950
951 public PTXHATKernelBuilder b32() {
952 return dot().typeName(PTXRegister.Type.B32.getName());
953 }
954
955 public PTXHATKernelBuilder u64() {
956 return dot().typeName(PTXRegister.Type.U64.getName());
957 }
958
959 public PTXHATKernelBuilder s64() {
960 return dot().typeName(PTXRegister.Type.S64.getName());
961 }
962
963 public PTXHATKernelBuilder f64() {
964 return dot().typeName(PTXRegister.Type.F64.getName());
965 }
966
967 public PTXHATKernelBuilder b64() {
968 return dot().typeName(PTXRegister.Type.B64.getName());
969 }
970
971 public PTXHATKernelBuilder version() {
972 return dot().keyword("version");
973 }
974
975 public PTXHATKernelBuilder target() {
976 return dot().keyword("target");
977 }
978
979 public PTXHATKernelBuilder addressSize() {
980 return dot().keyword("address_size");
981 }
982
983 public PTXHATKernelBuilder major(int major) {
984 return intVal(major);
985 }
986
987 public PTXHATKernelBuilder minor(int minor) {
988 return intVal(minor);
989 }
990
991 public PTXHATKernelBuilder target(String target) {
992 return keyword(target);
993 }
994
995 public PTXHATKernelBuilder size(int addressSize) {
996 return intVal(addressSize);
997 }
998
999 public PTXHATKernelBuilder funcName(String funcName) {
1000 return identifier(funcName);
1001 }
1002
1003 public PTXHATKernelBuilder visible() {
1004 return dot().keyword("visible");
1005 }
1006
1007 public PTXHATKernelBuilder entry() {
1008 return dot().keyword("entry");
1009 }
1010
1011 public PTXHATKernelBuilder func() {
1012 return dot().keyword("func");
1013 }
1014
1015 public PTXHATKernelBuilder oabrace() {
1016 return symbol("<");
1017 }
1018
1019 public PTXHATKernelBuilder cabrace() {
1020 return symbol(">");
1021 }
1022
1023 public PTXHATKernelBuilder regName(PTXRegister reg) {
1024 return identifier(reg.name());
1025 }
1026
1027 public PTXHATKernelBuilder regName(String regName) {
1028 return identifier(regName);
1029 }
1030
1031 public PTXHATKernelBuilder regType(Value val) {
1032 return keyword(getReg(val).type().getName());
1033 }
1034
1035 public PTXHATKernelBuilder regType(PTXRegister.Type t) {
1036 return keyword(t.getName());
1037 }
1038
1039 public PTXHATKernelBuilder regTypePrefix(PTXRegister.Type t) {
1040 return keyword(t.getRegPrefix());
1041 }
1042
1043 public PTXHATKernelBuilder reg() {
1044 return dot().keyword("reg");
1045 }
1046
1047 public PTXHATKernelBuilder retVal() {
1048 return keyword("retval");
1049 }
1050
1051 public PTXHATKernelBuilder intVal(int i) {
1052 return constant(String.valueOf(i));
1053 }
1054
1055 public PTXHATKernelBuilder floatVal(String s) {
1056 return constant("0f").constant(s);
1057 }
1058
1059 public PTXHATKernelBuilder doubleVal(String s) {
1060 return constant("0d").constant(s);
1061 }
1062 }