1 /*
   2  * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
   3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
   4  *
   5  * This code is free software; you can redistribute it and/or modify it
   6  * under the terms of the GNU General Public License version 2 only, as
   7  * published by the Free Software Foundation.  Oracle designates this
   8  * particular file as subject to the "Classpath" exception as provided
   9  * by Oracle in the LICENSE file that accompanied this code.
  10  *
  11  * This code is distributed in the hope that it will be useful, but WITHOUT
  12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
  13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
  14  * version 2 for more details (a copy is included in the LICENSE file that
  15  * accompanied this code).
  16  *
  17  * You should have received a copy of the GNU General Public License version
  18  * 2 along with this work; if not, write to the Free Software Foundation,
  19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
  20  *
  21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
  22  * or visit www.oracle.com if you need additional information or have any
  23  * questions.
  24  */
  25 package hat.backend.ffi;
  26 
  27 import hat.dialect.HATOp;
  28 import hat.ifacemapper.BoundSchema;
  29 import hat.optools.*;
  30 import hat.codebuilders.CodeBuilder;
  31 
  32 import jdk.incubator.code.*;
  33 import jdk.incubator.code.dialect.core.CoreOp;
  34 import jdk.incubator.code.dialect.java.JavaOp;
  35 import jdk.incubator.code.dialect.java.JavaType;
  36 
  37 import java.lang.foreign.MemoryLayout;
  38 import java.lang.invoke.MethodHandles;
  39 import java.util.ArrayList;
  40 import java.util.HashMap;
  41 import java.util.List;
  42 import java.util.Map;
  43 import java.util.stream.Stream;
  44 
  45 public class PTXHATKernelBuilder extends CodeBuilder<PTXHATKernelBuilder> {
  46 
  47     Map<Value, PTXRegister> varToRegMap;
  48     List<String> paramNames;
  49     List<Block.Parameter> paramObjects;
  50     Map<Field, PTXRegister> fieldToRegMap;
  51 
  52     HashMap<PTXRegister.Type, Integer> ordinalMap;
  53 
  54     PTXRegister returnReg;
  55     private int addressSize;
  56 
  57     public enum Field {
  58         NTID_X ("ntid.x", false),
  59         CTAID_X ("ctaid.x", false),
  60         TID_X ("tid.x", false),
  61         KC_X ("x", false),
  62         KC_ADDR("kc", true),
  63         KC_MAXX ("maxX", false);
  64 
  65         private final String name;
  66         private final boolean destination;
  67 
  68         Field(String name, boolean destination) {
  69             this.name = name;
  70             this.destination = destination;
  71         }
  72         public String toString() {
  73             return this.name;
  74         }
  75         public boolean isDestination() {return this.destination;}
  76     }
  77 
  78     public PTXHATKernelBuilder(int addressSize) {
  79         varToRegMap = new HashMap<>();
  80         paramNames = new ArrayList<>();
  81         fieldToRegMap = new HashMap<>();
  82         paramObjects = new ArrayList<>();
  83         ordinalMap = new HashMap<>();
  84         this.addressSize = addressSize;
  85     }
  86 
  87     public PTXHATKernelBuilder() {
  88         this(32);
  89     }
  90 
  91     public void ptxHeader(int major, int minor, String target, int addressSize) {
  92         this.addressSize = addressSize;
  93         version().space().major(major).dot().minor(minor).nl();
  94         target().space().target(target).nl();
  95         addressSize().space().size(addressSize);
  96     }
  97 
  98     public void functionHeader(String funcName, boolean entry, TypeElement yieldType) {
  99         if (entry) {
 100             visible().space().entry().space();
 101         } else {
 102             func().space();
 103         }
 104         if (!yieldType.toString().equals("void")) {
 105             returnReg = new PTXRegister(getOrdinal(getResultType(yieldType)), getResultType(yieldType));
 106             returnReg.name("%retReg");
 107             oparen().dot().param().space().paramType(yieldType);
 108             space().regName(returnReg).cparen().space();
 109         }
 110         funcName(funcName);
 111     }
 112 
 113     public PTXHATKernelBuilder parameters(List<FuncOpParams.Info> infoList) {
 114         paren(_ ->
 115                 nl().separated(infoList,(t)->t.comma().nl(), (info) -> {
 116             ptxIndent().dot().param().space().paramType(info.javaType);
 117             space().regName(info.varOp.varName());
 118             paramNames.add(info.varOp.varName());
 119         }).nl()).nl();
 120         return this;
 121     }
 122 
 123     public void blockBody(Block block, Stream<Op> ops) {
 124         if (block.index() == 0) {
 125             for (Block.Parameter p : block.parameters()) {
 126                 ptxIndent().ld().dot().param();
 127                 resultType(p.type(), false).ptxIndent().space();
 128                 reg(p, getResultType(p.type())).commaSpace().osbrace().regName(paramNames.get(p.index())).csbrace().semicolon().nl();
 129                 paramObjects.add(p);
 130             }
 131         }
 132         nl();
 133         block(block);
 134         colon().nl();
 135         ops.forEach(op -> {
 136             if (op instanceof JavaOp.InvokeOp invoke && !OpTk.isIfaceBufferMethod(MethodHandles.lookup(),invoke)) {  // We should pass lookup down
 137                 ptxIndent().convert(op).nl();
 138             } else {
 139                 ptxIndent().convert(op).semicolon().nl();
 140             }
 141         });
 142     }
 143 
 144     public void ptxRegisterDecl() {
 145         for (PTXRegister.Type t : ordinalMap.keySet()) {
 146             ptxIndent().reg().space();
 147             if (t.equals(PTXRegister.Type.U32)) {
 148                 b32();
 149             } else if (t.equals(PTXRegister.Type.U64)) {
 150                 b64();
 151             } else {
 152                 dot().regType(t);
 153             }
 154             ptxIndent().regTypePrefix(t).oabrace().intVal(ordinalMap.get(t)).cabrace().semicolon().nl();
 155         }
 156         nl();
 157     }
 158 
 159     public void functionPrologue() {
 160         obrace().nl();
 161     }
 162 
 163     public void functionEpilogue() {
 164         cbrace();
 165     }
 166 
 167     public static class PTXPtrOp extends HATOp {
 168         public String fieldName;
 169         public static final String NAME = "ptxPtr";
 170         final TypeElement resultType;
 171         public BoundSchema<?> boundSchema;
 172 
 173         PTXPtrOp(TypeElement resultType, String fieldName, List<Value> operands, BoundSchema<?> boundSchema) {
 174             super(operands);
 175             this.resultType = resultType;
 176             this.fieldName = fieldName;
 177             this.boundSchema = boundSchema;
 178         }
 179 
 180         PTXPtrOp(PTXPtrOp that, CopyContext cc) {
 181             super(that, cc);
 182             this.resultType = that.resultType;
 183             this.fieldName = that.fieldName;
 184             this.boundSchema = that.boundSchema;
 185         }
 186 
 187         @Override
 188         public PTXPtrOp transform(CopyContext cc, OpTransformer ot) {
 189             return new PTXPtrOp(this, cc);
 190         }
 191 
 192         @Override
 193         public TypeElement resultType() {
 194             return resultType;
 195         }
 196 
 197         @Override
 198         public String externalizeOpName() {
 199             return NAME;
 200         }
 201     }
 202 
 203 
 204     public PTXHATKernelBuilder convert(Op op) {
 205         switch (op) {
 206             case JavaOp.FieldAccessOp.FieldLoadOp $ -> fieldLoad($);
 207             case JavaOp.FieldAccessOp.FieldStoreOp $ -> fieldStore($);
 208             case JavaOp.BinaryOp $ -> binaryOperation($);
 209             case JavaOp.BinaryTestOp $ -> binaryTest($);
 210             case JavaOp.ConvOp $ -> conv($);
 211             case CoreOp.ConstantOp $ -> constant($);
 212             case CoreOp.YieldOp $ -> javaYield($);
 213             case JavaOp.InvokeOp $ -> methodCall($);
 214             case CoreOp.VarOp $ when OpTk.paramVar($) != null -> varFuncDeclaration($);
 215             case CoreOp.VarOp $ -> varDeclaration($);
 216             case CoreOp.ReturnOp $ -> ret($);
 217             case JavaOp.BreakOp $ -> javaBreak($);
 218             default -> { // Why are  these switch ops not just inlined above?
 219                 switch (op){
 220                     case CoreOp.BranchOp $ -> branch($);
 221                     case CoreOp.ConditionalBranchOp $ -> condBranch($);
 222                     case JavaOp.NegOp $ -> neg($);
 223                     case PTXPtrOp $ -> ptxPtr($);
 224                     default -> throw new IllegalStateException("op translation doesn't exist");
 225                 }
 226             }
 227         }
 228         return this;
 229     }
 230 
 231     public void ptxPtr(PTXPtrOp op) {
 232         PTXRegister source;
 233         int offset = (int) op.boundSchema.groupLayout().byteOffset(MemoryLayout.PathElement.groupElement(op.fieldName));
 234 
 235         if (op.fieldName.equals("array")) {
 236             source = new PTXRegister(incrOrdinal(addressType()), addressType());
 237             add().s64().space().regName(source).commaSpace().reg(op.operands().get(0)).commaSpace().reg(op.operands().get(1)).ptxNl();
 238         } else {
 239             source = getReg(op.operands().getFirst());
 240         }
 241 
 242         if (op.resultType.toString().equals("void")) {
 243             st().global().dot().regType(op.operands().getLast()).space().address(source.name(), offset).commaSpace().reg(op.operands().getLast());
 244         } else {
 245             ld().global().resultType(op.resultType(), true).space().reg(op.result(), getResultType(op.resultType())).commaSpace().address(source.name(), offset);
 246         }
 247     }
 248 
 249     public void fieldLoad(JavaOp.FieldAccessOp.FieldLoadOp fieldLoadOp) {
 250         if (OpTk.fieldName(fieldLoadOp).equals(Field.KC_X.toString())) {
 251             if (!fieldToRegMap.containsKey(Field.KC_X)) {
 252                 loadKcX(fieldLoadOp.result());
 253             } else {
 254                 mov().u32().space().resultReg(fieldLoadOp, PTXRegister.Type.U32).commaSpace().fieldReg(Field.KC_X);
 255             }
 256         } else if (OpTk.fieldName(fieldLoadOp).equals(Field.KC_MAXX.toString())) {
 257             if (!fieldToRegMap.containsKey(Field.KC_X)) {
 258                 loadKcX(fieldLoadOp.operands().getFirst());
 259             }
 260             ld().global().u32().space().fieldReg(Field.KC_MAXX, fieldLoadOp.result()).commaSpace()
 261                     .address(fieldToRegMap.get(Field.KC_ADDR).name(), 4);
 262         } else {
 263             ld().global().u32().space().resultReg(fieldLoadOp, PTXRegister.Type.U64).commaSpace().reg(fieldLoadOp.operands().getFirst());
 264         }
 265     }
 266 
 267     public void loadKcX(Value value) {
 268         cvta().to().global().size().space().fieldReg(Field.KC_ADDR).commaSpace()
 269                 .reg(paramObjects.get(paramNames.indexOf(Field.KC_ADDR.toString())), addressType()).ptxNl();
 270         mov().u32().space().fieldReg(Field.NTID_X).commaSpace().percent().regName(Field.NTID_X.toString()).ptxNl();
 271         mov().u32().space().fieldReg(Field.CTAID_X).commaSpace().percent().regName(Field.CTAID_X.toString()).ptxNl();
 272         mov().u32().space().fieldReg(Field.TID_X).commaSpace().percent().regName(Field.TID_X.toString()).ptxNl();
 273         mad().lo().s32().space().fieldReg(Field.KC_X, value).commaSpace().fieldReg(Field.CTAID_X)
 274                 .commaSpace().fieldReg(Field.NTID_X).commaSpace().fieldReg(Field.TID_X).ptxNl();
 275         st().global().u32().space().address(fieldToRegMap.get(Field.KC_ADDR).name()).commaSpace().fieldReg(Field.KC_X);
 276     }
 277 
 278     public void fieldStore(JavaOp.FieldAccessOp.FieldStoreOp op) {
 279         // TODO: fix
 280         st().global().u64().space().resultReg(op, PTXRegister.Type.U64).commaSpace().reg(op.operands().getFirst());
 281     }
 282 
 283     PTXHATKernelBuilder symbol(Op op) {
 284         return switch (op) {
 285             case JavaOp.ModOp _ -> rem();
 286             case JavaOp.MulOp _ -> mul();
 287             case JavaOp.DivOp _ -> div();
 288             case JavaOp.AddOp _ -> add();
 289             case JavaOp.SubOp _ -> sub();
 290             case JavaOp.LtOp _ -> lt();
 291             case JavaOp.GtOp _ -> gt();
 292             case JavaOp.LeOp _ -> le();
 293             case JavaOp.GeOp _ -> ge();
 294             case JavaOp.NeqOp _ -> ne();
 295             case JavaOp.EqOp _ -> eq();
 296             case JavaOp.OrOp _ -> or();
 297             case JavaOp.AndOp _ -> and();
 298             case JavaOp.XorOp _ -> xor();
 299             case JavaOp.LshlOp _ -> shl();
 300             case JavaOp.AshrOp _, JavaOp.LshrOp _ -> shr();
 301             default -> throw new IllegalStateException("Unexpected value");
 302         };
 303     }
 304 
 305     public void binaryOperation(JavaOp.BinaryOp op) {
 306         symbol(op);
 307         if (getResultType(op.resultType()).getBasicType().equals(PTXRegister.Type.BasicType.FLOATING)
 308                 && (op instanceof JavaOp.DivOp || op instanceof JavaOp.MulOp)) {
 309             rn();
 310         } else if (!getResultType(op.resultType()).getBasicType().equals(PTXRegister.Type.BasicType.FLOATING)
 311                 && op instanceof JavaOp.MulOp) {
 312             lo();
 313         }
 314         resultType(op.resultType(), true).space();
 315         resultReg(op, getResultType(op.resultType()));
 316         commaSpace();
 317         reg(op.operands().getFirst());
 318         commaSpace();
 319         reg(op.operands().get(1));
 320     }
 321 
 322     public void binaryTest(JavaOp.BinaryTestOp op) {
 323         setp().dot();
 324         symbol(op).resultType(op.operands().getFirst().type(), true).space();
 325         resultReg(op, PTXRegister.Type.PREDICATE);
 326         commaSpace();
 327         reg(op.operands().getFirst());
 328         commaSpace();
 329         reg(op.operands().get(1));
 330     }
 331 
 332     public void conv(JavaOp.ConvOp op) {
 333         if (op.resultType().equals(JavaType.LONG)) {
 334             if (isIndex(op)) {
 335                 mul().wide().s32().space().resultReg(op, PTXRegister.Type.U64).commaSpace()
 336                         .reg(op.operands().getFirst()).commaSpace().intVal(4);
 337             } else {
 338                 cvt().u64().dot().regType(op.operands().getFirst()).space()
 339                         .resultReg(op, PTXRegister.Type.U64).commaSpace().reg(op.operands().getFirst()).ptxNl();
 340             }
 341         } else if (op.resultType().equals(JavaType.FLOAT)) {
 342             cvt().rn().f32().dot().regType(op.operands().getFirst()).space()
 343                     .resultReg(op, PTXRegister.Type.F32).commaSpace().reg(op.operands().getFirst());
 344         } else if (op.resultType().equals(JavaType.DOUBLE)) {
 345             cvt();
 346             if (op.operands().getFirst().type().equals(JavaType.INT)) {
 347                 rn();
 348             }
 349             f64().dot().regType(op.operands().getFirst()).space()
 350                     .resultReg(op, PTXRegister.Type.F64).commaSpace().reg(op.operands().getFirst());
 351         } else if (op.resultType().equals(JavaType.INT)) {
 352             cvt();
 353             if (op.operands().getFirst().type().equals(JavaType.DOUBLE) || op.operands().getFirst().type().equals(JavaType.FLOAT)) {
 354                 rzi();
 355             } else {
 356                 rn();
 357             }
 358             s32().dot().regType(op.operands().getFirst()).space()
 359                     .resultReg(op, PTXRegister.Type.S32).commaSpace().reg(op.operands().getFirst());
 360         } else {
 361             cvt().rn().s32().dot().regType(op.operands().getFirst()).space()
 362                     .resultReg(op, PTXRegister.Type.S32).commaSpace().reg(op.operands().getFirst());
 363         }
 364     }
 365 
 366 
 367 
 368 
 369 
 370     public static class PTXRegister {
 371         private String name;
 372         private final Type type;
 373 
 374         public enum Type {
 375             S8 (8, BasicType.SIGNED, "s8", "%s"),
 376             S16 (16, BasicType.SIGNED, "s16", "%s"),
 377             S32 (32, BasicType.SIGNED, "s32", "%s"),
 378             S64 (64, BasicType.SIGNED, "s64", "%sd"),
 379             U8 (8, BasicType.UNSIGNED, "u8", "%r"),
 380             U16 (16, BasicType.UNSIGNED, "u16", "%r"),
 381             U32 (32, BasicType.UNSIGNED, "u32", "%r"),
 382             U64 (64, BasicType.UNSIGNED, "u64", "%rd"),
 383             F16 (16, BasicType.FLOATING, "f16", "%f"),
 384             F16X2 (16, BasicType.FLOATING, "f16", "%f"),
 385             F32 (32, BasicType.FLOATING, "f32", "%f"),
 386             F64 (64, BasicType.FLOATING, "f64", "%fd"),
 387             B8 (8, BasicType.BIT, "b8", "%b"),
 388             B16 (16, BasicType.BIT, "b16", "%b"),
 389             B32 (32, BasicType.BIT, "b32", "%b"),
 390             B64 (64, BasicType.BIT, "b64", "%bd"),
 391             B128 (128, BasicType.BIT, "b128", "%b"),
 392             PREDICATE (1, BasicType.PREDICATE, "pred", "%p");
 393 
 394             public enum BasicType {
 395                 SIGNED,
 396                 UNSIGNED,
 397                 FLOATING,
 398                 BIT,
 399                 PREDICATE
 400             }
 401 
 402             private final int size;
 403             private final BasicType basicType;
 404             private final String name;
 405             private final String regPrefix;
 406 
 407             Type(int size, BasicType type, String name, String regPrefix) {
 408                 this.size = size;
 409                 this.basicType = type;
 410                 this.name = name;
 411                 this.regPrefix = regPrefix;
 412             }
 413 
 414             public int getSize() {
 415                 return this.size;
 416             }
 417 
 418             public BasicType getBasicType() {
 419                 return this.basicType;
 420             }
 421 
 422             public String getName() {
 423                 return this.name;
 424             }
 425 
 426             public String getRegPrefix() {
 427                 return this.regPrefix;
 428             }
 429         }
 430 
 431         public PTXRegister(int num, Type type) {
 432             this.type = type;
 433             this.name = type.regPrefix + num;
 434         }
 435 
 436         public String name() {
 437             return this.name;
 438         }
 439 
 440         public void name(String name) {
 441             this.name = name;
 442         }
 443 
 444         public Type type() {
 445             return this.type;
 446         }
 447     }
 448 
 449 
 450     private boolean isIndex(JavaOp.ConvOp op) {
 451         for (Op.Result r : op.result().uses()) {
 452             if (r.op() instanceof PTXPtrOp) return true;
 453         }
 454         return false;
 455     }
 456 
 457     public void constant(CoreOp.ConstantOp op) {
 458         mov().resultType(op.resultType(), false).space().resultReg(op, getResultType(op.resultType())).commaSpace();
 459         if (op.resultType().toString().equals("float")) {
 460             if (op.value().toString().equals("0.0")) {
 461                 floatVal("00000000");
 462             } else {
 463                 floatVal(Integer.toHexString(Float.floatToIntBits(Float.parseFloat(op.value().toString()))).toUpperCase());
 464             }
 465         } else {
 466             constant(op.value().toString());
 467         }
 468     }
 469 
 470     public void javaYield(CoreOp.YieldOp op) {
 471         exit();
 472     }
 473 
 474     // S32Array and S32Array2D functions can be deleted after schema is done
 475     public void methodCall(JavaOp.InvokeOp op) {
 476         switch (op.invokeDescriptor().toString()) {
 477             // S32Array functions
 478             case "hat.buffer.S32Array::array(long)int" -> {
 479                 PTXRegister temp = new PTXRegister(incrOrdinal(addressType()), addressType());
 480                 add().s64().space().regName(temp).commaSpace().reg(op.operands().getFirst()).commaSpace().reg(op.operands().get(1)).ptxNl();
 481                 ld().global().u32().space().resultReg(op, PTXRegister.Type.U32).commaSpace().address(temp.name(), 4);
 482             }
 483             case "hat.buffer.S32Array::array(long, int)void" -> {
 484                 PTXRegister temp = new PTXRegister(incrOrdinal(addressType()), addressType());
 485                 add().s64().space().regName(temp).commaSpace().reg(op.operands().getFirst()).commaSpace().reg(op.operands().get(1)).ptxNl();
 486                 st().global().u32().space().address(temp.name(), 4).commaSpace().reg(op.operands().get(2));
 487             }
 488             case "hat.buffer.S32Array::length()int" -> {
 489                 ld().global().u32().space().resultReg(op, PTXRegister.Type.U32).commaSpace().address(getReg(op.operands().getFirst()).name());
 490             }
 491             // S32Array2D functions
 492             case "hat.buffer.S32Array2D::array(long, int)void" -> {
 493                 PTXRegister temp = new PTXRegister(incrOrdinal(addressType()), addressType());
 494                 add().s64().space().regName(temp).commaSpace().reg(op.operands().getFirst()).commaSpace().reg(op.operands().get(1)).ptxNl();
 495                 st().global().u32().space().address(temp.name(), 8).commaSpace().reg(op.operands().get(2));
 496             }
 497             case "hat.buffer.S32Array2D::width()int" -> {
 498                 ld().global().u32().space().resultReg(op, PTXRegister.Type.U32).commaSpace().address(getReg(op.operands().getFirst()).name());
 499             }
 500             case "hat.buffer.S32Array2D::height()int" -> {
 501                 ld().global().u32().space().resultReg(op, PTXRegister.Type.U32).commaSpace().address(getReg(op.operands().getFirst()).name(), 4);
 502             }
 503             // Java Math function
 504             case "java.lang.Math::sqrt(double)double" -> {
 505                 sqrt().rn().f64().space().resultReg(op, PTXRegister.Type.F64).commaSpace().reg(op.operands().getFirst()).semicolon();
 506             }
 507             default -> {
 508                 obrace().nl().ptxIndent();
 509                 for (int i = 0; i < op.operands().size(); i++) {
 510                     dot().param().space().paramType(op.operands().get(i).type()).space().param().intVal(i).ptxNl();
 511                     st().dot().param().paramType(op.operands().get(i).type()).space().osbrace().param().intVal(i).csbrace().commaSpace().reg(op.operands().get(i)).ptxNl();
 512                 }
 513                 dot().param().space().paramType(op.resultType()).space().retVal().ptxNl();
 514                 call().uni().space().oparen().retVal().cparen().commaSpace().identifier(OpTk.methodOrThrow(MethodHandles.lookup(),op).getName()).commaSpace();
 515                 final int[] counter = {0};
 516                 paren(_ ->
 517                         separated(op.operands(),(_)->commaSpace(),
 518                         //commaSeparated(op.operands(),
 519                                 _ -> param().intVal(counter[0]++))).ptxNl();
 520                 ld().dot().param().paramType(op.resultType()).space().resultReg(op, getResultType(op.resultType())).commaSpace().osbrace().retVal().csbrace();
 521                 ptxNl().cbrace();
 522             }
 523         }
 524     }
 525 
 526     public void varDeclaration(CoreOp.VarOp op) {
 527         ld().dot().param().resultType(op.resultType(), false).space().resultReg(op, addressType()).commaSpace().reg(op.operands().getFirst());
 528     }
 529 
 530     public void varFuncDeclaration(CoreOp.VarOp op) {
 531         ld().dot().param().resultType(op.resultType(), false).space().resultReg(op, addressType()).commaSpace().reg(op.operands().getFirst());
 532     }
 533 
 534     public void ret(CoreOp.ReturnOp op) {
 535         if (!op.operands().isEmpty()) {
 536             st().dot().param();
 537             if (returnReg.type().equals(PTXRegister.Type.U32)) {
 538                 b32();
 539             } else if (returnReg.type().equals(PTXRegister.Type.U64)) {
 540                 b64();
 541             } else {
 542                 dot().regType(returnReg.type());
 543             }
 544             space().osbrace().regName(returnReg).csbrace().commaSpace().reg(op.operands().getFirst()).ptxNl();
 545         }
 546         ret();
 547     }
 548 
 549     public void javaBreak(JavaOp.BreakOp op) {
 550         brkpt();
 551     }
 552 
 553     public void branch(CoreOp.BranchOp op) {
 554         loadBlockParams(op.successors().getFirst());
 555         bra().space().block(op.successors().getFirst().targetBlock());
 556     }
 557 
 558     public void condBranch(CoreOp.ConditionalBranchOp op) {
 559         loadBlockParams(op.successors().getFirst());
 560         loadBlockParams(op.successors().getLast());
 561         at().reg(op.operands().getFirst()).space()
 562                 .bra().space().block(op.successors().getFirst().targetBlock()).ptxNl();
 563         bra().space().block(op.successors().getLast().targetBlock());
 564     }
 565 
 566     public void neg(JavaOp.NegOp op) {
 567         neg().resultType(op.resultType(), true).space().reg(op.result(), getResultType(op.resultType())).commaSpace().reg(op.operands().getFirst());
 568     }
 569 
 570     /*
 571      * Helper functions for printing blocks and variables
 572      */
 573 
 574     public void loadBlockParams(Block.Reference block) {
 575         for (int i = 0; i < block.arguments().size(); i++) {
 576             Block.Parameter p = block.targetBlock().parameters().get(i);
 577             mov().resultType(p.type(), false).space().reg(p, getResultType(p.type()))
 578                     .commaSpace().reg(block.arguments().get(i)).ptxNl();
 579         }
 580     }
 581 
 582     public PTXHATKernelBuilder block(Block block) {
 583         return typeName("block_").intVal(block.index());
 584     }
 585 
 586     public PTXHATKernelBuilder fieldReg(Field ref) {
 587         if (fieldToRegMap.containsKey(ref)) {
 588             return regName(fieldToRegMap.get(ref));
 589         }
 590         if (ref.isDestination()) {
 591             fieldToRegMap.putIfAbsent(ref, new PTXRegister(incrOrdinal(addressType()), addressType()));
 592         } else {
 593             fieldToRegMap.putIfAbsent(ref, new PTXRegister(incrOrdinal(PTXRegister.Type.U32), PTXRegister.Type.U32));
 594         }
 595         return regName(fieldToRegMap.get(ref));
 596     }
 597 
 598     public PTXHATKernelBuilder fieldReg(Field ref, Value value) {
 599         if (fieldToRegMap.containsKey(ref)) {
 600             return regName(fieldToRegMap.get(ref));
 601         }
 602         if (ref.isDestination()) {
 603             fieldToRegMap.putIfAbsent(ref, new PTXRegister(getOrdinal(addressType()), addressType()));
 604             return reg(value, addressType());
 605         } else {
 606             fieldToRegMap.putIfAbsent(ref, new PTXRegister(getOrdinal(PTXRegister.Type.U32), PTXRegister.Type.U32));
 607             return reg(value, PTXRegister.Type.U32);
 608         }
 609     }
 610 
 611     public Field getFieldObj(String fieldName) {
 612         for (Field f : fieldToRegMap.keySet()) {
 613             if (f.toString().equals(fieldName)) return f;
 614         }
 615         throw new IllegalStateException("no existing field");
 616     }
 617 
 618     public PTXHATKernelBuilder resultReg(Op op, PTXRegister.Type type) {
 619         return identifier(addReg(op.result(), type));
 620     }
 621 
 622     public PTXHATKernelBuilder reg(Value val, PTXRegister.Type type) {
 623         if (varToRegMap.containsKey(val)) {
 624             return regName(getReg(val));
 625         } else {
 626             return identifier(addReg(val, type));
 627         }
 628     }
 629 
 630     public PTXHATKernelBuilder reg(Value val) {
 631         return regName(getReg(val));
 632     }
 633 
 634     public PTXRegister getReg(Value val) {
 635         if (varToRegMap.get(val) == null && val instanceof Op.Result result && result.op() instanceof JavaOp.FieldAccessOp.FieldLoadOp fieldLoadOp) {
 636             return fieldToRegMap.get(getFieldObj(fieldLoadOp.fieldDescriptor().name()));
 637         }
 638         if (varToRegMap.containsKey(val)) {
 639             return varToRegMap.get(val);
 640         } else {
 641             throw new IllegalStateException("var to reg mapping doesn't exist");
 642         }
 643     }
 644 
 645     public String addReg(Value val, PTXRegister.Type type) {
 646         if (varToRegMap.containsKey(val)) {
 647             return varToRegMap.get(val).name();
 648         }
 649         varToRegMap.put(val, new PTXRegister(incrOrdinal(type), type));
 650         return varToRegMap.get(val).name();
 651     }
 652 
 653     public Integer getOrdinal(PTXRegister.Type type) {
 654         ordinalMap.putIfAbsent(type, 1);
 655         return ordinalMap.get(type);
 656     }
 657 
 658     public Integer incrOrdinal(PTXRegister.Type type) {
 659         ordinalMap.putIfAbsent(type, 1);
 660         int out = ordinalMap.get(type);
 661         ordinalMap.put(type, out + 1);
 662         return out;
 663     }
 664 
 665     public PTXHATKernelBuilder size() {
 666         return (addressSize == 32) ? u32() : u64();
 667     }
 668 
 669     public PTXRegister.Type addressType() {
 670         return (addressSize == 32) ? PTXRegister.Type.U32 : PTXRegister.Type.U64;
 671     }
 672 
 673     public PTXHATKernelBuilder resultType(TypeElement type, boolean signedResult) {
 674         PTXRegister.Type res = getResultType(type);
 675         if (signedResult && (res == PTXRegister.Type.U32)) return s32();
 676         return dot().typeName(getResultType(type).getName());
 677     }
 678 
 679     public PTXHATKernelBuilder paramType(TypeElement type) {
 680         PTXRegister.Type res = getResultType(type);
 681         if (res == PTXRegister.Type.U32) return b32();
 682         if (res == PTXRegister.Type.U64) return b64();
 683         return dot().typeName(getResultType(type).getName());
 684     }
 685 
 686     public PTXRegister.Type getResultType(TypeElement type) {
 687         switch (type.toString()) {
 688             case "float" -> {
 689                 return PTXRegister.Type.F32;
 690             }
 691             case "double" -> {
 692                 return PTXRegister.Type.F64;
 693             }
 694             case "int" -> {
 695                 return PTXRegister.Type.U32;
 696             }
 697             case "boolean" -> {
 698                 return PTXRegister.Type.PREDICATE;
 699             }
 700             default -> {
 701                 return PTXRegister.Type.U64;
 702             }
 703         }
 704     }
 705 
 706     /*
 707      * Basic CodeBuilder functions
 708      */
 709 
 710     // used for parameter list
 711     // prints out items separated by a comma then new line
 712     // Don't know why this was overriding with the same code grf.
 713    /* @Override
 714     public <I> PTXHATKernelBuilder commaNlSeparated(Iterable<I> iterable, Consumer<I> c) {
 715         StreamCounter.of(iterable, (counter, t) -> {
 716             if (counter.isNotFirst()) {
 717                 comma().nl();
 718             }
 719             c.accept(t);
 720         });
 721         return self();
 722     }
 723 */
 724     public PTXHATKernelBuilder address(String address) {
 725         return osbrace().constant(address).csbrace();
 726     }
 727 
 728     public PTXHATKernelBuilder address(String address, int offset) {
 729         osbrace().constant(address);
 730         if (offset == 0) {
 731             return csbrace();
 732         } else if (offset > 0) {
 733             plus();
 734         }
 735         return intVal(offset).csbrace();
 736     }
 737 
 738     public PTXHATKernelBuilder ptxNl() {
 739         return semicolon().nl().ptxIndent();
 740     }
 741 
 742 
 743     public PTXHATKernelBuilder param() {
 744         return keyword("param");
 745     }
 746 
 747     public PTXHATKernelBuilder global() {
 748         return dot().keyword("global");
 749     }
 750 
 751     public PTXHATKernelBuilder rn() {
 752         return dot().keyword("rn");
 753     }
 754 
 755     public PTXHATKernelBuilder rm() {
 756         return dot().keyword("rm");
 757     }
 758 
 759     public PTXHATKernelBuilder rzi() {
 760         return dot().keyword("rzi");
 761     }
 762 
 763     public PTXHATKernelBuilder to() {
 764         return dot().keyword("to");
 765     }
 766 
 767     public PTXHATKernelBuilder lo() {
 768         return dot().keyword("lo");
 769     }
 770 
 771     public PTXHATKernelBuilder wide() {
 772         return dot().keyword("wide");
 773     }
 774 
 775     public PTXHATKernelBuilder uni() {
 776         return dot().keyword("uni");
 777     }
 778 
 779     public PTXHATKernelBuilder sat() {
 780         return dot().keyword("sat");
 781     }
 782 
 783     public PTXHATKernelBuilder ftz() {
 784         return dot().keyword("ftz");
 785     }
 786 
 787     public PTXHATKernelBuilder approx() {
 788         return dot().keyword("approx");
 789     }
 790 
 791     public PTXHATKernelBuilder mov() {
 792         return keyword("mov");
 793     }
 794 
 795     public PTXHATKernelBuilder setp() {
 796         return keyword("setp");
 797     }
 798 
 799     public PTXHATKernelBuilder selp() {
 800         return keyword("selp");
 801     }
 802 
 803     public PTXHATKernelBuilder ld() {
 804         return keyword("ld");
 805     }
 806 
 807     public PTXHATKernelBuilder st() {
 808         return keyword("st");
 809     }
 810 
 811     public PTXHATKernelBuilder cvt() {
 812         return keyword("cvt");
 813     }
 814 
 815     public PTXHATKernelBuilder bra() {
 816         return keyword("bra");
 817     }
 818 
 819     public PTXHATKernelBuilder ret() {
 820         return keyword("ret");
 821     }
 822 
 823     public PTXHATKernelBuilder rem() {
 824         return keyword("rem");
 825     }
 826 
 827     public PTXHATKernelBuilder mul() {
 828         return keyword("mul");
 829     }
 830 
 831     public PTXHATKernelBuilder div() {
 832         return keyword("div");
 833     }
 834 
 835     public PTXHATKernelBuilder rcp() {
 836         return keyword("rcp");
 837     }
 838 
 839     public PTXHATKernelBuilder add() {
 840         return keyword("add");
 841     }
 842 
 843     public PTXHATKernelBuilder sub() {
 844         return keyword("sub");
 845     }
 846 
 847     public PTXHATKernelBuilder lt() {
 848         return keyword("lt");
 849     }
 850 
 851     public PTXHATKernelBuilder gt() {
 852         return keyword("gt");
 853     }
 854 
 855     public PTXHATKernelBuilder le() {
 856         return keyword("le");
 857     }
 858 
 859     public PTXHATKernelBuilder ge() {
 860         return keyword("ge");
 861     }
 862 
 863     public PTXHATKernelBuilder geu() {
 864         return keyword("geu");
 865     }
 866 
 867     public PTXHATKernelBuilder ne() {
 868         return keyword("ne");
 869     }
 870 
 871     public PTXHATKernelBuilder eq() {
 872         return keyword("eq");
 873     }
 874 
 875     public PTXHATKernelBuilder xor() {
 876         return keyword("xor");
 877     }
 878 
 879     public PTXHATKernelBuilder or() {
 880         return keyword("or");
 881     }
 882 
 883     public PTXHATKernelBuilder and() {
 884         return keyword("and");
 885     }
 886 
 887     public PTXHATKernelBuilder cvta() {
 888         return keyword("cvta");
 889     }
 890 
 891     public PTXHATKernelBuilder mad() {
 892         return keyword("mad");
 893     }
 894 
 895     public PTXHATKernelBuilder fma() {
 896         return keyword("fma");
 897     }
 898 
 899     public PTXHATKernelBuilder sqrt() {
 900         return keyword("sqrt");
 901     }
 902 
 903     public PTXHATKernelBuilder abs() {
 904         return keyword("abs");
 905     }
 906 
 907     public PTXHATKernelBuilder ex2() {
 908         return keyword("ex2");
 909     }
 910 
 911     public PTXHATKernelBuilder shl() {
 912         return keyword("shl");
 913     }
 914 
 915     public PTXHATKernelBuilder shr() {
 916         return keyword("shr");
 917     }
 918 
 919     public PTXHATKernelBuilder neg() {
 920         return keyword("neg");
 921     }
 922 
 923     public PTXHATKernelBuilder call() {
 924         return keyword("call");
 925     }
 926 
 927     public PTXHATKernelBuilder exit() {
 928         return keyword("exit");
 929     }
 930 
 931     public PTXHATKernelBuilder brkpt() {
 932         return keyword("brkpt");
 933     }
 934 
 935     public PTXHATKernelBuilder ptxIndent() {
 936         return space().space().space().space();
 937     }
 938 
 939     public PTXHATKernelBuilder u32() {
 940         return dot().typeName(PTXRegister.Type.U32.getName());
 941     }
 942 
 943     public PTXHATKernelBuilder s32() {
 944         return dot().typeName(PTXRegister.Type.S32.getName());
 945     }
 946 
 947     public PTXHATKernelBuilder f32() {
 948         return dot().typeName(PTXRegister.Type.F32.getName());
 949     }
 950 
 951     public PTXHATKernelBuilder b32() {
 952         return dot().typeName(PTXRegister.Type.B32.getName());
 953     }
 954 
 955     public PTXHATKernelBuilder u64() {
 956         return dot().typeName(PTXRegister.Type.U64.getName());
 957     }
 958 
 959     public PTXHATKernelBuilder s64() {
 960         return dot().typeName(PTXRegister.Type.S64.getName());
 961     }
 962 
 963     public PTXHATKernelBuilder f64() {
 964         return dot().typeName(PTXRegister.Type.F64.getName());
 965     }
 966 
 967     public PTXHATKernelBuilder b64() {
 968         return dot().typeName(PTXRegister.Type.B64.getName());
 969     }
 970 
 971     public PTXHATKernelBuilder version() {
 972         return dot().keyword("version");
 973     }
 974 
 975     public PTXHATKernelBuilder target() {
 976         return dot().keyword("target");
 977     }
 978 
 979     public PTXHATKernelBuilder addressSize() {
 980         return dot().keyword("address_size");
 981     }
 982 
 983     public PTXHATKernelBuilder major(int major) {
 984         return intVal(major);
 985     }
 986 
 987     public PTXHATKernelBuilder minor(int minor) {
 988         return intVal(minor);
 989     }
 990 
 991     public PTXHATKernelBuilder target(String target) {
 992         return keyword(target);
 993     }
 994 
 995     public PTXHATKernelBuilder size(int addressSize) {
 996         return intVal(addressSize);
 997     }
 998 
 999     public PTXHATKernelBuilder funcName(String funcName) {
1000         return identifier(funcName);
1001     }
1002 
1003     public PTXHATKernelBuilder visible() {
1004         return dot().keyword("visible");
1005     }
1006 
1007     public PTXHATKernelBuilder entry() {
1008         return dot().keyword("entry");
1009     }
1010 
1011     public PTXHATKernelBuilder func() {
1012         return dot().keyword("func");
1013     }
1014 
1015     public PTXHATKernelBuilder oabrace() {
1016         return symbol("<");
1017     }
1018 
1019     public PTXHATKernelBuilder cabrace() {
1020         return symbol(">");
1021     }
1022 
1023     public PTXHATKernelBuilder regName(PTXRegister reg) {
1024         return identifier(reg.name());
1025     }
1026 
1027     public PTXHATKernelBuilder regName(String regName) {
1028         return identifier(regName);
1029     }
1030 
1031     public PTXHATKernelBuilder regType(Value val) {
1032         return keyword(getReg(val).type().getName());
1033     }
1034 
1035     public PTXHATKernelBuilder regType(PTXRegister.Type t) {
1036         return keyword(t.getName());
1037     }
1038 
1039     public PTXHATKernelBuilder regTypePrefix(PTXRegister.Type t) {
1040         return keyword(t.getRegPrefix());
1041     }
1042 
1043     public PTXHATKernelBuilder reg() {
1044         return dot().keyword("reg");
1045     }
1046 
1047     public PTXHATKernelBuilder retVal() {
1048         return keyword("retval");
1049     }
1050 
1051     public PTXHATKernelBuilder intVal(int i) {
1052         return constant(String.valueOf(i));
1053     }
1054 
1055     public PTXHATKernelBuilder floatVal(String s) {
1056         return constant("0f").constant(s);
1057     }
1058 
1059     public PTXHATKernelBuilder doubleVal(String s) {
1060         return constant("0d").constant(s);
1061     }
1062 }