1 /*
  2  * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 
 26 package oracle.code.triton;
 27 
 28 import java.lang.reflect.code.*;
 29 import java.lang.reflect.code.op.*;
 30 import java.lang.reflect.code.type.*;
 31 import java.util.*;
 32 import java.util.function.Consumer;
 33 
 34 public class TritonOps {
 35 
 36     static abstract class TritonOp extends ExternalizableOp {
 37         final TypeElement resultType;
 38 
 39         public TritonOp(ExternalizedOp def) {
 40             super(def);
 41 
 42             this.resultType = def.resultType();
 43         }
 44 
 45         TritonOp(TritonOp that, CopyContext cc) {
 46             super(that, cc);
 47 
 48             this.resultType = that.resultType;
 49         }
 50 
 51         TritonOp(String name, TypeElement resultType, List<? extends Value> operands) {
 52             super(name, operands);
 53 
 54             this.resultType = resultType;
 55         }
 56 
 57         @Override
 58         public TypeElement resultType() {
 59             return resultType;
 60         }
 61     }
 62 
 63     @OpFactory.OpDeclaration(ModuleOp.NAME)
 64     public static final class ModuleOp extends TritonOp implements Op.Isolated {
 65         public static final String NAME = "module";
 66 
 67         final Map<String, FuncOp> table;
 68         final Body body;
 69 
 70         public ModuleOp(ExternalizedOp def) {
 71             super(def);
 72 
 73             this.body = def.bodyDefinitions().get(0).build(this);
 74             this.table = createTable(body);
 75         }
 76 
 77         ModuleOp(ModuleOp that, CopyContext cc, OpTransformer ot) {
 78             super(that, cc);
 79 
 80             this.body = that.body.transform(cc, ot).build(this);
 81             this.table = createTable(body);
 82         }
 83 
 84         static Map<String, FuncOp> createTable(Body body) {
 85             Map<String, FuncOp> table = new HashMap<>();
 86             for (var op : body.entryBlock().ops()) {
 87                 if (op instanceof FuncOp fop) {
 88                     table.put(fop.funcName(), fop);
 89                 } else if (op instanceof CoreOp.UnreachableOp _) {
 90                     // no operation
 91                 } else {
 92                     throw new IllegalArgumentException("Bad operation in module: " + op);
 93                 }
 94             }
 95             return Collections.unmodifiableMap(table);
 96         }
 97 
 98         @Override
 99         public ModuleOp transform(CopyContext cc, OpTransformer ot) {
100             return new ModuleOp(this, cc, ot);
101         }
102 
103         public ModuleOp transform(OpTransformer ot) {
104             return new ModuleOp(this, CopyContext.create(), ot);
105         }
106 
107         ModuleOp(List<FuncOp> functions) {
108             super(NAME, JavaType.VOID,
109                     List.of());
110 
111             Body.Builder bodyC = Body.Builder.of(null, FunctionType.VOID);
112             Block.Builder entryBlock = bodyC.entryBlock();
113             Map<String, FuncOp> table = new HashMap<>();
114             for (FuncOp f : functions) {
115                 entryBlock.op(f);
116                 table.put(f.funcName(), f);
117             }
118             entryBlock.op(CoreOp.unreachable());
119             this.table = Collections.unmodifiableMap(table);
120             this.body = bodyC.build(this);
121         }
122 
123         @Override
124         public List<Body> bodies() {
125             return List.of(body);
126         }
127 
128         public Map<String, FuncOp> functionTable() {
129             return table;
130         }
131     }
132 
133     @OpFactory.OpDeclaration(FuncOp.NAME)
134     public static final class FuncOp extends TritonOp implements Op.Invokable, Op.Isolated, Op.Lowerable {
135 
136         public static class Builder {
137             final Body.Builder ancestorBody;
138             final String funcName;
139             final FunctionType funcType;
140 
141             Builder(Body.Builder ancestorBody, String funcName, FunctionType funcType) {
142                 this.ancestorBody = ancestorBody;
143                 this.funcName = funcName;
144                 this.funcType = funcType;
145             }
146 
147             public FuncOp body(Consumer<Block.Builder> c) {
148                 Body.Builder body = Body.Builder.of(ancestorBody, funcType);
149                 c.accept(body.entryBlock());
150                 return new FuncOp(funcName, body);
151             }
152         }
153 
154         public static final String NAME = "tt.func";
155         public static final String ATTRIBUTE_FUNC_NAME = NAME + ".name";
156 
157         final String funcName;
158         final Body body;
159 
160         public static FuncOp create(ExternalizedOp def) {
161             if (!def.operands().isEmpty()) {
162                 throw new IllegalStateException("Bad op " + def.name());
163             }
164 
165             String funcName = def.extractAttributeValue(ATTRIBUTE_FUNC_NAME, true,
166                     v -> switch (v) {
167                         case String s -> s;
168                         default -> throw new UnsupportedOperationException("Unsupported func name value:" + v);
169                     });
170             return new FuncOp(def, funcName);
171         }
172 
173         FuncOp(ExternalizedOp def, String funcName) {
174             super(def);
175 
176             this.funcName = funcName;
177             this.body = def.bodyDefinitions().get(0).build(this);
178         }
179 
180         FuncOp(FuncOp that, CopyContext cc, OpTransformer oa) {
181             this(that, that.funcName, cc, oa);
182         }
183 
184         FuncOp(FuncOp that, String funcName, CopyContext cc, OpTransformer ot) {
185             super(that, cc);
186 
187             this.funcName = funcName;
188             this.body = that.body.transform(cc, ot).build(this);
189         }
190 
191         @Override
192         public FuncOp transform(CopyContext cc, OpTransformer ot) {
193             return new FuncOp(this, cc, ot);
194         }
195 
196         public FuncOp transform(OpTransformer ot) {
197             return new FuncOp(this, CopyContext.create(), ot);
198         }
199 
200         public FuncOp transform(String funcName, OpTransformer ot) {
201             return new FuncOp(this, funcName, CopyContext.create(), ot);
202         }
203 
204         FuncOp(String funcName, Body.Builder bodyBuilder) {
205             super(NAME, JavaType.VOID,
206                     List.of());
207 
208             this.funcName = funcName;
209             this.body = bodyBuilder.build(this);
210         }
211 
212         @Override
213         public List<Body> bodies() {
214             return List.of(body);
215         }
216 
217         @Override
218         public Map<String, Object> attributes() {
219             HashMap<String, Object> m = new HashMap<>(super.attributes());
220             m.put("", funcName);
221             return Collections.unmodifiableMap(m);
222         }
223 
224         @Override
225         public FunctionType invokableType() {
226             return body.bodyType();
227         }
228 
229         public String funcName() {
230             return funcName;
231         }
232 
233         @Override
234         public Body body() {
235             return body;
236         }
237 
238         @Override
239         public Block.Builder lower(Block.Builder b, OpTransformer _ignore) {
240             // Isolate body with respect to ancestor transformations
241             // and copy directly without lowering descendant operations
242             b.op(this, OpTransformer.COPYING_TRANSFORMER);
243             return b;
244         }
245     }
246 
247     @OpFactory.OpDeclaration(CallOp.NAME)
248     public static final class CallOp extends TritonOp {
249         public static final String NAME = "tt.call";
250         public static final String ATTRIBUTE_FUNC_NAME = NAME + ".name";
251 
252         final String funcName;
253 
254         public static CallOp create(ExternalizedOp def) {
255             String funcName = def.extractAttributeValue(ATTRIBUTE_FUNC_NAME, true,
256                     v -> switch (v) {
257                         case String s -> s;
258                         default -> throw new UnsupportedOperationException("Unsupported func name value:" + v);
259                     });
260 
261             return new CallOp(def, funcName);
262         }
263 
264         CallOp(ExternalizedOp def, String funcName) {
265             super(def);
266 
267             this.funcName = funcName;
268         }
269 
270         CallOp(CallOp that, CopyContext cc) {
271             super(that, cc);
272 
273             this.funcName = that.funcName;
274         }
275 
276         @Override
277         public CallOp transform(CopyContext cc, OpTransformer ot) {
278             return new CallOp(this, cc);
279         }
280 
281         CallOp(String funcName, TypeElement resultType, List<Value> args) {
282             super(NAME, resultType, args);
283 
284             this.funcName = funcName;
285         }
286 
287         @Override
288         public Map<String, Object> attributes() {
289             HashMap<String, Object> m = new HashMap<>(super.attributes());
290             m.put("", funcName);
291             return Collections.unmodifiableMap(m);
292         }
293 
294         public String funcName() {
295             return funcName;
296         }
297     }
298 
299     @OpFactory.OpDeclaration(ReduceOp.NAME)
300     public static final class ReduceOp extends TritonOp {
301         // @@@ SSA transformation does not work with nested ops
302         // implements Op.Nested {
303 
304         public static class Builder {
305             final Body.Builder ancestorBody;
306             final int axis;
307             final Value v;
308             final FunctionType reduceType;
309 
310             Builder(Body.Builder ancestorBody, int axis, Value v, FunctionType reduceType) {
311                 this.ancestorBody = ancestorBody;
312                 this.axis = axis;
313                 this.v = v;
314                 this.reduceType = reduceType;
315             }
316 
317             public ReduceOp body(Consumer<Block.Builder> c) {
318                 Body.Builder body = Body.Builder.of(ancestorBody, reduceType);
319                 c.accept(body.entryBlock());
320                 return new ReduceOp(axis, v, body);
321             }
322         }
323 
324         public static final String NAME = "tt.reduce";
325         public static final String ATTRIBUTE_AXIS = "axis";
326 
327         final int axis;
328         final Body reducer;
329 
330         public static ReduceOp create(ExternalizedOp def) {
331             int axis = def.extractAttributeValue(ATTRIBUTE_AXIS, true,
332                     v -> switch (v) {
333                         case String s -> Integer.valueOf(s);
334                         case Integer i -> i;
335                         default -> throw new UnsupportedOperationException("Unsupported axis value:" + v);
336                     });
337             return new ReduceOp(def, axis);
338         }
339 
340         ReduceOp(ExternalizedOp def, int axis) {
341             super(def);
342 
343             this.axis = axis;
344             this.reducer = def.bodyDefinitions().get(0).build(this);
345         }
346 
347         ReduceOp(ReduceOp that, CopyContext cc, OpTransformer ot) {
348             super(that, cc);
349 
350             this.axis = that.axis;
351             this.reducer = that.reducer.transform(cc, ot).build(this);
352         }
353 
354         @Override
355         public ReduceOp transform(CopyContext cc, OpTransformer ot) {
356             return new ReduceOp(this, cc, ot);
357         }
358 
359         ReduceOp(int axis, Value tensor, Body.Builder reducerBuilder) {
360             super(NAME, reducerBuilder.bodyType().returnType(), List.of(tensor));
361 
362             this.axis = axis;
363             this.reducer = reducerBuilder.build(this);
364         }
365 
366         @Override
367         public List<Body> bodies() {
368             return List.of(reducer);
369         }
370 
371         @Override
372         public Map<String, Object> attributes() {
373             HashMap<String, Object> m = new HashMap<>(super.attributes());
374             m.put(ATTRIBUTE_AXIS, axis);
375             return Collections.unmodifiableMap(m);
376         }
377 
378         public int axis() {
379             return axis;
380         }
381 
382         public Body reducer() {
383             return reducer;
384         }
385     }
386 
387     @OpFactory.OpDeclaration(ReduceReturnOp.NAME)
388     public static class ReduceReturnOp extends TritonOp implements Op.Terminating {
389         public static final String NAME = "tt.reduce.return";
390 
391         public ReduceReturnOp(ExternalizedOp def) {
392             super(def);
393         }
394 
395         ReduceReturnOp(ReduceReturnOp that, CopyContext cc) {
396             super(that, cc);
397         }
398 
399         @Override
400         public ReduceReturnOp transform(CopyContext cc, OpTransformer ot) {
401             return new ReduceReturnOp(this, cc);
402         }
403 
404         ReduceReturnOp(Value r) {
405             super(NAME, JavaType.VOID, List.of(r));
406         }
407     }
408 
409     @OpFactory.OpDeclaration(GetProgramIdOp.NAME)
410     public static class GetProgramIdOp extends TritonOp implements Op.Pure {
411         public static final String NAME = "tt.get_program_id";
412         public static final String ATTRIBUTE_AXIS = "axis";
413 
414         final int axis;
415 
416         public static GetProgramIdOp create(ExternalizedOp def) {
417             int axis = def.extractAttributeValue(ATTRIBUTE_AXIS, true,
418                     v -> switch (v) {
419                         case String s -> Integer.valueOf(s);
420                         case Integer i -> i;
421                         default -> throw new UnsupportedOperationException("Unsupported axis value:" + v);
422                     });
423             return new GetProgramIdOp(def, axis);
424         }
425 
426         GetProgramIdOp(ExternalizedOp def, int axis) {
427             super(def);
428 
429             this.axis = axis;
430         }
431 
432         GetProgramIdOp(GetProgramIdOp that, CopyContext cc) {
433             super(that, cc);
434 
435             this.axis = that.axis;
436         }
437 
438         @Override
439         public GetProgramIdOp transform(CopyContext cc, OpTransformer ot) {
440             return new GetProgramIdOp(this, cc);
441         }
442 
443         GetProgramIdOp(int axis) {
444             super(NAME, JavaType.INT, List.of());
445 
446             this.axis = axis;
447         }
448 
449         @Override
450         public Map<String, Object> attributes() {
451             HashMap<String, Object> m = new HashMap<>(super.attributes());
452             m.put("", axis);
453             return Collections.unmodifiableMap(m);
454         }
455 
456         public int axis() {
457             return axis;
458         }
459     }
460 
461     @OpFactory.OpDeclaration(MakeRangeOp.NAME)
462     public static class MakeRangeOp extends TritonOp implements Op.Pure {
463         public static final String NAME = "tt.make_range";
464         public static final String ATTRIBUTE_START = "start";
465         public static final String ATTRIBUTE_END = "end";
466 
467         final int start;
468         final int end;
469 
470         public static MakeRangeOp create(ExternalizedOp def) {
471             int start = def.extractAttributeValue(ATTRIBUTE_START, false,
472                     v -> switch (v) {
473                         case String s -> Integer.valueOf(s);
474                         case Integer i -> i;
475                         default -> throw new UnsupportedOperationException("Unsupported start value:" + v);
476                     });
477             int end = def.extractAttributeValue(ATTRIBUTE_END, false,
478                     v -> switch (v) {
479                         case String s -> Integer.valueOf(s);
480                         case Integer i -> i;
481                         default -> throw new UnsupportedOperationException("Unsupported end value:" + v);
482                     });
483             return new MakeRangeOp(def, start, end);
484         }
485 
486         MakeRangeOp(ExternalizedOp def, int start, int end) {
487             super(def);
488 
489             this.start = start;
490             this.end = end;
491         }
492 
493         MakeRangeOp(MakeRangeOp that, CopyContext cc) {
494             super(that, cc);
495 
496             this.start = that.start;
497             this.end = that.end;
498         }
499 
500         @Override
501         public MakeRangeOp transform(CopyContext cc, OpTransformer ot) {
502             return new MakeRangeOp(this, cc);
503         }
504 
505         MakeRangeOp(int start, int end) {
506             super(NAME, tensorType(start, end), List.of());
507 
508             this.start = start;
509             this.end = end;
510         }
511 
512         static TensorType tensorType(int start, int end) {
513             return new TensorType(JavaType.INT, List.of(end - start));
514         }
515 
516         @Override
517         public Map<String, Object> attributes() {
518             HashMap<String, Object> m = new HashMap<>(super.attributes());
519             m.put(ATTRIBUTE_START, start);
520             m.put(ATTRIBUTE_END, end);
521             return Collections.unmodifiableMap(m);
522         }
523     }
524 
525     @OpFactory.OpDeclaration(ExpandOp.NAME)
526     public static class ExpandOp extends TritonOp implements Op.Pure {
527         public static final String NAME = "tt.expand_dims";
528         public static final String ATTRIBUTE_AXIS = "axis";
529 
530         final int axis;
531 
532         public static ExpandOp create(ExternalizedOp def) {
533             int axis = def.extractAttributeValue(ATTRIBUTE_AXIS, true,
534                     v -> switch (v) {
535                         case String s -> Integer.valueOf(s);
536                         case Integer i -> i;
537                         default -> throw new UnsupportedOperationException("Unsupported axis value:" + v);
538                     });
539             return new ExpandOp(def, axis);
540         }
541 
542         ExpandOp(ExternalizedOp def, int axis) {
543             super(def);
544 
545             this.axis = axis;
546         }
547 
548         ExpandOp(ExpandOp that, CopyContext cc) {
549             super(that, cc);
550 
551             this.axis = that.axis;
552         }
553 
554         @Override
555         public ExpandOp transform(CopyContext cc, OpTransformer ot) {
556             return new ExpandOp(this, cc);
557         }
558 
559         ExpandOp(int axis, TypeElement tensorType, Value v) {
560             super(NAME, tensorType, List.of(v));
561 
562             this.axis = axis;
563         }
564 
565         @Override
566         public Map<String, Object> attributes() {
567             HashMap<String, Object> m = new HashMap<>(super.attributes());
568             m.put("", axis);
569             return Collections.unmodifiableMap(m);
570         }
571 
572         public int axis() {
573             return axis;
574         }
575     }
576 
577     @OpFactory.OpDeclaration(SplatOp.NAME)
578     public static class SplatOp extends TritonOp implements Op.Pure {
579         public static final String NAME = "tt.splat";
580 
581         public SplatOp(ExternalizedOp def) {
582             super(def);
583         }
584 
585         SplatOp(SplatOp that, CopyContext cc) {
586             super(that, cc);
587         }
588 
589         @Override
590         public SplatOp transform(CopyContext cc, OpTransformer ot) {
591             return new SplatOp(this, cc);
592         }
593 
594         SplatOp(TypeElement tensorType, Value v) {
595             super(NAME, tensorType, List.of(v));
596         }
597     }
598 
599     @OpFactory.OpDeclaration(BroadcastOp.NAME)
600     public static class BroadcastOp extends TritonOp implements Op.Pure {
601         public static final String NAME = "tt.broadcast";
602 
603         public BroadcastOp(ExternalizedOp def) {
604             super(def);
605         }
606 
607         BroadcastOp(BroadcastOp that, CopyContext cc) {
608             super(that, cc);
609         }
610 
611         @Override
612         public BroadcastOp transform(CopyContext cc, OpTransformer ot) {
613             return new BroadcastOp(this, cc);
614         }
615 
616         BroadcastOp(TypeElement tensorType, Value v) {
617             super(NAME, tensorType, List.of(v));
618         }
619     }
620 
621     @OpFactory.OpDeclaration(AddPtrOp.NAME)
622     public static class AddPtrOp extends TritonOp implements Op.Pure {
623         public static final String NAME = "tt.addptr";
624 
625         public AddPtrOp(ExternalizedOp def) {
626             super(def);
627         }
628 
629         AddPtrOp(AddPtrOp that, CopyContext cc) {
630             super(that, cc);
631         }
632 
633         @Override
634         public AddPtrOp transform(CopyContext cc, OpTransformer ot) {
635             return new AddPtrOp(this, cc);
636         }
637 
638         AddPtrOp(Value ptr, Value offset) {
639             super(NAME, ptr.type(), List.of(ptr, offset));
640         }
641     }
642 
643     @OpFactory.OpDeclaration(LoadOp.NAME)
644     public static class LoadOp extends TritonOp implements Op.Pure {
645         public static final String NAME = "tt.load";
646 
647         public LoadOp(ExternalizedOp def) {
648             super(def);
649         }
650 
651         LoadOp(LoadOp that, CopyContext cc) {
652             super(that, cc);
653         }
654 
655         @Override
656         public LoadOp transform(CopyContext cc, OpTransformer ot) {
657             return new LoadOp(this, cc);
658         }
659 
660         LoadOp(TypeElement tensorType, Value ptr, Value mask) {
661             super(NAME, tensorType, List.of(ptr, mask));
662         }
663     }
664 
665     @OpFactory.OpDeclaration(StoreOp.NAME)
666     public static class StoreOp extends TritonOp {
667         public static final String NAME = "tt.store";
668 
669         public StoreOp(ExternalizedOp def) {
670             super(def);
671         }
672 
673         StoreOp(StoreOp that, CopyContext cc) {
674             super(that, cc);
675         }
676 
677         @Override
678         public StoreOp transform(CopyContext cc, OpTransformer ot) {
679             return new StoreOp(this, cc);
680         }
681 
682         StoreOp(Value ptr, Value v, Value mask) {
683             super(NAME, JavaType.VOID, List.of(ptr, v, mask));
684         }
685     }
686 
687     @OpFactory.OpDeclaration(ReturnOp.NAME)
688     public static class ReturnOp extends TritonOp implements Op.Terminating {
689         public static final String NAME = "tt.return";
690 
691         public ReturnOp(ExternalizedOp def) {
692             super(def);
693         }
694 
695         ReturnOp(ReturnOp that, CopyContext cc) {
696             super(that, cc);
697         }
698 
699         @Override
700         public ReturnOp transform(CopyContext cc, OpTransformer ot) {
701             return new ReturnOp(this, cc);
702         }
703 
704         ReturnOp() {
705             super(NAME, JavaType.VOID, List.of());
706         }
707 
708         ReturnOp(Value v) {
709             super(NAME, JavaType.VOID, List.of(v));
710         }
711     }
712 
713     @OpFactory.OpDeclaration(DotOp.NAME)
714     public static class DotOp extends TritonOp implements Op.Pure {
715         public static final String NAME = "tt.dot";
716 
717         public DotOp(ExternalizedOp def) {
718             super(def);
719         }
720 
721         DotOp(DotOp that, CopyContext cc) {
722             super(that, cc);
723         }
724 
725         @Override
726         public DotOp transform(CopyContext cc, OpTransformer ot) {
727             return new DotOp(this, cc);
728         }
729 
730         DotOp(TypeElement tensorType, Value a, Value b) {
731             super(NAME, tensorType, List.of(a, b));
732         }
733     }
734 
735 
736     public static ModuleOp module(FuncOp... functions) {
737         return module(List.of(functions));
738     }
739 
740     public static ModuleOp module(List<FuncOp> functions) {
741         return new ModuleOp(List.copyOf(functions));
742     }
743 
744     public static FuncOp.Builder func(String funcName, FunctionType funcType) {
745         return new FuncOp.Builder(null, funcName, funcType);
746     }
747 
748     public static FuncOp func(String funcName, Body.Builder body) {
749         return new FuncOp(funcName, body);
750     }
751 
752     public static CallOp call(FuncOp func, Value... args) {
753         return call(func, List.of(args));
754     }
755 
756     public static CallOp call(FuncOp func, List<Value> args) {
757         return new CallOp(func.funcName(), func.invokableType().returnType(), args);
758     }
759 
760     public static ReduceOp.Builder reduce(Body.Builder ancestorBody, int axis, Value tensor,
761                                           FunctionType reduceType) {
762         return new ReduceOp.Builder(ancestorBody, axis, tensor, reduceType);
763     }
764 
765     public static ReduceOp reduce(int axis, Value tensor, Body.Builder reducerBuilder) {
766         return new ReduceOp(axis, tensor, reducerBuilder);
767     }
768 
769     public static ReduceReturnOp reduceReturn(Value r) {
770         return new ReduceReturnOp(r);
771     }
772 
773     public static GetProgramIdOp getProgramId(int axis) {
774         // @@@ 1 <= axis <= 3
775         return new GetProgramIdOp(axis);
776     }
777 
778     public static MakeRangeOp makeRange(int start, int end) {
779         // @@@ 0 <= start < end
780         return new MakeRangeOp(start, end);
781     }
782 
783     public static ExpandOp expand(int axis, TypeElement tensorType, Value v) {
784         return new ExpandOp(axis, tensorType, v);
785     }
786 
787     // v is scalar
788     public static SplatOp splat(TypeElement tensorType, Value v) {
789         return new SplatOp(tensorType, v);
790     }
791 
792     // v is tensor
793     public static BroadcastOp broadcast(TypeElement tensorType, Value v) {
794         return new BroadcastOp(tensorType, v);
795     }
796 
797     public static AddPtrOp addptr(Value ptr, Value offset) {
798         return new AddPtrOp(ptr, offset);
799     }
800 
801     public static LoadOp load(TypeElement tensorType, Value ptr, Value mask) {
802         return new LoadOp(tensorType, ptr, mask);
803     }
804 
805     public static StoreOp store(Value ptr, Value v, Value mask) {
806         return new StoreOp(ptr, v, mask);
807     }
808 
809     public static ReturnOp return_() {
810         return new ReturnOp();
811     }
812 
813     public static ReturnOp return_(Value v) {
814         return new ReturnOp(v);
815     }
816 
817     public static DotOp dot(TypeElement tensorType, Value a, Value b) {
818         return new DotOp(tensorType, a, b);
819     }
820 
821 
822     // Operation and type factories
823 
824     public static final OpFactory FACTORY = OpFactory.OP_FACTORY.get(TritonOps.class);
825 
826     static final TypeElementFactory TRITON_TYPE_FACTORY = new TypeElementFactory() {
827         @Override
828         public TypeElement constructType(TypeElement.ExternalizedTypeElement tree) {
829             return switch (tree.identifier()) {
830                 case PtrType.NAME -> {
831                     if (tree.arguments().size() != 1) {
832                         throw new IllegalArgumentException();
833                     }
834 
835                     TypeElement v = TRITON_JAVA_TYPE_FACTORY.constructType(tree.arguments().getFirst());
836                     if (v == null) {
837                         throw new IllegalArgumentException("Bad type: " + tree);
838                     }
839                     if (v instanceof JavaType || v instanceof TritonType) {
840                         yield new PtrType(v);
841                     } else {
842                         throw new IllegalArgumentException("Bad type: " + tree);
843                     }
844                 }
845                 case TensorType.NAME -> {
846                     if (tree.arguments().size() < 2) {
847                         throw new IllegalArgumentException("Bad type: " + tree);
848                     }
849 
850                     List<Integer> shape = new ArrayList<>();
851                     for (int i = 0; i < tree.arguments().size() - 1; i++) {
852                         TypeElement.ExternalizedTypeElement a = tree.arguments().get(i);
853                         if (!a.identifier().startsWith("x")) {
854                             throw new IllegalArgumentException("Bad type: " + tree);
855                         }
856                         int d;
857                         try {
858                             d = Integer.parseInt(a.identifier().substring(1));
859                         } catch (NumberFormatException e) {
860                             throw new IllegalArgumentException("Bad type: " + tree, e);
861                         }
862                         shape.add(d);
863                     }
864 
865                     TypeElement v = TRITON_JAVA_TYPE_FACTORY.constructType(tree.arguments().getLast());
866                     if (v == null) {
867                         throw new IllegalArgumentException("Bad type: " + tree);
868                     }
869                     if (v instanceof JavaType || v instanceof TritonType) {
870                         yield new TensorType(v, shape);
871                     } else {
872                         throw new IllegalArgumentException("Bad type: " + tree);
873                     }
874                 }
875                 default -> null;
876             };
877         }
878     };
879 
880     // Triton types then Java types
881     static final TypeElementFactory TRITON_JAVA_TYPE_FACTORY =
882             TRITON_TYPE_FACTORY.andThen(CoreTypeFactory.JAVA_TYPE_FACTORY);
883 
884     // Triton types then Java types, combined with code model types
885     public static final TypeElementFactory TYPE_FACTORY =
886             CoreTypeFactory.codeModelTypeFactory(TRITON_JAVA_TYPE_FACTORY);
887 
888 }