1 /* 2 * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved. 3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. 4 * 5 * This code is free software; you can redistribute it and/or modify it 6 * under the terms of the GNU General Public License version 2 only, as 7 * published by the Free Software Foundation. Oracle designates this 8 * particular file as subject to the "Classpath" exception as provided 9 * by Oracle in the LICENSE file that accompanied this code. 10 * 11 * This code is distributed in the hope that it will be useful, but WITHOUT 12 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or 13 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License 14 * version 2 for more details (a copy is included in the LICENSE file that 15 * accompanied this code). 16 * 17 * You should have received a copy of the GNU General Public License version 18 * 2 along with this work; if not, write to the Free Software Foundation, 19 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. 20 * 21 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA 22 * or visit www.oracle.com if you need additional information or have any 23 * questions. 24 */ 25 package hat.backend.ffi; 26 27 import hat.optools.*; 28 import hat.text.CodeBuilder; 29 import hat.util.StreamCounter; 30 import jdk.incubator.code.Block; 31 import jdk.incubator.code.Op; 32 import jdk.incubator.code.TypeElement; 33 import jdk.incubator.code.Value; 34 import jdk.incubator.code.op.CoreOp; 35 import jdk.incubator.code.type.JavaType; 36 37 import java.lang.foreign.MemoryLayout; 38 import java.util.ArrayList; 39 import java.util.HashMap; 40 import java.util.List; 41 import java.util.Map; 42 import java.util.function.Consumer; 43 import java.util.stream.Stream; 44 45 public class PTXHATKernelBuilder extends CodeBuilder<PTXHATKernelBuilder> { 46 47 Map<Value, PTXRegister> varToRegMap; 48 List<String> paramNames; 49 List<Block.Parameter> paramObjects; 50 Map<Field, PTXRegister> fieldToRegMap; 51 52 HashMap<PTXRegister.Type, Integer> ordinalMap; 53 54 PTXRegister returnReg; 55 private int addressSize; 56 57 public enum Field { 58 NTID_X ("ntid.x", false), 59 CTAID_X ("ctaid.x", false), 60 TID_X ("tid.x", false), 61 KC_X ("x", false), 62 KC_ADDR("kc", true), 63 KC_MAXX ("maxX", false); 64 65 private final String name; 66 private final boolean destination; 67 68 Field(String name, boolean destination) { 69 this.name = name; 70 this.destination = destination; 71 } 72 public String toString() { 73 return this.name; 74 } 75 public boolean isDestination() {return this.destination;} 76 } 77 78 public PTXHATKernelBuilder(int addressSize) { 79 varToRegMap = new HashMap<>(); 80 paramNames = new ArrayList<>(); 81 fieldToRegMap = new HashMap<>(); 82 paramObjects = new ArrayList<>(); 83 ordinalMap = new HashMap<>(); 84 this.addressSize = addressSize; 85 } 86 87 public PTXHATKernelBuilder() { 88 this(32); 89 } 90 91 public void ptxHeader(int major, int minor, String target, int addressSize) { 92 this.addressSize = addressSize; 93 version().space().major(major).dot().minor(minor).nl(); 94 target().space().target(target).nl(); 95 addressSize().space().size(addressSize); 96 } 97 98 public void functionHeader(String funcName, boolean entry, TypeElement yieldType) { 99 if (entry) { 100 visible().space().entry().space(); 101 } else { 102 func().space(); 103 } 104 if (!yieldType.toString().equals("void")) { 105 returnReg = new PTXRegister(getOrdinal(getResultType(yieldType)), getResultType(yieldType)); 106 returnReg.name("%retReg"); 107 oparen().dot().param().space().paramType(yieldType); 108 space().regName(returnReg).cparen().space(); 109 } 110 funcName(funcName); 111 } 112 113 public PTXHATKernelBuilder parameters(List<FuncOpWrapper.ParamTable.Info> infoList) { 114 paren(_ -> nl().commaNlSeparated(infoList, (info) -> { 115 ptxIndent().dot().param().space().paramType(info.javaType); 116 space().regName(info.varOp.varName()); 117 paramNames.add(info.varOp.varName()); 118 }).nl()).nl(); 119 return this; 120 } 121 122 public void blockBody(Block block, Stream<OpWrapper<?>> ops) { 123 if (block.index() == 0) { 124 for (Block.Parameter p : block.parameters()) { 125 ptxIndent().ld().dot().param(); 126 resultType(p.type(), false).ptxIndent().space(); 127 reg(p, getResultType(p.type())).commaSpace().osbrace().regName(paramNames.get(p.index())).csbrace().semicolon().nl(); 128 paramObjects.add(p); 129 } 130 } 131 nl(); 132 block(block); 133 colon().nl(); 134 ops.forEach(op -> { 135 if (op instanceof InvokeOpWrapper invoke && !invoke.isIfaceBufferMethod()) { 136 ptxIndent().convert(op).nl(); 137 } else { 138 ptxIndent().convert(op).semicolon().nl(); 139 } 140 }); 141 } 142 143 public void ptxRegisterDecl() { 144 for (PTXRegister.Type t : ordinalMap.keySet()) { 145 ptxIndent().reg().space(); 146 if (t.equals(PTXRegister.Type.U32)) { 147 b32(); 148 } else if (t.equals(PTXRegister.Type.U64)) { 149 b64(); 150 } else { 151 dot().regType(t); 152 } 153 ptxIndent().regTypePrefix(t).oabrace().intVal(ordinalMap.get(t)).cabrace().semicolon().nl(); 154 } 155 nl(); 156 } 157 158 public void functionPrologue() { 159 obrace().nl(); 160 } 161 162 public void functionEpilogue() { 163 cbrace(); 164 } 165 166 public PTXHATKernelBuilder convert(OpWrapper<?> wrappedOp) { 167 switch (wrappedOp) { 168 case FieldLoadOpWrapper op -> fieldLoad(op); 169 case FieldStoreOpWrapper op -> fieldStore(op); 170 case BinaryArithmeticOrLogicOperation op -> binaryOperation(op); 171 case BinaryTestOpWrapper op -> binaryTest(op); 172 case ConvOpWrapper op -> conv(op); 173 case ConstantOpWrapper op -> constant(op); 174 case YieldOpWrapper op -> javaYield(op); 175 case InvokeOpWrapper op -> methodCall(op); 176 case VarDeclarationOpWrapper op -> varDeclaration(op); 177 case VarFuncDeclarationOpWrapper op -> varFuncDeclaration(op); 178 case ReturnOpWrapper op -> ret(op); 179 case JavaBreakOpWrapper op -> javaBreak(op); 180 default -> { 181 switch (wrappedOp.op()){ 182 case CoreOp.BranchOp op -> branch(op); 183 case CoreOp.ConditionalBranchOp op -> condBranch(op); 184 case CoreOp.NegOp op -> neg(op); 185 case PTXPtrOp op -> ptxPtr(op); 186 default -> throw new IllegalStateException("op translation doesn't exist"); 187 } 188 } 189 } 190 return this; 191 } 192 193 public void ptxPtr(PTXPtrOp op) { 194 PTXRegister source; 195 int offset = (int) op.boundSchema.groupLayout().byteOffset(MemoryLayout.PathElement.groupElement(op.fieldName)); 196 197 if (op.fieldName.equals("array")) { 198 source = new PTXRegister(incrOrdinal(addressType()), addressType()); 199 add().s64().space().regName(source).commaSpace().reg(op.operands().get(0)).commaSpace().reg(op.operands().get(1)).ptxNl(); 200 } else { 201 source = getReg(op.operands().getFirst()); 202 } 203 204 if (op.resultType.toString().equals("void")) { 205 st().global().dot().regType(op.operands().getLast()).space().address(source.name(), offset).commaSpace().reg(op.operands().getLast()); 206 } else { 207 ld().global().resultType(op.resultType(), true).space().reg(op.result(), getResultType(op.resultType())).commaSpace().address(source.name(), offset); 208 } 209 } 210 211 public void fieldLoad(FieldLoadOpWrapper op) { 212 if (op.fieldName().equals(Field.KC_X.toString())) { 213 if (!fieldToRegMap.containsKey(Field.KC_X)) { 214 loadKcX(op.result()); 215 } else { 216 mov().u32().space().resultReg(op, PTXRegister.Type.U32).commaSpace().fieldReg(Field.KC_X); 217 } 218 } else if (op.fieldName().equals(Field.KC_MAXX.toString())) { 219 if (!fieldToRegMap.containsKey(Field.KC_X)) { 220 loadKcX(op.operandNAsValue(0)); 221 } 222 ld().global().u32().space().fieldReg(Field.KC_MAXX, op.result()).commaSpace() 223 .address(fieldToRegMap.get(Field.KC_ADDR).name(), 4); 224 } else { 225 ld().global().u32().space().resultReg(op, PTXRegister.Type.U64).commaSpace().reg(op.operandNAsValue(0)); 226 } 227 } 228 229 public void loadKcX(Value value) { 230 cvta().to().global().size().space().fieldReg(Field.KC_ADDR).commaSpace() 231 .reg(paramObjects.get(paramNames.indexOf(Field.KC_ADDR.toString())), addressType()).ptxNl(); 232 mov().u32().space().fieldReg(Field.NTID_X).commaSpace().percent().regName(Field.NTID_X.toString()).ptxNl(); 233 mov().u32().space().fieldReg(Field.CTAID_X).commaSpace().percent().regName(Field.CTAID_X.toString()).ptxNl(); 234 mov().u32().space().fieldReg(Field.TID_X).commaSpace().percent().regName(Field.TID_X.toString()).ptxNl(); 235 mad().lo().s32().space().fieldReg(Field.KC_X, value).commaSpace().fieldReg(Field.CTAID_X) 236 .commaSpace().fieldReg(Field.NTID_X).commaSpace().fieldReg(Field.TID_X).ptxNl(); 237 st().global().u32().space().address(fieldToRegMap.get(Field.KC_ADDR).name()).commaSpace().fieldReg(Field.KC_X); 238 } 239 240 public void fieldStore(FieldStoreOpWrapper op) { 241 // TODO: fix 242 st().global().u64().space().resultReg(op, PTXRegister.Type.U64).commaSpace().reg(op.operandNAsValue(0)); 243 } 244 245 PTXHATKernelBuilder symbol(Op op) { 246 return switch (op) { 247 case CoreOp.ModOp _ -> rem(); 248 case CoreOp.MulOp _ -> mul(); 249 case CoreOp.DivOp _ -> div(); 250 case CoreOp.AddOp _ -> add(); 251 case CoreOp.SubOp _ -> sub(); 252 case CoreOp.LtOp _ -> lt(); 253 case CoreOp.GtOp _ -> gt(); 254 case CoreOp.LeOp _ -> le(); 255 case CoreOp.GeOp _ -> ge(); 256 case CoreOp.NeqOp _ -> ne(); 257 case CoreOp.EqOp _ -> eq(); 258 case CoreOp.OrOp _ -> or(); 259 case CoreOp.AndOp _ -> and(); 260 case CoreOp.XorOp _ -> xor(); 261 case CoreOp.LshlOp _ -> shl(); 262 case CoreOp.AshrOp _, CoreOp.LshrOp _ -> shr(); 263 default -> throw new IllegalStateException("Unexpected value"); 264 }; 265 } 266 267 public void binaryOperation(BinaryArithmeticOrLogicOperation op) { 268 symbol(op.op()); 269 if (getResultType(op.resultType()).getBasicType().equals(PTXRegister.Type.BasicType.FLOATING) 270 && (op.op() instanceof CoreOp.DivOp || op.op() instanceof CoreOp.MulOp)) { 271 rn(); 272 } else if (!getResultType(op.resultType()).getBasicType().equals(PTXRegister.Type.BasicType.FLOATING) 273 && op.op() instanceof CoreOp.MulOp) { 274 lo(); 275 } 276 resultType(op.resultType(), true).space(); 277 resultReg(op, getResultType(op.resultType())); 278 commaSpace(); 279 reg(op.operandNAsValue(0)); 280 commaSpace(); 281 reg(op.operandNAsValue(1)); 282 } 283 284 public void binaryTest(BinaryTestOpWrapper op) { 285 setp().dot(); 286 symbol(op.op()).resultType(op.operandNAsValue(0).type(), true).space(); 287 resultReg(op, PTXRegister.Type.PREDICATE); 288 commaSpace(); 289 reg(op.operandNAsValue(0)); 290 commaSpace(); 291 reg(op.operandNAsValue(1)); 292 } 293 294 public void conv(ConvOpWrapper op) { 295 if (op.resultJavaType().equals(JavaType.LONG)) { 296 if (isIndex(op)) { 297 mul().wide().s32().space().resultReg(op, PTXRegister.Type.U64).commaSpace() 298 .reg(op.operandNAsValue(0)).commaSpace().intVal(4); 299 } else { 300 cvt().u64().dot().regType(op.operandNAsValue(0)).space() 301 .resultReg(op, PTXRegister.Type.U64).commaSpace().reg(op.operandNAsValue(0)).ptxNl(); 302 } 303 } else if (op.resultJavaType().equals(JavaType.FLOAT)) { 304 cvt().rn().f32().dot().regType(op.operandNAsValue(0)).space() 305 .resultReg(op, PTXRegister.Type.F32).commaSpace().reg(op.operandNAsValue(0)); 306 } else if (op.resultJavaType().equals(JavaType.DOUBLE)) { 307 cvt(); 308 if (op.operandNAsValue(0).type().equals(JavaType.INT)) { 309 rn(); 310 } 311 f64().dot().regType(op.operandNAsValue(0)).space() 312 .resultReg(op, PTXRegister.Type.F64).commaSpace().reg(op.operandNAsValue(0)); 313 } else if (op.resultJavaType().equals(JavaType.INT)) { 314 cvt(); 315 if (op.operandNAsValue(0).type().equals(JavaType.DOUBLE) || op.operandNAsValue(0).type().equals(JavaType.FLOAT)) { 316 rzi(); 317 } else { 318 rn(); 319 } 320 s32().dot().regType(op.operandNAsValue(0)).space() 321 .resultReg(op, PTXRegister.Type.S32).commaSpace().reg(op.operandNAsValue(0)); 322 } else { 323 cvt().rn().s32().dot().regType(op.operandNAsValue(0)).space() 324 .resultReg(op, PTXRegister.Type.S32).commaSpace().reg(op.operandNAsValue(0)); 325 } 326 } 327 328 private boolean isIndex(ConvOpWrapper op) { 329 for (Op.Result r : op.result().uses()) { 330 if (r.op() instanceof PTXPtrOp) return true; 331 } 332 return false; 333 } 334 335 public void constant(ConstantOpWrapper op) { 336 mov().resultType(op.resultType(), false).space().resultReg(op, getResultType(op.resultType())).commaSpace(); 337 if (op.resultType().toString().equals("float")) { 338 if (op.op().value().toString().equals("0.0")) { 339 floatVal("00000000"); 340 } else { 341 floatVal(Integer.toHexString(Float.floatToIntBits(Float.parseFloat(op.op().value().toString()))).toUpperCase()); 342 } 343 } else { 344 append(op.op().value().toString()); 345 } 346 } 347 348 public void javaYield(YieldOpWrapper op) { 349 exit(); 350 } 351 352 // S32Array and S32Array2D functions can be deleted after schema is done 353 public void methodCall(InvokeOpWrapper op) { 354 switch (op.methodRef().toString()) { 355 // S32Array functions 356 case "hat.buffer.S32Array::array(long)int" -> { 357 PTXRegister temp = new PTXRegister(incrOrdinal(addressType()), addressType()); 358 add().s64().space().regName(temp).commaSpace().reg(op.operandNAsValue(0)).commaSpace().reg(op.operandNAsValue(1)).ptxNl(); 359 ld().global().u32().space().resultReg(op, PTXRegister.Type.U32).commaSpace().address(temp.name(), 4); 360 } 361 case "hat.buffer.S32Array::array(long, int)void" -> { 362 PTXRegister temp = new PTXRegister(incrOrdinal(addressType()), addressType()); 363 add().s64().space().regName(temp).commaSpace().reg(op.operandNAsValue(0)).commaSpace().reg(op.operandNAsValue(1)).ptxNl(); 364 st().global().u32().space().address(temp.name(), 4).commaSpace().reg(op.operandNAsValue(2)); 365 } 366 case "hat.buffer.S32Array::length()int" -> { 367 ld().global().u32().space().resultReg(op, PTXRegister.Type.U32).commaSpace().address(getReg(op.operandNAsValue(0)).name()); 368 } 369 // S32Array2D functions 370 case "hat.buffer.S32Array2D::array(long, int)void" -> { 371 PTXRegister temp = new PTXRegister(incrOrdinal(addressType()), addressType()); 372 add().s64().space().regName(temp).commaSpace().reg(op.operandNAsValue(0)).commaSpace().reg(op.operandNAsValue(1)).ptxNl(); 373 st().global().u32().space().address(temp.name(), 8).commaSpace().reg(op.operandNAsValue(2)); 374 } 375 case "hat.buffer.S32Array2D::width()int" -> { 376 ld().global().u32().space().resultReg(op, PTXRegister.Type.U32).commaSpace().address(getReg(op.operandNAsValue(0)).name()); 377 } 378 case "hat.buffer.S32Array2D::height()int" -> { 379 ld().global().u32().space().resultReg(op, PTXRegister.Type.U32).commaSpace().address(getReg(op.operandNAsValue(0)).name(), 4); 380 } 381 // Java Math function 382 case "java.lang.Math::sqrt(double)double" -> { 383 sqrt().rn().f64().space().resultReg(op, PTXRegister.Type.F64).commaSpace().reg(op.operandNAsValue(0)).semicolon(); 384 } 385 default -> { 386 obrace().nl().ptxIndent(); 387 for (int i = 0; i < op.operands().size(); i++) { 388 dot().param().space().paramType(op.operandNAsValue(i).type()).space().param().intVal(i).ptxNl(); 389 st().dot().param().paramType(op.operandNAsValue(i).type()).space().osbrace().param().intVal(i).csbrace().commaSpace().reg(op.operandNAsValue(i)).ptxNl(); 390 } 391 dot().param().space().paramType(op.resultType()).space().retVal().ptxNl(); 392 call().uni().space().oparen().retVal().cparen().commaSpace().append(op.method().getName()).commaSpace(); 393 final int[] counter = {0}; 394 paren(_ -> commaSeparated(op.operands(), _ -> param().intVal(counter[0]++))).ptxNl(); 395 ld().dot().param().paramType(op.resultType()).space().resultReg(op, getResultType(op.resultType())).commaSpace().osbrace().retVal().csbrace(); 396 ptxNl().cbrace(); 397 } 398 } 399 } 400 401 public void varDeclaration(VarDeclarationOpWrapper op) { 402 ld().dot().param().resultType(op.resultType(), false).space().resultReg(op, addressType()).commaSpace().reg(op.operandNAsValue(0)); 403 } 404 405 public void varFuncDeclaration(VarFuncDeclarationOpWrapper op) { 406 ld().dot().param().resultType(op.resultType(), false).space().resultReg(op, addressType()).commaSpace().reg(op.operandNAsValue(0)); 407 } 408 409 public void ret(ReturnOpWrapper op) { 410 if (op.hasOperands()) { 411 st().dot().param(); 412 if (returnReg.type().equals(PTXRegister.Type.U32)) { 413 b32(); 414 } else if (returnReg.type().equals(PTXRegister.Type.U64)) { 415 b64(); 416 } else { 417 dot().regType(returnReg.type()); 418 } 419 space().osbrace().regName(returnReg).csbrace().commaSpace().reg(op.operandNAsValue(0)).ptxNl(); 420 } 421 ret(); 422 } 423 424 public void javaBreak(JavaBreakOpWrapper op) { 425 brkpt(); 426 } 427 428 public void branch(CoreOp.BranchOp op) { 429 loadBlockParams(op.successors().getFirst()); 430 bra().space().block(op.successors().getFirst().targetBlock()); 431 } 432 433 public void condBranch(CoreOp.ConditionalBranchOp op) { 434 loadBlockParams(op.successors().getFirst()); 435 loadBlockParams(op.successors().getLast()); 436 at().reg(op.operands().getFirst()).space() 437 .bra().space().block(op.successors().getFirst().targetBlock()).ptxNl(); 438 bra().space().block(op.successors().getLast().targetBlock()); 439 } 440 441 public void neg(CoreOp.NegOp op) { 442 neg().resultType(op.resultType(), true).space().reg(op.result(), getResultType(op.resultType())).commaSpace().reg(op.operands().getFirst()); 443 } 444 445 /* 446 * Helper functions for printing blocks and variables 447 */ 448 449 public void loadBlockParams(Block.Reference block) { 450 for (int i = 0; i < block.arguments().size(); i++) { 451 Block.Parameter p = block.targetBlock().parameters().get(i); 452 mov().resultType(p.type(), false).space().reg(p, getResultType(p.type())) 453 .commaSpace().reg(block.arguments().get(i)).ptxNl(); 454 } 455 } 456 457 public PTXHATKernelBuilder block(Block block) { 458 return append("block_").intVal(block.index()); 459 } 460 461 public PTXHATKernelBuilder fieldReg(Field ref) { 462 if (fieldToRegMap.containsKey(ref)) { 463 return regName(fieldToRegMap.get(ref)); 464 } 465 if (ref.isDestination()) { 466 fieldToRegMap.putIfAbsent(ref, new PTXRegister(incrOrdinal(addressType()), addressType())); 467 } else { 468 fieldToRegMap.putIfAbsent(ref, new PTXRegister(incrOrdinal(PTXRegister.Type.U32), PTXRegister.Type.U32)); 469 } 470 return regName(fieldToRegMap.get(ref)); 471 } 472 473 public PTXHATKernelBuilder fieldReg(Field ref, Value value) { 474 if (fieldToRegMap.containsKey(ref)) { 475 return regName(fieldToRegMap.get(ref)); 476 } 477 if (ref.isDestination()) { 478 fieldToRegMap.putIfAbsent(ref, new PTXRegister(getOrdinal(addressType()), addressType())); 479 return reg(value, addressType()); 480 } else { 481 fieldToRegMap.putIfAbsent(ref, new PTXRegister(getOrdinal(PTXRegister.Type.U32), PTXRegister.Type.U32)); 482 return reg(value, PTXRegister.Type.U32); 483 } 484 } 485 486 public Field getFieldObj(String fieldName) { 487 for (Field f : fieldToRegMap.keySet()) { 488 if (f.toString().equals(fieldName)) return f; 489 } 490 throw new IllegalStateException("no existing field"); 491 } 492 493 public PTXHATKernelBuilder resultReg(OpWrapper<?> opWrapper, PTXRegister.Type type) { 494 return append(addReg(opWrapper.result(), type)); 495 } 496 497 public PTXHATKernelBuilder reg(Value val, PTXRegister.Type type) { 498 if (varToRegMap.containsKey(val)) { 499 return regName(getReg(val)); 500 } else { 501 return append(addReg(val, type)); 502 } 503 } 504 505 public PTXHATKernelBuilder reg(Value val) { 506 return regName(getReg(val)); 507 } 508 509 public PTXRegister getReg(Value val) { 510 if (varToRegMap.get(val) == null && val instanceof Op.Result result && result.op() instanceof CoreOp.FieldAccessOp.FieldLoadOp fieldLoadOp) { 511 return fieldToRegMap.get(getFieldObj(fieldLoadOp.fieldDescriptor().name())); 512 } 513 if (varToRegMap.containsKey(val)) { 514 return varToRegMap.get(val); 515 } else { 516 throw new IllegalStateException("var to reg mapping doesn't exist"); 517 } 518 } 519 520 public String addReg(Value val, PTXRegister.Type type) { 521 if (varToRegMap.containsKey(val)) { 522 return varToRegMap.get(val).name(); 523 } 524 varToRegMap.put(val, new PTXRegister(incrOrdinal(type), type)); 525 return varToRegMap.get(val).name(); 526 } 527 528 public Integer getOrdinal(PTXRegister.Type type) { 529 ordinalMap.putIfAbsent(type, 1); 530 return ordinalMap.get(type); 531 } 532 533 public Integer incrOrdinal(PTXRegister.Type type) { 534 ordinalMap.putIfAbsent(type, 1); 535 int out = ordinalMap.get(type); 536 ordinalMap.put(type, out + 1); 537 return out; 538 } 539 540 public PTXHATKernelBuilder size() { 541 return (addressSize == 32) ? u32() : u64(); 542 } 543 544 public PTXRegister.Type addressType() { 545 return (addressSize == 32) ? PTXRegister.Type.U32 : PTXRegister.Type.U64; 546 } 547 548 public PTXHATKernelBuilder resultType(TypeElement type, boolean signedResult) { 549 PTXRegister.Type res = getResultType(type); 550 if (signedResult && (res == PTXRegister.Type.U32)) return s32(); 551 return dot().append(getResultType(type).getName()); 552 } 553 554 public PTXHATKernelBuilder paramType(TypeElement type) { 555 PTXRegister.Type res = getResultType(type); 556 if (res == PTXRegister.Type.U32) return b32(); 557 if (res == PTXRegister.Type.U64) return b64(); 558 return dot().append(getResultType(type).getName()); 559 } 560 561 public PTXRegister.Type getResultType(TypeElement type) { 562 switch (type.toString()) { 563 case "float" -> { 564 return PTXRegister.Type.F32; 565 } 566 case "double" -> { 567 return PTXRegister.Type.F64; 568 } 569 case "int" -> { 570 return PTXRegister.Type.U32; 571 } 572 case "boolean" -> { 573 return PTXRegister.Type.PREDICATE; 574 } 575 default -> { 576 return PTXRegister.Type.U64; 577 } 578 } 579 } 580 581 /* 582 * Basic CodeBuilder functions 583 */ 584 585 // used for parameter list 586 // prints out items separated by a comma then new line 587 public <I> PTXHATKernelBuilder commaNlSeparated(Iterable<I> iterable, Consumer<I> c) { 588 StreamCounter.of(iterable, (counter, t) -> { 589 if (counter.isNotFirst()) { 590 comma().nl(); 591 } 592 c.accept(t); 593 }); 594 return self(); 595 } 596 597 public PTXHATKernelBuilder address(String address) { 598 return osbrace().append(address).csbrace(); 599 } 600 601 public PTXHATKernelBuilder address(String address, int offset) { 602 osbrace().append(address); 603 if (offset == 0) { 604 return csbrace(); 605 } else if (offset > 0) { 606 plus(); 607 } 608 return intVal(offset).csbrace(); 609 } 610 611 public PTXHATKernelBuilder ptxNl() { 612 return semicolon().nl().ptxIndent(); 613 } 614 615 public PTXHATKernelBuilder commaSpace() { 616 return comma().space(); 617 } 618 619 public PTXHATKernelBuilder param() { 620 return append("param"); 621 } 622 623 public PTXHATKernelBuilder global() { 624 return dot().append("global"); 625 } 626 627 public PTXHATKernelBuilder rn() { 628 return dot().append("rn"); 629 } 630 631 public PTXHATKernelBuilder rm() { 632 return dot().append("rm"); 633 } 634 635 public PTXHATKernelBuilder rzi() { 636 return dot().append("rzi"); 637 } 638 639 public PTXHATKernelBuilder to() { 640 return dot().append("to"); 641 } 642 643 public PTXHATKernelBuilder lo() { 644 return dot().append("lo"); 645 } 646 647 public PTXHATKernelBuilder wide() { 648 return dot().append("wide"); 649 } 650 651 public PTXHATKernelBuilder uni() { 652 return dot().append("uni"); 653 } 654 655 public PTXHATKernelBuilder sat() { 656 return dot().append("sat"); 657 } 658 659 public PTXHATKernelBuilder ftz() { 660 return dot().append("ftz"); 661 } 662 663 public PTXHATKernelBuilder approx() { 664 return dot().append("approx"); 665 } 666 667 public PTXHATKernelBuilder mov() { 668 return append("mov"); 669 } 670 671 public PTXHATKernelBuilder setp() { 672 return append("setp"); 673 } 674 675 public PTXHATKernelBuilder selp() { 676 return append("selp"); 677 } 678 679 public PTXHATKernelBuilder ld() { 680 return append("ld"); 681 } 682 683 public PTXHATKernelBuilder st() { 684 return append("st"); 685 } 686 687 public PTXHATKernelBuilder cvt() { 688 return append("cvt"); 689 } 690 691 public PTXHATKernelBuilder bra() { 692 return append("bra"); 693 } 694 695 public PTXHATKernelBuilder ret() { 696 return append("ret"); 697 } 698 699 public PTXHATKernelBuilder rem() { 700 return append("rem"); 701 } 702 703 public PTXHATKernelBuilder mul() { 704 return append("mul"); 705 } 706 707 public PTXHATKernelBuilder div() { 708 return append("div"); 709 } 710 711 public PTXHATKernelBuilder rcp() { 712 return append("rcp"); 713 } 714 715 public PTXHATKernelBuilder add() { 716 return append("add"); 717 } 718 719 public PTXHATKernelBuilder sub() { 720 return append("sub"); 721 } 722 723 public PTXHATKernelBuilder lt() { 724 return append("lt"); 725 } 726 727 public PTXHATKernelBuilder gt() { 728 return append("gt"); 729 } 730 731 public PTXHATKernelBuilder le() { 732 return append("le"); 733 } 734 735 public PTXHATKernelBuilder ge() { 736 return append("ge"); 737 } 738 739 public PTXHATKernelBuilder geu() { 740 return append("geu"); 741 } 742 743 public PTXHATKernelBuilder ne() { 744 return append("ne"); 745 } 746 747 public PTXHATKernelBuilder eq() { 748 return append("eq"); 749 } 750 751 public PTXHATKernelBuilder xor() { 752 return append("xor"); 753 } 754 755 public PTXHATKernelBuilder or() { 756 return append("or"); 757 } 758 759 public PTXHATKernelBuilder and() { 760 return append("and"); 761 } 762 763 public PTXHATKernelBuilder cvta() { 764 return append("cvta"); 765 } 766 767 public PTXHATKernelBuilder mad() { 768 return append("mad"); 769 } 770 771 public PTXHATKernelBuilder fma() { 772 return append("fma"); 773 } 774 775 public PTXHATKernelBuilder sqrt() { 776 return append("sqrt"); 777 } 778 779 public PTXHATKernelBuilder abs() { 780 return append("abs"); 781 } 782 783 public PTXHATKernelBuilder ex2() { 784 return append("ex2"); 785 } 786 787 public PTXHATKernelBuilder shl() { 788 return append("shl"); 789 } 790 791 public PTXHATKernelBuilder shr() { 792 return append("shr"); 793 } 794 795 public PTXHATKernelBuilder neg() { 796 return append("neg"); 797 } 798 799 public PTXHATKernelBuilder call() { 800 return append("call"); 801 } 802 803 public PTXHATKernelBuilder exit() { 804 return append("exit"); 805 } 806 807 public PTXHATKernelBuilder brkpt() { 808 return append("brkpt"); 809 } 810 811 public PTXHATKernelBuilder ptxIndent() { 812 return append(" "); 813 } 814 815 public PTXHATKernelBuilder u32() { 816 return dot().append(PTXRegister.Type.U32.getName()); 817 } 818 819 public PTXHATKernelBuilder s32() { 820 return dot().append(PTXRegister.Type.S32.getName()); 821 } 822 823 public PTXHATKernelBuilder f32() { 824 return dot().append(PTXRegister.Type.F32.getName()); 825 } 826 827 public PTXHATKernelBuilder b32() { 828 return dot().append(PTXRegister.Type.B32.getName()); 829 } 830 831 public PTXHATKernelBuilder u64() { 832 return dot().append(PTXRegister.Type.U64.getName()); 833 } 834 835 public PTXHATKernelBuilder s64() { 836 return dot().append(PTXRegister.Type.S64.getName()); 837 } 838 839 public PTXHATKernelBuilder f64() { 840 return dot().append(PTXRegister.Type.F64.getName()); 841 } 842 843 public PTXHATKernelBuilder b64() { 844 return dot().append(PTXRegister.Type.B64.getName()); 845 } 846 847 public PTXHATKernelBuilder version() { 848 return dot().append("version"); 849 } 850 851 public PTXHATKernelBuilder target() { 852 return dot().append("target"); 853 } 854 855 public PTXHATKernelBuilder addressSize() { 856 return dot().append("address_size"); 857 } 858 859 public PTXHATKernelBuilder major(int major) { 860 return intVal(major); 861 } 862 863 public PTXHATKernelBuilder minor(int minor) { 864 return intVal(minor); 865 } 866 867 public PTXHATKernelBuilder target(String target) { 868 return append(target); 869 } 870 871 public PTXHATKernelBuilder size(int addressSize) { 872 return intVal(addressSize); 873 } 874 875 public PTXHATKernelBuilder funcName(String funcName) { 876 return append(funcName); 877 } 878 879 public PTXHATKernelBuilder visible() { 880 return dot().append("visible"); 881 } 882 883 public PTXHATKernelBuilder entry() { 884 return dot().append("entry"); 885 } 886 887 public PTXHATKernelBuilder func() { 888 return dot().append("func"); 889 } 890 891 public PTXHATKernelBuilder oabrace() { 892 return append("<"); 893 } 894 895 public PTXHATKernelBuilder cabrace() { 896 return append(">"); 897 } 898 899 public PTXHATKernelBuilder regName(PTXRegister reg) { 900 return append(reg.name()); 901 } 902 903 public PTXHATKernelBuilder regName(String regName) { 904 return append(regName); 905 } 906 907 public PTXHATKernelBuilder regType(Value val) { 908 return append(getReg(val).type().getName()); 909 } 910 911 public PTXHATKernelBuilder regType(PTXRegister.Type t) { 912 return append(t.getName()); 913 } 914 915 public PTXHATKernelBuilder regTypePrefix(PTXRegister.Type t) { 916 return append(t.getRegPrefix()); 917 } 918 919 public PTXHATKernelBuilder reg() { 920 return dot().append("reg"); 921 } 922 923 public PTXHATKernelBuilder retVal() { 924 return append("retval"); 925 } 926 927 public PTXHATKernelBuilder temp() { 928 return append("temp"); 929 } 930 931 public PTXHATKernelBuilder intVal(int i) { 932 return append(String.valueOf(i)); 933 } 934 935 public PTXHATKernelBuilder floatVal(String s) { 936 return append("0f").append(s); 937 } 938 939 public PTXHATKernelBuilder doubleVal(String s) { 940 return append("0d").append(s); 941 } 942 }