1 /*
  2  * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 package hat.backend.ffi;
 26 
 27 import hat.optools.*;
 28 import hat.text.CodeBuilder;
 29 import hat.util.StreamCounter;
 30 import jdk.incubator.code.Block;
 31 import jdk.incubator.code.Op;
 32 import jdk.incubator.code.TypeElement;
 33 import jdk.incubator.code.Value;
 34 import jdk.incubator.code.op.CoreOp;
 35 import jdk.incubator.code.type.JavaType;
 36 
 37 import java.lang.foreign.MemoryLayout;
 38 import java.util.ArrayList;
 39 import java.util.HashMap;
 40 import java.util.List;
 41 import java.util.Map;
 42 import java.util.function.Consumer;
 43 import java.util.stream.Stream;
 44 
 45 public class PTXHATKernelBuilder extends CodeBuilder<PTXHATKernelBuilder> {
 46 
 47     Map<Value, PTXRegister> varToRegMap;
 48     List<String> paramNames;
 49     List<Block.Parameter> paramObjects;
 50     Map<Field, PTXRegister> fieldToRegMap;
 51 
 52     HashMap<PTXRegister.Type, Integer> ordinalMap;
 53 
 54     PTXRegister returnReg;
 55     private int addressSize;
 56 
 57     public enum Field {
 58         NTID_X ("ntid.x", false),
 59         CTAID_X ("ctaid.x", false),
 60         TID_X ("tid.x", false),
 61         KC_X ("x", false),
 62         KC_ADDR("kc", true),
 63         KC_MAXX ("maxX", false);
 64 
 65         private final String name;
 66         private final boolean destination;
 67 
 68         Field(String name, boolean destination) {
 69             this.name = name;
 70             this.destination = destination;
 71         }
 72         public String toString() {
 73             return this.name;
 74         }
 75         public boolean isDestination() {return this.destination;}
 76     }
 77 
 78     public PTXHATKernelBuilder(int addressSize) {
 79         varToRegMap = new HashMap<>();
 80         paramNames = new ArrayList<>();
 81         fieldToRegMap = new HashMap<>();
 82         paramObjects = new ArrayList<>();
 83         ordinalMap = new HashMap<>();
 84         this.addressSize = addressSize;
 85     }
 86 
 87     public PTXHATKernelBuilder() {
 88         this(32);
 89     }
 90 
 91     public void ptxHeader(int major, int minor, String target, int addressSize) {
 92         this.addressSize = addressSize;
 93         version().space().major(major).dot().minor(minor).nl();
 94         target().space().target(target).nl();
 95         addressSize().space().size(addressSize);
 96     }
 97 
 98     public void functionHeader(String funcName, boolean entry, TypeElement yieldType) {
 99         if (entry) {
100             visible().space().entry().space();
101         } else {
102             func().space();
103         }
104         if (!yieldType.toString().equals("void")) {
105             returnReg = new PTXRegister(getOrdinal(getResultType(yieldType)), getResultType(yieldType));
106             returnReg.name("%retReg");
107             oparen().dot().param().space().paramType(yieldType);
108             space().regName(returnReg).cparen().space();
109         }
110         funcName(funcName);
111     }
112 
113     public PTXHATKernelBuilder parameters(List<FuncOpWrapper.ParamTable.Info> infoList) {
114         paren(_ -> nl().commaNlSeparated(infoList, (info) -> {
115             ptxIndent().dot().param().space().paramType(info.javaType);
116             space().regName(info.varOp.varName());
117             paramNames.add(info.varOp.varName());
118         }).nl()).nl();
119         return this;
120     }
121 
122     public void blockBody(Block block, Stream<OpWrapper<?>> ops) {
123         if (block.index() == 0) {
124             for (Block.Parameter p : block.parameters()) {
125                 ptxIndent().ld().dot().param();
126                 resultType(p.type(), false).ptxIndent().space();
127                 reg(p, getResultType(p.type())).commaSpace().osbrace().regName(paramNames.get(p.index())).csbrace().semicolon().nl();
128                 paramObjects.add(p);
129             }
130         }
131         nl();
132         block(block);
133         colon().nl();
134         ops.forEach(op -> {
135             if (op instanceof InvokeOpWrapper invoke && !invoke.isIfaceBufferMethod()) {
136                 ptxIndent().convert(op).nl();
137             } else {
138                 ptxIndent().convert(op).semicolon().nl();
139             }
140         });
141     }
142 
143     public void ptxRegisterDecl() {
144         for (PTXRegister.Type t : ordinalMap.keySet()) {
145             ptxIndent().reg().space();
146             if (t.equals(PTXRegister.Type.U32)) {
147                 b32();
148             } else if (t.equals(PTXRegister.Type.U64)) {
149                 b64();
150             } else {
151                 dot().regType(t);
152             }
153             ptxIndent().regTypePrefix(t).oabrace().intVal(ordinalMap.get(t)).cabrace().semicolon().nl();
154         }
155         nl();
156     }
157 
158     public void functionPrologue() {
159         obrace().nl();
160     }
161 
162     public void functionEpilogue() {
163         cbrace();
164     }
165 
166     public PTXHATKernelBuilder convert(OpWrapper<?> wrappedOp) {
167         switch (wrappedOp) {
168             case FieldLoadOpWrapper op -> fieldLoad(op);
169             case FieldStoreOpWrapper op -> fieldStore(op);
170             case BinaryArithmeticOrLogicOperation op -> binaryOperation(op);
171             case BinaryTestOpWrapper op -> binaryTest(op);
172             case ConvOpWrapper op -> conv(op);
173             case ConstantOpWrapper op -> constant(op);
174             case YieldOpWrapper op -> javaYield(op);
175             case InvokeOpWrapper op -> methodCall(op);
176             case VarDeclarationOpWrapper op -> varDeclaration(op);
177             case VarFuncDeclarationOpWrapper op -> varFuncDeclaration(op);
178             case ReturnOpWrapper op -> ret(op);
179             case JavaBreakOpWrapper op -> javaBreak(op);
180             default -> {
181                 switch (wrappedOp.op()){
182                     case CoreOp.BranchOp op -> branch(op);
183                     case CoreOp.ConditionalBranchOp op -> condBranch(op);
184                     case CoreOp.NegOp op -> neg(op);
185                     case PTXPtrOp op -> ptxPtr(op);
186                     default -> throw new IllegalStateException("op translation doesn't exist");
187                 }
188             }
189         }
190         return this;
191     }
192 
193     public void ptxPtr(PTXPtrOp op) {
194         PTXRegister source;
195         int offset = (int) op.boundSchema.groupLayout().byteOffset(MemoryLayout.PathElement.groupElement(op.fieldName));
196 
197         if (op.fieldName.equals("array")) {
198             source = new PTXRegister(incrOrdinal(addressType()), addressType());
199             add().s64().space().regName(source).commaSpace().reg(op.operands().get(0)).commaSpace().reg(op.operands().get(1)).ptxNl();
200         } else {
201             source = getReg(op.operands().getFirst());
202         }
203 
204         if (op.resultType.toString().equals("void")) {
205             st().global().dot().regType(op.operands().getLast()).space().address(source.name(), offset).commaSpace().reg(op.operands().getLast());
206         } else {
207             ld().global().resultType(op.resultType(), true).space().reg(op.result(), getResultType(op.resultType())).commaSpace().address(source.name(), offset);
208         }
209     }
210 
211     public void fieldLoad(FieldLoadOpWrapper op) {
212         if (op.fieldName().equals(Field.KC_X.toString())) {
213             if (!fieldToRegMap.containsKey(Field.KC_X)) {
214                 loadKcX(op.result());
215             } else {
216                 mov().u32().space().resultReg(op, PTXRegister.Type.U32).commaSpace().fieldReg(Field.KC_X);
217             }
218         } else if (op.fieldName().equals(Field.KC_MAXX.toString())) {
219             if (!fieldToRegMap.containsKey(Field.KC_X)) {
220                 loadKcX(op.operandNAsValue(0));
221             }
222             ld().global().u32().space().fieldReg(Field.KC_MAXX, op.result()).commaSpace()
223                     .address(fieldToRegMap.get(Field.KC_ADDR).name(), 4);
224         } else {
225             ld().global().u32().space().resultReg(op, PTXRegister.Type.U64).commaSpace().reg(op.operandNAsValue(0));
226         }
227     }
228 
229     public void loadKcX(Value value) {
230         cvta().to().global().size().space().fieldReg(Field.KC_ADDR).commaSpace()
231                 .reg(paramObjects.get(paramNames.indexOf(Field.KC_ADDR.toString())), addressType()).ptxNl();
232         mov().u32().space().fieldReg(Field.NTID_X).commaSpace().percent().regName(Field.NTID_X.toString()).ptxNl();
233         mov().u32().space().fieldReg(Field.CTAID_X).commaSpace().percent().regName(Field.CTAID_X.toString()).ptxNl();
234         mov().u32().space().fieldReg(Field.TID_X).commaSpace().percent().regName(Field.TID_X.toString()).ptxNl();
235         mad().lo().s32().space().fieldReg(Field.KC_X, value).commaSpace().fieldReg(Field.CTAID_X)
236                 .commaSpace().fieldReg(Field.NTID_X).commaSpace().fieldReg(Field.TID_X).ptxNl();
237         st().global().u32().space().address(fieldToRegMap.get(Field.KC_ADDR).name()).commaSpace().fieldReg(Field.KC_X);
238     }
239 
240     public void fieldStore(FieldStoreOpWrapper op) {
241         // TODO: fix
242         st().global().u64().space().resultReg(op, PTXRegister.Type.U64).commaSpace().reg(op.operandNAsValue(0));
243     }
244 
245     PTXHATKernelBuilder symbol(Op op) {
246         return switch (op) {
247             case CoreOp.ModOp _ -> rem();
248             case CoreOp.MulOp _ -> mul();
249             case CoreOp.DivOp _ -> div();
250             case CoreOp.AddOp _ -> add();
251             case CoreOp.SubOp _ -> sub();
252             case CoreOp.LtOp _ -> lt();
253             case CoreOp.GtOp _ -> gt();
254             case CoreOp.LeOp _ -> le();
255             case CoreOp.GeOp _ -> ge();
256             case CoreOp.NeqOp _ -> ne();
257             case CoreOp.EqOp _ -> eq();
258             case CoreOp.OrOp _ -> or();
259             case CoreOp.AndOp _ -> and();
260             case CoreOp.XorOp _ -> xor();
261             case CoreOp.LshlOp _ -> shl();
262             case CoreOp.AshrOp _, CoreOp.LshrOp _ -> shr();
263             default -> throw new IllegalStateException("Unexpected value");
264         };
265     }
266 
267     public void binaryOperation(BinaryArithmeticOrLogicOperation op) {
268         symbol(op.op());
269         if (getResultType(op.resultType()).getBasicType().equals(PTXRegister.Type.BasicType.FLOATING)
270                 && (op.op() instanceof CoreOp.DivOp || op.op() instanceof CoreOp.MulOp)) {
271             rn();
272         } else if (!getResultType(op.resultType()).getBasicType().equals(PTXRegister.Type.BasicType.FLOATING)
273                 && op.op() instanceof CoreOp.MulOp) {
274             lo();
275         }
276         resultType(op.resultType(), true).space();
277         resultReg(op, getResultType(op.resultType()));
278         commaSpace();
279         reg(op.operandNAsValue(0));
280         commaSpace();
281         reg(op.operandNAsValue(1));
282     }
283 
284     public void binaryTest(BinaryTestOpWrapper op) {
285         setp().dot();
286         symbol(op.op()).resultType(op.operandNAsValue(0).type(), true).space();
287         resultReg(op, PTXRegister.Type.PREDICATE);
288         commaSpace();
289         reg(op.operandNAsValue(0));
290         commaSpace();
291         reg(op.operandNAsValue(1));
292     }
293 
294     public void conv(ConvOpWrapper op) {
295         if (op.resultJavaType().equals(JavaType.LONG)) {
296             if (isIndex(op)) {
297                 mul().wide().s32().space().resultReg(op, PTXRegister.Type.U64).commaSpace()
298                         .reg(op.operandNAsValue(0)).commaSpace().intVal(4);
299             } else {
300                 cvt().u64().dot().regType(op.operandNAsValue(0)).space()
301                         .resultReg(op, PTXRegister.Type.U64).commaSpace().reg(op.operandNAsValue(0)).ptxNl();
302             }
303         } else if (op.resultJavaType().equals(JavaType.FLOAT)) {
304             cvt().rn().f32().dot().regType(op.operandNAsValue(0)).space()
305                     .resultReg(op, PTXRegister.Type.F32).commaSpace().reg(op.operandNAsValue(0));
306         } else if (op.resultJavaType().equals(JavaType.DOUBLE)) {
307             cvt();
308             if (op.operandNAsValue(0).type().equals(JavaType.INT)) {
309                 rn();
310             }
311             f64().dot().regType(op.operandNAsValue(0)).space()
312                     .resultReg(op, PTXRegister.Type.F64).commaSpace().reg(op.operandNAsValue(0));
313         } else if (op.resultJavaType().equals(JavaType.INT)) {
314             cvt();
315             if (op.operandNAsValue(0).type().equals(JavaType.DOUBLE) || op.operandNAsValue(0).type().equals(JavaType.FLOAT)) {
316                 rzi();
317             } else {
318                 rn();
319             }
320             s32().dot().regType(op.operandNAsValue(0)).space()
321                     .resultReg(op, PTXRegister.Type.S32).commaSpace().reg(op.operandNAsValue(0));
322         } else {
323             cvt().rn().s32().dot().regType(op.operandNAsValue(0)).space()
324                     .resultReg(op, PTXRegister.Type.S32).commaSpace().reg(op.operandNAsValue(0));
325         }
326     }
327 
328     private boolean isIndex(ConvOpWrapper op) {
329         for (Op.Result r : op.result().uses()) {
330             if (r.op() instanceof PTXPtrOp) return true;
331         }
332         return false;
333     }
334 
335     public void constant(ConstantOpWrapper op) {
336         mov().resultType(op.resultType(), false).space().resultReg(op, getResultType(op.resultType())).commaSpace();
337         if (op.resultType().toString().equals("float")) {
338             if (op.op().value().toString().equals("0.0")) {
339                 floatVal("00000000");
340             } else {
341                 floatVal(Integer.toHexString(Float.floatToIntBits(Float.parseFloat(op.op().value().toString()))).toUpperCase());
342             }
343         } else {
344             append(op.op().value().toString());
345         }
346     }
347 
348     public void javaYield(YieldOpWrapper op) {
349         exit();
350     }
351 
352     // S32Array and S32Array2D functions can be deleted after schema is done
353     public void methodCall(InvokeOpWrapper op) {
354         switch (op.methodRef().toString()) {
355             // S32Array functions
356             case "hat.buffer.S32Array::array(long)int" -> {
357                 PTXRegister temp = new PTXRegister(incrOrdinal(addressType()), addressType());
358                 add().s64().space().regName(temp).commaSpace().reg(op.operandNAsValue(0)).commaSpace().reg(op.operandNAsValue(1)).ptxNl();
359                 ld().global().u32().space().resultReg(op, PTXRegister.Type.U32).commaSpace().address(temp.name(), 4);
360             }
361             case "hat.buffer.S32Array::array(long, int)void" -> {
362                 PTXRegister temp = new PTXRegister(incrOrdinal(addressType()), addressType());
363                 add().s64().space().regName(temp).commaSpace().reg(op.operandNAsValue(0)).commaSpace().reg(op.operandNAsValue(1)).ptxNl();
364                 st().global().u32().space().address(temp.name(), 4).commaSpace().reg(op.operandNAsValue(2));
365             }
366             case "hat.buffer.S32Array::length()int" -> {
367                 ld().global().u32().space().resultReg(op, PTXRegister.Type.U32).commaSpace().address(getReg(op.operandNAsValue(0)).name());
368             }
369             // S32Array2D functions
370             case "hat.buffer.S32Array2D::array(long, int)void" -> {
371                 PTXRegister temp = new PTXRegister(incrOrdinal(addressType()), addressType());
372                 add().s64().space().regName(temp).commaSpace().reg(op.operandNAsValue(0)).commaSpace().reg(op.operandNAsValue(1)).ptxNl();
373                 st().global().u32().space().address(temp.name(), 8).commaSpace().reg(op.operandNAsValue(2));
374             }
375             case "hat.buffer.S32Array2D::width()int" -> {
376                 ld().global().u32().space().resultReg(op, PTXRegister.Type.U32).commaSpace().address(getReg(op.operandNAsValue(0)).name());
377             }
378             case "hat.buffer.S32Array2D::height()int" -> {
379                 ld().global().u32().space().resultReg(op, PTXRegister.Type.U32).commaSpace().address(getReg(op.operandNAsValue(0)).name(), 4);
380             }
381             // Java Math function
382             case "java.lang.Math::sqrt(double)double" -> {
383                 sqrt().rn().f64().space().resultReg(op, PTXRegister.Type.F64).commaSpace().reg(op.operandNAsValue(0)).semicolon();
384             }
385             default -> {
386                 obrace().nl().ptxIndent();
387                 for (int i = 0; i < op.operands().size(); i++) {
388                     dot().param().space().paramType(op.operandNAsValue(i).type()).space().param().intVal(i).ptxNl();
389                     st().dot().param().paramType(op.operandNAsValue(i).type()).space().osbrace().param().intVal(i).csbrace().commaSpace().reg(op.operandNAsValue(i)).ptxNl();
390                 }
391                 dot().param().space().paramType(op.resultType()).space().retVal().ptxNl();
392                 call().uni().space().oparen().retVal().cparen().commaSpace().append(op.method().getName()).commaSpace();
393                 final int[] counter = {0};
394                 paren(_ -> commaSeparated(op.operands(), _ -> param().intVal(counter[0]++))).ptxNl();
395                 ld().dot().param().paramType(op.resultType()).space().resultReg(op, getResultType(op.resultType())).commaSpace().osbrace().retVal().csbrace();
396                 ptxNl().cbrace();
397             }
398         }
399     }
400 
401     public void varDeclaration(VarDeclarationOpWrapper op) {
402         ld().dot().param().resultType(op.resultType(), false).space().resultReg(op, addressType()).commaSpace().reg(op.operandNAsValue(0));
403     }
404 
405     public void varFuncDeclaration(VarFuncDeclarationOpWrapper op) {
406         ld().dot().param().resultType(op.resultType(), false).space().resultReg(op, addressType()).commaSpace().reg(op.operandNAsValue(0));
407     }
408 
409     public void ret(ReturnOpWrapper op) {
410         if (op.hasOperands()) {
411             st().dot().param();
412             if (returnReg.type().equals(PTXRegister.Type.U32)) {
413                 b32();
414             } else if (returnReg.type().equals(PTXRegister.Type.U64)) {
415                 b64();
416             } else {
417                 dot().regType(returnReg.type());
418             }
419             space().osbrace().regName(returnReg).csbrace().commaSpace().reg(op.operandNAsValue(0)).ptxNl();
420         }
421         ret();
422     }
423 
424     public void javaBreak(JavaBreakOpWrapper op) {
425         brkpt();
426     }
427 
428     public void branch(CoreOp.BranchOp op) {
429         loadBlockParams(op.successors().getFirst());
430         bra().space().block(op.successors().getFirst().targetBlock());
431     }
432 
433     public void condBranch(CoreOp.ConditionalBranchOp op) {
434         loadBlockParams(op.successors().getFirst());
435         loadBlockParams(op.successors().getLast());
436         at().reg(op.operands().getFirst()).space()
437                 .bra().space().block(op.successors().getFirst().targetBlock()).ptxNl();
438         bra().space().block(op.successors().getLast().targetBlock());
439     }
440 
441     public void neg(CoreOp.NegOp op) {
442         neg().resultType(op.resultType(), true).space().reg(op.result(), getResultType(op.resultType())).commaSpace().reg(op.operands().getFirst());
443     }
444 
445     /*
446      * Helper functions for printing blocks and variables
447      */
448 
449     public void loadBlockParams(Block.Reference block) {
450         for (int i = 0; i < block.arguments().size(); i++) {
451             Block.Parameter p = block.targetBlock().parameters().get(i);
452             mov().resultType(p.type(), false).space().reg(p, getResultType(p.type()))
453                     .commaSpace().reg(block.arguments().get(i)).ptxNl();
454         }
455     }
456 
457     public PTXHATKernelBuilder block(Block block) {
458         return append("block_").intVal(block.index());
459     }
460 
461     public PTXHATKernelBuilder fieldReg(Field ref) {
462         if (fieldToRegMap.containsKey(ref)) {
463             return regName(fieldToRegMap.get(ref));
464         }
465         if (ref.isDestination()) {
466             fieldToRegMap.putIfAbsent(ref, new PTXRegister(incrOrdinal(addressType()), addressType()));
467         } else {
468             fieldToRegMap.putIfAbsent(ref, new PTXRegister(incrOrdinal(PTXRegister.Type.U32), PTXRegister.Type.U32));
469         }
470         return regName(fieldToRegMap.get(ref));
471     }
472 
473     public PTXHATKernelBuilder fieldReg(Field ref, Value value) {
474         if (fieldToRegMap.containsKey(ref)) {
475             return regName(fieldToRegMap.get(ref));
476         }
477         if (ref.isDestination()) {
478             fieldToRegMap.putIfAbsent(ref, new PTXRegister(getOrdinal(addressType()), addressType()));
479             return reg(value, addressType());
480         } else {
481             fieldToRegMap.putIfAbsent(ref, new PTXRegister(getOrdinal(PTXRegister.Type.U32), PTXRegister.Type.U32));
482             return reg(value, PTXRegister.Type.U32);
483         }
484     }
485 
486     public Field getFieldObj(String fieldName) {
487         for (Field f : fieldToRegMap.keySet()) {
488             if (f.toString().equals(fieldName)) return f;
489         }
490         throw new IllegalStateException("no existing field");
491     }
492 
493     public PTXHATKernelBuilder resultReg(OpWrapper<?> opWrapper, PTXRegister.Type type) {
494         return append(addReg(opWrapper.result(), type));
495     }
496 
497     public PTXHATKernelBuilder reg(Value val, PTXRegister.Type type) {
498         if (varToRegMap.containsKey(val)) {
499             return regName(getReg(val));
500         } else {
501             return append(addReg(val, type));
502         }
503     }
504 
505     public PTXHATKernelBuilder reg(Value val) {
506         return regName(getReg(val));
507     }
508 
509     public PTXRegister getReg(Value val) {
510         if (varToRegMap.get(val) == null && val instanceof Op.Result result && result.op() instanceof CoreOp.FieldAccessOp.FieldLoadOp fieldLoadOp) {
511             return fieldToRegMap.get(getFieldObj(fieldLoadOp.fieldDescriptor().name()));
512         }
513         if (varToRegMap.containsKey(val)) {
514             return varToRegMap.get(val);
515         } else {
516             throw new IllegalStateException("var to reg mapping doesn't exist");
517         }
518     }
519 
520     public String addReg(Value val, PTXRegister.Type type) {
521         if (varToRegMap.containsKey(val)) {
522             return varToRegMap.get(val).name();
523         }
524         varToRegMap.put(val, new PTXRegister(incrOrdinal(type), type));
525         return varToRegMap.get(val).name();
526     }
527 
528     public Integer getOrdinal(PTXRegister.Type type) {
529         ordinalMap.putIfAbsent(type, 1);
530         return ordinalMap.get(type);
531     }
532 
533     public Integer incrOrdinal(PTXRegister.Type type) {
534         ordinalMap.putIfAbsent(type, 1);
535         int out = ordinalMap.get(type);
536         ordinalMap.put(type, out + 1);
537         return out;
538     }
539 
540     public PTXHATKernelBuilder size() {
541         return (addressSize == 32) ? u32() : u64();
542     }
543 
544     public PTXRegister.Type addressType() {
545         return (addressSize == 32) ? PTXRegister.Type.U32 : PTXRegister.Type.U64;
546     }
547 
548     public PTXHATKernelBuilder resultType(TypeElement type, boolean signedResult) {
549         PTXRegister.Type res = getResultType(type);
550         if (signedResult && (res == PTXRegister.Type.U32)) return s32();
551         return dot().append(getResultType(type).getName());
552     }
553 
554     public PTXHATKernelBuilder paramType(TypeElement type) {
555         PTXRegister.Type res = getResultType(type);
556         if (res == PTXRegister.Type.U32) return b32();
557         if (res == PTXRegister.Type.U64) return b64();
558         return dot().append(getResultType(type).getName());
559     }
560 
561     public PTXRegister.Type getResultType(TypeElement type) {
562         switch (type.toString()) {
563             case "float" -> {
564                 return PTXRegister.Type.F32;
565             }
566             case "double" -> {
567                 return PTXRegister.Type.F64;
568             }
569             case "int" -> {
570                 return PTXRegister.Type.U32;
571             }
572             case "boolean" -> {
573                 return PTXRegister.Type.PREDICATE;
574             }
575             default -> {
576                 return PTXRegister.Type.U64;
577             }
578         }
579     }
580 
581     /*
582      * Basic CodeBuilder functions
583      */
584 
585     // used for parameter list
586     // prints out items separated by a comma then new line
587     public <I> PTXHATKernelBuilder commaNlSeparated(Iterable<I> iterable, Consumer<I> c) {
588         StreamCounter.of(iterable, (counter, t) -> {
589             if (counter.isNotFirst()) {
590                 comma().nl();
591             }
592             c.accept(t);
593         });
594         return self();
595     }
596 
597     public PTXHATKernelBuilder address(String address) {
598         return osbrace().append(address).csbrace();
599     }
600 
601     public PTXHATKernelBuilder address(String address, int offset) {
602         osbrace().append(address);
603         if (offset == 0) {
604             return csbrace();
605         } else if (offset > 0) {
606             plus();
607         }
608         return intVal(offset).csbrace();
609     }
610 
611     public PTXHATKernelBuilder ptxNl() {
612         return semicolon().nl().ptxIndent();
613     }
614 
615     public PTXHATKernelBuilder commaSpace() {
616         return comma().space();
617     }
618 
619     public PTXHATKernelBuilder param() {
620         return append("param");
621     }
622 
623     public PTXHATKernelBuilder global() {
624         return dot().append("global");
625     }
626 
627     public PTXHATKernelBuilder rn() {
628         return dot().append("rn");
629     }
630 
631     public PTXHATKernelBuilder rm() {
632         return dot().append("rm");
633     }
634 
635     public PTXHATKernelBuilder rzi() {
636         return dot().append("rzi");
637     }
638 
639     public PTXHATKernelBuilder to() {
640         return dot().append("to");
641     }
642 
643     public PTXHATKernelBuilder lo() {
644         return dot().append("lo");
645     }
646 
647     public PTXHATKernelBuilder wide() {
648         return dot().append("wide");
649     }
650 
651     public PTXHATKernelBuilder uni() {
652         return dot().append("uni");
653     }
654 
655     public PTXHATKernelBuilder sat() {
656         return dot().append("sat");
657     }
658 
659     public PTXHATKernelBuilder ftz() {
660         return dot().append("ftz");
661     }
662 
663     public PTXHATKernelBuilder approx() {
664         return dot().append("approx");
665     }
666 
667     public PTXHATKernelBuilder mov() {
668         return append("mov");
669     }
670 
671     public PTXHATKernelBuilder setp() {
672         return append("setp");
673     }
674 
675     public PTXHATKernelBuilder selp() {
676         return append("selp");
677     }
678 
679     public PTXHATKernelBuilder ld() {
680         return append("ld");
681     }
682 
683     public PTXHATKernelBuilder st() {
684         return append("st");
685     }
686 
687     public PTXHATKernelBuilder cvt() {
688         return append("cvt");
689     }
690 
691     public PTXHATKernelBuilder bra() {
692         return append("bra");
693     }
694 
695     public PTXHATKernelBuilder ret() {
696         return append("ret");
697     }
698 
699     public PTXHATKernelBuilder rem() {
700         return append("rem");
701     }
702 
703     public PTXHATKernelBuilder mul() {
704         return append("mul");
705     }
706 
707     public PTXHATKernelBuilder div() {
708         return append("div");
709     }
710 
711     public PTXHATKernelBuilder rcp() {
712         return append("rcp");
713     }
714 
715     public PTXHATKernelBuilder add() {
716         return append("add");
717     }
718 
719     public PTXHATKernelBuilder sub() {
720         return append("sub");
721     }
722 
723     public PTXHATKernelBuilder lt() {
724         return append("lt");
725     }
726 
727     public PTXHATKernelBuilder gt() {
728         return append("gt");
729     }
730 
731     public PTXHATKernelBuilder le() {
732         return append("le");
733     }
734 
735     public PTXHATKernelBuilder ge() {
736         return append("ge");
737     }
738 
739     public PTXHATKernelBuilder geu() {
740         return append("geu");
741     }
742 
743     public PTXHATKernelBuilder ne() {
744         return append("ne");
745     }
746 
747     public PTXHATKernelBuilder eq() {
748         return append("eq");
749     }
750 
751     public PTXHATKernelBuilder xor() {
752         return append("xor");
753     }
754 
755     public PTXHATKernelBuilder or() {
756         return append("or");
757     }
758 
759     public PTXHATKernelBuilder and() {
760         return append("and");
761     }
762 
763     public PTXHATKernelBuilder cvta() {
764         return append("cvta");
765     }
766 
767     public PTXHATKernelBuilder mad() {
768         return append("mad");
769     }
770 
771     public PTXHATKernelBuilder fma() {
772         return append("fma");
773     }
774 
775     public PTXHATKernelBuilder sqrt() {
776         return append("sqrt");
777     }
778 
779     public PTXHATKernelBuilder abs() {
780         return append("abs");
781     }
782 
783     public PTXHATKernelBuilder ex2() {
784         return append("ex2");
785     }
786 
787     public PTXHATKernelBuilder shl() {
788         return append("shl");
789     }
790 
791     public PTXHATKernelBuilder shr() {
792         return append("shr");
793     }
794 
795     public PTXHATKernelBuilder neg() {
796         return append("neg");
797     }
798 
799     public PTXHATKernelBuilder call() {
800         return append("call");
801     }
802 
803     public PTXHATKernelBuilder exit() {
804         return append("exit");
805     }
806 
807     public PTXHATKernelBuilder brkpt() {
808         return append("brkpt");
809     }
810 
811     public PTXHATKernelBuilder ptxIndent() {
812         return append("    ");
813     }
814 
815     public PTXHATKernelBuilder u32() {
816         return dot().append(PTXRegister.Type.U32.getName());
817     }
818 
819     public PTXHATKernelBuilder s32() {
820         return dot().append(PTXRegister.Type.S32.getName());
821     }
822 
823     public PTXHATKernelBuilder f32() {
824         return dot().append(PTXRegister.Type.F32.getName());
825     }
826 
827     public PTXHATKernelBuilder b32() {
828         return dot().append(PTXRegister.Type.B32.getName());
829     }
830 
831     public PTXHATKernelBuilder u64() {
832         return dot().append(PTXRegister.Type.U64.getName());
833     }
834 
835     public PTXHATKernelBuilder s64() {
836         return dot().append(PTXRegister.Type.S64.getName());
837     }
838 
839     public PTXHATKernelBuilder f64() {
840         return dot().append(PTXRegister.Type.F64.getName());
841     }
842 
843     public PTXHATKernelBuilder b64() {
844         return dot().append(PTXRegister.Type.B64.getName());
845     }
846 
847     public PTXHATKernelBuilder version() {
848         return dot().append("version");
849     }
850 
851     public PTXHATKernelBuilder target() {
852         return dot().append("target");
853     }
854 
855     public PTXHATKernelBuilder addressSize() {
856         return dot().append("address_size");
857     }
858 
859     public PTXHATKernelBuilder major(int major) {
860         return intVal(major);
861     }
862 
863     public PTXHATKernelBuilder minor(int minor) {
864         return intVal(minor);
865     }
866 
867     public PTXHATKernelBuilder target(String target) {
868         return append(target);
869     }
870 
871     public PTXHATKernelBuilder size(int addressSize) {
872         return intVal(addressSize);
873     }
874 
875     public PTXHATKernelBuilder funcName(String funcName) {
876         return append(funcName);
877     }
878 
879     public PTXHATKernelBuilder visible() {
880         return dot().append("visible");
881     }
882 
883     public PTXHATKernelBuilder entry() {
884         return dot().append("entry");
885     }
886 
887     public PTXHATKernelBuilder func() {
888         return dot().append("func");
889     }
890 
891     public PTXHATKernelBuilder oabrace() {
892         return append("<");
893     }
894 
895     public PTXHATKernelBuilder cabrace() {
896         return append(">");
897     }
898 
899     public PTXHATKernelBuilder regName(PTXRegister reg) {
900         return append(reg.name());
901     }
902 
903     public PTXHATKernelBuilder regName(String regName) {
904         return append(regName);
905     }
906 
907     public PTXHATKernelBuilder regType(Value val) {
908         return append(getReg(val).type().getName());
909     }
910 
911     public PTXHATKernelBuilder regType(PTXRegister.Type t) {
912         return append(t.getName());
913     }
914 
915     public PTXHATKernelBuilder regTypePrefix(PTXRegister.Type t) {
916         return append(t.getRegPrefix());
917     }
918 
919     public PTXHATKernelBuilder reg() {
920         return dot().append("reg");
921     }
922 
923     public PTXHATKernelBuilder retVal() {
924         return append("retval");
925     }
926 
927     public PTXHATKernelBuilder temp() {
928         return append("temp");
929     }
930 
931     public PTXHATKernelBuilder intVal(int i) {
932         return append(String.valueOf(i));
933     }
934 
935     public PTXHATKernelBuilder floatVal(String s) {
936         return append("0f").append(s);
937     }
938 
939     public PTXHATKernelBuilder doubleVal(String s) {
940         return append("0d").append(s);
941     }
942 }