1 /*
2 * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation. Oracle designates this
8 * particular file as subject to the "Classpath" exception as provided
9 * by Oracle in the LICENSE file that accompanied this code.
10 *
11 * This code is distributed in the hope that it will be useful, but WITHOUT
12 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
13 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
14 * version 2 for more details (a copy is included in the LICENSE file that
15 * accompanied this code).
16 *
17 * You should have received a copy of the GNU General Public License version
18 * 2 along with this work; if not, write to the Free Software Foundation,
19 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
20 *
21 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
22 * or visit www.oracle.com if you need additional information or have any
23 * questions.
24 */
25 package hat.backend.ffi;
26
27 import optkl.FuncOpParams;
28 import optkl.ParamVar;
29 import optkl.codebuilders.CodeBuilder;
30
31 import jdk.incubator.code.*;
32 import jdk.incubator.code.dialect.core.CoreOp;
33 import jdk.incubator.code.dialect.java.JavaOp;
34 import jdk.incubator.code.dialect.java.JavaType;
35
36 import java.lang.foreign.MemoryLayout;
37 import java.lang.invoke.MethodHandles;
38 import java.util.ArrayList;
39 import java.util.HashMap;
40 import java.util.List;
41 import java.util.Map;
42 import java.util.stream.Stream;
43
44 import static optkl.OpHelper.FieldAccess.fieldAccess;
45 import static optkl.OpHelper.Invoke;
46
47 import static optkl.OpHelper.Invoke.invoke;
48
49
50 public class PTXHATKernelBuilder extends CodeBuilder<PTXHATKernelBuilder> {
51
52 Map<Value, PTXRegister> varToRegMap;
53 List<String> paramNames;
54 List<Block.Parameter> paramObjects;
55 Map<Field, PTXRegister> fieldToRegMap;
56
57 HashMap<PTXRegister.Type, Integer> ordinalMap;
58
59 PTXRegister returnReg;
60 private int addressSize;
61
62 public enum Field {
63 NTID_X ("ntid.x", false),
64 CTAID_X ("ctaid.x", false),
65 TID_X ("tid.x", false),
66 KC_X ("x", false),
67 KC_ADDR("kc", true),
68 KC_MAXX ("maxX", false);
69
70 private final String name;
71 private final boolean destination;
72
73 Field(String name, boolean destination) {
74 this.name = name;
75 this.destination = destination;
76 }
77 public String toString() {
78 return this.name;
79 }
80 public boolean isDestination() {return this.destination;}
81 }
82
83 public PTXHATKernelBuilder(int addressSize) {
84 varToRegMap = new HashMap<>();
85 paramNames = new ArrayList<>();
86 fieldToRegMap = new HashMap<>();
87 paramObjects = new ArrayList<>();
88 ordinalMap = new HashMap<>();
89 this.addressSize = addressSize;
90 }
91
92 public PTXHATKernelBuilder() {
93 this(32);
94 }
95
96 public void ptxHeader(int major, int minor, String target, int addressSize) {
97 this.addressSize = addressSize;
98 version().sp().major(major).dot().minor(minor).nl();
99 target().sp().target(target).nl();
100 addressSize().sp().size(addressSize);
101 }
102
103 public void functionHeader(String funcName, boolean entry, CodeType yieldType) {
104 if (entry) {
105 visible().sp().entry().sp();
106 } else {
107 func().sp();
108 }
109 if (!yieldType.toString().equals("void")) {
110 returnReg = new PTXRegister(getOrdinal(getResultType(yieldType)), getResultType(yieldType));
111 returnReg.name("%retReg");
112 oparen().dot().param().sp().paramType(yieldType);
113 sp().regName(returnReg).cparen().sp();
114 }
115 funcName(funcName);
116 }
117
118 public PTXHATKernelBuilder parameters(List<FuncOpParams.Info> infoList) {
119 paren(_ ->
120 nl()
121 .commaNlSeparated(
122 infoList,
123 info -> {
124 ptxIndent().dot().param().sp().paramType(info.javaType);
125 sp().regName(info.varOp.varName());
126 paramNames.add(info.varOp.varName());
127 }
128 ).nl()).nl();
129 return this;
130 }
131
132 public void blockBody(MethodHandles.Lookup lookup,Block block, Stream<Op> ops) {
133 if (block.index() == 0) {
134 for (Block.Parameter p : block.parameters()) {
135 ptxIndent().ld().dot().param();
136 resultType(p.type(), false).ptxIndent().sp();
137 reg(p, getResultType(p.type())).csp().osbrace().regName(paramNames.get(p.index())).csbrace().semicolon().nl();
138 paramObjects.add(p);
139 }
140 }
141 nl();
142 block(block);
143 colon().nl();
144 ops.forEach(op -> {
145 if (invoke(lookup,op) instanceof Invoke invoke && !invoke.isMappableIface()) {
146 ptxIndent().convert(lookup,op).nl();
147 } else {
148 ptxIndent().convert(lookup,op).semicolon().nl();
149 }
150 });
151 }
152
153 public void ptxRegisterDecl() {
154 for (PTXRegister.Type t : ordinalMap.keySet()) {
155 ptxIndent().reg().sp();
156 if (t.equals(PTXRegister.Type.U32)) {
157 b32();
158 } else if (t.equals(PTXRegister.Type.U64)) {
159 b64();
160 } else {
161 dot().regType(t);
162 }
163 ptxIndent().regTypePrefix(t).oabrace().intVal(ordinalMap.get(t)).cabrace().semicolon().nl();
164 }
165 nl();
166 }
167
168 public void functionPrologue() {
169 obrace().nl();
170 }
171
172 public void functionEpilogue() {
173 cbrace();
174 }
175
176
177 public PTXHATKernelBuilder convert(MethodHandles.Lookup lookup,Op op) {
178 switch (op) {
179 case JavaOp.FieldAccessOp.FieldLoadOp $ -> fieldLoad(lookup,$);
180 case JavaOp.FieldAccessOp.FieldStoreOp $ -> fieldStore($);
181 case JavaOp.BinaryOp $ -> binaryOperation($);
182 case JavaOp.CompareOp $ -> compareOperation($);
183 case JavaOp.ConvOp $ -> conv($);
184 case CoreOp.ConstantOp $ -> constant($);
185 case CoreOp.YieldOp $ -> javaYield($);
186 case JavaOp.InvokeOp $ -> methodCall(invoke(lookup,$));
187 case CoreOp.VarOp $ when ParamVar.of($) != null -> varFuncDeclaration($);
188 case CoreOp.VarOp $ -> varDeclaration($);
189 case CoreOp.ReturnOp $ -> ret($);
190 case JavaOp.BreakOp $ -> javaBreak($);
191 default -> { // Why are these switch ops not just inlined above?
192 switch (op){
193 case CoreOp.BranchOp $ -> branch($);
194 case CoreOp.ConditionalBranchOp $ -> condBranch($);
195 case JavaOp.NegOp $ -> neg($);
196 case PTXPtrOp $ -> ptxPtr($);
197 default -> throw new IllegalStateException("op translation doesn't exist");
198 }
199 }
200 }
201 return this;
202 }
203
204 public void ptxPtr(PTXPtrOp op) {
205 PTXRegister source;
206 int offset = (int) op.boundSchema.groupLayout().byteOffset(MemoryLayout.PathElement.groupElement(op.fieldName));
207
208 if (op.fieldName.equals("array")) {
209 source = new PTXRegister(incrOrdinal(addressType()), addressType());
210 addKeyword().s64().sp().regName(source).csp().reg(op.operands().get(0)).csp().reg(op.operands().get(1)).ptxNl();
211 } else {
212 source = getReg(op.operands().getFirst());
213 }
214
215 if (op.resultType.toString().equals("void")) {
216 st().global().dot().regType(op.operands().getLast()).sp().address(source.name(), offset).csp().reg(op.operands().getLast());
217 } else {
218 ld().global().resultType(op.resultType(), true).sp().reg(op.result(), getResultType(op.resultType())).csp().address(source.name(), offset);
219 }
220 }
221
222 public void fieldLoad(MethodHandles.Lookup lookup,JavaOp.FieldAccessOp.FieldLoadOp fieldLoadOp) {
223
224 var fieldAccess = fieldAccess(lookup,fieldLoadOp);
225 if (fieldAccess.named(Field.KC_X.toString())) {
226 if (!fieldToRegMap.containsKey(Field.KC_X)) {
227 loadKcX(fieldLoadOp.result());
228 } else {
229 mov().u32().sp().resultReg(fieldLoadOp, PTXRegister.Type.U32).csp().fieldReg(Field.KC_X);
230 }
231 } else if (fieldAccess.named(Field.KC_MAXX.toString())) {
232 if (!fieldToRegMap.containsKey(Field.KC_X)) {
233 loadKcX(fieldLoadOp.operands().getFirst());
234 }
235 ld().global().u32().sp().fieldReg(Field.KC_MAXX, fieldLoadOp.result()).csp()
236 .address(fieldToRegMap.get(Field.KC_ADDR).name(), 4);
237 } else {
238 ld().global().u32().sp().resultReg(fieldLoadOp, PTXRegister.Type.U64).csp().reg(fieldLoadOp.operands().getFirst());
239 }
240 }
241
242 public void loadKcX(Value value) {
243 cvta().to().global().size().sp().fieldReg(Field.KC_ADDR).csp()
244 .reg(paramObjects.get(paramNames.indexOf(Field.KC_ADDR.toString())), addressType()).ptxNl();
245 mov().u32().sp().fieldReg(Field.NTID_X).csp().percent().regName(Field.NTID_X.toString()).ptxNl();
246 mov().u32().sp().fieldReg(Field.CTAID_X).csp().percent().regName(Field.CTAID_X.toString()).ptxNl();
247 mov().u32().sp().fieldReg(Field.TID_X).csp().percent().regName(Field.TID_X.toString()).ptxNl();
248 mad().lo().s32().sp().fieldReg(Field.KC_X, value).csp().fieldReg(Field.CTAID_X)
249 .csp().fieldReg(Field.NTID_X).csp().fieldReg(Field.TID_X).ptxNl();
250 st().global().u32().sp().address(fieldToRegMap.get(Field.KC_ADDR).name()).csp().fieldReg(Field.KC_X);
251 }
252
253 public void fieldStore(JavaOp.FieldAccessOp.FieldStoreOp op) {
254 // TODO: fix
255 st().global().u64().sp().resultReg(op, PTXRegister.Type.U64).csp().reg(op.operands().getFirst());
256 }
257 // this might be duplication of CodeBuilder symbol....
258 @Override public
259 PTXHATKernelBuilder symbol(Op op) {
260 return switch (op) {
261 case JavaOp.ModOp _ -> remKw();
262 case JavaOp.MulOp _ -> mulKeword();
263 case JavaOp.DivOp _ -> divKeyword();
264 case JavaOp.AddOp _ -> addKeyword();
265 case JavaOp.SubOp _ -> subKeyword();
266 case JavaOp.LtOp _ -> ltKeyword();
267 case JavaOp.GtOp _ -> gtKeyword();
268 case JavaOp.LeOp _ -> le();
269 case JavaOp.GeOp _ -> ge();
270 case JavaOp.NeqOp _ -> neKeyword();
271 case JavaOp.EqOp _ -> eqKeyword();
272 case JavaOp.OrOp _ -> or();
273 case JavaOp.AndOp _ -> and();
274 case JavaOp.XorOp _ -> xor();
275 case JavaOp.LshlOp _ -> shl();
276 case JavaOp.AshrOp _, JavaOp.LshrOp _ -> shr();
277 default -> throw new IllegalStateException("Unexpected value");
278 };
279 }
280
281 public void binaryOperation(JavaOp.BinaryOp op) {
282 symbol(op);
283 if (getResultType(op.resultType()).getBasicType().equals(PTXRegister.Type.BasicType.FLOATING)
284 && (op instanceof JavaOp.DivOp || op instanceof JavaOp.MulOp)) {
285 rn();
286 } else if (!getResultType(op.resultType()).getBasicType().equals(PTXRegister.Type.BasicType.FLOATING)
287 && op instanceof JavaOp.MulOp) {
288 lo();
289 }
290 resultType(op.resultType(), true).sp();
291 resultReg(op, getResultType(op.resultType()));
292 csp();
293 reg(op.operands().getFirst());
294 csp();
295 reg(op.operands().get(1));
296 }
297
298 public void compareOperation(JavaOp.CompareOp op) {
299 setp().dot();
300 symbol(op).resultType(op.operands().getFirst().type(), true).sp();
301 resultReg(op, PTXRegister.Type.PREDICATE);
302 csp();
303 reg(op.operands().getFirst());
304 csp();
305 reg(op.operands().get(1));
306 }
307
308 public void conv(JavaOp.ConvOp op) {
309 if (op.resultType().equals(JavaType.LONG)) {
310 if (isIndex(op)) {
311 mulKeword().wide().s32().sp().resultReg(op, PTXRegister.Type.U64).csp()
312 .reg(op.operands().getFirst()).csp().intVal(4);
313 } else {
314 cvt().u64().dot().regType(op.operands().getFirst()).sp()
315 .resultReg(op, PTXRegister.Type.U64).csp().reg(op.operands().getFirst()).ptxNl();
316 }
317 } else if (op.resultType().equals(JavaType.FLOAT)) {
318 cvt().rn().f32().dot().regType(op.operands().getFirst()).sp()
319 .resultReg(op, PTXRegister.Type.F32).csp().reg(op.operands().getFirst());
320 } else if (op.resultType().equals(JavaType.DOUBLE)) {
321 cvt();
322 if (op.operands().getFirst().type().equals(JavaType.INT)) {
323 rn();
324 }
325 f64().dot().regType(op.operands().getFirst()).sp()
326 .resultReg(op, PTXRegister.Type.F64).csp().reg(op.operands().getFirst());
327 } else if (op.resultType().equals(JavaType.INT)) {
328 cvt();
329 if (op.operands().getFirst().type().equals(JavaType.DOUBLE) || op.operands().getFirst().type().equals(JavaType.FLOAT)) {
330 rzi();
331 } else {
332 rn();
333 }
334 s32().dot().regType(op.operands().getFirst()).sp()
335 .resultReg(op, PTXRegister.Type.S32).csp().reg(op.operands().getFirst());
336 } else {
337 cvt().rn().s32().dot().regType(op.operands().getFirst()).sp()
338 .resultReg(op, PTXRegister.Type.S32).csp().reg(op.operands().getFirst());
339 }
340 }
341
342
343
344
345
346 public static class PTXRegister {
347 private String name;
348 private final Type type;
349
350 public enum Type {
351 S8 (8, BasicType.SIGNED, "s8", "%s"),
352 S16 (16, BasicType.SIGNED, "s16", "%s"),
353 S32 (32, BasicType.SIGNED, "s32", "%s"),
354 S64 (64, BasicType.SIGNED, "s64", "%sd"),
355 U8 (8, BasicType.UNSIGNED, "u8", "%r"),
356 U16 (16, BasicType.UNSIGNED, "u16", "%r"),
357 U32 (32, BasicType.UNSIGNED, "u32", "%r"),
358 U64 (64, BasicType.UNSIGNED, "u64", "%rd"),
359 F16 (16, BasicType.FLOATING, "f16", "%f"),
360 F16X2 (16, BasicType.FLOATING, "f16", "%f"),
361 F32 (32, BasicType.FLOATING, "f32", "%f"),
362 F64 (64, BasicType.FLOATING, "f64", "%fd"),
363 B8 (8, BasicType.BIT, "b8", "%b"),
364 B16 (16, BasicType.BIT, "b16", "%b"),
365 B32 (32, BasicType.BIT, "b32", "%b"),
366 B64 (64, BasicType.BIT, "b64", "%bd"),
367 B128 (128, BasicType.BIT, "b128", "%b"),
368 PREDICATE (1, BasicType.PREDICATE, "pred", "%p");
369
370 public enum BasicType {
371 SIGNED,
372 UNSIGNED,
373 FLOATING,
374 BIT,
375 PREDICATE
376 }
377
378 private final int size;
379 private final BasicType basicType;
380 private final String name;
381 private final String regPrefix;
382
383 Type(int size, BasicType type, String name, String regPrefix) {
384 this.size = size;
385 this.basicType = type;
386 this.name = name;
387 this.regPrefix = regPrefix;
388 }
389
390 public int getSize() {
391 return this.size;
392 }
393
394 public BasicType getBasicType() {
395 return this.basicType;
396 }
397
398 public String getName() {
399 return this.name;
400 }
401
402 public String getRegPrefix() {
403 return this.regPrefix;
404 }
405 }
406
407 public PTXRegister(int num, Type type) {
408 this.type = type;
409 this.name = type.regPrefix + num;
410 }
411
412 public String name() {
413 return this.name;
414 }
415
416 public void name(String name) {
417 this.name = name;
418 }
419
420 public Type type() {
421 return this.type;
422 }
423 }
424
425
426 private boolean isIndex(JavaOp.ConvOp op) {
427 for (Op.Result r : op.result().uses()) {
428 if (r.op() instanceof PTXPtrOp) return true;
429 }
430 return false;
431 }
432
433 public void constant(CoreOp.ConstantOp op) {
434 mov().resultType(op.resultType(), false).sp().resultReg(op, getResultType(op.resultType())).csp();
435 if (op.resultType().toString().equals("float")) {
436 if (op.value().toString().equals("0.0")) {
437 floatVal("00000000");
438 } else {
439 floatVal(Integer.toHexString(Float.floatToIntBits(Float.parseFloat(op.value().toString()))).toUpperCase());
440 }
441 } else {
442 constant(op.value().toString());
443 }
444 }
445
446 public void javaYield(CoreOp.YieldOp op) {
447 exit();
448 }
449
450 // S32Array and S32Array2D functions can be deleted after schema is done
451 public void methodCall(Invoke invoke) {
452 // Invoke invoke = Invoke.invokeOpHelper(MethodHandles.lookup(),invokeOp);
453 switch (invoke.op().invokeReference().toString()) {
454 // S32Array functions
455 case "hat.buffer.S32Array::array(long)int" -> {
456 PTXRegister temp = new PTXRegister(incrOrdinal(addressType()), addressType());
457 addKeyword().s64().sp().regName(temp).csp().reg(invoke.op().operands().getFirst()).csp().reg(invoke.op().operands().get(1)).ptxNl();
458 ld().global().u32().sp().resultReg(invoke.op(), PTXRegister.Type.U32).csp().address(temp.name(), 4);
459 }
460 case "hat.buffer.S32Array::array(long, int)void" -> {
461 PTXRegister temp = new PTXRegister(incrOrdinal(addressType()), addressType());
462 addKeyword().s64().sp().regName(temp).csp().reg(invoke.op().operands().getFirst()).csp().reg(invoke.op().operands().get(1)).ptxNl();
463 st().global().u32().sp().address(temp.name(), 4).csp().reg(invoke.op().operands().get(2));
464 }
465 case "hat.buffer.S32Array::length()int" -> {
466 ld().global().u32().sp().resultReg(invoke.op(), PTXRegister.Type.U32).csp().address(getReg(invoke.op().operands().getFirst()).name());
467 }
468 // S32Array2D functions
469 case "hat.buffer.S32Array2D::array(long, int)void" -> {
470 PTXRegister temp = new PTXRegister(incrOrdinal(addressType()), addressType());
471 addKeyword().s64().sp().regName(temp).csp().reg(invoke.op().operands().getFirst()).csp().reg(invoke.op().operands().get(1)).ptxNl();
472 st().global().u32().sp().address(temp.name(), 8).csp().reg(invoke.op().operands().get(2));
473 }
474 case "hat.buffer.S32Array2D::width()int" -> {
475 ld().global().u32().sp().resultReg(invoke.op(), PTXRegister.Type.U32).csp().address(getReg(invoke.op().operands().getFirst()).name());
476 }
477 case "hat.buffer.S32Array2D::height()int" -> {
478 ld().global().u32().sp().resultReg(invoke.op(), PTXRegister.Type.U32).csp().address(getReg(invoke.op().operands().getFirst()).name(), 4);
479 }
480 // Java Math function
481 case "java.lang.Math::sqrt(double)double" -> {
482 sqrt().rn().f64().sp().resultReg(invoke.op(), PTXRegister.Type.F64).csp().reg(invoke.op().operands().getFirst()).semicolon();
483 }
484 default -> {
485 obrace().nl().ptxIndent();
486 for (int i = 0; i < invoke.op().operands().size(); i++) {
487 dot().param().sp().paramType(invoke.op().operands().get(i).type()).sp().param().intVal(i).ptxNl();
488 st().dot().param().paramType(invoke.op().operands().get(i).type()).sp().osbrace().param().intVal(i).csbrace().csp().reg(invoke.op().operands().get(i)).ptxNl();
489 }
490 dot().param().sp().paramType(invoke.op().resultType()).sp().retVal().ptxNl();
491 call().uni().sp().oparen().retVal().cparen().csp().id(invoke.name()).csp();
492 final int[] counter = {0};
493 paren(_ ->
494 commaSpaceSeparated(
495 invoke.op().operands(),
496 _ -> param().intVal(counter[0]++)
497 )
498 ).ptxNl();
499 ld().dot().param().paramType(invoke.op().resultType()).sp().resultReg(invoke.op(), getResultType(invoke.op().resultType())).csp().osbrace().retVal().csbrace();
500 ptxNl().cbrace();
501 }
502 }
503 }
504
505 public void varDeclaration(CoreOp.VarOp op) {
506 ld().dot().param().resultType(op.resultType(), false).sp().resultReg(op, addressType()).csp().reg(op.operands().getFirst());
507 }
508
509 public void varFuncDeclaration(CoreOp.VarOp op) {
510 ld().dot().param().resultType(op.resultType(), false).sp().resultReg(op, addressType()).csp().reg(op.operands().getFirst());
511 }
512
513 public void ret(CoreOp.ReturnOp op) {
514 if (!op.operands().isEmpty()) {
515 st().dot().param();
516 if (returnReg.type().equals(PTXRegister.Type.U32)) {
517 b32();
518 } else if (returnReg.type().equals(PTXRegister.Type.U64)) {
519 b64();
520 } else {
521 dot().regType(returnReg.type());
522 }
523 sp().osbrace().regName(returnReg).csbrace().csp().reg(op.operands().getFirst()).ptxNl();
524 }
525 ret();
526 }
527
528 public void javaBreak(JavaOp.BreakOp op) {
529 brkpt();
530 }
531
532 public void branch(CoreOp.BranchOp op) {
533 loadBlockParams(op.successors().getFirst());
534 bra().sp().block(op.successors().getFirst().targetBlock());
535 }
536
537 public void condBranch(CoreOp.ConditionalBranchOp op) {
538 loadBlockParams(op.successors().getFirst());
539 loadBlockParams(op.successors().getLast());
540 at().reg(op.operands().getFirst()).sp()
541 .bra().sp().block(op.successors().getFirst().targetBlock()).ptxNl();
542 bra().sp().block(op.successors().getLast().targetBlock());
543 }
544
545 public void neg(JavaOp.NegOp op) {
546 neg().resultType(op.resultType(), true).sp().reg(op.result(), getResultType(op.resultType())).csp().reg(op.operands().getFirst());
547 }
548
549 /*
550 * Helper functions for printing blocks and variables
551 */
552
553 public void loadBlockParams(Block.Reference block) {
554 for (int i = 0; i < block.arguments().size(); i++) {
555 Block.Parameter p = block.targetBlock().parameters().get(i);
556 mov().resultType(p.type(), false).sp().reg(p, getResultType(p.type()))
557 .csp().reg(block.arguments().get(i)).ptxNl();
558 }
559 }
560
561 public PTXHATKernelBuilder block(Block block) {
562 return type("block_").intVal(block.index());
563 }
564
565 public PTXHATKernelBuilder fieldReg(Field ref) {
566 if (fieldToRegMap.containsKey(ref)) {
567 return regName(fieldToRegMap.get(ref));
568 }
569 if (ref.isDestination()) {
570 fieldToRegMap.putIfAbsent(ref, new PTXRegister(incrOrdinal(addressType()), addressType()));
571 } else {
572 fieldToRegMap.putIfAbsent(ref, new PTXRegister(incrOrdinal(PTXRegister.Type.U32), PTXRegister.Type.U32));
573 }
574 return regName(fieldToRegMap.get(ref));
575 }
576
577 public PTXHATKernelBuilder fieldReg(Field ref, Value value) {
578 if (fieldToRegMap.containsKey(ref)) {
579 return regName(fieldToRegMap.get(ref));
580 }
581 if (ref.isDestination()) {
582 fieldToRegMap.putIfAbsent(ref, new PTXRegister(getOrdinal(addressType()), addressType()));
583 return reg(value, addressType());
584 } else {
585 fieldToRegMap.putIfAbsent(ref, new PTXRegister(getOrdinal(PTXRegister.Type.U32), PTXRegister.Type.U32));
586 return reg(value, PTXRegister.Type.U32);
587 }
588 }
589
590 public Field getFieldObj(String fieldName) {
591 for (Field f : fieldToRegMap.keySet()) {
592 if (f.toString().equals(fieldName)) return f;
593 }
594 throw new IllegalStateException("no existing field");
595 }
596
597 public PTXHATKernelBuilder resultReg(Op op, PTXRegister.Type type) {
598 return id(addReg(op.result(), type));
599 }
600
601 public PTXHATKernelBuilder reg(Value val, PTXRegister.Type type) {
602 if (varToRegMap.containsKey(val)) {
603 return regName(getReg(val));
604 } else {
605 return id(addReg(val, type));
606 }
607 }
608
609 public PTXHATKernelBuilder reg(Value val) {
610 return regName(getReg(val));
611 }
612
613 public PTXRegister getReg(Value val) {
614 if (varToRegMap.get(val) == null && val instanceof Op.Result result && result.op() instanceof JavaOp.FieldAccessOp.FieldLoadOp fieldLoadOp) {
615 return fieldToRegMap.get(getFieldObj(fieldLoadOp.fieldReference().name()));
616 }
617 if (varToRegMap.containsKey(val)) {
618 return varToRegMap.get(val);
619 } else {
620 throw new IllegalStateException("var to reg mapping doesn't exist");
621 }
622 }
623
624 public String addReg(Value val, PTXRegister.Type type) {
625 if (varToRegMap.containsKey(val)) {
626 return varToRegMap.get(val).name();
627 }
628 varToRegMap.put(val, new PTXRegister(incrOrdinal(type), type));
629 return varToRegMap.get(val).name();
630 }
631
632 public Integer getOrdinal(PTXRegister.Type type) {
633 ordinalMap.putIfAbsent(type, 1);
634 return ordinalMap.get(type);
635 }
636
637 public Integer incrOrdinal(PTXRegister.Type type) {
638 ordinalMap.putIfAbsent(type, 1);
639 int out = ordinalMap.get(type);
640 ordinalMap.put(type, out + 1);
641 return out;
642 }
643
644 public PTXHATKernelBuilder size() {
645 return (addressSize == 32) ? u32() : u64();
646 }
647
648 public PTXRegister.Type addressType() {
649 return (addressSize == 32) ? PTXRegister.Type.U32 : PTXRegister.Type.U64;
650 }
651
652 public PTXHATKernelBuilder resultType(CodeType type, boolean signedResult) {
653 PTXRegister.Type res = getResultType(type);
654 if (signedResult && (res == PTXRegister.Type.U32)) return s32();
655 return dot().type(getResultType(type).getName());
656 }
657
658 public PTXHATKernelBuilder paramType(CodeType type) {
659 PTXRegister.Type res = getResultType(type);
660 if (res == PTXRegister.Type.U32) return b32();
661 if (res == PTXRegister.Type.U64) return b64();
662 return dot().type(getResultType(type).getName());
663 }
664
665 public PTXRegister.Type getResultType(CodeType type) {
666 switch (type.toString()) {
667 case "float" -> {
668 return PTXRegister.Type.F32;
669 }
670 case "double" -> {
671 return PTXRegister.Type.F64;
672 }
673 case "int" -> {
674 return PTXRegister.Type.U32;
675 }
676 case "boolean" -> {
677 return PTXRegister.Type.PREDICATE;
678 }
679 default -> {
680 return PTXRegister.Type.U64;
681 }
682 }
683 }
684
685 /*
686 * Basic CodeBuilder functions
687 */
688
689 // used for parameter list
690 // prints out items separated by a comma then new line
691 // Don't know why this was overriding with the same code grf.
692 /* @Override
693 public <I> PTXHATKernelBuilder commaNlSeparated(Iterable<I> iterable, Consumer<I> c) {
694 StreamCounter.of(iterable, (counter, t) -> {
695 if (counter.isNotFirst()) {
696 comma().nl();
697 }
698 c.accept(t);
699 });
700 return self();
701 }
702 */
703 public PTXHATKernelBuilder address(String address) {
704 return osbrace().constant(address).csbrace();
705 }
706
707 public PTXHATKernelBuilder address(String address, int offset) {
708 osbrace().constant(address);
709 if (offset == 0) {
710 return csbrace();
711 } else if (offset > 0) {
712 plus();
713 }
714 return intVal(offset).csbrace();
715 }
716
717 public PTXHATKernelBuilder ptxNl() {
718 return semicolon().nl().ptxIndent();
719 }
720
721
722 public PTXHATKernelBuilder param() {
723 return keyword("param");
724 }
725
726 public PTXHATKernelBuilder global() {
727 return dot().keyword("global");
728 }
729
730 public PTXHATKernelBuilder rn() {
731 return dot().keyword("rn");
732 }
733
734 public PTXHATKernelBuilder rm() {
735 return dot().keyword("rm");
736 }
737
738 public PTXHATKernelBuilder rzi() {
739 return dot().keyword("rzi");
740 }
741
742 public PTXHATKernelBuilder to() {
743 return dot().keyword("to");
744 }
745
746 public PTXHATKernelBuilder lo() {
747 return dot().keyword("lo");
748 }
749
750 public PTXHATKernelBuilder wide() {
751 return dot().keyword("wide");
752 }
753
754 public PTXHATKernelBuilder uni() {
755 return dot().keyword("uni");
756 }
757
758 public PTXHATKernelBuilder sat() {
759 return dot().keyword("sat");
760 }
761
762 public PTXHATKernelBuilder ftz() {
763 return dot().keyword("ftz");
764 }
765
766 public PTXHATKernelBuilder approx() {
767 return dot().keyword("approx");
768 }
769
770 public PTXHATKernelBuilder mov() {
771 return keyword("mov");
772 }
773
774 public PTXHATKernelBuilder setp() {
775 return keyword("setp");
776 }
777
778 public PTXHATKernelBuilder selp() {
779 return keyword("selp");
780 }
781
782 public PTXHATKernelBuilder ld() {
783 return keyword("ld");
784 }
785
786 public PTXHATKernelBuilder st() {
787 return keyword("st");
788 }
789
790 public PTXHATKernelBuilder cvt() {
791 return keyword("cvt");
792 }
793
794 public PTXHATKernelBuilder bra() {
795 return keyword("bra");
796 }
797
798 public PTXHATKernelBuilder ret() {
799 return keyword("ret");
800 }
801
802 public PTXHATKernelBuilder remKw() {
803 return keyword("rem");
804 }
805
806 public PTXHATKernelBuilder mulKeword() {
807 return keyword("mul");
808 }
809
810 public PTXHATKernelBuilder divKeyword() {
811 return keyword("div");
812 }
813
814 public PTXHATKernelBuilder rcp() {
815 return keyword("rcp");
816 }
817
818 public PTXHATKernelBuilder addKeyword() {
819 return keyword("add");
820 }
821
822 public PTXHATKernelBuilder subKeyword() {
823 return keyword("sub");
824 }
825
826 public PTXHATKernelBuilder ltKeyword() {
827 return keyword("lt");
828 }
829
830 public PTXHATKernelBuilder gtKeyword() {
831 return keyword("gt");
832 }
833
834 public PTXHATKernelBuilder le() {
835 return keyword("le");
836 }
837
838 public PTXHATKernelBuilder ge() {
839 return keyword("ge");
840 }
841
842 public PTXHATKernelBuilder geu() {
843 return keyword("geu");
844 }
845
846 public PTXHATKernelBuilder neKeyword() {
847 return keyword("ne");
848 }
849
850 public PTXHATKernelBuilder eqKeyword() {
851 return keyword("eq");
852 }
853
854 public PTXHATKernelBuilder xor() {
855 return keyword("xor");
856 }
857
858 public PTXHATKernelBuilder or() {
859 return keyword("or");
860 }
861
862 public PTXHATKernelBuilder and() {
863 return keyword("and");
864 }
865
866 public PTXHATKernelBuilder cvta() {
867 return keyword("cvta");
868 }
869
870 public PTXHATKernelBuilder mad() {
871 return keyword("mad");
872 }
873
874 public PTXHATKernelBuilder fma() {
875 return keyword("fma");
876 }
877
878 public PTXHATKernelBuilder sqrt() {
879 return keyword("sqrt");
880 }
881
882 public PTXHATKernelBuilder abs() {
883 return keyword("abs");
884 }
885
886 public PTXHATKernelBuilder ex2() {
887 return keyword("ex2");
888 }
889
890 public PTXHATKernelBuilder shl() {
891 return keyword("shl");
892 }
893
894 public PTXHATKernelBuilder shr() {
895 return keyword("shr");
896 }
897
898 public PTXHATKernelBuilder neg() {
899 return keyword("neg");
900 }
901
902 public PTXHATKernelBuilder call() {
903 return keyword("call");
904 }
905
906 public PTXHATKernelBuilder exit() {
907 return keyword("exit");
908 }
909
910 public PTXHATKernelBuilder brkpt() {
911 return keyword("brkpt");
912 }
913
914 public PTXHATKernelBuilder ptxIndent() {
915 return sp().sp().sp().sp();
916 }
917
918 public PTXHATKernelBuilder u32() {
919 return dot().type(PTXRegister.Type.U32.getName());
920 }
921
922 public PTXHATKernelBuilder s32() {
923 return dot().type(PTXRegister.Type.S32.getName());
924 }
925
926 public PTXHATKernelBuilder f32() {
927 return dot().type(PTXRegister.Type.F32.getName());
928 }
929
930 public PTXHATKernelBuilder b32() {
931 return dot().type(PTXRegister.Type.B32.getName());
932 }
933
934 public PTXHATKernelBuilder u64() {
935 return dot().type(PTXRegister.Type.U64.getName());
936 }
937
938 public PTXHATKernelBuilder s64() {
939 return dot().type(PTXRegister.Type.S64.getName());
940 }
941
942 public PTXHATKernelBuilder f64() {
943 return dot().type(PTXRegister.Type.F64.getName());
944 }
945
946 public PTXHATKernelBuilder b64() {
947 return dot().type(PTXRegister.Type.B64.getName());
948 }
949
950 public PTXHATKernelBuilder version() {
951 return dot().keyword("version");
952 }
953
954 public PTXHATKernelBuilder target() {
955 return dot().keyword("target");
956 }
957
958 public PTXHATKernelBuilder addressSize() {
959 return dot().keyword("address_size");
960 }
961
962 public PTXHATKernelBuilder major(int major) {
963 return intVal(major);
964 }
965
966 public PTXHATKernelBuilder minor(int minor) {
967 return intVal(minor);
968 }
969
970 public PTXHATKernelBuilder target(String target) {
971 return keyword(target);
972 }
973
974 public PTXHATKernelBuilder size(int addressSize) {
975 return intVal(addressSize);
976 }
977
978
979
980 public PTXHATKernelBuilder visible() {
981 return dot().keyword("visible");
982 }
983
984 public PTXHATKernelBuilder entry() {
985 return dot().keyword("entry");
986 }
987
988 public PTXHATKernelBuilder func() {
989 return dot().keyword("func");
990 }
991
992 public PTXHATKernelBuilder oabrace() {
993 return symbol("<");
994 }
995
996 public PTXHATKernelBuilder cabrace() {
997 return symbol(">");
998 }
999
1000 public PTXHATKernelBuilder regName(PTXRegister reg) {
1001 return id(reg.name());
1002 }
1003
1004 public PTXHATKernelBuilder regName(String regName) {
1005 return id(regName);
1006 }
1007
1008 public PTXHATKernelBuilder regType(Value val) {
1009 return keyword(getReg(val).type().getName());
1010 }
1011
1012 public PTXHATKernelBuilder regType(PTXRegister.Type t) {
1013 return keyword(t.getName());
1014 }
1015
1016 public PTXHATKernelBuilder regTypePrefix(PTXRegister.Type t) {
1017 return keyword(t.getRegPrefix());
1018 }
1019
1020 public PTXHATKernelBuilder reg() {
1021 return dot().keyword("reg");
1022 }
1023
1024 public PTXHATKernelBuilder retVal() {
1025 return keyword("retval");
1026 }
1027
1028 public PTXHATKernelBuilder intVal(int i) {
1029 return constant(String.valueOf(i));
1030 }
1031
1032 public PTXHATKernelBuilder floatVal(String s) {
1033 return constant("0f").constant(s);
1034 }
1035
1036 public PTXHATKernelBuilder doubleVal(String s) {
1037 return constant("0d").constant(s);
1038 }
1039 }