1 /* 2 * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved. 3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. 4 * 5 * This code is free software; you can redistribute it and/or modify it 6 * under the terms of the GNU General Public License version 2 only, as 7 * published by the Free Software Foundation. Oracle designates this 8 * particular file as subject to the "Classpath" exception as provided 9 * by Oracle in the LICENSE file that accompanied this code. 10 * 11 * This code is distributed in the hope that it will be useful, but WITHOUT 12 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or 13 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License 14 * version 2 for more details (a copy is included in the LICENSE file that 15 * accompanied this code). 16 * 17 * You should have received a copy of the GNU General Public License version 18 * 2 along with this work; if not, write to the Free Software Foundation, 19 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. 20 * 21 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA 22 * or visit www.oracle.com if you need additional information or have any 23 * questions. 24 */ 25 26 package oracle.code.triton; 27 28 import java.lang.invoke.MethodHandle; 29 import java.lang.invoke.MethodHandles; 30 import java.lang.reflect.Field; 31 import java.lang.reflect.Method; 32 import java.lang.reflect.code.*; 33 import java.lang.reflect.code.analysis.SSA; 34 import java.lang.reflect.code.op.CoreOp; 35 import java.lang.reflect.code.op.ExtendedOp; 36 import java.lang.reflect.code.type.JavaType; 37 import java.lang.reflect.code.type.VarType; 38 import java.util.*; 39 import java.util.concurrent.atomic.AtomicInteger; 40 import java.util.stream.Stream; 41 42 import static java.lang.reflect.code.op.CoreOp.*; 43 import static java.lang.reflect.code.type.FunctionType.functionType; 44 45 public final class TritonTransformer { 46 private TritonTransformer() {} 47 48 static final JavaType TYPE_Triton = JavaType.type(Triton.class); 49 50 static final JavaType TYPE_Triton_Test = JavaType.ofString("oracle.code.triton.TritonTest"); 51 52 static final JavaType TYPE_Tensor = JavaType.type(Tensor.class); 53 54 static final JavaType TYPE_J_L_MATH = JavaType.type(Math.class); 55 56 public static <O extends Op & Op.Invokable> 57 TritonOps.ModuleOp tritonModule(O kernel, 58 TypeElement rType, 59 List<? extends TypeElement> argTypes) { 60 Map<String, TritonOps.FuncOp> fsymTable = new LinkedHashMap<>(); 61 tritonFunction(kernel, rType, argTypes, fsymTable); 62 return TritonOps.module(fsymTable.values().stream().toList()); 63 } 64 65 public static <O extends Op & Op.Invokable> 66 TritonOps.FuncOp tritonFunction(O javaKernel, 67 TypeElement rType, 68 List<? extends TypeElement> argTypes, 69 Map<String, TritonOps.FuncOp> fsymTable) { 70 String name = (javaKernel instanceof FuncOp f) ? f.funcName() : "kernel"; 71 String signature = signature(name, rType, argTypes); 72 if (fsymTable.containsKey(signature)) { 73 return fsymTable.get(signature); 74 } 75 76 System.out.println(javaKernel.toText()); 77 78 Map<Value, TypeElement> valueTypeMap = new HashMap<>(); 79 Map<Op, Object> opData = new HashMap<>(); 80 TritonTransformer.typeCheckKernel(javaKernel, argTypes, valueTypeMap, opData); 81 TritonTransformer.printTypeMap(javaKernel, valueTypeMap); 82 83 return TritonTransformer.transformToTritonFunction(javaKernel, signature, 84 rType, valueTypeMap, opData, 85 fsymTable); 86 } 87 88 static String signature(String name, TypeElement rType, List<? extends TypeElement> argTypes) { 89 StringBuilder sb = new StringBuilder(name); 90 91 for (TypeElement argType : argTypes) { 92 sb.append("_"); 93 if (argType instanceof ConstantType ct) { 94 sb.append(ct.value()); 95 } else { 96 sb.append(argType); 97 } 98 } 99 sb.append("_"); 100 sb.append(rType); 101 return sb.toString(); 102 } 103 104 public static <O extends Op & Op.Invokable> void typeCheckKernel( 105 O kernel, List<? extends TypeElement> argTypes, 106 Map<Value, TypeElement> valueTypeMap, Map<Op, Object> opData) { 107 kernel.traverse(null, CodeElement.opVisitor((o, op) -> { 108 switch (op) { 109 case Op.Invokable fop -> { 110 List<Block.Parameter> parameters = fop.body().entryBlock().parameters(); 111 for (int i = 0; i < parameters.size(); i++) { 112 valueTypeMap.put(parameters.get(i), argTypes.get(i)); 113 } 114 } 115 case VarOp _, VarAccessOp.VarLoadOp _ -> { 116 Value init = op.operands().get(0); 117 valueTypeMap.put(op.result(), valueTypeMap.get(init)); 118 } 119 case VarAccessOp.VarStoreOp _ -> { 120 Value var = op.operands().get(0); 121 TypeElement varType = valueTypeMap.get(var); 122 Value v = op.operands().get(1); 123 TypeElement vType = valueTypeMap.get(v); 124 if (!varType.equals(vType)) { 125 throw new IllegalStateException("Storing to variable with different type: " 126 + varType + " <- " + vType); 127 } 128 129 valueTypeMap.put(op.result(), valueTypeMap.get(var)); 130 } 131 case ConstantOp cop -> { 132 valueTypeMap.put(op.result(), new ConstantType(op.result().type(), cop.value())); 133 } 134 case ArithmeticOperation _ -> { 135 TypeElement t = checkWithTypeInterpreter(op, op.opName(), valueTypeMap); 136 valueTypeMap.put(op.result(), t); 137 } 138 case FieldAccessOp.FieldLoadOp flop -> { 139 if (!flop.operands().isEmpty()) { 140 throw new IllegalStateException("Unsupported field load: " + flop.fieldDescriptor()); 141 } 142 143 Field f; 144 try { 145 f = flop.fieldDescriptor().resolveToMember(MethodHandles.lookup()); 146 } catch (ReflectiveOperationException e) { 147 throw new IllegalStateException("Unsupported field load: " + flop.fieldDescriptor(), e); 148 } 149 Object value; 150 try { 151 value = f.get(null); 152 } catch (IllegalAccessException e) { 153 throw new IllegalStateException("Unsupported field load: " + f, e); 154 } 155 valueTypeMap.put(op.result(), new ConstantType(JavaType.type(f.getType()), value)); 156 } 157 case InvokeOp iop when iop.invokeDescriptor().refType().equals(JavaType.J_L_INTEGER) -> { 158 // Box 159 if (iop.invokeDescriptor().name().equals("valueOf")) { 160 Value a = op.operands().get(0); 161 valueTypeMap.put(op.result(), valueTypeMap.get(a)); 162 } else { 163 throw new UnsupportedOperationException("Unsupported invocation on Integer: " + iop.invokeDescriptor()); 164 } 165 } 166 case InvokeOp iop when iop.invokeDescriptor().refType().equals(TYPE_J_L_MATH) -> { 167 String name = iop.invokeDescriptor().name(); 168 if (name.equals("max") || name.equals("min")) { 169 Value a = op.operands().get(0); 170 valueTypeMap.put(op.result(), valueTypeMap.get(a)); 171 } else { 172 throw new UnsupportedOperationException("Unsupported invocation on Math: " + iop.invokeDescriptor()); 173 } 174 } 175 case InvokeOp iop when iop.invokeDescriptor().refType().equals(TYPE_Tensor) -> { 176 if (iop.invokeDescriptor().name().equals("type")) { 177 Value a = op.operands().get(0); 178 valueTypeMap.put(op.result(), valueTypeMap.get(a)); 179 } else { 180 throw new UnsupportedOperationException("Unsupported invocation on Tensor: " + iop.invokeDescriptor()); 181 } 182 } 183 case InvokeOp iop when iop.invokeDescriptor().refType().equals(TYPE_Triton) -> { 184 TypeElement t = checkWithTypeInterpreter(op, iop.invokeDescriptor().name(), valueTypeMap); 185 valueTypeMap.put(op.result(), t); 186 } 187 case InvokeOp iop when iop.invokeDescriptor().refType().equals(TYPE_Triton_Test) -> { 188 TypeElement t = checkWithTypeInterpreter(op, iop.invokeDescriptor().name(), valueTypeMap); 189 valueTypeMap.put(op.result(), t); 190 } 191 case ExtendedOp.JavaForOp fop -> { 192 SimpleCountedForLoopInfo li = new SimpleCountedForLoopInfo(fop); 193 opData.put(fop, li); 194 195 TypeElement type = fop.init().yieldType(); 196 if (type instanceof VarType vt && vt.valueType().equals(JavaType.INT)) { 197 for (Body b : List.of(fop.cond(), fop.update(), fop.loopBody())) { 198 valueTypeMap.put(b.entryBlock().parameters().get(0), JavaType.INT); 199 } 200 } else { 201 throw new IllegalStateException(); 202 } 203 } 204 case TestOperation _ -> { 205 } 206 case ExtendedOp.JavaContinueOp _ -> { 207 } 208 case YieldOp _ -> { 209 } 210 case ReturnOp _ -> { 211 } 212 default -> throw new UnsupportedOperationException("Unsupported operation: " + op); 213 } 214 215 return null; 216 })); 217 } 218 219 static TypeElement checkWithTypeInterpreter(Op op, String name, Map<Value, TypeElement> valueTypeMap) { 220 // Obtain associated type-based method 221 MethodHandle mh; 222 try { 223 Optional<Method> om = Stream.of(TritonTypeInterpreter.class.getDeclaredMethods()) 224 .filter(m -> m.getName().equals(name)) 225 .findFirst(); 226 mh = MethodHandles.lookup().unreflect( 227 om.orElseThrow(() -> new NoSuchMethodException(name))); 228 } catch (ReflectiveOperationException e) { 229 throw new IllegalStateException(name, e); 230 } 231 232 // Invoke with the values' types 233 List<TypeElement> operandTypes = op.operands().stream().map(valueTypeMap::get).toList(); 234 try { 235 return (TypeElement) mh.invokeWithArguments(operandTypes.toArray(Object[]::new)); 236 } catch (Throwable e) { 237 throw new IllegalStateException(mh.toString(), e); 238 } 239 } 240 241 // @@@ type check tensor shapes 242 static class TritonTypeInterpreter { 243 private TritonTypeInterpreter() { 244 } 245 246 // int programId(@Constant int axis) { 247 public static JavaType programId(ConstantType axis) { 248 assert axis.cType().equals(JavaType.INT); 249 int axisValue = (int) axis.value(); 250 if (axisValue < 0 || axisValue > 3) { 251 throw new IllegalStateException(); 252 } 253 254 return JavaType.INT; 255 } 256 257 // Tensor arange(@Constant int start, @Constant int end) 258 public static TensorType arange(ConstantType start, ConstantType end) { 259 assert start.cType().equals(JavaType.INT); 260 assert end.cType().equals(JavaType.INT); 261 262 int startValue = (int) start.value(); 263 int endValue = (int) end.value(); 264 265 return new TensorType(JavaType.INT, List.of(endValue - startValue)); 266 } 267 268 // Tensor expand(Tensor a, int axis) { 269 public static TensorType expand(TensorType a, ConstantType axis) { 270 assert axis.cType().equals(JavaType.INT); 271 int axisValue = (int) axis.value(); 272 273 List<Integer> s = new ArrayList<>(a.shape()); 274 if (axisValue < s.size()) { 275 s.add(axisValue, 1); 276 } else { 277 for (int i = 0; i <= (axisValue - s.size()); i++) { 278 s.add(1); 279 } 280 } 281 return new TensorType(a.eType(), s); 282 } 283 284 // Tensor load(Tensor ptr, Tensor mask) 285 public static TensorType load(TensorType ptr, TensorType mask) { 286 checkTensorShape(ptr, mask); 287 if (ptr.eType() instanceof PtrType eptr) { 288 return new TensorType(eptr.rType(), ptr.shape()); 289 } 290 291 throw new IllegalStateException(); 292 } 293 294 // void store(Tensor ptr, Tensor value, Tensor mask) 295 public static void store(TensorType ptr, TensorType value, TensorType mask) { 296 if (!(ptr.eType() instanceof PtrType)) { 297 throw new IllegalStateException(); 298 } 299 } 300 301 // Tensor zeros(TensorType type) 302 public static TensorType zeros(ConstantType eType, ConstantType... cShape) { 303 List<Integer> shape = Stream.of(cShape).map(s -> (int) s.value()).toList(); 304 return new TensorType((TypeElement) eType.value(), shape); 305 } 306 307 // Tensor broadcast(Object o, TensorType type) 308 public static TensorType broadcast(TypeElement o, TensorType type) { 309 if (o instanceof TensorType ot) { 310 // @@@ 311 if (ot.shape().size() != type.shape().size()) { 312 throw new IllegalStateException(); 313 } 314 o = ot.eType(); 315 } if (o instanceof ConstantType oc) { 316 o = oc.cType(); 317 } 318 return new TensorType(o, type.shape()); 319 } 320 321 public static TensorType joinShape(TensorType a, TensorType b) { 322 return checkTensorTypes(a, b); 323 } 324 325 // Tensor add(Number a, Number b) 326 // Ptr add(Ptr a, int offset) 327 public static TypeElement add(TypeElement a, TypeElement b) { 328 // @@@ Pass additional argument for checking ptr 329 return binary(a, b); 330 } 331 332 public static TypeElement sub(TypeElement a, TypeElement b) { 333 return binary(a, b); 334 } 335 336 public static TypeElement mul(TypeElement a, TypeElement b) { 337 return binary(a, b); 338 } 339 340 public static TypeElement div(TypeElement a, TypeElement b) { 341 return binary(a, b); 342 } 343 344 public static TypeElement mod(TypeElement a, TypeElement b) { 345 return binary(a, b); 346 } 347 348 public static TypeElement and(TypeElement a, TypeElement b) { 349 return binary(a, b); 350 } 351 352 public static TypeElement cdiv(TypeElement a, TypeElement b) { 353 a = reduceScalarType(a); 354 b = reduceScalarType(b); 355 if (!a.equals(JavaType.INT) && !b.equals(JavaType.INT)) { 356 throw new IllegalStateException(); 357 } 358 return a; 359 } 360 361 // Number conv(Type t, Number a) { 362 public static TypeElement conv(ConstantType eType, TypeElement a) { 363 return convTypes(eType, a); 364 } 365 366 public static TypeElement convTypes(ConstantType eType, TypeElement a) { 367 if (a instanceof TensorType tb) { 368 TypeElement e = convScalarTypes(eType, tb.eType()); 369 return new TensorType(e, tb.shape()); 370 } else { 371 return convScalarTypes(eType, a); 372 } 373 } 374 375 public static TypeElement convScalarTypes(ConstantType eType, TypeElement a) { 376 TypeElement t = (TypeElement) eType.value(); 377 if (t.equals(Float16.FLOAT_16_TYPE) && a.equals(JavaType.FLOAT)) { 378 return Float16.FLOAT_16_TYPE; 379 } else if (t.equals(a)) { 380 return t; 381 } else { 382 // @@@ Conversions; 383 throw new IllegalStateException(); 384 } 385 } 386 387 // Tensor exp(Tensor a) 388 public static TypeElement exp(TypeElement a) { 389 return unary(a); 390 } 391 392 static TypeElement unary(TypeElement a) { 393 return a; 394 } 395 396 // Tensor compare(Number a, Number b, @Constant CompareKind ck) { 397 public static TypeElement compare(TypeElement a, TypeElement b, ConstantType kind) { 398 assert kind.cType().equals(JavaType.type(Triton.CompareKind.class)); 399 400 return binary(a, b); 401 } 402 403 // Tensor dot(Tensor a, Tensor b) 404 public static TensorType dot(TensorType a, TensorType b) { 405 if (a.shape().size() != 2 || b.shape().size() != 2) { 406 throw new IllegalStateException(); 407 } 408 409 if (!a.shape().get(1).equals(b.shape().get(0))) { 410 throw new IllegalStateException(); 411 } 412 413 if (a.eType() != b.eType()) { 414 // @@@ Conversion, type checking 415 throw new IllegalStateException(); 416 } 417 418 // Computed result is tensor of floats, regardless of inputs 419 return new TensorType(JavaType.FLOAT, List.of(a.shape().get(0), b.shape().get(1))); 420 } 421 422 423 // Tensor max(Tensor a, @Constant int axis) { 424 public static TypeElement max(TensorType a, ConstantType axis) { 425 return reduce(a, axis); 426 } 427 428 // Tensor sum(Tensor a, @Constant int axis) { 429 public static TypeElement sum(TensorType a, ConstantType axis) { 430 return reduce(a, axis); 431 } 432 433 static TypeElement reduce(TensorType a, ConstantType axis) { 434 assert axis.cType().equals(JavaType.INT); 435 int axisValue = (int) axis.value(); 436 if (axisValue < 0 || axisValue > 3) { 437 throw new IllegalStateException(); 438 } 439 440 List<Integer> reduceShape = new ArrayList<>(); 441 for (int i = 0; i < a.shape().size(); i++) { 442 if (i != axisValue) { 443 reduceShape.add(a.shape().get(i)); 444 } else { 445 reduceShape.add(1); 446 } 447 } 448 449 if (reduceShape.size() == 1 && reduceShape.getFirst() == 1) { 450 return a.eType(); 451 } else { 452 return new TensorType(a.eType(), reduceShape); 453 } 454 } 455 456 // @@@ Test 457 public static void consume(TypeElement a) { 458 } 459 460 461 static TypeElement binary(TypeElement a, TypeElement b) { 462 if (a instanceof TensorType ta && b instanceof TensorType tb) { 463 return checkTensorTypes(ta, tb); 464 } else if (a instanceof TensorType ta) { 465 return new TensorType(checkScalarTypes(ta.eType(), b), ta.shape()); 466 } else if (b instanceof TensorType tb) { 467 return new TensorType(checkScalarTypes(a, tb.eType()), tb.shape()); 468 } else { 469 return checkScalarTypes(a, b); 470 } 471 } 472 473 static TensorType checkTensorTypes(TensorType a, TensorType b) { 474 List<Integer> s = checkTensorShape(a, b); 475 TypeElement e = checkScalarTypes(a.eType(), b.eType()); 476 return new TensorType(e, s); 477 } 478 479 static List<Integer> checkTensorShape(TensorType a, TensorType b) { 480 if (a.shape().size() != b.shape().size()) { 481 // Shape mismatch 482 throw new IllegalStateException(); 483 } 484 485 List<Integer> s = new ArrayList<>(); 486 for (int i = 0; i < a.shape().size(); i++) { 487 int ad = a.shape().get(i); 488 int bd = b.shape().get(i); 489 490 // Expand dimensions 491 int d; 492 if (ad == bd) { 493 d = ad; 494 } else { 495 if (ad != 1 && bd == 1) { 496 d = ad; 497 } else if (ad == 1) { 498 d = bd; 499 } else { 500 // Shape mismatch 501 throw new IllegalStateException(); 502 } 503 } 504 505 s.add(d); 506 } 507 508 return s; 509 } 510 511 static TypeElement checkScalarTypes(TypeElement a, TypeElement b) { 512 // @@@ Optional ptr checking 513 if (a instanceof PtrType) { 514 if (!b.equals(JavaType.INT)) { 515 throw new IllegalStateException(); 516 } 517 } else if (b instanceof PtrType) { 518 // Pointer must be first argument 519 throw new IllegalStateException(); 520 } else if (a instanceof ConstantType || b instanceof ConstantType) { 521 return checkScalarTypes(reduceScalarType(a), reduceScalarType(b)); 522 } else if (!a.equals(b)) { 523 // @@@ Conversion 524 throw new IllegalStateException(); 525 } 526 return a; 527 } 528 529 static TypeElement reduceScalarType(TypeElement a) { 530 return a instanceof ConstantType ct ? ct.cType() : a; 531 } 532 } 533 534 public static <O extends Op & Op.Invokable> TritonOps.FuncOp transformToTritonFunction( 535 O kernel, 536 String signature, 537 TypeElement rType, 538 Map<Value, TypeElement> valueTypeMap, Map<Op, Object> opData, 539 Map<String, TritonOps.FuncOp> fsymTable) { 540 TritonOps.FuncOp ttKernel = TritonOps.func(signature, functionType(rType)) 541 .body(fblock -> { 542 // Process kernel parameters 543 List<Value> args = new ArrayList<>(); 544 for (Block.Parameter kp : kernel.body().entryBlock().parameters()) { 545 TypeElement type = valueTypeMap.get(kp); 546 if (type instanceof ConstantType ct) { 547 // Constant 548 Op.Result cr = fblock.op(ArithMathOps.constant( 549 ct.cType(), ct.value())); 550 args.add(cr); 551 } else { 552 args.add(fblock.parameter(type)); 553 } 554 } 555 556 // Transform kernel body 557 fblock.transformBody(kernel.body(), args, (kblock, op) -> { 558 return transformToTritonOperation(kblock, op, valueTypeMap, opData, fsymTable); 559 }); 560 }); 561 562 ttKernel = cleanup(ttKernel); 563 fsymTable.put(ttKernel.funcName(), ttKernel); 564 return ttKernel; 565 } 566 567 static Block.Builder transformToTritonOperation(Block.Builder kblock, Op op, 568 Map<Value, TypeElement> valueTypeMap, Map<Op, Object> opData, 569 Map<String, TritonOps.FuncOp> fsymTable) { 570 // @@@ Avoid constructing for each operation -- block builder passed as argument or a scoped value 571 TritonBuilderInterpreter tbi = new TritonBuilderInterpreter(fsymTable, kblock); 572 CopyContext cc = kblock.context(); 573 switch (op) { 574 case VarOp varOp -> { 575 // @@@ Cannot copy op because the result type 576 // is derived from init type 577 Value init = cc.getValue(op.operands().get(0)); 578 Op.Result r = kblock.op(var(varOp.varName(), init)); 579 cc.mapValue(op.result(), r); 580 } 581 case ConstantOp cop -> { 582 TypeElement t = valueTypeMap.get(cop.result()); 583 if (t instanceof ConstantType ct) { 584 Op.Result r = kblock.op(ArithMathOps.constant( 585 ct.cType(), ct.value())); 586 cc.mapValue(op.result(), r); 587 } else { 588 kblock.op(op); 589 } 590 } 591 case ArithmeticOperation _ -> { 592 Value result = tbi.build(op, op.opName(), valueTypeMap); 593 if (result != null) { 594 cc.mapValue(op.result(), result); 595 } 596 } 597 case InvokeOp iop when iop.invokeDescriptor().refType().equals(JavaType.J_L_INTEGER) -> { 598 // Replace box with its value 599 Value a = cc.getValue(op.operands().get(0)); 600 cc.mapValue(op.result(), a); 601 } 602 case InvokeOp iop when iop.invokeDescriptor().refType().equals(TYPE_J_L_MATH) -> { 603 String name = iop.invokeDescriptor().name(); 604 if (name.equals("max")) { 605 Value a = cc.getValue(op.operands().get(0)); 606 Value b = cc.getValue(op.operands().get(1)); 607 608 Op.Result result = kblock.op(ArithMathOps.maximum(a, b)); 609 cc.mapValue(op.result(), result); 610 } else if (name.equals("min")) { 611 Value a = cc.getValue(op.operands().get(0)); 612 Value b = cc.getValue(op.operands().get(1)); 613 614 Op.Result result = kblock.op(ArithMathOps.minimum(a, b)); 615 cc.mapValue(op.result(), result); 616 } 617 } 618 case InvokeOp iop when iop.invokeDescriptor().refType().equals(TYPE_Tensor) -> { 619 if (iop.invokeDescriptor().name().equals("type")) { 620 // Replace with constant operation to produce tensor type. 621 // Result may be used, but transitively it will be removed due to no uses 622 // contributing to the computation 623 Value a = op.operands().get(0); 624 TensorType aType = (TensorType) valueTypeMap.get(a); 625 Op.Result result = kblock.op(CoreOp.constant(iop.resultType(), aType)); 626 cc.mapValue(op.result(), result); 627 valueTypeMap.put(result, aType); 628 } 629 // Remove 630 } 631 case InvokeOp iop when iop.invokeDescriptor().refType().equals(TYPE_Triton) -> { 632 Value result = tbi.build(op, iop.invokeDescriptor().name(), valueTypeMap); 633 if (result != null) { 634 cc.mapValue(op.result(), result); 635 } 636 } 637 case InvokeOp iop when iop.invokeDescriptor().refType().equals(TYPE_Triton_Test) -> { 638 Value result = tbi.build(op, iop.invokeDescriptor().name(), valueTypeMap); 639 if (result != null) { 640 cc.mapValue(op.result(), result); 641 } 642 } 643 case ExtendedOp.JavaForOp fop -> { 644 transformToSCFFor(cc, kblock, fop, valueTypeMap, opData, fsymTable); 645 } 646 case ReturnOp rop -> { 647 if (rop.operands().isEmpty()) { 648 kblock.op(TritonOps.return_()); 649 } else { 650 kblock.op(TritonOps.return_( 651 cc.getValue(rop.returnValue()))); 652 } 653 } 654 default -> kblock.op(op); 655 } 656 return kblock; 657 } 658 659 static void transformToSCFFor(CopyContext cc, Block.Builder kblock, ExtendedOp.JavaForOp fop, 660 Map<Value, TypeElement> valueTypeMap, Map<Op, Object> opData, 661 Map<String, TritonOps.FuncOp> fsymTable) { 662 Body body = fop.loopBody(); 663 664 // Hoist expressions for start, end, and step 665 SimpleCountedForLoopInfo li = (SimpleCountedForLoopInfo) opData.get(fop); 666 Value start = null; 667 for (Op o : li.startExpression()) { 668 transformToTritonOperation(kblock, o, valueTypeMap, opData, fsymTable); 669 start = cc.getValue(o.result()); 670 } 671 Value end = null; 672 for (Op o : li.endExpression()) { 673 transformToTritonOperation(kblock, o, valueTypeMap, opData, fsymTable); 674 end = cc.getValue(o.result()); 675 } 676 Value step = null; 677 for (Op o : li.stepExpression()) { 678 transformToTritonOperation(kblock, o, valueTypeMap, opData, fsymTable); 679 step = cc.getValue(o.result()); 680 } 681 682 // Obtain captured vars 683 // true == stores 684 // false == loads only 685 Map<Boolean, Set<Value>> capturedVars = capturedVars(body); 686 Set<Value> capturedAndStoredVars = capturedVars.get(true); 687 688 // Get load values 689 // Loaded values are hoisted out of the loop body 690 Map<Value, Value> loadValues = new HashMap<>(); 691 for (Value v : capturedVars.get(false)) { 692 Value load = kblock.op(varLoad(cc.getValue(v))); 693 valueTypeMap.put(load, valueTypeMap.get(v)); 694 loadValues.put(v, load); 695 } 696 697 // Get iteration values -- represented by captured vars that are stored to in the loop 698 // The SCF for operation returns the iteration values of the last loop iteration, which 699 // are then to be stored to the iteration variables 700 List<Value> iterValues = new ArrayList<>(); 701 for (Value v : capturedAndStoredVars) { 702 iterValues.add(kblock.op(varLoad(cc.getValue(v)))); 703 } 704 705 // @@@ Build in java code model, then transform? 706 SCFOps.ForOp scffor = SCFOps.for_(kblock.parentBody(), start, end, step, iterValues) 707 // Ensure existing context is used 708 .body(CopyContext.create(cc), builder -> { 709 // Create index var initialized from entry block parameter 710 Value index = builder.parameters().get(0); 711 valueTypeMap.put(index, JavaType.INT); 712 Value varIndex = builder.op(var("index", index)); 713 valueTypeMap.put(varIndex, JavaType.INT); 714 builder.context().mapValue(body.entryBlock().parameters().get(0), varIndex); 715 716 // Create iter vars initialized from entry block parameters 717 int pi = 1; 718 for (Value v : capturedAndStoredVars) { 719 TypeElement type = valueTypeMap.get(v); 720 Value iter = builder.parameters().get(pi++); 721 valueTypeMap.put(iter, type); 722 Value varIter = builder.op(var(Integer.toString(pi), iter)); 723 valueTypeMap.put(varIter, type); 724 builder.context().mapValue(v, varIter); 725 } 726 727 // Transform the Java for body into the SCF for body 728 builder.transformBody(body, List.of(), (block, op) -> { 729 // Yield iter values 730 if (op instanceof ExtendedOp.JavaContinueOp) { 731 // Replace with yield of loaded vars 732 List<Value> yieldValues = new ArrayList<>(); 733 for (Value value : capturedAndStoredVars) { 734 Value varIter = block.context().getValue(value); 735 Value v = block.op(varLoad(varIter)); 736 yieldValues.add(v); 737 } 738 block.op(SCFOps.yield_(yieldValues)); 739 } else if (op instanceof VarAccessOp.VarLoadOp) { 740 // Replace with value loaded immediately before loop 741 Value v = op.operands().get(0); 742 if (capturedVars.get(false).contains(v)) { 743 block.context().mapValue(op.result(), loadValues.get(v)); 744 } else { 745 block.op(op); 746 } 747 } else { 748 block = transformToTritonOperation(block, op, valueTypeMap, opData, fsymTable); 749 } 750 return block; 751 }); 752 }); 753 Op.Result forResult = kblock.op(scffor); 754 755 // Assign back result to iter vars 756 if (capturedAndStoredVars.size() == 1) { 757 for (Value v : capturedAndStoredVars) { 758 kblock.op(varStore(cc.getValue(v), forResult)); 759 } 760 } else { 761 int i = 0; 762 for (Value v : capturedAndStoredVars) { 763 kblock.op(varStore(cc.getValue(v), 764 kblock.op(tupleLoad(forResult, i++)))); 765 } 766 } 767 } 768 769 static Map<Boolean, Set<Value>> capturedVars(Body body) { 770 Map<Boolean, Set<Value>> capturedValues = new HashMap<>(); 771 capturedValues.put(false, new LinkedHashSet<>()); 772 capturedValues.put(true, new LinkedHashSet<>()); 773 774 capturedVars(capturedValues, new ArrayDeque<>(), body); 775 return capturedValues; 776 } 777 778 static void capturedVars(Map<Boolean, Set<Value>> capturedVars, Deque<Body> bodyStack, Body body) { 779 bodyStack.push(body); 780 781 for (Block b : body.blocks()) { 782 for (Op op : b.ops()) { 783 // @@@ Nested bodies 784 if (!op.bodies().isEmpty()) { 785 throw new IllegalStateException(); 786 } 787 // for (Body childBody : op.bodies()) { 788 // capturedAndUpdatedVars(capturedValues, bodyStack, childBody); 789 // } 790 791 if (op instanceof VarAccessOp) { 792 Value v = op.operands().get(0); 793 if (!bodyStack.contains(v.declaringBlock().parentBody())) { 794 if (op instanceof VarAccessOp.VarStoreOp) { 795 capturedVars.get(true).add(v); 796 capturedVars.get(false).remove(v); 797 } else if (!capturedVars.get(true).contains(v)) { 798 capturedVars.get(false).add(v); 799 } 800 } 801 } 802 } 803 } 804 805 bodyStack.pop(); 806 } 807 808 public static final ScopedValue<Boolean> SV_SSA = ScopedValue.newInstance(); 809 810 static TritonOps.FuncOp cleanup(TritonOps.FuncOp f) { 811 // Remove var ops 812 boolean doSSA = SV_SSA.isBound() ? SV_SSA.get() : true; 813 if (doSSA) { 814 f = SSA.transform(f); 815 } 816 // Remove unused ops 817 f = f.transform((fblock, op) -> { 818 if (op instanceof Op.Pure && op.result().uses().isEmpty()) { 819 return fblock; 820 } else if (op instanceof VarAccessOp.VarLoadOp && op.result().uses().isEmpty()) { 821 return fblock; 822 } 823 824 fblock.op(op); 825 return fblock; 826 }); 827 return f; 828 } 829 830 static class TritonBuilderInterpreter { 831 final Map<String, TritonOps.FuncOp> fsymTable; 832 final Block.Builder block; 833 834 TritonBuilderInterpreter(Map<String, TritonOps.FuncOp> fsymTable, Block.Builder block) { 835 this.fsymTable = fsymTable; 836 this.block = block; 837 } 838 839 Value build(Op op, String name, Map<Value, TypeElement> valueTypeMap) { 840 // Obtain associated type-based method 841 MethodHandle mh; 842 try { 843 Optional<Method> om = Stream.of(TritonBuilderInterpreter.class.getDeclaredMethods()) 844 .filter(m -> m.getName().equals(name)) 845 .findFirst(); 846 mh = MethodHandles.lookup().unreflect( 847 om.orElseThrow(() -> new NoSuchMethodException(name))); 848 } catch (ReflectiveOperationException e) { 849 throw new IllegalStateException(e); 850 } 851 852 List<Object> iArgs = new ArrayList<>(); 853 iArgs.add(this); 854 iArgs.add(valueTypeMap.get(op.result())); 855 iArgs.add(op.result()); 856 for (Value o : op.operands()) { 857 iArgs.add(valueTypeMap.get(o)); 858 iArgs.add(o); 859 } 860 try { 861 return (Value) mh.invokeWithArguments(iArgs.toArray(Object[]::new)); 862 } catch (Throwable e) { 863 throw new IllegalStateException(e); 864 } 865 } 866 867 868 public Value programId(TypeElement rType, Op.Result r, 869 ConstantType axisType, Value axis) { 870 return block.op(TritonOps.getProgramId( 871 (int) axisType.value())); 872 } 873 874 public Value arange(TensorType rType, Op.Result r, 875 ConstantType startType, Value start, 876 ConstantType endType, Value end) { 877 return block.op(TritonOps.makeRange( 878 (int) startType.value(), 879 (int) endType.value())); 880 } 881 882 public Value expand(TensorType rType, Op.Result r, 883 TensorType aType, Value a, 884 ConstantType axisType, Value axis) { 885 return block.op(TritonOps.expand( 886 (int) axisType.value(), 887 rType, 888 block.context().getValue(a))); 889 } 890 891 public Value zeros(TensorType rType, Op.Result r, 892 ConstantType aType, Value a, 893 Object... constantsAndValues) { 894 Object zero; 895 try { 896 JavaType zeroType = (JavaType) aType.value(); 897 zero = MethodHandles.zero((Class<?>) zeroType.resolve(MethodHandles.lookup())).invoke(); 898 } catch (Throwable e) { 899 throw new RuntimeException(e); 900 } 901 return block.op(ArithMathOps.constant(rType, zero)); 902 } 903 904 public Value load(TensorType rType, Op.Result r, 905 TensorType ptrType, Value ptr, 906 TensorType maskType, Value mask) { 907 broadcastConversionRight(ptrType, maskType, mask); 908 return block.op(TritonOps.load( 909 rType, 910 block.context().getValue(ptr), 911 block.context().getValue(mask))); 912 } 913 914 public Value store(TensorType rType, Op.Result r, 915 TensorType ptrType, Value ptr, 916 TensorType valueType, Value value, 917 TensorType maskType, Value mask) { 918 broadcastConversionRight(ptrType, valueType, value); 919 broadcastConversionRight(ptrType, maskType, mask); 920 return block.op(TritonOps.store( 921 block.context().getValue(ptr), 922 block.context().getValue(value), 923 block.context().getValue(mask))); 924 } 925 926 public Value broadcast(TensorType rType, Op.Result r, 927 TypeElement oType, Value o, 928 TensorType tensorTypeType, Value tensorType) { 929 // @@@ tt.splat with scalar operand, tt.broadcast with tensor operand 930 if (oType instanceof TensorType) { 931 return block.op(TritonOps.broadcast( 932 rType, 933 block.context().getValue(o))); 934 } else { 935 return block.op(TritonOps.splat( 936 rType, 937 block.context().getValue(o))); 938 } 939 } 940 941 public Value joinShape(TensorType rType, Op.Result r, 942 TensorType aType, Value a, 943 TensorType bType, Value b) { 944 // Replace with constant operation to produce tensor type. 945 // Result may be used, but transitively it will be removed due to no uses 946 // contributing to the computation 947 return block.op(CoreOp.constant(JavaType.type(TensorType.class), r.type())); 948 } 949 950 951 public Value add(TypeElement rType, Op.Result r, 952 TypeElement aType, Value a, 953 TypeElement bType, Value b) { 954 broadcastConversion(rType, aType, a, bType, b); 955 a = block.context().getValue(a); 956 b = block.context().getValue(b); 957 958 if (rType instanceof PtrType || 959 rType instanceof TensorType t && t.eType() instanceof PtrType) { 960 return block.op(TritonOps.addptr(a, b)); 961 } else { 962 return block.op(ArithMathOps.add(a, b)); 963 } 964 } 965 966 public Value sub(TypeElement rType, Op.Result r, 967 TypeElement aType, Value a, 968 TypeElement bType, Value b) { 969 broadcastConversion(rType, aType, a, bType, b); 970 a = block.context().getValue(a); 971 b = block.context().getValue(b); 972 973 return block.op(ArithMathOps.sub(a, b)); 974 } 975 976 public Value mul(TypeElement rType, Op.Result r, 977 TypeElement aType, Value a, 978 TypeElement bType, Value b) { 979 broadcastConversion(rType, aType, a, bType, b); 980 a = block.context().getValue(a); 981 b = block.context().getValue(b); 982 983 return block.op(ArithMathOps.mul(a, b)); 984 } 985 986 public Value div(TypeElement rType, Op.Result r, 987 TypeElement aType, Value a, 988 TypeElement bType, Value b) { 989 broadcastConversion(rType, aType, a, bType, b); 990 a = block.context().getValue(a); 991 b = block.context().getValue(b); 992 993 return block.op(ArithMathOps.div(a, b)); 994 } 995 996 public Value mod(TypeElement rType, Op.Result r, 997 TypeElement aType, Value a, 998 TypeElement bType, Value b) { 999 broadcastConversion(rType, aType, a, bType, b); 1000 a = block.context().getValue(a); 1001 b = block.context().getValue(b); 1002 1003 return block.op(ArithMathOps.rem(a, b)); 1004 } 1005 1006 public Value and(TypeElement rType, Op.Result r, 1007 TypeElement aType, Value a, 1008 TypeElement bType, Value b) { 1009 broadcastConversion(rType, aType, a, bType, b); 1010 a = block.context().getValue(a); 1011 b = block.context().getValue(b); 1012 1013 return block.op(ArithMathOps.and(a, b)); 1014 } 1015 1016 public Value dot(TensorType rType, Op.Result r, 1017 TypeElement aType, Value a, 1018 TypeElement bType, Value b) { 1019 a = block.context().getValue(a); 1020 b = block.context().getValue(b); 1021 1022 return block.op(TritonOps.dot(rType, a, b)); 1023 } 1024 1025 public Value cdiv(TypeElement rType, Op.Result r, 1026 TypeElement aType, Value a, 1027 TypeElement bType, Value b) { 1028 a = block.context().getValue(a); 1029 b = block.context().getValue(b); 1030 1031 TritonOps.FuncOp cdiv = tritonFunction(Functions.getJavaCodeModel("cdiv"), 1032 rType, List.of(aType, bType), 1033 fsymTable); 1034 // @@@ Generalize 1035 List<Value> args = new ArrayList<>(); 1036 if (!(aType instanceof ConstantType)) { 1037 args.add(a); 1038 } 1039 if (!(bType instanceof ConstantType)) { 1040 args.add(b); 1041 } 1042 return block.op(TritonOps.call(cdiv, args)); 1043 } 1044 1045 public Value conv(TypeElement rType, Op.Result r, 1046 ConstantType tType, Value t, 1047 TypeElement aType, Value a) { 1048 a = block.context().getValue(a); 1049 1050 TypeElement rScalarType; 1051 TypeElement aScalarType; 1052 if (rType instanceof TensorType rTensorType && aType instanceof TensorType aTensorType) { 1053 rScalarType = rTensorType.eType(); 1054 aScalarType = aTensorType.eType(); 1055 } else { 1056 rScalarType = rType; 1057 aScalarType = aType; 1058 } 1059 1060 if (rScalarType.equals(Float16.FLOAT_16_TYPE) && aScalarType.equals(JavaType.FLOAT)) { 1061 return block.op(ArithMathOps.trunc(rType, a)); 1062 } else if (rType.equals(aType)) { 1063 return a; 1064 } else { 1065 throw new IllegalStateException(); 1066 } 1067 } 1068 1069 public Value exp(TritonType rType, Op.Result r, 1070 TritonType aType, Value a) { 1071 return block.op(ArithMathOps.exp( 1072 block.context().getValue(a))); 1073 } 1074 1075 public Value compare(TensorType rType, Op.Result r, 1076 TypeElement aType, Value a, 1077 TypeElement bType, Value b, 1078 ConstantType compareType, Value compare) { 1079 Triton.CompareKind ck = (Triton.CompareKind) compareType.value(); 1080 1081 ArithMathOps.CompareOp.CompareKind ack = switch (ck) { 1082 case LessThan -> ArithMathOps.CompareOp.CompareKind.slt; 1083 default -> throw new UnsupportedOperationException("Unsupported comparison: " + ck); 1084 }; 1085 1086 broadcastConversion(rType, aType, a, bType, b); 1087 a = block.context().getValue(a); 1088 b = block.context().getValue(b); 1089 1090 return block.op(ArithMathOps.cmp(ack, a, b)); 1091 } 1092 1093 1094 public Value max(TypeElement rType, Op.Result r, 1095 TensorType xType, Value x, 1096 ConstantType axisType, Value axis) { 1097 TritonOps.FuncOp f = tritonFunction(Functions.getJavaCodeModel("max"), 1098 rType, List.of(rType, rType), fsymTable); 1099 return reduce(rType, r, xType, x, axisType, axis, f); 1100 } 1101 1102 public Value sum(TypeElement rType, Op.Result r, 1103 TensorType xType, Value x, 1104 ConstantType axisType, Value axis) { 1105 TritonOps.FuncOp f = tritonFunction(Functions.getJavaCodeModel("sum"), 1106 rType, List.of(rType, rType), fsymTable); 1107 return reduce(rType, r, xType, x, axisType, axis, f); 1108 } 1109 1110 Value reduce(TypeElement rType, Op.Result r, 1111 TensorType xType, Value x, 1112 ConstantType axisType, Value axis, 1113 TritonOps.FuncOp f) { 1114 int axisConstant = (int) axisType.value(); 1115 1116 String signature = "reduce_" + f.funcName() + "_" + axisConstant; 1117 TritonOps.FuncOp rf = fsymTable.computeIfAbsent(signature, 1118 s -> reduce(rType, xType, axisConstant, s, f)); 1119 1120 return block.op(TritonOps.call(rf, block.context().getValue(x))); 1121 } 1122 1123 static TritonOps.FuncOp reduce(TypeElement elementType, 1124 TensorType tensorType, 1125 int axisConstant, 1126 String name, TritonOps.FuncOp scalarFunc) { 1127 return TritonOps.func(name, 1128 functionType(elementType, tensorType)) 1129 .body(fblock -> { 1130 TritonOps.ReduceOp reduceOp = TritonOps.reduce(fblock.parentBody(), 1131 axisConstant, fblock.parameters().get(0), 1132 functionType(elementType, elementType, elementType)) 1133 .body(rblock -> { 1134 Block.Parameter a = rblock.parameters().get(0); 1135 Block.Parameter b = rblock.parameters().get(1); 1136 Op.Result _r = rblock.op(TritonOps.call(scalarFunc, a, b)); 1137 rblock.op(TritonOps.reduceReturn(_r)); 1138 }); 1139 1140 Op.Result opr = fblock.op(reduceOp); 1141 fblock.op(TritonOps.return_(opr)); 1142 }); 1143 } 1144 1145 // @@@ Test 1146 public Value consume(TypeElement rType, Op.Result r, 1147 TypeElement aType, Value a) { 1148 return block.op(TritonTestOps.consume(block.context().getValue(a))); 1149 } 1150 1151 void broadcastConversion(TypeElement rType, 1152 TypeElement aType, Value a, 1153 TypeElement bType, Value b) { 1154 Value ma = block.context().getValue(a); 1155 Value mb = block.context().getValue(b); 1156 if (aType instanceof TensorType at && bType instanceof TensorType bTensorType) { 1157 TensorType rTensorType = (TensorType) rType; 1158 if (!at.shape().equals(rTensorType.shape())) { 1159 ma = block.op(TritonOps.broadcast(rTensorType, ma)); 1160 } 1161 if (!bTensorType.shape().equals(rTensorType.shape())) { 1162 if (rTensorType.eType() instanceof PtrType) { 1163 bTensorType = new TensorType(bType, rTensorType.shape()); 1164 mb = block.op(TritonOps.broadcast(bTensorType, mb)); 1165 } else { 1166 mb = block.op(TritonOps.broadcast(rTensorType, mb)); 1167 } 1168 } 1169 } else if (aType instanceof TensorType) { 1170 TensorType rTensorType = (TensorType) rType; 1171 if (rTensorType.eType() instanceof PtrType) { 1172 TensorType bTensorType = new TensorType(bType, rTensorType.shape()); 1173 mb = block.op(TritonOps.splat(bTensorType, mb)); 1174 } else { 1175 mb = block.op(TritonOps.splat(rTensorType, mb)); 1176 } 1177 } else if (bType instanceof TensorType) { 1178 TensorType rTensorType = (TensorType) rType; 1179 ma = block.op(TritonOps.splat(rTensorType, ma)); 1180 } 1181 block.context().mapValue(a, ma); 1182 block.context().mapValue(b, mb); 1183 } 1184 1185 void broadcastConversionRight(TypeElement aType, 1186 TypeElement bType, Value b) { 1187 Value mb = block.context().getValue(b); 1188 if (aType instanceof TensorType aTensorType && bType instanceof TensorType bTensorType) { 1189 if (!bTensorType.shape().equals(aTensorType.shape())) { 1190 if (aTensorType.eType() instanceof PtrType) { 1191 bTensorType = new TensorType(bTensorType.eType(), aTensorType.shape()); 1192 mb = block.op(TritonOps.broadcast(bTensorType, mb)); 1193 } else { 1194 mb = block.op(TritonOps.broadcast(aTensorType, mb)); 1195 } 1196 } 1197 } else if (aType instanceof TensorType rTensorType) { 1198 if (rTensorType.eType() instanceof PtrType) { 1199 TensorType bTensorType = new TensorType(bType, rTensorType.shape()); 1200 mb = block.op(TritonOps.splat(bTensorType, mb)); 1201 } else { 1202 mb = block.op(TritonOps.splat(rTensorType, mb)); 1203 } 1204 } 1205 block.context().mapValue(b, mb); 1206 } 1207 } 1208 1209 public static <O extends Op & Op.Invokable> void printTypeMap( 1210 O kernel, Map<Value, TypeElement> valueTypeMap) { 1211 AtomicInteger valueId = new AtomicInteger(); 1212 Map<Value, Integer> valueIdMap = new LinkedHashMap<>(); 1213 kernel.traverse(null, (o, codeElement) -> { 1214 switch (codeElement) { 1215 case FuncOp _ -> { 1216 // Ignore 1217 } 1218 case Op op when !op.result().type().equals(JavaType.VOID) -> { 1219 valueIdMap.put(op.result(), valueId.getAndIncrement()); 1220 } 1221 case Block block -> { 1222 for (Block.Parameter parameter : block.parameters()) { 1223 valueIdMap.put(parameter, valueId.getAndIncrement()); 1224 } 1225 } 1226 default -> { 1227 } 1228 } 1229 return null; 1230 }); 1231 1232 valueIdMap.forEach((value, id) -> { 1233 TypeElement type = valueTypeMap.get(value); 1234 if (type != null) { 1235 System.out.println("%" + id + " : " + value.type() + " -> " + type); 1236 } 1237 }); 1238 } 1239 }