1 /*
   2  * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
   3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
   4  *
   5  * This code is free software; you can redistribute it and/or modify it
   6  * under the terms of the GNU General Public License version 2 only, as
   7  * published by the Free Software Foundation.  Oracle designates this
   8  * particular file as subject to the "Classpath" exception as provided
   9  * by Oracle in the LICENSE file that accompanied this code.
  10  *
  11  * This code is distributed in the hope that it will be useful, but WITHOUT
  12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
  13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
  14  * version 2 for more details (a copy is included in the LICENSE file that
  15  * accompanied this code).
  16  *
  17  * You should have received a copy of the GNU General Public License version
  18  * 2 along with this work; if not, write to the Free Software Foundation,
  19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
  20  *
  21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
  22  * or visit www.oracle.com if you need additional information or have any
  23  * questions.
  24  */
  25 package hat.backend.ffi;
  26 
  27 import optkl.FuncOpParams;
  28 import optkl.ParamVar;
  29 import optkl.codebuilders.CodeBuilder;
  30 
  31 import jdk.incubator.code.*;
  32 import jdk.incubator.code.dialect.core.CoreOp;
  33 import jdk.incubator.code.dialect.java.JavaOp;
  34 import jdk.incubator.code.dialect.java.JavaType;
  35 
  36 import java.lang.foreign.MemoryLayout;
  37 import java.lang.invoke.MethodHandles;
  38 import java.util.ArrayList;
  39 import java.util.HashMap;
  40 import java.util.List;
  41 import java.util.Map;
  42 import java.util.stream.Stream;
  43 
  44 import static optkl.OpHelper.FieldAccess.fieldAccess;
  45 import static optkl.OpHelper.Invoke;
  46 
  47 import static optkl.OpHelper.Invoke.invoke;
  48 
  49 
  50 public class PTXHATKernelBuilder extends CodeBuilder<PTXHATKernelBuilder> {
  51 
  52     Map<Value, PTXRegister> varToRegMap;
  53     List<String> paramNames;
  54     List<Block.Parameter> paramObjects;
  55     Map<Field, PTXRegister> fieldToRegMap;
  56 
  57     HashMap<PTXRegister.Type, Integer> ordinalMap;
  58 
  59     PTXRegister returnReg;
  60     private int addressSize;
  61 
  62     public enum Field {
  63         NTID_X ("ntid.x", false),
  64         CTAID_X ("ctaid.x", false),
  65         TID_X ("tid.x", false),
  66         KC_X ("x", false),
  67         KC_ADDR("kc", true),
  68         KC_MAXX ("maxX", false);
  69 
  70         private final String name;
  71         private final boolean destination;
  72 
  73         Field(String name, boolean destination) {
  74             this.name = name;
  75             this.destination = destination;
  76         }
  77         public String toString() {
  78             return this.name;
  79         }
  80         public boolean isDestination() {return this.destination;}
  81     }
  82 
  83     public PTXHATKernelBuilder(int addressSize) {
  84         varToRegMap = new HashMap<>();
  85         paramNames = new ArrayList<>();
  86         fieldToRegMap = new HashMap<>();
  87         paramObjects = new ArrayList<>();
  88         ordinalMap = new HashMap<>();
  89         this.addressSize = addressSize;
  90     }
  91 
  92     public PTXHATKernelBuilder() {
  93         this(32);
  94     }
  95 
  96     public void ptxHeader(int major, int minor, String target, int addressSize) {
  97         this.addressSize = addressSize;
  98         version().sp().major(major).dot().minor(minor).nl();
  99         target().sp().target(target).nl();
 100         addressSize().sp().size(addressSize);
 101     }
 102 
 103     public void functionHeader(String funcName, boolean entry, CodeType yieldType) {
 104         if (entry) {
 105             visible().sp().entry().sp();
 106         } else {
 107             func().sp();
 108         }
 109         if (!yieldType.toString().equals("void")) {
 110             returnReg = new PTXRegister(getOrdinal(getResultType(yieldType)), getResultType(yieldType));
 111             returnReg.name("%retReg");
 112             oparen().dot().param().sp().paramType(yieldType);
 113             sp().regName(returnReg).cparen().sp();
 114         }
 115         funcName(funcName);
 116     }
 117 
 118     public PTXHATKernelBuilder parameters(List<FuncOpParams.Info> infoList) {
 119         paren(_ ->
 120                 nl()
 121                         .commaNlSeparated(
 122                         infoList,
 123                         info -> {
 124                             ptxIndent().dot().param().sp().paramType(info.javaType);
 125                             sp().regName(info.varOp.varName());
 126                             paramNames.add(info.varOp.varName());
 127                         }
 128                         ).nl()).nl();
 129         return this;
 130     }
 131 
 132     public void blockBody(MethodHandles.Lookup lookup,Block block, Stream<Op> ops) {
 133         if (block.index() == 0) {
 134             for (Block.Parameter p : block.parameters()) {
 135                 ptxIndent().ld().dot().param();
 136                 resultType(p.type(), false).ptxIndent().sp();
 137                 reg(p, getResultType(p.type())).csp().osbrace().regName(paramNames.get(p.index())).csbrace().semicolon().nl();
 138                 paramObjects.add(p);
 139             }
 140         }
 141         nl();
 142         block(block);
 143         colon().nl();
 144         ops.forEach(op -> {
 145             if (invoke(lookup,op) instanceof Invoke invoke && !invoke.isMappableIface()) {
 146                 ptxIndent().convert(lookup,op).nl();
 147             } else {
 148                 ptxIndent().convert(lookup,op).semicolon().nl();
 149             }
 150         });
 151     }
 152 
 153     public void ptxRegisterDecl() {
 154         for (PTXRegister.Type t : ordinalMap.keySet()) {
 155             ptxIndent().reg().sp();
 156             if (t.equals(PTXRegister.Type.U32)) {
 157                 b32();
 158             } else if (t.equals(PTXRegister.Type.U64)) {
 159                 b64();
 160             } else {
 161                 dot().regType(t);
 162             }
 163             ptxIndent().regTypePrefix(t).oabrace().intVal(ordinalMap.get(t)).cabrace().semicolon().nl();
 164         }
 165         nl();
 166     }
 167 
 168     public void functionPrologue() {
 169         obrace().nl();
 170     }
 171 
 172     public void functionEpilogue() {
 173         cbrace();
 174     }
 175 
 176 
 177     public PTXHATKernelBuilder convert(MethodHandles.Lookup lookup,Op op) {
 178         switch (op) {
 179             case JavaOp.FieldAccessOp.FieldLoadOp $ -> fieldLoad(lookup,$);
 180             case JavaOp.FieldAccessOp.FieldStoreOp $ -> fieldStore($);
 181             case JavaOp.BinaryOp $ -> binaryOperation($);
 182             case JavaOp.CompareOp $ -> compareOperation($);
 183             case JavaOp.ConvOp $ -> conv($);
 184             case CoreOp.ConstantOp $ -> constant($);
 185             case CoreOp.YieldOp $ -> javaYield($);
 186             case JavaOp.InvokeOp $ -> methodCall(invoke(lookup,$));
 187             case CoreOp.VarOp $ when ParamVar.of($) != null -> varFuncDeclaration($);
 188             case CoreOp.VarOp $ -> varDeclaration($);
 189             case CoreOp.ReturnOp $ -> ret($);
 190             case JavaOp.BreakOp $ -> javaBreak($);
 191             default -> { // Why are  these switch ops not just inlined above?
 192                 switch (op){
 193                     case CoreOp.BranchOp $ -> branch($);
 194                     case CoreOp.ConditionalBranchOp $ -> condBranch($);
 195                     case JavaOp.NegOp $ -> neg($);
 196                     case PTXPtrOp $ -> ptxPtr($);
 197                     default -> throw new IllegalStateException("op translation doesn't exist");
 198                 }
 199             }
 200         }
 201         return this;
 202     }
 203 
 204     public void ptxPtr(PTXPtrOp op) {
 205         PTXRegister source;
 206         int offset = (int) op.boundSchema.groupLayout().byteOffset(MemoryLayout.PathElement.groupElement(op.fieldName));
 207 
 208         if (op.fieldName.equals("array")) {
 209             source = new PTXRegister(incrOrdinal(addressType()), addressType());
 210             addKeyword().s64().sp().regName(source).csp().reg(op.operands().get(0)).csp().reg(op.operands().get(1)).ptxNl();
 211         } else {
 212             source = getReg(op.operands().getFirst());
 213         }
 214 
 215         if (op.resultType.toString().equals("void")) {
 216             st().global().dot().regType(op.operands().getLast()).sp().address(source.name(), offset).csp().reg(op.operands().getLast());
 217         } else {
 218             ld().global().resultType(op.resultType(), true).sp().reg(op.result(), getResultType(op.resultType())).csp().address(source.name(), offset);
 219         }
 220     }
 221 
 222     public void fieldLoad(MethodHandles.Lookup lookup,JavaOp.FieldAccessOp.FieldLoadOp fieldLoadOp) {
 223 
 224         var fieldAccess = fieldAccess(lookup,fieldLoadOp);
 225         if (fieldAccess.named(Field.KC_X.toString())) {
 226             if (!fieldToRegMap.containsKey(Field.KC_X)) {
 227                 loadKcX(fieldLoadOp.result());
 228             } else {
 229                 mov().u32().sp().resultReg(fieldLoadOp, PTXRegister.Type.U32).csp().fieldReg(Field.KC_X);
 230             }
 231         } else if (fieldAccess.named(Field.KC_MAXX.toString())) {
 232             if (!fieldToRegMap.containsKey(Field.KC_X)) {
 233                 loadKcX(fieldLoadOp.operands().getFirst());
 234             }
 235             ld().global().u32().sp().fieldReg(Field.KC_MAXX, fieldLoadOp.result()).csp()
 236                     .address(fieldToRegMap.get(Field.KC_ADDR).name(), 4);
 237         } else {
 238             ld().global().u32().sp().resultReg(fieldLoadOp, PTXRegister.Type.U64).csp().reg(fieldLoadOp.operands().getFirst());
 239         }
 240     }
 241 
 242     public void loadKcX(Value value) {
 243         cvta().to().global().size().sp().fieldReg(Field.KC_ADDR).csp()
 244                 .reg(paramObjects.get(paramNames.indexOf(Field.KC_ADDR.toString())), addressType()).ptxNl();
 245         mov().u32().sp().fieldReg(Field.NTID_X).csp().percent().regName(Field.NTID_X.toString()).ptxNl();
 246         mov().u32().sp().fieldReg(Field.CTAID_X).csp().percent().regName(Field.CTAID_X.toString()).ptxNl();
 247         mov().u32().sp().fieldReg(Field.TID_X).csp().percent().regName(Field.TID_X.toString()).ptxNl();
 248         mad().lo().s32().sp().fieldReg(Field.KC_X, value).csp().fieldReg(Field.CTAID_X)
 249                 .csp().fieldReg(Field.NTID_X).csp().fieldReg(Field.TID_X).ptxNl();
 250         st().global().u32().sp().address(fieldToRegMap.get(Field.KC_ADDR).name()).csp().fieldReg(Field.KC_X);
 251     }
 252 
 253     public void fieldStore(JavaOp.FieldAccessOp.FieldStoreOp op) {
 254         // TODO: fix
 255         st().global().u64().sp().resultReg(op, PTXRegister.Type.U64).csp().reg(op.operands().getFirst());
 256     }
 257     // this might be duplication of CodeBuilder symbol....
 258 @Override public
 259 PTXHATKernelBuilder symbol(Op op) {
 260         return switch (op) {
 261             case JavaOp.ModOp _ -> remKw();
 262             case JavaOp.MulOp _ -> mulKeword();
 263             case JavaOp.DivOp _ -> divKeyword();
 264             case JavaOp.AddOp _ -> addKeyword();
 265             case JavaOp.SubOp _ -> subKeyword();
 266             case JavaOp.LtOp _ -> ltKeyword();
 267             case JavaOp.GtOp _ -> gtKeyword();
 268             case JavaOp.LeOp _ -> le();
 269             case JavaOp.GeOp _ -> ge();
 270             case JavaOp.NeqOp _ -> neKeyword();
 271             case JavaOp.EqOp _ -> eqKeyword();
 272             case JavaOp.OrOp _ -> or();
 273             case JavaOp.AndOp _ -> and();
 274             case JavaOp.XorOp _ -> xor();
 275             case JavaOp.LshlOp _ -> shl();
 276             case JavaOp.AshrOp _, JavaOp.LshrOp _ -> shr();
 277             default -> throw new IllegalStateException("Unexpected value");
 278         };
 279     }
 280 
 281     public void binaryOperation(JavaOp.BinaryOp op) {
 282         symbol(op);
 283         if (getResultType(op.resultType()).getBasicType().equals(PTXRegister.Type.BasicType.FLOATING)
 284                 && (op instanceof JavaOp.DivOp || op instanceof JavaOp.MulOp)) {
 285             rn();
 286         } else if (!getResultType(op.resultType()).getBasicType().equals(PTXRegister.Type.BasicType.FLOATING)
 287                 && op instanceof JavaOp.MulOp) {
 288             lo();
 289         }
 290         resultType(op.resultType(), true).sp();
 291         resultReg(op, getResultType(op.resultType()));
 292         csp();
 293         reg(op.operands().getFirst());
 294         csp();
 295         reg(op.operands().get(1));
 296     }
 297 
 298     public void compareOperation(JavaOp.CompareOp op) {
 299         setp().dot();
 300         symbol(op).resultType(op.operands().getFirst().type(), true).sp();
 301         resultReg(op, PTXRegister.Type.PREDICATE);
 302         csp();
 303         reg(op.operands().getFirst());
 304         csp();
 305         reg(op.operands().get(1));
 306     }
 307 
 308     public void conv(JavaOp.ConvOp op) {
 309         if (op.resultType().equals(JavaType.LONG)) {
 310             if (isIndex(op)) {
 311                 mulKeword().wide().s32().sp().resultReg(op, PTXRegister.Type.U64).csp()
 312                         .reg(op.operands().getFirst()).csp().intVal(4);
 313             } else {
 314                 cvt().u64().dot().regType(op.operands().getFirst()).sp()
 315                         .resultReg(op, PTXRegister.Type.U64).csp().reg(op.operands().getFirst()).ptxNl();
 316             }
 317         } else if (op.resultType().equals(JavaType.FLOAT)) {
 318             cvt().rn().f32().dot().regType(op.operands().getFirst()).sp()
 319                     .resultReg(op, PTXRegister.Type.F32).csp().reg(op.operands().getFirst());
 320         } else if (op.resultType().equals(JavaType.DOUBLE)) {
 321             cvt();
 322             if (op.operands().getFirst().type().equals(JavaType.INT)) {
 323                 rn();
 324             }
 325             f64().dot().regType(op.operands().getFirst()).sp()
 326                     .resultReg(op, PTXRegister.Type.F64).csp().reg(op.operands().getFirst());
 327         } else if (op.resultType().equals(JavaType.INT)) {
 328             cvt();
 329             if (op.operands().getFirst().type().equals(JavaType.DOUBLE) || op.operands().getFirst().type().equals(JavaType.FLOAT)) {
 330                 rzi();
 331             } else {
 332                 rn();
 333             }
 334             s32().dot().regType(op.operands().getFirst()).sp()
 335                     .resultReg(op, PTXRegister.Type.S32).csp().reg(op.operands().getFirst());
 336         } else {
 337             cvt().rn().s32().dot().regType(op.operands().getFirst()).sp()
 338                     .resultReg(op, PTXRegister.Type.S32).csp().reg(op.operands().getFirst());
 339         }
 340     }
 341 
 342 
 343 
 344 
 345 
 346     public static class PTXRegister {
 347         private String name;
 348         private final Type type;
 349 
 350         public enum Type {
 351             S8 (8, BasicType.SIGNED, "s8", "%s"),
 352             S16 (16, BasicType.SIGNED, "s16", "%s"),
 353             S32 (32, BasicType.SIGNED, "s32", "%s"),
 354             S64 (64, BasicType.SIGNED, "s64", "%sd"),
 355             U8 (8, BasicType.UNSIGNED, "u8", "%r"),
 356             U16 (16, BasicType.UNSIGNED, "u16", "%r"),
 357             U32 (32, BasicType.UNSIGNED, "u32", "%r"),
 358             U64 (64, BasicType.UNSIGNED, "u64", "%rd"),
 359             F16 (16, BasicType.FLOATING, "f16", "%f"),
 360             F16X2 (16, BasicType.FLOATING, "f16", "%f"),
 361             F32 (32, BasicType.FLOATING, "f32", "%f"),
 362             F64 (64, BasicType.FLOATING, "f64", "%fd"),
 363             B8 (8, BasicType.BIT, "b8", "%b"),
 364             B16 (16, BasicType.BIT, "b16", "%b"),
 365             B32 (32, BasicType.BIT, "b32", "%b"),
 366             B64 (64, BasicType.BIT, "b64", "%bd"),
 367             B128 (128, BasicType.BIT, "b128", "%b"),
 368             PREDICATE (1, BasicType.PREDICATE, "pred", "%p");
 369 
 370             public enum BasicType {
 371                 SIGNED,
 372                 UNSIGNED,
 373                 FLOATING,
 374                 BIT,
 375                 PREDICATE
 376             }
 377 
 378             private final int size;
 379             private final BasicType basicType;
 380             private final String name;
 381             private final String regPrefix;
 382 
 383             Type(int size, BasicType type, String name, String regPrefix) {
 384                 this.size = size;
 385                 this.basicType = type;
 386                 this.name = name;
 387                 this.regPrefix = regPrefix;
 388             }
 389 
 390             public int getSize() {
 391                 return this.size;
 392             }
 393 
 394             public BasicType getBasicType() {
 395                 return this.basicType;
 396             }
 397 
 398             public String getName() {
 399                 return this.name;
 400             }
 401 
 402             public String getRegPrefix() {
 403                 return this.regPrefix;
 404             }
 405         }
 406 
 407         public PTXRegister(int num, Type type) {
 408             this.type = type;
 409             this.name = type.regPrefix + num;
 410         }
 411 
 412         public String name() {
 413             return this.name;
 414         }
 415 
 416         public void name(String name) {
 417             this.name = name;
 418         }
 419 
 420         public Type type() {
 421             return this.type;
 422         }
 423     }
 424 
 425 
 426     private boolean isIndex(JavaOp.ConvOp op) {
 427         for (Op.Result r : op.result().uses()) {
 428             if (r.op() instanceof PTXPtrOp) return true;
 429         }
 430         return false;
 431     }
 432 
 433     public void constant(CoreOp.ConstantOp op) {
 434         mov().resultType(op.resultType(), false).sp().resultReg(op, getResultType(op.resultType())).csp();
 435         if (op.resultType().toString().equals("float")) {
 436             if (op.value().toString().equals("0.0")) {
 437                 floatVal("00000000");
 438             } else {
 439                 floatVal(Integer.toHexString(Float.floatToIntBits(Float.parseFloat(op.value().toString()))).toUpperCase());
 440             }
 441         } else {
 442             constant(op.value().toString());
 443         }
 444     }
 445 
 446     public void javaYield(CoreOp.YieldOp op) {
 447         exit();
 448     }
 449 
 450     // S32Array and S32Array2D functions can be deleted after schema is done
 451     public void methodCall(Invoke invoke) {
 452        // Invoke invoke = Invoke.invokeOpHelper(MethodHandles.lookup(),invokeOp);
 453         switch (invoke.op().invokeReference().toString()) {
 454             // S32Array functions
 455             case "hat.buffer.S32Array::array(long)int" -> {
 456                 PTXRegister temp = new PTXRegister(incrOrdinal(addressType()), addressType());
 457                 addKeyword().s64().sp().regName(temp).csp().reg(invoke.op().operands().getFirst()).csp().reg(invoke.op().operands().get(1)).ptxNl();
 458                 ld().global().u32().sp().resultReg(invoke.op(), PTXRegister.Type.U32).csp().address(temp.name(), 4);
 459             }
 460             case "hat.buffer.S32Array::array(long, int)void" -> {
 461                 PTXRegister temp = new PTXRegister(incrOrdinal(addressType()), addressType());
 462                 addKeyword().s64().sp().regName(temp).csp().reg(invoke.op().operands().getFirst()).csp().reg(invoke.op().operands().get(1)).ptxNl();
 463                 st().global().u32().sp().address(temp.name(), 4).csp().reg(invoke.op().operands().get(2));
 464             }
 465             case "hat.buffer.S32Array::length()int" -> {
 466                 ld().global().u32().sp().resultReg(invoke.op(), PTXRegister.Type.U32).csp().address(getReg(invoke.op().operands().getFirst()).name());
 467             }
 468             // S32Array2D functions
 469             case "hat.buffer.S32Array2D::array(long, int)void" -> {
 470                 PTXRegister temp = new PTXRegister(incrOrdinal(addressType()), addressType());
 471                 addKeyword().s64().sp().regName(temp).csp().reg(invoke.op().operands().getFirst()).csp().reg(invoke.op().operands().get(1)).ptxNl();
 472                 st().global().u32().sp().address(temp.name(), 8).csp().reg(invoke.op().operands().get(2));
 473             }
 474             case "hat.buffer.S32Array2D::width()int" -> {
 475                 ld().global().u32().sp().resultReg(invoke.op(), PTXRegister.Type.U32).csp().address(getReg(invoke.op().operands().getFirst()).name());
 476             }
 477             case "hat.buffer.S32Array2D::height()int" -> {
 478                 ld().global().u32().sp().resultReg(invoke.op(), PTXRegister.Type.U32).csp().address(getReg(invoke.op().operands().getFirst()).name(), 4);
 479             }
 480             // Java Math function
 481             case "java.lang.Math::sqrt(double)double" -> {
 482                 sqrt().rn().f64().sp().resultReg(invoke.op(), PTXRegister.Type.F64).csp().reg(invoke.op().operands().getFirst()).semicolon();
 483             }
 484             default -> {
 485                 obrace().nl().ptxIndent();
 486                 for (int i = 0; i < invoke.op().operands().size(); i++) {
 487                     dot().param().sp().paramType(invoke.op().operands().get(i).type()).sp().param().intVal(i).ptxNl();
 488                     st().dot().param().paramType(invoke.op().operands().get(i).type()).sp().osbrace().param().intVal(i).csbrace().csp().reg(invoke.op().operands().get(i)).ptxNl();
 489                 }
 490                 dot().param().sp().paramType(invoke.op().resultType()).sp().retVal().ptxNl();
 491                 call().uni().sp().oparen().retVal().cparen().csp().id(invoke.name()).csp();
 492                 final int[] counter = {0};
 493                 paren(_ ->
 494                         commaSpaceSeparated(
 495                                 invoke.op().operands(),
 496                                 _ -> param().intVal(counter[0]++)
 497                         )
 498                 ).ptxNl();
 499                 ld().dot().param().paramType(invoke.op().resultType()).sp().resultReg(invoke.op(), getResultType(invoke.op().resultType())).csp().osbrace().retVal().csbrace();
 500                 ptxNl().cbrace();
 501             }
 502         }
 503     }
 504 
 505     public void varDeclaration(CoreOp.VarOp op) {
 506         ld().dot().param().resultType(op.resultType(), false).sp().resultReg(op, addressType()).csp().reg(op.operands().getFirst());
 507     }
 508 
 509     public void varFuncDeclaration(CoreOp.VarOp op) {
 510         ld().dot().param().resultType(op.resultType(), false).sp().resultReg(op, addressType()).csp().reg(op.operands().getFirst());
 511     }
 512 
 513     public void ret(CoreOp.ReturnOp op) {
 514         if (!op.operands().isEmpty()) {
 515             st().dot().param();
 516             if (returnReg.type().equals(PTXRegister.Type.U32)) {
 517                 b32();
 518             } else if (returnReg.type().equals(PTXRegister.Type.U64)) {
 519                 b64();
 520             } else {
 521                 dot().regType(returnReg.type());
 522             }
 523             sp().osbrace().regName(returnReg).csbrace().csp().reg(op.operands().getFirst()).ptxNl();
 524         }
 525         ret();
 526     }
 527 
 528     public void javaBreak(JavaOp.BreakOp op) {
 529         brkpt();
 530     }
 531 
 532     public void branch(CoreOp.BranchOp op) {
 533         loadBlockParams(op.successors().getFirst());
 534         bra().sp().block(op.successors().getFirst().targetBlock());
 535     }
 536 
 537     public void condBranch(CoreOp.ConditionalBranchOp op) {
 538         loadBlockParams(op.successors().getFirst());
 539         loadBlockParams(op.successors().getLast());
 540         at().reg(op.operands().getFirst()).sp()
 541                 .bra().sp().block(op.successors().getFirst().targetBlock()).ptxNl();
 542         bra().sp().block(op.successors().getLast().targetBlock());
 543     }
 544 
 545     public void neg(JavaOp.NegOp op) {
 546         neg().resultType(op.resultType(), true).sp().reg(op.result(), getResultType(op.resultType())).csp().reg(op.operands().getFirst());
 547     }
 548 
 549     /*
 550      * Helper functions for printing blocks and variables
 551      */
 552 
 553     public void loadBlockParams(Block.Reference block) {
 554         for (int i = 0; i < block.arguments().size(); i++) {
 555             Block.Parameter p = block.targetBlock().parameters().get(i);
 556             mov().resultType(p.type(), false).sp().reg(p, getResultType(p.type()))
 557                     .csp().reg(block.arguments().get(i)).ptxNl();
 558         }
 559     }
 560 
 561     public PTXHATKernelBuilder block(Block block) {
 562         return type("block_").intVal(block.index());
 563     }
 564 
 565     public PTXHATKernelBuilder fieldReg(Field ref) {
 566         if (fieldToRegMap.containsKey(ref)) {
 567             return regName(fieldToRegMap.get(ref));
 568         }
 569         if (ref.isDestination()) {
 570             fieldToRegMap.putIfAbsent(ref, new PTXRegister(incrOrdinal(addressType()), addressType()));
 571         } else {
 572             fieldToRegMap.putIfAbsent(ref, new PTXRegister(incrOrdinal(PTXRegister.Type.U32), PTXRegister.Type.U32));
 573         }
 574         return regName(fieldToRegMap.get(ref));
 575     }
 576 
 577     public PTXHATKernelBuilder fieldReg(Field ref, Value value) {
 578         if (fieldToRegMap.containsKey(ref)) {
 579             return regName(fieldToRegMap.get(ref));
 580         }
 581         if (ref.isDestination()) {
 582             fieldToRegMap.putIfAbsent(ref, new PTXRegister(getOrdinal(addressType()), addressType()));
 583             return reg(value, addressType());
 584         } else {
 585             fieldToRegMap.putIfAbsent(ref, new PTXRegister(getOrdinal(PTXRegister.Type.U32), PTXRegister.Type.U32));
 586             return reg(value, PTXRegister.Type.U32);
 587         }
 588     }
 589 
 590     public Field getFieldObj(String fieldName) {
 591         for (Field f : fieldToRegMap.keySet()) {
 592             if (f.toString().equals(fieldName)) return f;
 593         }
 594         throw new IllegalStateException("no existing field");
 595     }
 596 
 597     public PTXHATKernelBuilder resultReg(Op op, PTXRegister.Type type) {
 598         return id(addReg(op.result(), type));
 599     }
 600 
 601     public PTXHATKernelBuilder reg(Value val, PTXRegister.Type type) {
 602         if (varToRegMap.containsKey(val)) {
 603             return regName(getReg(val));
 604         } else {
 605             return id(addReg(val, type));
 606         }
 607     }
 608 
 609     public PTXHATKernelBuilder reg(Value val) {
 610         return regName(getReg(val));
 611     }
 612 
 613     public PTXRegister getReg(Value val) {
 614         if (varToRegMap.get(val) == null && val instanceof Op.Result result && result.op() instanceof JavaOp.FieldAccessOp.FieldLoadOp fieldLoadOp) {
 615             return fieldToRegMap.get(getFieldObj(fieldLoadOp.fieldReference().name()));
 616         }
 617         if (varToRegMap.containsKey(val)) {
 618             return varToRegMap.get(val);
 619         } else {
 620             throw new IllegalStateException("var to reg mapping doesn't exist");
 621         }
 622     }
 623 
 624     public String addReg(Value val, PTXRegister.Type type) {
 625         if (varToRegMap.containsKey(val)) {
 626             return varToRegMap.get(val).name();
 627         }
 628         varToRegMap.put(val, new PTXRegister(incrOrdinal(type), type));
 629         return varToRegMap.get(val).name();
 630     }
 631 
 632     public Integer getOrdinal(PTXRegister.Type type) {
 633         ordinalMap.putIfAbsent(type, 1);
 634         return ordinalMap.get(type);
 635     }
 636 
 637     public Integer incrOrdinal(PTXRegister.Type type) {
 638         ordinalMap.putIfAbsent(type, 1);
 639         int out = ordinalMap.get(type);
 640         ordinalMap.put(type, out + 1);
 641         return out;
 642     }
 643 
 644     public PTXHATKernelBuilder size() {
 645         return (addressSize == 32) ? u32() : u64();
 646     }
 647 
 648     public PTXRegister.Type addressType() {
 649         return (addressSize == 32) ? PTXRegister.Type.U32 : PTXRegister.Type.U64;
 650     }
 651 
 652     public PTXHATKernelBuilder resultType(CodeType type, boolean signedResult) {
 653         PTXRegister.Type res = getResultType(type);
 654         if (signedResult && (res == PTXRegister.Type.U32)) return s32();
 655         return dot().type(getResultType(type).getName());
 656     }
 657 
 658     public PTXHATKernelBuilder paramType(CodeType type) {
 659         PTXRegister.Type res = getResultType(type);
 660         if (res == PTXRegister.Type.U32) return b32();
 661         if (res == PTXRegister.Type.U64) return b64();
 662         return dot().type(getResultType(type).getName());
 663     }
 664 
 665     public PTXRegister.Type getResultType(CodeType type) {
 666         switch (type.toString()) {
 667             case "float" -> {
 668                 return PTXRegister.Type.F32;
 669             }
 670             case "double" -> {
 671                 return PTXRegister.Type.F64;
 672             }
 673             case "int" -> {
 674                 return PTXRegister.Type.U32;
 675             }
 676             case "boolean" -> {
 677                 return PTXRegister.Type.PREDICATE;
 678             }
 679             default -> {
 680                 return PTXRegister.Type.U64;
 681             }
 682         }
 683     }
 684 
 685     /*
 686      * Basic CodeBuilder functions
 687      */
 688 
 689     // used for parameter list
 690     // prints out items separated by a comma then new line
 691     // Don't know why this was overriding with the same code grf.
 692    /* @Override
 693     public <I> PTXHATKernelBuilder commaNlSeparated(Iterable<I> iterable, Consumer<I> c) {
 694         StreamCounter.of(iterable, (counter, t) -> {
 695             if (counter.isNotFirst()) {
 696                 comma().nl();
 697             }
 698             c.accept(t);
 699         });
 700         return self();
 701     }
 702 */
 703     public PTXHATKernelBuilder address(String address) {
 704         return osbrace().constant(address).csbrace();
 705     }
 706 
 707     public PTXHATKernelBuilder address(String address, int offset) {
 708         osbrace().constant(address);
 709         if (offset == 0) {
 710             return csbrace();
 711         } else if (offset > 0) {
 712             plus();
 713         }
 714         return intVal(offset).csbrace();
 715     }
 716 
 717     public PTXHATKernelBuilder ptxNl() {
 718         return semicolon().nl().ptxIndent();
 719     }
 720 
 721 
 722     public PTXHATKernelBuilder param() {
 723         return keyword("param");
 724     }
 725 
 726     public PTXHATKernelBuilder global() {
 727         return dot().keyword("global");
 728     }
 729 
 730     public PTXHATKernelBuilder rn() {
 731         return dot().keyword("rn");
 732     }
 733 
 734     public PTXHATKernelBuilder rm() {
 735         return dot().keyword("rm");
 736     }
 737 
 738     public PTXHATKernelBuilder rzi() {
 739         return dot().keyword("rzi");
 740     }
 741 
 742     public PTXHATKernelBuilder to() {
 743         return dot().keyword("to");
 744     }
 745 
 746     public PTXHATKernelBuilder lo() {
 747         return dot().keyword("lo");
 748     }
 749 
 750     public PTXHATKernelBuilder wide() {
 751         return dot().keyword("wide");
 752     }
 753 
 754     public PTXHATKernelBuilder uni() {
 755         return dot().keyword("uni");
 756     }
 757 
 758     public PTXHATKernelBuilder sat() {
 759         return dot().keyword("sat");
 760     }
 761 
 762     public PTXHATKernelBuilder ftz() {
 763         return dot().keyword("ftz");
 764     }
 765 
 766     public PTXHATKernelBuilder approx() {
 767         return dot().keyword("approx");
 768     }
 769 
 770     public PTXHATKernelBuilder mov() {
 771         return keyword("mov");
 772     }
 773 
 774     public PTXHATKernelBuilder setp() {
 775         return keyword("setp");
 776     }
 777 
 778     public PTXHATKernelBuilder selp() {
 779         return keyword("selp");
 780     }
 781 
 782     public PTXHATKernelBuilder ld() {
 783         return keyword("ld");
 784     }
 785 
 786     public PTXHATKernelBuilder st() {
 787         return keyword("st");
 788     }
 789 
 790     public PTXHATKernelBuilder cvt() {
 791         return keyword("cvt");
 792     }
 793 
 794     public PTXHATKernelBuilder bra() {
 795         return keyword("bra");
 796     }
 797 
 798     public PTXHATKernelBuilder ret() {
 799         return keyword("ret");
 800     }
 801 
 802     public PTXHATKernelBuilder remKw() {
 803         return keyword("rem");
 804     }
 805 
 806     public PTXHATKernelBuilder mulKeword() {
 807         return keyword("mul");
 808     }
 809 
 810     public PTXHATKernelBuilder divKeyword() {
 811         return keyword("div");
 812     }
 813 
 814     public PTXHATKernelBuilder rcp() {
 815         return keyword("rcp");
 816     }
 817 
 818     public PTXHATKernelBuilder addKeyword() {
 819         return keyword("add");
 820     }
 821 
 822     public PTXHATKernelBuilder subKeyword() {
 823         return keyword("sub");
 824     }
 825 
 826     public PTXHATKernelBuilder ltKeyword() {
 827         return keyword("lt");
 828     }
 829 
 830     public PTXHATKernelBuilder gtKeyword() {
 831         return keyword("gt");
 832     }
 833 
 834     public PTXHATKernelBuilder le() {
 835         return keyword("le");
 836     }
 837 
 838     public PTXHATKernelBuilder ge() {
 839         return keyword("ge");
 840     }
 841 
 842     public PTXHATKernelBuilder geu() {
 843         return keyword("geu");
 844     }
 845 
 846     public PTXHATKernelBuilder neKeyword() {
 847         return keyword("ne");
 848     }
 849 
 850     public PTXHATKernelBuilder eqKeyword() {
 851         return keyword("eq");
 852     }
 853 
 854     public PTXHATKernelBuilder xor() {
 855         return keyword("xor");
 856     }
 857 
 858     public PTXHATKernelBuilder or() {
 859         return keyword("or");
 860     }
 861 
 862     public PTXHATKernelBuilder and() {
 863         return keyword("and");
 864     }
 865 
 866     public PTXHATKernelBuilder cvta() {
 867         return keyword("cvta");
 868     }
 869 
 870     public PTXHATKernelBuilder mad() {
 871         return keyword("mad");
 872     }
 873 
 874     public PTXHATKernelBuilder fma() {
 875         return keyword("fma");
 876     }
 877 
 878     public PTXHATKernelBuilder sqrt() {
 879         return keyword("sqrt");
 880     }
 881 
 882     public PTXHATKernelBuilder abs() {
 883         return keyword("abs");
 884     }
 885 
 886     public PTXHATKernelBuilder ex2() {
 887         return keyword("ex2");
 888     }
 889 
 890     public PTXHATKernelBuilder shl() {
 891         return keyword("shl");
 892     }
 893 
 894     public PTXHATKernelBuilder shr() {
 895         return keyword("shr");
 896     }
 897 
 898     public PTXHATKernelBuilder neg() {
 899         return keyword("neg");
 900     }
 901 
 902     public PTXHATKernelBuilder call() {
 903         return keyword("call");
 904     }
 905 
 906     public PTXHATKernelBuilder exit() {
 907         return keyword("exit");
 908     }
 909 
 910     public PTXHATKernelBuilder brkpt() {
 911         return keyword("brkpt");
 912     }
 913 
 914     public PTXHATKernelBuilder ptxIndent() {
 915         return sp().sp().sp().sp();
 916     }
 917 
 918     public PTXHATKernelBuilder u32() {
 919         return dot().type(PTXRegister.Type.U32.getName());
 920     }
 921 
 922     public PTXHATKernelBuilder s32() {
 923         return dot().type(PTXRegister.Type.S32.getName());
 924     }
 925 
 926     public PTXHATKernelBuilder f32() {
 927         return dot().type(PTXRegister.Type.F32.getName());
 928     }
 929 
 930     public PTXHATKernelBuilder b32() {
 931         return dot().type(PTXRegister.Type.B32.getName());
 932     }
 933 
 934     public PTXHATKernelBuilder u64() {
 935         return dot().type(PTXRegister.Type.U64.getName());
 936     }
 937 
 938     public PTXHATKernelBuilder s64() {
 939         return dot().type(PTXRegister.Type.S64.getName());
 940     }
 941 
 942     public PTXHATKernelBuilder f64() {
 943         return dot().type(PTXRegister.Type.F64.getName());
 944     }
 945 
 946     public PTXHATKernelBuilder b64() {
 947         return dot().type(PTXRegister.Type.B64.getName());
 948     }
 949 
 950     public PTXHATKernelBuilder version() {
 951         return dot().keyword("version");
 952     }
 953 
 954     public PTXHATKernelBuilder target() {
 955         return dot().keyword("target");
 956     }
 957 
 958     public PTXHATKernelBuilder addressSize() {
 959         return dot().keyword("address_size");
 960     }
 961 
 962     public PTXHATKernelBuilder major(int major) {
 963         return intVal(major);
 964     }
 965 
 966     public PTXHATKernelBuilder minor(int minor) {
 967         return intVal(minor);
 968     }
 969 
 970     public PTXHATKernelBuilder target(String target) {
 971         return keyword(target);
 972     }
 973 
 974     public PTXHATKernelBuilder size(int addressSize) {
 975         return intVal(addressSize);
 976     }
 977 
 978 
 979 
 980     public PTXHATKernelBuilder visible() {
 981         return dot().keyword("visible");
 982     }
 983 
 984     public PTXHATKernelBuilder entry() {
 985         return dot().keyword("entry");
 986     }
 987 
 988     public PTXHATKernelBuilder func() {
 989         return dot().keyword("func");
 990     }
 991 
 992     public PTXHATKernelBuilder oabrace() {
 993         return symbol("<");
 994     }
 995 
 996     public PTXHATKernelBuilder cabrace() {
 997         return symbol(">");
 998     }
 999 
1000     public PTXHATKernelBuilder regName(PTXRegister reg) {
1001         return id(reg.name());
1002     }
1003 
1004     public PTXHATKernelBuilder regName(String regName) {
1005         return id(regName);
1006     }
1007 
1008     public PTXHATKernelBuilder regType(Value val) {
1009         return keyword(getReg(val).type().getName());
1010     }
1011 
1012     public PTXHATKernelBuilder regType(PTXRegister.Type t) {
1013         return keyword(t.getName());
1014     }
1015 
1016     public PTXHATKernelBuilder regTypePrefix(PTXRegister.Type t) {
1017         return keyword(t.getRegPrefix());
1018     }
1019 
1020     public PTXHATKernelBuilder reg() {
1021         return dot().keyword("reg");
1022     }
1023 
1024     public PTXHATKernelBuilder retVal() {
1025         return keyword("retval");
1026     }
1027 
1028     public PTXHATKernelBuilder intVal(int i) {
1029         return constant(String.valueOf(i));
1030     }
1031 
1032     public PTXHATKernelBuilder floatVal(String s) {
1033         return constant("0f").constant(s);
1034     }
1035 
1036     public PTXHATKernelBuilder doubleVal(String s) {
1037         return constant("0d").constant(s);
1038     }
1039 }