1 /*
  2  * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 
 26 package oracle.code.triton;
 27 
 28 import jdk.incubator.code.*;
 29 import jdk.incubator.code.extern.ExternalizableOp;
 30 import jdk.incubator.code.extern.OpFactory;
 31 import jdk.incubator.code.dialect.java.JavaType;
 32 
 33 import java.util.HashMap;
 34 import java.util.List;
 35 import java.util.Map;
 36 
 37 public class ArithMathOps {
 38 
 39     static abstract class ArithMathOp extends ExternalizableOp {
 40         final TypeElement resultType;
 41 
 42         public ArithMathOp(ExternalizedOp def) {
 43             super(def);
 44 
 45             this.resultType = def.resultType();
 46         }
 47 
 48         ArithMathOp(ArithMathOp that, CopyContext cc) {
 49             super(that, cc);
 50 
 51             this.resultType = that.resultType;
 52         }
 53 
 54         ArithMathOp(String name, TypeElement resultType, List<? extends Value> operands) {
 55             super(name, operands);
 56 
 57             this.resultType = resultType;
 58         }
 59 
 60         @Override
 61         public TypeElement resultType() {
 62             return resultType;
 63         }
 64     }
 65 
 66     @OpFactory.OpDeclaration(ConstantOp.NAME)
 67     public static class ConstantOp extends ArithMathOp implements Op.Pure {
 68         public static final String NAME = "arith.constant";
 69         public static final String ATTRIBUTE_CONSTANT_VALUE = "value";
 70 
 71         final Object value;
 72 
 73         public static ConstantOp create(ExternalizedOp def) {
 74             if (!def.operands().isEmpty()) {
 75                 throw new IllegalArgumentException("Operation must have zero operands");
 76             }
 77 
 78             Object value = def.extractAttributeValue(ATTRIBUTE_CONSTANT_VALUE,true,
 79                     v -> processConstantValue(def.resultType(), v));
 80             return new ConstantOp(def, value);
 81         }
 82 
 83         static Object processConstantValue(TypeElement t, Object value) {
 84             if (t.equals(JavaType.BOOLEAN) && value instanceof Boolean) {
 85                 return value;
 86             } else if (t.equals(JavaType.BYTE) && value instanceof Number n) {
 87                 return n.byteValue();
 88             } else if (t.equals(JavaType.SHORT) && value instanceof Number n) {
 89                 return n.shortValue();
 90             } else if (t.equals(JavaType.CHAR) && value instanceof Character) {
 91                 return value;
 92             } else if (t.equals(JavaType.INT) && value instanceof Number n) {
 93                 return n.intValue();
 94             } else if (t.equals(JavaType.LONG) && value instanceof Number n) {
 95                 return n.longValue();
 96             } else if (t.equals(JavaType.FLOAT) && value instanceof Number n) {
 97                 return n.floatValue();
 98             } else if (t.equals(Float16.FLOAT_16_TYPE) && value instanceof Number n) {
 99                 // represent as a float for now
100                 return n.floatValue();
101             } else if (t.equals(JavaType.DOUBLE) && value instanceof Number n) {
102                 return n.doubleValue();
103             } else if (t instanceof TensorType tt) {
104                 return processConstantValue(tt.eType(), value);
105             }
106 
107             throw new UnsupportedOperationException("Unsupported constant type and value: " + t + " " + value);
108         }
109 
110         ConstantOp(ExternalizedOp def, Object value) {
111             super(def);
112 
113             this.value = value;
114         }
115 
116         ConstantOp(ConstantOp that, CopyContext cc) {
117             super(that, cc);
118 
119             this.value = that.value;
120         }
121 
122         @Override
123         public ConstantOp transform(CopyContext cc, OpTransformer ot) {
124             return new ConstantOp(this, cc);
125         }
126 
127         ConstantOp(TypeElement type, Object value) {
128             super(NAME, type, List.of());
129 
130             this.value = value;
131         }
132 
133         @Override
134         public Map<String, Object> attributes() {
135             HashMap<String, Object> attrs = new HashMap<>(super.attributes());
136             attrs.put(ATTRIBUTE_CONSTANT_VALUE, value);
137             return attrs;
138         }
139 
140         public Object value() {
141             return value;
142         }
143     }
144 
145     @OpFactory.OpDeclaration(AddOp.NAME)
146     public static class AddOp extends ArithMathOp implements Op.Pure {
147         public static final String NAME = "arith.add";
148 
149         public AddOp(ExternalizedOp def) {
150             super(def);
151         }
152 
153         AddOp(AddOp that, CopyContext cc) {
154             super(that, cc);
155         }
156 
157         @Override
158         public AddOp transform(CopyContext cc, OpTransformer ot) {
159             return new AddOp(this, cc);
160         }
161 
162         AddOp(Value a, Value b) {
163             super(NAME + nameSuffixFromType(a.type(), false), a.type(), List.of(a, b));
164         }
165     }
166 
167     @OpFactory.OpDeclaration(SubOp.NAME)
168     public static class SubOp extends ArithMathOp implements Op.Pure {
169         public static final String NAME = "arith.sub";
170 
171         public SubOp(ExternalizedOp def) {
172             super(def);
173         }
174 
175         SubOp(SubOp that, CopyContext cc) {
176             super(that, cc);
177         }
178 
179         @Override
180         public SubOp transform(CopyContext cc, OpTransformer ot) {
181             return new SubOp(this, cc);
182         }
183 
184         SubOp(Value a, Value b) {
185             super(NAME + nameSuffixFromType(a.type(), false), a.type(), List.of(a, b));
186         }
187     }
188 
189     @OpFactory.OpDeclaration(MulOp.NAME)
190     public static class MulOp extends ArithMathOp implements Op.Pure {
191         public static final String NAME = "arith.mul";
192 
193         public MulOp(ExternalizedOp def) {
194             super(def);
195         }
196 
197         MulOp(MulOp that, CopyContext cc) {
198             super(that, cc);
199         }
200 
201         @Override
202         public MulOp transform(CopyContext cc, OpTransformer ot) {
203             return new MulOp(this, cc);
204         }
205 
206         MulOp(Value a, Value b) {
207             super(NAME + nameSuffixFromType(a.type(), false), a.type(), List.of(a, b));
208         }
209     }
210 
211     @OpFactory.OpDeclaration(DivOp.NAME)
212     public static class DivOp extends ArithMathOp implements Op.Pure {
213         public static final String NAME = "arith.div";
214 
215         public DivOp(ExternalizedOp def) {
216             super(def);
217         }
218 
219         DivOp(DivOp that, CopyContext cc) {
220             super(that, cc);
221         }
222 
223         @Override
224         public DivOp transform(CopyContext cc, OpTransformer ot) {
225             return new DivOp(this, cc);
226         }
227 
228         DivOp(Value a, Value b) {
229             super(NAME + nameSuffixFromType(a.type(), true), a.type(), List.of(a, b));
230         }
231     }
232 
233     @OpFactory.OpDeclaration(RemOp.NAME)
234     public static class RemOp extends ArithMathOp implements Op.Pure {
235         public static final String NAME = "arith.rem";
236 
237         public RemOp(ExternalizedOp def) {
238             super(def);
239         }
240 
241         RemOp(RemOp that, CopyContext cc) {
242             super(that, cc);
243         }
244 
245         @Override
246         public RemOp transform(CopyContext cc, OpTransformer ot) {
247             return new RemOp(this, cc);
248         }
249 
250         RemOp(Value a, Value b) {
251             super(NAME + nameSuffixFromType(a.type(), true), a.type(), List.of(a, b));
252         }
253     }
254 
255     @OpFactory.OpDeclaration(AndOp.NAME)
256     public static class AndOp extends ArithMathOp implements Op.Pure {
257         public static final String NAME = "arith.andi";
258 
259         public AndOp(ExternalizedOp def) {
260             super(def);
261         }
262 
263         AndOp(AndOp that, CopyContext cc) {
264             super(that, cc);
265         }
266 
267         @Override
268         public AndOp transform(CopyContext cc, OpTransformer ot) {
269             return new AndOp(this, cc);
270         }
271 
272         AndOp(Value a, Value b) {
273             super(NAME, a.type(), List.of(a, b));
274         }
275     }
276 
277     @OpFactory.OpDeclaration(MaxOp.NAME)
278     public static class MaxOp extends ArithMathOp implements Op.Pure {
279         public static final String NAME = "arith.max";
280 
281         public MaxOp(ExternalizedOp def) {
282             super(def);
283         }
284 
285         MaxOp(MaxOp that, CopyContext cc) {
286             super(that, cc);
287         }
288 
289         @Override
290         public MaxOp transform(CopyContext cc, OpTransformer ot) {
291             return new MaxOp(this, cc);
292         }
293 
294         MaxOp(Value a, Value b) {
295             super(NAME + maxMinSuffixFromType(a.type()) + nameSuffixFromType(a.type(), true),
296                     a.type(), List.of(a, b));
297         }
298     }
299 
300     @OpFactory.OpDeclaration(MinOp.NAME)
301     public static class MinOp extends ArithMathOp implements Op.Pure {
302         public static final String NAME = "arith.min";
303 
304         public MinOp(ExternalizedOp def) {
305             super(def);
306         }
307 
308         MinOp(MinOp that, CopyContext cc) {
309             super(that, cc);
310         }
311 
312         @Override
313         public MinOp transform(CopyContext cc, OpTransformer ot) {
314             return new MinOp(this, cc);
315         }
316 
317         MinOp(Value a, Value b) {
318             super(NAME + maxMinSuffixFromType(a.type()) + nameSuffixFromType(a.type(), true),
319                     a.type(), List.of(a, b));
320         }
321     }
322 
323     @OpFactory.OpDeclaration(ExpOp.NAME)
324     public static class TruncOp extends ArithMathOp implements Op.Pure {
325         public static final String NAME = "arith.trunc";
326 
327         public TruncOp(ExternalizedOp def) {
328             super(def);
329         }
330 
331         TruncOp(TruncOp that, CopyContext cc) {
332             super(that, cc);
333         }
334 
335         @Override
336         public TruncOp transform(CopyContext cc, OpTransformer ot) {
337             return new TruncOp(this, cc);
338         }
339 
340         TruncOp(TypeElement t, Value a) {
341             super(NAME + nameSuffixFromType(a.type(), false),
342                     t, List.of(a));
343         }
344     }
345 
346     @OpFactory.OpDeclaration(ExpOp.NAME)
347     public static class ExpOp extends ArithMathOp implements Op.Pure {
348         public static final String NAME = "math.exp";
349 
350         public ExpOp(ExternalizedOp def) {
351             super(def);
352         }
353 
354         ExpOp(ExpOp that, CopyContext cc) {
355             super(that, cc);
356         }
357 
358         @Override
359         public ExpOp transform(CopyContext cc, OpTransformer ot) {
360             return new ExpOp(this, cc);
361         }
362 
363         ExpOp(Value a) {
364             super(NAME, a.type(), List.of(a));
365         }
366     }
367 
368     @OpFactory.OpDeclaration(CompareOp.NAME)
369     public static class CompareOp extends ArithMathOp implements Op.Pure {
370         public static final String NAME = "arith.cmp";
371         public static final String ATTRIBUTE_PREDICATE = "predicate";
372 
373         // https://mlir.llvm.org/docs/Dialects/ArithOps/#cmpipredicate
374         // The ordinal values correspond to the MLIR symbol's values
375         // Need to refine when considering comparisons of floating point numbers which is in a different namespace
376         public enum CompareKind {
377             eq,
378             ne,
379             slt,
380             sle,
381             sgt,
382             sge,
383             ult,
384             ule,
385             ugt,
386             uge
387         }
388 
389         final CompareKind ck;
390 
391         public static CompareOp create(ExternalizedOp def) {
392             CompareKind ck = def.extractAttributeValue(ATTRIBUTE_PREDICATE, true,
393                     v -> switch (v) {
394                         case String s -> CompareKind.valueOf(s);
395                         case CompareKind k -> k;
396                         case null, default -> throw new UnsupportedOperationException("Unsupported start value:" + v);
397                     });
398             return new CompareOp(def, ck);
399         }
400 
401         CompareOp(ExternalizedOp def, CompareKind ck) {
402             super(def);
403 
404             this.ck = ck;
405         }
406 
407         CompareOp(CompareOp that, CopyContext cc) {
408             super(that, cc);
409 
410             this.ck = that.ck;
411         }
412 
413         @Override
414         public CompareOp transform(CopyContext cc, OpTransformer ot) {
415             return new CompareOp(this, cc);
416         }
417 
418         CompareOp(CompareKind ck, Value a, Value b) {
419             TypeElement t;
420             if (a.type() instanceof TensorType ot) {
421                 t = new TensorType(JavaType.BOOLEAN, ot.shape());
422             }
423             else {
424                 t = JavaType.BOOLEAN;
425             }
426             super(NAME + nameSuffixFromType(a.type(), false), t, List.of(a, b));
427 
428             this.ck = ck;
429         }
430 
431         @Override
432         public Map<String, Object> attributes() {
433             HashMap<String, Object> attrs = new HashMap<>(super.attributes());
434             attrs.put(ATTRIBUTE_PREDICATE, Long.valueOf(ck.ordinal()));
435             return attrs;
436         }
437 
438         public CompareKind kind() {
439             return ck;
440         }
441     }
442 
443     static String maxMinSuffixFromType(TypeElement t) {
444         if (t instanceof TensorType tt) {
445             return maxMinSuffixFromType(tt.eType());
446         } else if (t instanceof PtrType pt) {
447             return maxMinSuffixFromType(pt.rType());
448         } else if (t.equals(JavaType.INT)) {
449             return "";
450         } else if (t.equals(JavaType.FLOAT)) {
451             return "imum";
452         } else {
453             throw new UnsupportedOperationException("Unsupported type: " + t);
454         }
455     }
456 
457     static String nameSuffixFromType(TypeElement t, boolean signed) {
458         if (t instanceof TensorType tt) {
459             return nameSuffixFromType(tt.eType(), signed);
460         } else if (t instanceof PtrType pt) {
461             return nameSuffixFromType(pt.rType(), signed);
462         } else if (t.equals(JavaType.INT) || t.equals(JavaType.LONG)) {
463             return (signed ? "s" : "") + "i";
464         } else if (t.equals(JavaType.FLOAT) || t.equals(JavaType.DOUBLE) ||
465                 Float16.FLOAT_16_TYPE.equals(t)) {
466             return "f";
467         } else {
468             throw new UnsupportedOperationException("Unsupported type: " + t);
469         }
470     }
471 
472     public static final OpFactory OP_FACTORY = def -> {
473         return switch (def.name()) {
474             case ConstantOp.NAME -> ConstantOp.create(def);
475             case ExpOp.NAME -> new ExpOp(def);
476             case AddOp.NAME + "i", AddOp.NAME + "f" -> new AddOp(def);
477             case SubOp.NAME + "i", SubOp.NAME + "f" -> new SubOp(def);
478             case MulOp.NAME + "i", MulOp.NAME + "f" -> new MulOp(def);
479             case DivOp.NAME + "si", DivOp.NAME + "f" -> new DivOp(def);
480             case RemOp.NAME + "si", RemOp.NAME + "f" -> new DivOp(def);
481             case AndOp.NAME -> new AndOp(def);
482             case MaxOp.NAME + "si", MaxOp.NAME + "imumf" -> new MaxOp(def);
483             case MinOp.NAME + "si", MinOp.NAME + "imumf" -> new MinOp(def);
484             case TruncOp.NAME + "i", TruncOp.NAME + "f" -> new TruncOp(def);
485             case CompareOp.NAME + "i", CompareOp.NAME + "f" -> CompareOp.create(def);
486             default -> null;
487         };
488     };
489 
490     // Arith
491 
492     public static ConstantOp constant(TypeElement type, Object value) {
493         return new ConstantOp(type, value);
494     }
495 
496     public static MulOp mul(Value a, Value b) {
497         return new MulOp(a, b);
498     }
499 
500     public static AddOp add(Value a, Value b) {
501         return new AddOp(a, b);
502     }
503 
504     public static SubOp sub(Value a, Value b) {
505         return new SubOp(a, b);
506     }
507 
508     public static DivOp div(Value a, Value b) {
509         return new DivOp(a, b);
510     }
511 
512     public static RemOp rem(Value a, Value b) {
513         return new RemOp(a, b);
514     }
515 
516     public static AndOp and(Value a, Value b) {
517         return new AndOp(a, b);
518     }
519 
520     public static MaxOp maximum(Value a, Value b) {
521         return new MaxOp(a, b);
522     }
523 
524     public static MinOp minimum(Value a, Value b) {
525         return new MinOp(a, b);
526     }
527 
528     public static CompareOp cmp(CompareOp.CompareKind ck, Value a, Value b) {
529         return new CompareOp(ck, a, b);
530     }
531 
532     public static TruncOp trunc(TypeElement type, Value a) {
533         return new TruncOp(type, a);
534     }
535 
536     // Math
537 
538     public static ExpOp exp(Value a) {
539         return new ExpOp(a);
540     }
541 }