1 /*
  2  * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 
 26 package oracle.code.triton;
 27 
 28 import jdk.incubator.code.*;
 29 import jdk.incubator.code.op.*;
 30 import jdk.incubator.code.type.JavaType;
 31 import java.util.HashMap;
 32 import java.util.List;
 33 import java.util.Map;
 34 
 35 public class ArithMathOps {
 36 
 37     static abstract class ArithMathOp extends ExternalizableOp {
 38         final TypeElement resultType;
 39 
 40         public ArithMathOp(ExternalizedOp def) {
 41             super(def);
 42 
 43             this.resultType = def.resultType();
 44         }
 45 
 46         ArithMathOp(ArithMathOp that, CopyContext cc) {
 47             super(that, cc);
 48 
 49             this.resultType = that.resultType;
 50         }
 51 
 52         ArithMathOp(String name, TypeElement resultType, List<? extends Value> operands) {
 53             super(name, operands);
 54 
 55             this.resultType = resultType;
 56         }
 57 
 58         @Override
 59         public TypeElement resultType() {
 60             return resultType;
 61         }
 62     }
 63 
 64     @OpFactory.OpDeclaration(ConstantOp.NAME)
 65     public static class ConstantOp extends ArithMathOp implements Op.Pure {
 66         public static final String NAME = "arith.constant";
 67         public static final String ATTRIBUTE_CONSTANT_VALUE = "value";
 68 
 69         final Object value;
 70 
 71         public static ConstantOp create(ExternalizedOp def) {
 72             if (!def.operands().isEmpty()) {
 73                 throw new IllegalArgumentException("Operation must have zero operands");
 74             }
 75 
 76             Object value = def.extractAttributeValue(ATTRIBUTE_CONSTANT_VALUE,true,
 77                     v -> processConstantValue(def.resultType(), v));
 78             return new ConstantOp(def, value);
 79         }
 80 
 81         static Object processConstantValue(TypeElement t, Object value) {
 82             if (t.equals(JavaType.BOOLEAN)) {
 83                 if (value instanceof String s) {
 84                     return Boolean.valueOf(s);
 85                 } else if (value instanceof Boolean) {
 86                     return value;
 87                 }
 88             } else if (t.equals(JavaType.BYTE)) {
 89                 if (value instanceof String s) {
 90                     return Byte.valueOf(s);
 91                 } else if (value instanceof Number n) {
 92                     return n.byteValue();
 93                 }
 94             } else if (t.equals(JavaType.SHORT)) {
 95                 if (value instanceof String s) {
 96                     return Short.valueOf(s);
 97                 } else if (value instanceof Number n) {
 98                     return n.shortValue();
 99                 }
100             } else if (t.equals(JavaType.CHAR)) {
101                 if (value instanceof String s) {
102                     return s.charAt(0);
103                 } else if (value instanceof Character) {
104                     return value;
105                 }
106             } else if (t.equals(JavaType.INT)) {
107                 if (value instanceof String s) {
108                     return Integer.valueOf(s);
109                 } else if (value instanceof Number n) {
110                     return n.intValue();
111                 }
112             } else if (t.equals(JavaType.LONG)) {
113                 if (value instanceof String s) {
114                     return Long.valueOf(s);
115                 } else if (value instanceof Number n) {
116                     return n.longValue();
117                 }
118             } else if (t.equals(JavaType.FLOAT)) {
119                 if (value instanceof String s) {
120                     return Float.valueOf(s);
121                 } else if (value instanceof Number n) {
122                     return n.floatValue();
123                 }
124             } else if (t.equals(Float16.FLOAT_16_TYPE)) {
125                 // represent as a float for now
126                 if (value instanceof String s) {
127                     return Float.valueOf(s);
128                 } else if (value instanceof Number n) {
129                     return n.floatValue();
130                 }
131             } else if (t.equals(JavaType.DOUBLE)) {
132                 if (value instanceof String s) {
133                     return Double.valueOf(s);
134                 } else if (value instanceof Number n) {
135                     return n.doubleValue();
136                 }
137             } else if (t instanceof TensorType tt) {
138                 return processConstantValue(tt.eType(), value);
139             }
140 
141             throw new UnsupportedOperationException("Unsupported constant type and value: " + t + " " + value);
142         }
143 
144         ConstantOp(ExternalizedOp def, Object value) {
145             super(def);
146 
147             this.value = value;
148         }
149 
150         ConstantOp(ConstantOp that, CopyContext cc) {
151             super(that, cc);
152 
153             this.value = that.value;
154         }
155 
156         @Override
157         public ConstantOp transform(CopyContext cc, OpTransformer ot) {
158             return new ConstantOp(this, cc);
159         }
160 
161         ConstantOp(TypeElement type, Object value) {
162             super(NAME, type, List.of());
163 
164             this.value = value;
165         }
166 
167         @Override
168         public Map<String, Object> attributes() {
169             HashMap<String, Object> attrs = new HashMap<>(super.attributes());
170             attrs.put(ATTRIBUTE_CONSTANT_VALUE, value);
171             return attrs;
172         }
173 
174         public Object value() {
175             return value;
176         }
177     }
178 
179     @OpFactory.OpDeclaration(AddOp.NAME)
180     public static class AddOp extends ArithMathOp implements Op.Pure {
181         public static final String NAME = "arith.add";
182 
183         public AddOp(ExternalizedOp def) {
184             super(def);
185         }
186 
187         AddOp(AddOp that, CopyContext cc) {
188             super(that, cc);
189         }
190 
191         @Override
192         public AddOp transform(CopyContext cc, OpTransformer ot) {
193             return new AddOp(this, cc);
194         }
195 
196         AddOp(Value a, Value b) {
197             super(NAME + nameSuffixFromType(a.type(), false), a.type(), List.of(a, b));
198         }
199     }
200 
201     @OpFactory.OpDeclaration(SubOp.NAME)
202     public static class SubOp extends ArithMathOp implements Op.Pure {
203         public static final String NAME = "arith.sub";
204 
205         public SubOp(ExternalizedOp def) {
206             super(def);
207         }
208 
209         SubOp(SubOp that, CopyContext cc) {
210             super(that, cc);
211         }
212 
213         @Override
214         public SubOp transform(CopyContext cc, OpTransformer ot) {
215             return new SubOp(this, cc);
216         }
217 
218         SubOp(Value a, Value b) {
219             super(NAME + nameSuffixFromType(a.type(), false), a.type(), List.of(a, b));
220         }
221     }
222 
223     @OpFactory.OpDeclaration(MulOp.NAME)
224     public static class MulOp extends ArithMathOp implements Op.Pure {
225         public static final String NAME = "arith.mul";
226 
227         public MulOp(ExternalizedOp def) {
228             super(def);
229         }
230 
231         MulOp(MulOp that, CopyContext cc) {
232             super(that, cc);
233         }
234 
235         @Override
236         public MulOp transform(CopyContext cc, OpTransformer ot) {
237             return new MulOp(this, cc);
238         }
239 
240         MulOp(Value a, Value b) {
241             super(NAME + nameSuffixFromType(a.type(), false), a.type(), List.of(a, b));
242         }
243     }
244 
245     @OpFactory.OpDeclaration(DivOp.NAME)
246     public static class DivOp extends ArithMathOp implements Op.Pure {
247         public static final String NAME = "arith.div";
248 
249         public DivOp(ExternalizedOp def) {
250             super(def);
251         }
252 
253         DivOp(DivOp that, CopyContext cc) {
254             super(that, cc);
255         }
256 
257         @Override
258         public DivOp transform(CopyContext cc, OpTransformer ot) {
259             return new DivOp(this, cc);
260         }
261 
262         DivOp(Value a, Value b) {
263             super(NAME + nameSuffixFromType(a.type(), true), a.type(), List.of(a, b));
264         }
265     }
266 
267     @OpFactory.OpDeclaration(RemOp.NAME)
268     public static class RemOp extends ArithMathOp implements Op.Pure {
269         public static final String NAME = "arith.rem";
270 
271         public RemOp(ExternalizedOp def) {
272             super(def);
273         }
274 
275         RemOp(RemOp that, CopyContext cc) {
276             super(that, cc);
277         }
278 
279         @Override
280         public RemOp transform(CopyContext cc, OpTransformer ot) {
281             return new RemOp(this, cc);
282         }
283 
284         RemOp(Value a, Value b) {
285             super(NAME + nameSuffixFromType(a.type(), true), a.type(), List.of(a, b));
286         }
287     }
288 
289     @OpFactory.OpDeclaration(AndOp.NAME)
290     public static class AndOp extends ArithMathOp implements Op.Pure {
291         public static final String NAME = "arith.andi";
292 
293         public AndOp(ExternalizedOp def) {
294             super(def);
295         }
296 
297         AndOp(AndOp that, CopyContext cc) {
298             super(that, cc);
299         }
300 
301         @Override
302         public AndOp transform(CopyContext cc, OpTransformer ot) {
303             return new AndOp(this, cc);
304         }
305 
306         AndOp(Value a, Value b) {
307             super(NAME, a.type(), List.of(a, b));
308         }
309     }
310 
311     @OpFactory.OpDeclaration(MaxOp.NAME)
312     public static class MaxOp extends ArithMathOp implements Op.Pure {
313         public static final String NAME = "arith.max";
314 
315         public MaxOp(ExternalizedOp def) {
316             super(def);
317         }
318 
319         MaxOp(MaxOp that, CopyContext cc) {
320             super(that, cc);
321         }
322 
323         @Override
324         public MaxOp transform(CopyContext cc, OpTransformer ot) {
325             return new MaxOp(this, cc);
326         }
327 
328         MaxOp(Value a, Value b) {
329             super(NAME + maxMinSuffixFromType(a.type()) + nameSuffixFromType(a.type(), true),
330                     a.type(), List.of(a, b));
331         }
332     }
333 
334     @OpFactory.OpDeclaration(MinOp.NAME)
335     public static class MinOp extends ArithMathOp implements Op.Pure {
336         public static final String NAME = "arith.min";
337 
338         public MinOp(ExternalizedOp def) {
339             super(def);
340         }
341 
342         MinOp(MinOp that, CopyContext cc) {
343             super(that, cc);
344         }
345 
346         @Override
347         public MinOp transform(CopyContext cc, OpTransformer ot) {
348             return new MinOp(this, cc);
349         }
350 
351         MinOp(Value a, Value b) {
352             super(NAME + maxMinSuffixFromType(a.type()) + nameSuffixFromType(a.type(), true),
353                     a.type(), List.of(a, b));
354         }
355     }
356 
357     @OpFactory.OpDeclaration(ExpOp.NAME)
358     public static class TruncOp extends ArithMathOp implements Op.Pure {
359         public static final String NAME = "arith.trunc";
360 
361         public TruncOp(ExternalizedOp def) {
362             super(def);
363         }
364 
365         TruncOp(TruncOp that, CopyContext cc) {
366             super(that, cc);
367         }
368 
369         @Override
370         public TruncOp transform(CopyContext cc, OpTransformer ot) {
371             return new TruncOp(this, cc);
372         }
373 
374         TruncOp(TypeElement t, Value a) {
375             super(NAME + nameSuffixFromType(a.type(), false),
376                     t, List.of(a));
377         }
378     }
379 
380     @OpFactory.OpDeclaration(ExpOp.NAME)
381     public static class ExpOp extends ArithMathOp implements Op.Pure {
382         public static final String NAME = "math.exp";
383 
384         public ExpOp(ExternalizedOp def) {
385             super(def);
386         }
387 
388         ExpOp(ExpOp that, CopyContext cc) {
389             super(that, cc);
390         }
391 
392         @Override
393         public ExpOp transform(CopyContext cc, OpTransformer ot) {
394             return new ExpOp(this, cc);
395         }
396 
397         ExpOp(Value a) {
398             super(NAME, a.type(), List.of(a));
399         }
400     }
401 
402     @OpFactory.OpDeclaration(CompareOp.NAME)
403     public static class CompareOp extends ArithMathOp implements Op.Pure {
404         public static final String NAME = "arith.cmp";
405         public static final String ATTRIBUTE_PREDICATE = "predicate";
406 
407         // https://mlir.llvm.org/docs/Dialects/ArithOps/#cmpipredicate
408         // The ordinal values correspond to the MLIR symbol's values
409         // Need to refine when considering comparisons of floating point numbers which is in a different namespace
410         public enum CompareKind {
411             eq,
412             ne,
413             slt,
414             sle,
415             sgt,
416             sge,
417             ult,
418             ule,
419             ugt,
420             uge
421         }
422 
423         final CompareKind ck;
424 
425         public static CompareOp create(ExternalizedOp def) {
426             CompareKind ck = def.extractAttributeValue(ATTRIBUTE_PREDICATE, true,
427                     v -> switch (v) {
428                         case String s -> CompareKind.valueOf(s);
429                         case CompareKind k -> k;
430                         case null, default -> throw new UnsupportedOperationException("Unsupported start value:" + v);
431                     });
432             return new CompareOp(def, ck);
433         }
434 
435         CompareOp(ExternalizedOp def, CompareKind ck) {
436             super(def);
437 
438             this.ck = ck;
439         }
440 
441         CompareOp(CompareOp that, CopyContext cc) {
442             super(that, cc);
443 
444             this.ck = that.ck;
445         }
446 
447         @Override
448         public CompareOp transform(CopyContext cc, OpTransformer ot) {
449             return new CompareOp(this, cc);
450         }
451 
452         CompareOp(CompareKind ck, Value a, Value b) {
453             TypeElement t;
454             if (a.type() instanceof TensorType ot) {
455                 t = new TensorType(JavaType.BOOLEAN, ot.shape());
456             }
457             else {
458                 t = JavaType.BOOLEAN;
459             }
460             super(NAME + nameSuffixFromType(a.type(), false), t, List.of(a, b));
461 
462             this.ck = ck;
463         }
464 
465         @Override
466         public Map<String, Object> attributes() {
467             HashMap<String, Object> attrs = new HashMap<>(super.attributes());
468             attrs.put(ATTRIBUTE_PREDICATE, Long.valueOf(ck.ordinal()));
469             return attrs;
470         }
471 
472         public CompareKind kind() {
473             return ck;
474         }
475     }
476 
477     static String maxMinSuffixFromType(TypeElement t) {
478         if (t instanceof TensorType tt) {
479             return maxMinSuffixFromType(tt.eType());
480         } else if (t instanceof PtrType pt) {
481             return maxMinSuffixFromType(pt.rType());
482         } else if (t.equals(JavaType.INT)) {
483             return "";
484         } else if (t.equals(JavaType.FLOAT)) {
485             return "imum";
486         } else {
487             throw new UnsupportedOperationException("Unsupported type: " + t);
488         }
489     }
490 
491     static String nameSuffixFromType(TypeElement t, boolean signed) {
492         if (t instanceof TensorType tt) {
493             return nameSuffixFromType(tt.eType(), signed);
494         } else if (t instanceof PtrType pt) {
495             return nameSuffixFromType(pt.rType(), signed);
496         } else if (t.equals(JavaType.INT) || t.equals(JavaType.LONG)) {
497             return (signed ? "s" : "") + "i";
498         } else if (t.equals(JavaType.FLOAT) || t.equals(JavaType.DOUBLE) ||
499                 Float16.FLOAT_16_TYPE.equals(t)) {
500             return "f";
501         } else {
502             throw new UnsupportedOperationException("Unsupported type: " + t);
503         }
504     }
505 
506     public static final OpFactory FACTORY = def -> {
507         return switch (def.name()) {
508             case ConstantOp.NAME -> ConstantOp.create(def);
509             case ExpOp.NAME -> new ExpOp(def);
510             case AddOp.NAME + "i", AddOp.NAME + "f" -> new AddOp(def);
511             case SubOp.NAME + "i", SubOp.NAME + "f" -> new SubOp(def);
512             case MulOp.NAME + "i", MulOp.NAME + "f" -> new MulOp(def);
513             case DivOp.NAME + "si", DivOp.NAME + "f" -> new DivOp(def);
514             case RemOp.NAME + "si", RemOp.NAME + "f" -> new DivOp(def);
515             case AndOp.NAME -> new AndOp(def);
516             case MaxOp.NAME + "si", MaxOp.NAME + "imumf" -> new MaxOp(def);
517             case MinOp.NAME + "si", MinOp.NAME + "imumf" -> new MinOp(def);
518             case TruncOp.NAME + "i", TruncOp.NAME + "f" -> new TruncOp(def);
519             case CompareOp.NAME + "i", CompareOp.NAME + "f" -> CompareOp.create(def);
520             default -> null;
521         };
522     };
523 
524     // Arith
525 
526     public static ConstantOp constant(TypeElement type, Object value) {
527         return new ConstantOp(type, value);
528     }
529 
530     public static MulOp mul(Value a, Value b) {
531         return new MulOp(a, b);
532     }
533 
534     public static AddOp add(Value a, Value b) {
535         return new AddOp(a, b);
536     }
537 
538     public static SubOp sub(Value a, Value b) {
539         return new SubOp(a, b);
540     }
541 
542     public static DivOp div(Value a, Value b) {
543         return new DivOp(a, b);
544     }
545 
546     public static RemOp rem(Value a, Value b) {
547         return new RemOp(a, b);
548     }
549 
550     public static AndOp and(Value a, Value b) {
551         return new AndOp(a, b);
552     }
553 
554     public static MaxOp maximum(Value a, Value b) {
555         return new MaxOp(a, b);
556     }
557 
558     public static MinOp minimum(Value a, Value b) {
559         return new MinOp(a, b);
560     }
561 
562     public static CompareOp cmp(CompareOp.CompareKind ck, Value a, Value b) {
563         return new CompareOp(ck, a, b);
564     }
565 
566     public static TruncOp trunc(TypeElement type, Value a) {
567         return new TruncOp(type, a);
568     }
569 
570     // Math
571 
572     public static ExpOp exp(Value a) {
573         return new ExpOp(a);
574     }
575 }