1 /* 2 * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved. 3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. 4 * 5 * This code is free software; you can redistribute it and/or modify it 6 * under the terms of the GNU General Public License version 2 only, as 7 * published by the Free Software Foundation. Oracle designates this 8 * particular file as subject to the "Classpath" exception as provided 9 * by Oracle in the LICENSE file that accompanied this code. 10 * 11 * This code is distributed in the hope that it will be useful, but WITHOUT 12 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or 13 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License 14 * version 2 for more details (a copy is included in the LICENSE file that 15 * accompanied this code). 16 * 17 * You should have received a copy of the GNU General Public License version 18 * 2 along with this work; if not, write to the Free Software Foundation, 19 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. 20 * 21 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA 22 * or visit www.oracle.com if you need additional information or have any 23 * questions. 24 */ 25 26 package oracle.code.triton; 27 28 import java.lang.reflect.code.*; 29 import java.lang.reflect.code.op.*; 30 import java.lang.reflect.code.type.*; 31 import java.util.*; 32 import java.util.function.Consumer; 33 34 public class TritonOps { 35 36 static abstract class TritonOp extends ExternalizableOp { 37 final TypeElement resultType; 38 39 public TritonOp(ExternalizedOp def) { 40 super(def); 41 42 this.resultType = def.resultType(); 43 } 44 45 TritonOp(TritonOp that, CopyContext cc) { 46 super(that, cc); 47 48 this.resultType = that.resultType; 49 } 50 51 TritonOp(String name, TypeElement resultType, List<? extends Value> operands) { 52 super(name, operands); 53 54 this.resultType = resultType; 55 } 56 57 @Override 58 public TypeElement resultType() { 59 return resultType; 60 } 61 } 62 63 @OpFactory.OpDeclaration(ModuleOp.NAME) 64 public static final class ModuleOp extends TritonOp implements Op.Isolated { 65 public static final String NAME = "module"; 66 67 final Map<String, FuncOp> table; 68 final Body body; 69 70 public ModuleOp(ExternalizedOp def) { 71 super(def); 72 73 this.body = def.bodyDefinitions().get(0).build(this); 74 this.table = createTable(body); 75 } 76 77 ModuleOp(ModuleOp that, CopyContext cc, OpTransformer ot) { 78 super(that, cc); 79 80 this.body = that.body.transform(cc, ot).build(this); 81 this.table = createTable(body); 82 } 83 84 static Map<String, FuncOp> createTable(Body body) { 85 Map<String, FuncOp> table = new HashMap<>(); 86 for (var op : body.entryBlock().ops()) { 87 if (op instanceof FuncOp fop) { 88 table.put(fop.funcName(), fop); 89 } else if (op instanceof CoreOp.UnreachableOp _) { 90 // no operation 91 } else { 92 throw new IllegalArgumentException("Bad operation in module: " + op); 93 } 94 } 95 return Collections.unmodifiableMap(table); 96 } 97 98 @Override 99 public ModuleOp transform(CopyContext cc, OpTransformer ot) { 100 return new ModuleOp(this, cc, ot); 101 } 102 103 public ModuleOp transform(OpTransformer ot) { 104 return new ModuleOp(this, CopyContext.create(), ot); 105 } 106 107 ModuleOp(List<FuncOp> functions) { 108 super(NAME, JavaType.VOID, 109 List.of()); 110 111 Body.Builder bodyC = Body.Builder.of(null, FunctionType.VOID); 112 Block.Builder entryBlock = bodyC.entryBlock(); 113 Map<String, FuncOp> table = new HashMap<>(); 114 for (FuncOp f : functions) { 115 entryBlock.op(f); 116 table.put(f.funcName(), f); 117 } 118 entryBlock.op(CoreOp.unreachable()); 119 this.table = Collections.unmodifiableMap(table); 120 this.body = bodyC.build(this); 121 } 122 123 @Override 124 public List<Body> bodies() { 125 return List.of(body); 126 } 127 128 public Map<String, FuncOp> functionTable() { 129 return table; 130 } 131 } 132 133 @OpFactory.OpDeclaration(FuncOp.NAME) 134 public static final class FuncOp extends TritonOp implements Op.Invokable, Op.Isolated, Op.Lowerable { 135 136 public static class Builder { 137 final Body.Builder ancestorBody; 138 final String funcName; 139 final FunctionType funcType; 140 141 Builder(Body.Builder ancestorBody, String funcName, FunctionType funcType) { 142 this.ancestorBody = ancestorBody; 143 this.funcName = funcName; 144 this.funcType = funcType; 145 } 146 147 public FuncOp body(Consumer<Block.Builder> c) { 148 Body.Builder body = Body.Builder.of(ancestorBody, funcType); 149 c.accept(body.entryBlock()); 150 return new FuncOp(funcName, body); 151 } 152 } 153 154 public static final String NAME = "tt.func"; 155 public static final String ATTRIBUTE_FUNC_NAME = NAME + ".name"; 156 157 final String funcName; 158 final Body body; 159 160 public static FuncOp create(ExternalizedOp def) { 161 if (!def.operands().isEmpty()) { 162 throw new IllegalStateException("Bad op " + def.name()); 163 } 164 165 String funcName = def.extractAttributeValue(ATTRIBUTE_FUNC_NAME, true, 166 v -> switch (v) { 167 case String s -> s; 168 default -> throw new UnsupportedOperationException("Unsupported func name value:" + v); 169 }); 170 return new FuncOp(def, funcName); 171 } 172 173 FuncOp(ExternalizedOp def, String funcName) { 174 super(def); 175 176 this.funcName = funcName; 177 this.body = def.bodyDefinitions().get(0).build(this); 178 } 179 180 FuncOp(FuncOp that, CopyContext cc, OpTransformer oa) { 181 this(that, that.funcName, cc, oa); 182 } 183 184 FuncOp(FuncOp that, String funcName, CopyContext cc, OpTransformer ot) { 185 super(that, cc); 186 187 this.funcName = funcName; 188 this.body = that.body.transform(cc, ot).build(this); 189 } 190 191 @Override 192 public FuncOp transform(CopyContext cc, OpTransformer ot) { 193 return new FuncOp(this, cc, ot); 194 } 195 196 public FuncOp transform(OpTransformer ot) { 197 return new FuncOp(this, CopyContext.create(), ot); 198 } 199 200 public FuncOp transform(String funcName, OpTransformer ot) { 201 return new FuncOp(this, funcName, CopyContext.create(), ot); 202 } 203 204 FuncOp(String funcName, Body.Builder bodyBuilder) { 205 super(NAME, JavaType.VOID, 206 List.of()); 207 208 this.funcName = funcName; 209 this.body = bodyBuilder.build(this); 210 } 211 212 @Override 213 public List<Body> bodies() { 214 return List.of(body); 215 } 216 217 @Override 218 public Map<String, Object> attributes() { 219 HashMap<String, Object> m = new HashMap<>(super.attributes()); 220 m.put("", funcName); 221 return Collections.unmodifiableMap(m); 222 } 223 224 @Override 225 public FunctionType invokableType() { 226 return body.bodyType(); 227 } 228 229 public String funcName() { 230 return funcName; 231 } 232 233 @Override 234 public Body body() { 235 return body; 236 } 237 238 @Override 239 public Block.Builder lower(Block.Builder b, OpTransformer _ignore) { 240 // Isolate body with respect to ancestor transformations 241 // and copy directly without lowering descendant operations 242 b.op(this, OpTransformer.COPYING_TRANSFORMER); 243 return b; 244 } 245 } 246 247 @OpFactory.OpDeclaration(CallOp.NAME) 248 public static final class CallOp extends TritonOp { 249 public static final String NAME = "tt.call"; 250 public static final String ATTRIBUTE_FUNC_NAME = NAME + ".name"; 251 252 final String funcName; 253 254 public static CallOp create(ExternalizedOp def) { 255 String funcName = def.extractAttributeValue(ATTRIBUTE_FUNC_NAME, true, 256 v -> switch (v) { 257 case String s -> s; 258 default -> throw new UnsupportedOperationException("Unsupported func name value:" + v); 259 }); 260 261 return new CallOp(def, funcName); 262 } 263 264 CallOp(ExternalizedOp def, String funcName) { 265 super(def); 266 267 this.funcName = funcName; 268 } 269 270 CallOp(CallOp that, CopyContext cc) { 271 super(that, cc); 272 273 this.funcName = that.funcName; 274 } 275 276 @Override 277 public CallOp transform(CopyContext cc, OpTransformer ot) { 278 return new CallOp(this, cc); 279 } 280 281 CallOp(String funcName, TypeElement resultType, List<Value> args) { 282 super(NAME, resultType, args); 283 284 this.funcName = funcName; 285 } 286 287 @Override 288 public Map<String, Object> attributes() { 289 HashMap<String, Object> m = new HashMap<>(super.attributes()); 290 m.put("", funcName); 291 return Collections.unmodifiableMap(m); 292 } 293 294 public String funcName() { 295 return funcName; 296 } 297 } 298 299 @OpFactory.OpDeclaration(ReduceOp.NAME) 300 public static final class ReduceOp extends TritonOp { 301 // @@@ SSA transformation does not work with nested ops 302 // implements Op.Nested { 303 304 public static class Builder { 305 final Body.Builder ancestorBody; 306 final int axis; 307 final Value v; 308 final FunctionType reduceType; 309 310 Builder(Body.Builder ancestorBody, int axis, Value v, FunctionType reduceType) { 311 this.ancestorBody = ancestorBody; 312 this.axis = axis; 313 this.v = v; 314 this.reduceType = reduceType; 315 } 316 317 public ReduceOp body(Consumer<Block.Builder> c) { 318 Body.Builder body = Body.Builder.of(ancestorBody, reduceType); 319 c.accept(body.entryBlock()); 320 return new ReduceOp(axis, v, body); 321 } 322 } 323 324 public static final String NAME = "tt.reduce"; 325 public static final String ATTRIBUTE_AXIS = "axis"; 326 327 final int axis; 328 final Body reducer; 329 330 public static ReduceOp create(ExternalizedOp def) { 331 int axis = def.extractAttributeValue(ATTRIBUTE_AXIS, true, 332 v -> switch (v) { 333 case String s -> Integer.valueOf(s); 334 case Integer i -> i; 335 default -> throw new UnsupportedOperationException("Unsupported axis value:" + v); 336 }); 337 return new ReduceOp(def, axis); 338 } 339 340 ReduceOp(ExternalizedOp def, int axis) { 341 super(def); 342 343 this.axis = axis; 344 this.reducer = def.bodyDefinitions().get(0).build(this); 345 } 346 347 ReduceOp(ReduceOp that, CopyContext cc, OpTransformer ot) { 348 super(that, cc); 349 350 this.axis = that.axis; 351 this.reducer = that.reducer.transform(cc, ot).build(this); 352 } 353 354 @Override 355 public ReduceOp transform(CopyContext cc, OpTransformer ot) { 356 return new ReduceOp(this, cc, ot); 357 } 358 359 ReduceOp(int axis, Value tensor, Body.Builder reducerBuilder) { 360 super(NAME, reducerBuilder.bodyType().returnType(), List.of(tensor)); 361 362 this.axis = axis; 363 this.reducer = reducerBuilder.build(this); 364 } 365 366 @Override 367 public List<Body> bodies() { 368 return List.of(reducer); 369 } 370 371 @Override 372 public Map<String, Object> attributes() { 373 HashMap<String, Object> m = new HashMap<>(super.attributes()); 374 m.put(ATTRIBUTE_AXIS, axis); 375 return Collections.unmodifiableMap(m); 376 } 377 378 public int axis() { 379 return axis; 380 } 381 382 public Body reducer() { 383 return reducer; 384 } 385 } 386 387 @OpFactory.OpDeclaration(ReduceReturnOp.NAME) 388 public static class ReduceReturnOp extends TritonOp implements Op.Terminating { 389 public static final String NAME = "tt.reduce.return"; 390 391 public ReduceReturnOp(ExternalizedOp def) { 392 super(def); 393 } 394 395 ReduceReturnOp(ReduceReturnOp that, CopyContext cc) { 396 super(that, cc); 397 } 398 399 @Override 400 public ReduceReturnOp transform(CopyContext cc, OpTransformer ot) { 401 return new ReduceReturnOp(this, cc); 402 } 403 404 ReduceReturnOp(Value r) { 405 super(NAME, JavaType.VOID, List.of(r)); 406 } 407 } 408 409 @OpFactory.OpDeclaration(GetProgramIdOp.NAME) 410 public static class GetProgramIdOp extends TritonOp implements Op.Pure { 411 public static final String NAME = "tt.get_program_id"; 412 public static final String ATTRIBUTE_AXIS = "axis"; 413 414 final int axis; 415 416 public static GetProgramIdOp create(ExternalizedOp def) { 417 int axis = def.extractAttributeValue(ATTRIBUTE_AXIS, true, 418 v -> switch (v) { 419 case String s -> Integer.valueOf(s); 420 case Integer i -> i; 421 default -> throw new UnsupportedOperationException("Unsupported axis value:" + v); 422 }); 423 return new GetProgramIdOp(def, axis); 424 } 425 426 GetProgramIdOp(ExternalizedOp def, int axis) { 427 super(def); 428 429 this.axis = axis; 430 } 431 432 GetProgramIdOp(GetProgramIdOp that, CopyContext cc) { 433 super(that, cc); 434 435 this.axis = that.axis; 436 } 437 438 @Override 439 public GetProgramIdOp transform(CopyContext cc, OpTransformer ot) { 440 return new GetProgramIdOp(this, cc); 441 } 442 443 GetProgramIdOp(int axis) { 444 super(NAME, JavaType.INT, List.of()); 445 446 this.axis = axis; 447 } 448 449 @Override 450 public Map<String, Object> attributes() { 451 HashMap<String, Object> m = new HashMap<>(super.attributes()); 452 m.put("", axis); 453 return Collections.unmodifiableMap(m); 454 } 455 456 public int axis() { 457 return axis; 458 } 459 } 460 461 @OpFactory.OpDeclaration(MakeRangeOp.NAME) 462 public static class MakeRangeOp extends TritonOp implements Op.Pure { 463 public static final String NAME = "tt.make_range"; 464 public static final String ATTRIBUTE_START = "start"; 465 public static final String ATTRIBUTE_END = "end"; 466 467 final int start; 468 final int end; 469 470 public static MakeRangeOp create(ExternalizedOp def) { 471 int start = def.extractAttributeValue(ATTRIBUTE_START, false, 472 v -> switch (v) { 473 case String s -> Integer.valueOf(s); 474 case Integer i -> i; 475 default -> throw new UnsupportedOperationException("Unsupported start value:" + v); 476 }); 477 int end = def.extractAttributeValue(ATTRIBUTE_END, false, 478 v -> switch (v) { 479 case String s -> Integer.valueOf(s); 480 case Integer i -> i; 481 default -> throw new UnsupportedOperationException("Unsupported end value:" + v); 482 }); 483 return new MakeRangeOp(def, start, end); 484 } 485 486 MakeRangeOp(ExternalizedOp def, int start, int end) { 487 super(def); 488 489 this.start = start; 490 this.end = end; 491 } 492 493 MakeRangeOp(MakeRangeOp that, CopyContext cc) { 494 super(that, cc); 495 496 this.start = that.start; 497 this.end = that.end; 498 } 499 500 @Override 501 public MakeRangeOp transform(CopyContext cc, OpTransformer ot) { 502 return new MakeRangeOp(this, cc); 503 } 504 505 MakeRangeOp(int start, int end) { 506 super(NAME, tensorType(start, end), List.of()); 507 508 this.start = start; 509 this.end = end; 510 } 511 512 static TensorType tensorType(int start, int end) { 513 return new TensorType(JavaType.INT, List.of(end - start)); 514 } 515 516 @Override 517 public Map<String, Object> attributes() { 518 HashMap<String, Object> m = new HashMap<>(super.attributes()); 519 m.put(ATTRIBUTE_START, start); 520 m.put(ATTRIBUTE_END, end); 521 return Collections.unmodifiableMap(m); 522 } 523 } 524 525 @OpFactory.OpDeclaration(ExpandOp.NAME) 526 public static class ExpandOp extends TritonOp implements Op.Pure { 527 public static final String NAME = "tt.expand_dims"; 528 public static final String ATTRIBUTE_AXIS = "axis"; 529 530 final int axis; 531 532 public static ExpandOp create(ExternalizedOp def) { 533 int axis = def.extractAttributeValue(ATTRIBUTE_AXIS, true, 534 v -> switch (v) { 535 case String s -> Integer.valueOf(s); 536 case Integer i -> i; 537 default -> throw new UnsupportedOperationException("Unsupported axis value:" + v); 538 }); 539 return new ExpandOp(def, axis); 540 } 541 542 ExpandOp(ExternalizedOp def, int axis) { 543 super(def); 544 545 this.axis = axis; 546 } 547 548 ExpandOp(ExpandOp that, CopyContext cc) { 549 super(that, cc); 550 551 this.axis = that.axis; 552 } 553 554 @Override 555 public ExpandOp transform(CopyContext cc, OpTransformer ot) { 556 return new ExpandOp(this, cc); 557 } 558 559 ExpandOp(int axis, TypeElement tensorType, Value v) { 560 super(NAME, tensorType, List.of(v)); 561 562 this.axis = axis; 563 } 564 565 @Override 566 public Map<String, Object> attributes() { 567 HashMap<String, Object> m = new HashMap<>(super.attributes()); 568 m.put("", axis); 569 return Collections.unmodifiableMap(m); 570 } 571 572 public int axis() { 573 return axis; 574 } 575 } 576 577 @OpFactory.OpDeclaration(SplatOp.NAME) 578 public static class SplatOp extends TritonOp implements Op.Pure { 579 public static final String NAME = "tt.splat"; 580 581 public SplatOp(ExternalizedOp def) { 582 super(def); 583 } 584 585 SplatOp(SplatOp that, CopyContext cc) { 586 super(that, cc); 587 } 588 589 @Override 590 public SplatOp transform(CopyContext cc, OpTransformer ot) { 591 return new SplatOp(this, cc); 592 } 593 594 SplatOp(TypeElement tensorType, Value v) { 595 super(NAME, tensorType, List.of(v)); 596 } 597 } 598 599 @OpFactory.OpDeclaration(BroadcastOp.NAME) 600 public static class BroadcastOp extends TritonOp implements Op.Pure { 601 public static final String NAME = "tt.broadcast"; 602 603 public BroadcastOp(ExternalizedOp def) { 604 super(def); 605 } 606 607 BroadcastOp(BroadcastOp that, CopyContext cc) { 608 super(that, cc); 609 } 610 611 @Override 612 public BroadcastOp transform(CopyContext cc, OpTransformer ot) { 613 return new BroadcastOp(this, cc); 614 } 615 616 BroadcastOp(TypeElement tensorType, Value v) { 617 super(NAME, tensorType, List.of(v)); 618 } 619 } 620 621 @OpFactory.OpDeclaration(AddPtrOp.NAME) 622 public static class AddPtrOp extends TritonOp implements Op.Pure { 623 public static final String NAME = "tt.addptr"; 624 625 public AddPtrOp(ExternalizedOp def) { 626 super(def); 627 } 628 629 AddPtrOp(AddPtrOp that, CopyContext cc) { 630 super(that, cc); 631 } 632 633 @Override 634 public AddPtrOp transform(CopyContext cc, OpTransformer ot) { 635 return new AddPtrOp(this, cc); 636 } 637 638 AddPtrOp(Value ptr, Value offset) { 639 super(NAME, ptr.type(), List.of(ptr, offset)); 640 } 641 } 642 643 @OpFactory.OpDeclaration(LoadOp.NAME) 644 public static class LoadOp extends TritonOp implements Op.Pure { 645 public static final String NAME = "tt.load"; 646 647 public LoadOp(ExternalizedOp def) { 648 super(def); 649 } 650 651 LoadOp(LoadOp that, CopyContext cc) { 652 super(that, cc); 653 } 654 655 @Override 656 public LoadOp transform(CopyContext cc, OpTransformer ot) { 657 return new LoadOp(this, cc); 658 } 659 660 LoadOp(TypeElement tensorType, Value ptr, Value mask) { 661 super(NAME, tensorType, List.of(ptr, mask)); 662 } 663 } 664 665 @OpFactory.OpDeclaration(StoreOp.NAME) 666 public static class StoreOp extends TritonOp { 667 public static final String NAME = "tt.store"; 668 669 public StoreOp(ExternalizedOp def) { 670 super(def); 671 } 672 673 StoreOp(StoreOp that, CopyContext cc) { 674 super(that, cc); 675 } 676 677 @Override 678 public StoreOp transform(CopyContext cc, OpTransformer ot) { 679 return new StoreOp(this, cc); 680 } 681 682 StoreOp(Value ptr, Value v, Value mask) { 683 super(NAME, JavaType.VOID, List.of(ptr, v, mask)); 684 } 685 } 686 687 @OpFactory.OpDeclaration(ReturnOp.NAME) 688 public static class ReturnOp extends TritonOp implements Op.Terminating { 689 public static final String NAME = "tt.return"; 690 691 public ReturnOp(ExternalizedOp def) { 692 super(def); 693 } 694 695 ReturnOp(ReturnOp that, CopyContext cc) { 696 super(that, cc); 697 } 698 699 @Override 700 public ReturnOp transform(CopyContext cc, OpTransformer ot) { 701 return new ReturnOp(this, cc); 702 } 703 704 ReturnOp() { 705 super(NAME, JavaType.VOID, List.of()); 706 } 707 708 ReturnOp(Value v) { 709 super(NAME, JavaType.VOID, List.of(v)); 710 } 711 } 712 713 @OpFactory.OpDeclaration(DotOp.NAME) 714 public static class DotOp extends TritonOp implements Op.Pure { 715 public static final String NAME = "tt.dot"; 716 717 public DotOp(ExternalizedOp def) { 718 super(def); 719 } 720 721 DotOp(DotOp that, CopyContext cc) { 722 super(that, cc); 723 } 724 725 @Override 726 public DotOp transform(CopyContext cc, OpTransformer ot) { 727 return new DotOp(this, cc); 728 } 729 730 DotOp(TypeElement tensorType, Value a, Value b) { 731 super(NAME, tensorType, List.of(a, b)); 732 } 733 } 734 735 736 public static ModuleOp module(FuncOp... functions) { 737 return module(List.of(functions)); 738 } 739 740 public static ModuleOp module(List<FuncOp> functions) { 741 return new ModuleOp(List.copyOf(functions)); 742 } 743 744 public static FuncOp.Builder func(String funcName, FunctionType funcType) { 745 return new FuncOp.Builder(null, funcName, funcType); 746 } 747 748 public static FuncOp func(String funcName, Body.Builder body) { 749 return new FuncOp(funcName, body); 750 } 751 752 public static CallOp call(FuncOp func, Value... args) { 753 return call(func, List.of(args)); 754 } 755 756 public static CallOp call(FuncOp func, List<Value> args) { 757 return new CallOp(func.funcName(), func.invokableType().returnType(), args); 758 } 759 760 public static ReduceOp.Builder reduce(Body.Builder ancestorBody, int axis, Value tensor, 761 FunctionType reduceType) { 762 return new ReduceOp.Builder(ancestorBody, axis, tensor, reduceType); 763 } 764 765 public static ReduceOp reduce(int axis, Value tensor, Body.Builder reducerBuilder) { 766 return new ReduceOp(axis, tensor, reducerBuilder); 767 } 768 769 public static ReduceReturnOp reduceReturn(Value r) { 770 return new ReduceReturnOp(r); 771 } 772 773 public static GetProgramIdOp getProgramId(int axis) { 774 // @@@ 1 <= axis <= 3 775 return new GetProgramIdOp(axis); 776 } 777 778 public static MakeRangeOp makeRange(int start, int end) { 779 // @@@ 0 <= start < end 780 return new MakeRangeOp(start, end); 781 } 782 783 public static ExpandOp expand(int axis, TypeElement tensorType, Value v) { 784 return new ExpandOp(axis, tensorType, v); 785 } 786 787 // v is scalar 788 public static SplatOp splat(TypeElement tensorType, Value v) { 789 return new SplatOp(tensorType, v); 790 } 791 792 // v is tensor 793 public static BroadcastOp broadcast(TypeElement tensorType, Value v) { 794 return new BroadcastOp(tensorType, v); 795 } 796 797 public static AddPtrOp addptr(Value ptr, Value offset) { 798 return new AddPtrOp(ptr, offset); 799 } 800 801 public static LoadOp load(TypeElement tensorType, Value ptr, Value mask) { 802 return new LoadOp(tensorType, ptr, mask); 803 } 804 805 public static StoreOp store(Value ptr, Value v, Value mask) { 806 return new StoreOp(ptr, v, mask); 807 } 808 809 public static ReturnOp return_() { 810 return new ReturnOp(); 811 } 812 813 public static ReturnOp return_(Value v) { 814 return new ReturnOp(v); 815 } 816 817 public static DotOp dot(TypeElement tensorType, Value a, Value b) { 818 return new DotOp(tensorType, a, b); 819 } 820 821 822 // Operation and type factories 823 824 public static final OpFactory FACTORY = OpFactory.OP_FACTORY.get(TritonOps.class); 825 826 static final TypeElementFactory TRITON_TYPE_FACTORY = new TypeElementFactory() { 827 @Override 828 public TypeElement constructType(TypeElement.ExternalizedTypeElement tree) { 829 return switch (tree.identifier()) { 830 case PtrType.NAME -> { 831 if (tree.arguments().size() != 1) { 832 throw new IllegalArgumentException(); 833 } 834 835 TypeElement v = TRITON_JAVA_TYPE_FACTORY.constructType(tree.arguments().getFirst()); 836 if (v == null) { 837 throw new IllegalArgumentException("Bad type: " + tree); 838 } 839 if (v instanceof JavaType || v instanceof TritonType) { 840 yield new PtrType(v); 841 } else { 842 throw new IllegalArgumentException("Bad type: " + tree); 843 } 844 } 845 case TensorType.NAME -> { 846 if (tree.arguments().size() < 2) { 847 throw new IllegalArgumentException("Bad type: " + tree); 848 } 849 850 List<Integer> shape = new ArrayList<>(); 851 for (int i = 0; i < tree.arguments().size() - 1; i++) { 852 TypeElement.ExternalizedTypeElement a = tree.arguments().get(i); 853 if (!a.identifier().startsWith("x")) { 854 throw new IllegalArgumentException("Bad type: " + tree); 855 } 856 int d; 857 try { 858 d = Integer.parseInt(a.identifier().substring(1)); 859 } catch (NumberFormatException e) { 860 throw new IllegalArgumentException("Bad type: " + tree, e); 861 } 862 shape.add(d); 863 } 864 865 TypeElement v = TRITON_JAVA_TYPE_FACTORY.constructType(tree.arguments().getLast()); 866 if (v == null) { 867 throw new IllegalArgumentException("Bad type: " + tree); 868 } 869 if (v instanceof JavaType || v instanceof TritonType) { 870 yield new TensorType(v, shape); 871 } else { 872 throw new IllegalArgumentException("Bad type: " + tree); 873 } 874 } 875 default -> null; 876 }; 877 } 878 }; 879 880 // Triton types then Java types 881 static final TypeElementFactory TRITON_JAVA_TYPE_FACTORY = 882 TRITON_TYPE_FACTORY.andThen(CoreTypeFactory.JAVA_TYPE_FACTORY); 883 884 // Triton types then Java types, combined with code model types 885 public static final TypeElementFactory TYPE_FACTORY = 886 CoreTypeFactory.codeModelTypeFactory(TRITON_JAVA_TYPE_FACTORY); 887 888 }