1 /*
  2  * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 
 26 package oracle.code.triton;
 27 
 28 import jdk.incubator.code.*;
 29 import jdk.incubator.code.extern.ExternalizedOp;
 30 import jdk.incubator.code.extern.OpFactory;
 31 import jdk.incubator.code.dialect.java.JavaType;
 32 
 33 import java.util.HashMap;
 34 import java.util.List;
 35 import java.util.Map;
 36 
 37 public class ArithMathOps {
 38 
 39     static abstract class ArithMathOp extends Op {
 40         final TypeElement resultType;
 41 
 42         public ArithMathOp(ExternalizedOp def) {
 43             super(def.name(), def.operands());;
 44 
 45             this.resultType = def.resultType();
 46         }
 47 
 48         ArithMathOp(ArithMathOp that, CopyContext cc) {
 49             super(that, cc);
 50 
 51             this.resultType = that.resultType;
 52         }
 53 
 54         ArithMathOp(String name, TypeElement resultType, List<? extends Value> operands) {
 55             super(name, operands);
 56 
 57             this.resultType = resultType;
 58         }
 59 
 60         @Override
 61         public TypeElement resultType() {
 62             return resultType;
 63         }
 64     }
 65 
 66     @OpFactoryHelper.OpDeclaration(ConstantOp.NAME)
 67     public static class ConstantOp extends ArithMathOp implements Op.Pure {
 68         public static final String NAME = "arith.constant";
 69         public static final String ATTRIBUTE_CONSTANT_VALUE = "value";
 70 
 71         final Object value;
 72 
 73         public static ConstantOp create(ExternalizedOp def) {
 74             if (!def.operands().isEmpty()) {
 75                 throw new IllegalArgumentException("Operation must have zero operands");
 76             }
 77 
 78             Object value = def.extractAttributeValue(ATTRIBUTE_CONSTANT_VALUE,true,
 79                     v -> processConstantValue(def.resultType(), v));
 80             return new ConstantOp(def, value);
 81         }
 82 
 83         static Object processConstantValue(TypeElement t, Object value) {
 84             if (t.equals(JavaType.BOOLEAN) && value instanceof Boolean) {
 85                 return value;
 86             } else if (t.equals(JavaType.BYTE) && value instanceof Number n) {
 87                 return n.byteValue();
 88             } else if (t.equals(JavaType.SHORT) && value instanceof Number n) {
 89                 return n.shortValue();
 90             } else if (t.equals(JavaType.CHAR) && value instanceof Character) {
 91                 return value;
 92             } else if (t.equals(JavaType.INT) && value instanceof Number n) {
 93                 return n.intValue();
 94             } else if (t.equals(JavaType.LONG) && value instanceof Number n) {
 95                 return n.longValue();
 96             } else if (t.equals(JavaType.FLOAT) && value instanceof Number n) {
 97                 return n.floatValue();
 98             } else if (t.equals(Float16.FLOAT_16_TYPE) && value instanceof Number n) {
 99                 // represent as a float for now
100                 return n.floatValue();
101             } else if (t.equals(JavaType.DOUBLE) && value instanceof Number n) {
102                 return n.doubleValue();
103             } else if (t instanceof TensorType tt) {
104                 return processConstantValue(tt.eType(), value);
105             }
106 
107             throw new UnsupportedOperationException("Unsupported constant type and value: " + t + " " + value);
108         }
109 
110         ConstantOp(ExternalizedOp def, Object value) {
111             super(def);
112 
113             this.value = value;
114         }
115 
116         ConstantOp(ConstantOp that, CopyContext cc) {
117             super(that, cc);
118 
119             this.value = that.value;
120         }
121 
122         @Override
123         public ConstantOp transform(CopyContext cc, OpTransformer ot) {
124             return new ConstantOp(this, cc);
125         }
126 
127         ConstantOp(TypeElement type, Object value) {
128             super(NAME, type, List.of());
129 
130             this.value = value;
131         }
132 
133         @Override
134         public Map<String, Object> externalize() {
135             return Map.of(ATTRIBUTE_CONSTANT_VALUE, value);
136         }
137 
138         public Object value() {
139             return value;
140         }
141     }
142 
143     @OpFactoryHelper.OpDeclaration(AddOp.NAME)
144     public static class AddOp extends ArithMathOp implements Op.Pure {
145         public static final String NAME = "arith.add";
146 
147         public AddOp(ExternalizedOp def) {
148             super(def);
149         }
150 
151         AddOp(AddOp that, CopyContext cc) {
152             super(that, cc);
153         }
154 
155         @Override
156         public AddOp transform(CopyContext cc, OpTransformer ot) {
157             return new AddOp(this, cc);
158         }
159 
160         AddOp(Value a, Value b) {
161             super(NAME + nameSuffixFromType(a.type(), false), a.type(), List.of(a, b));
162         }
163     }
164 
165     @OpFactoryHelper.OpDeclaration(SubOp.NAME)
166     public static class SubOp extends ArithMathOp implements Op.Pure {
167         public static final String NAME = "arith.sub";
168 
169         public SubOp(ExternalizedOp def) {
170             super(def);
171         }
172 
173         SubOp(SubOp that, CopyContext cc) {
174             super(that, cc);
175         }
176 
177         @Override
178         public SubOp transform(CopyContext cc, OpTransformer ot) {
179             return new SubOp(this, cc);
180         }
181 
182         SubOp(Value a, Value b) {
183             super(NAME + nameSuffixFromType(a.type(), false), a.type(), List.of(a, b));
184         }
185     }
186 
187     @OpFactoryHelper.OpDeclaration(MulOp.NAME)
188     public static class MulOp extends ArithMathOp implements Op.Pure {
189         public static final String NAME = "arith.mul";
190 
191         public MulOp(ExternalizedOp def) {
192             super(def);
193         }
194 
195         MulOp(MulOp that, CopyContext cc) {
196             super(that, cc);
197         }
198 
199         @Override
200         public MulOp transform(CopyContext cc, OpTransformer ot) {
201             return new MulOp(this, cc);
202         }
203 
204         MulOp(Value a, Value b) {
205             super(NAME + nameSuffixFromType(a.type(), false), a.type(), List.of(a, b));
206         }
207     }
208 
209     @OpFactoryHelper.OpDeclaration(DivOp.NAME)
210     public static class DivOp extends ArithMathOp implements Op.Pure {
211         public static final String NAME = "arith.div";
212 
213         public DivOp(ExternalizedOp def) {
214             super(def);
215         }
216 
217         DivOp(DivOp that, CopyContext cc) {
218             super(that, cc);
219         }
220 
221         @Override
222         public DivOp transform(CopyContext cc, OpTransformer ot) {
223             return new DivOp(this, cc);
224         }
225 
226         DivOp(Value a, Value b) {
227             super(NAME + nameSuffixFromType(a.type(), true), a.type(), List.of(a, b));
228         }
229     }
230 
231     @OpFactoryHelper.OpDeclaration(RemOp.NAME)
232     public static class RemOp extends ArithMathOp implements Op.Pure {
233         public static final String NAME = "arith.rem";
234 
235         public RemOp(ExternalizedOp def) {
236             super(def);
237         }
238 
239         RemOp(RemOp that, CopyContext cc) {
240             super(that, cc);
241         }
242 
243         @Override
244         public RemOp transform(CopyContext cc, OpTransformer ot) {
245             return new RemOp(this, cc);
246         }
247 
248         RemOp(Value a, Value b) {
249             super(NAME + nameSuffixFromType(a.type(), true), a.type(), List.of(a, b));
250         }
251     }
252 
253     @OpFactoryHelper.OpDeclaration(AndOp.NAME)
254     public static class AndOp extends ArithMathOp implements Op.Pure {
255         public static final String NAME = "arith.andi";
256 
257         public AndOp(ExternalizedOp def) {
258             super(def);
259         }
260 
261         AndOp(AndOp that, CopyContext cc) {
262             super(that, cc);
263         }
264 
265         @Override
266         public AndOp transform(CopyContext cc, OpTransformer ot) {
267             return new AndOp(this, cc);
268         }
269 
270         AndOp(Value a, Value b) {
271             super(NAME, a.type(), List.of(a, b));
272         }
273     }
274 
275     @OpFactoryHelper.OpDeclaration(MaxOp.NAME)
276     public static class MaxOp extends ArithMathOp implements Op.Pure {
277         public static final String NAME = "arith.max";
278 
279         public MaxOp(ExternalizedOp def) {
280             super(def);
281         }
282 
283         MaxOp(MaxOp that, CopyContext cc) {
284             super(that, cc);
285         }
286 
287         @Override
288         public MaxOp transform(CopyContext cc, OpTransformer ot) {
289             return new MaxOp(this, cc);
290         }
291 
292         MaxOp(Value a, Value b) {
293             super(NAME + maxMinSuffixFromType(a.type()) + nameSuffixFromType(a.type(), true),
294                     a.type(), List.of(a, b));
295         }
296     }
297 
298     @OpFactoryHelper.OpDeclaration(MinOp.NAME)
299     public static class MinOp extends ArithMathOp implements Op.Pure {
300         public static final String NAME = "arith.min";
301 
302         public MinOp(ExternalizedOp def) {
303             super(def);
304         }
305 
306         MinOp(MinOp that, CopyContext cc) {
307             super(that, cc);
308         }
309 
310         @Override
311         public MinOp transform(CopyContext cc, OpTransformer ot) {
312             return new MinOp(this, cc);
313         }
314 
315         MinOp(Value a, Value b) {
316             super(NAME + maxMinSuffixFromType(a.type()) + nameSuffixFromType(a.type(), true),
317                     a.type(), List.of(a, b));
318         }
319     }
320 
321     @OpFactoryHelper.OpDeclaration(ExpOp.NAME)
322     public static class TruncOp extends ArithMathOp implements Op.Pure {
323         public static final String NAME = "arith.trunc";
324 
325         public TruncOp(ExternalizedOp def) {
326             super(def);
327         }
328 
329         TruncOp(TruncOp that, CopyContext cc) {
330             super(that, cc);
331         }
332 
333         @Override
334         public TruncOp transform(CopyContext cc, OpTransformer ot) {
335             return new TruncOp(this, cc);
336         }
337 
338         TruncOp(TypeElement t, Value a) {
339             super(NAME + nameSuffixFromType(a.type(), false),
340                     t, List.of(a));
341         }
342     }
343 
344     @OpFactoryHelper.OpDeclaration(ExpOp.NAME)
345     public static class ExpOp extends ArithMathOp implements Op.Pure {
346         public static final String NAME = "math.exp";
347 
348         public ExpOp(ExternalizedOp def) {
349             super(def);
350         }
351 
352         ExpOp(ExpOp that, CopyContext cc) {
353             super(that, cc);
354         }
355 
356         @Override
357         public ExpOp transform(CopyContext cc, OpTransformer ot) {
358             return new ExpOp(this, cc);
359         }
360 
361         ExpOp(Value a) {
362             super(NAME, a.type(), List.of(a));
363         }
364     }
365 
366     @OpFactoryHelper.OpDeclaration(CompareOp.NAME)
367     public static class CompareOp extends ArithMathOp implements Op.Pure {
368         public static final String NAME = "arith.cmp";
369         public static final String ATTRIBUTE_PREDICATE = "predicate";
370 
371         // https://mlir.llvm.org/docs/Dialects/ArithOps/#cmpipredicate
372         // The ordinal values correspond to the MLIR symbol's values
373         // Need to refine when considering comparisons of floating point numbers which is in a different namespace
374         public enum CompareKind {
375             eq,
376             ne,
377             slt,
378             sle,
379             sgt,
380             sge,
381             ult,
382             ule,
383             ugt,
384             uge
385         }
386 
387         final CompareKind ck;
388 
389         public static CompareOp create(ExternalizedOp def) {
390             CompareKind ck = def.extractAttributeValue(ATTRIBUTE_PREDICATE, true,
391                     v -> switch (v) {
392                         case String s -> CompareKind.valueOf(s);
393                         case CompareKind k -> k;
394                         case null, default -> throw new UnsupportedOperationException("Unsupported start value:" + v);
395                     });
396             return new CompareOp(def, ck);
397         }
398 
399         CompareOp(ExternalizedOp def, CompareKind ck) {
400             super(def);
401 
402             this.ck = ck;
403         }
404 
405         CompareOp(CompareOp that, CopyContext cc) {
406             super(that, cc);
407 
408             this.ck = that.ck;
409         }
410 
411         @Override
412         public CompareOp transform(CopyContext cc, OpTransformer ot) {
413             return new CompareOp(this, cc);
414         }
415 
416         CompareOp(CompareKind ck, Value a, Value b) {
417             TypeElement t;
418             if (a.type() instanceof TensorType ot) {
419                 t = new TensorType(JavaType.BOOLEAN, ot.shape());
420             }
421             else {
422                 t = JavaType.BOOLEAN;
423             }
424             super(NAME + nameSuffixFromType(a.type(), false), t, List.of(a, b));
425 
426             this.ck = ck;
427         }
428 
429         @Override
430         public Map<String, Object> externalize() {
431             return Map.of(ATTRIBUTE_PREDICATE, Long.valueOf(ck.ordinal()));
432         }
433 
434         public CompareKind kind() {
435             return ck;
436         }
437     }
438 
439     static String maxMinSuffixFromType(TypeElement t) {
440         if (t instanceof TensorType tt) {
441             return maxMinSuffixFromType(tt.eType());
442         } else if (t instanceof PtrType pt) {
443             return maxMinSuffixFromType(pt.rType());
444         } else if (t.equals(JavaType.INT)) {
445             return "";
446         } else if (t.equals(JavaType.FLOAT)) {
447             return "imum";
448         } else {
449             throw new UnsupportedOperationException("Unsupported type: " + t);
450         }
451     }
452 
453     static String nameSuffixFromType(TypeElement t, boolean signed) {
454         if (t instanceof TensorType tt) {
455             return nameSuffixFromType(tt.eType(), signed);
456         } else if (t instanceof PtrType pt) {
457             return nameSuffixFromType(pt.rType(), signed);
458         } else if (t.equals(JavaType.INT) || t.equals(JavaType.LONG)) {
459             return (signed ? "s" : "") + "i";
460         } else if (t.equals(JavaType.FLOAT) || t.equals(JavaType.DOUBLE) ||
461                 Float16.FLOAT_16_TYPE.equals(t)) {
462             return "f";
463         } else {
464             throw new UnsupportedOperationException("Unsupported type: " + t);
465         }
466     }
467 
468     public static final OpFactory OP_FACTORY = def -> {
469         return switch (def.name()) {
470             case ConstantOp.NAME -> ConstantOp.create(def);
471             case ExpOp.NAME -> new ExpOp(def);
472             case AddOp.NAME + "i", AddOp.NAME + "f" -> new AddOp(def);
473             case SubOp.NAME + "i", SubOp.NAME + "f" -> new SubOp(def);
474             case MulOp.NAME + "i", MulOp.NAME + "f" -> new MulOp(def);
475             case DivOp.NAME + "si", DivOp.NAME + "f" -> new DivOp(def);
476             case RemOp.NAME + "si", RemOp.NAME + "f" -> new DivOp(def);
477             case AndOp.NAME -> new AndOp(def);
478             case MaxOp.NAME + "si", MaxOp.NAME + "imumf" -> new MaxOp(def);
479             case MinOp.NAME + "si", MinOp.NAME + "imumf" -> new MinOp(def);
480             case TruncOp.NAME + "i", TruncOp.NAME + "f" -> new TruncOp(def);
481             case CompareOp.NAME + "i", CompareOp.NAME + "f" -> CompareOp.create(def);
482             default -> null;
483         };
484     };
485 
486     // Arith
487 
488     public static ConstantOp constant(TypeElement type, Object value) {
489         return new ConstantOp(type, value);
490     }
491 
492     public static MulOp mul(Value a, Value b) {
493         return new MulOp(a, b);
494     }
495 
496     public static AddOp add(Value a, Value b) {
497         return new AddOp(a, b);
498     }
499 
500     public static SubOp sub(Value a, Value b) {
501         return new SubOp(a, b);
502     }
503 
504     public static DivOp div(Value a, Value b) {
505         return new DivOp(a, b);
506     }
507 
508     public static RemOp rem(Value a, Value b) {
509         return new RemOp(a, b);
510     }
511 
512     public static AndOp and(Value a, Value b) {
513         return new AndOp(a, b);
514     }
515 
516     public static MaxOp maximum(Value a, Value b) {
517         return new MaxOp(a, b);
518     }
519 
520     public static MinOp minimum(Value a, Value b) {
521         return new MinOp(a, b);
522     }
523 
524     public static CompareOp cmp(CompareOp.CompareKind ck, Value a, Value b) {
525         return new CompareOp(ck, a, b);
526     }
527 
528     public static TruncOp trunc(TypeElement type, Value a) {
529         return new TruncOp(type, a);
530     }
531 
532     // Math
533 
534     public static ExpOp exp(Value a) {
535         return new ExpOp(a);
536     }
537 }