1 /*
  2  * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 
 26 package oracle.code.triton;
 27 
 28 import java.lang.reflect.code.*;
 29 import java.lang.reflect.code.op.*;
 30 import java.lang.reflect.code.type.JavaType;
 31 import java.util.HashMap;
 32 import java.util.List;
 33 import java.util.Map;
 34 
 35 public class ArithMathOps {
 36 
 37     static abstract class ArithMathOp extends ExternalizableOp {
 38         final TypeElement resultType;
 39 
 40         public ArithMathOp(ExternalizedOp def) {
 41             super(def);
 42 
 43             this.resultType = def.resultType();
 44         }
 45 
 46         ArithMathOp(ArithMathOp that, CopyContext cc) {
 47             super(that, cc);
 48 
 49             this.resultType = that.resultType;
 50         }
 51 
 52         ArithMathOp(String name, TypeElement resultType, List<? extends Value> operands) {
 53             super(name, operands);
 54 
 55             this.resultType = resultType;
 56         }
 57 
 58         @Override
 59         public TypeElement resultType() {
 60             return resultType;
 61         }
 62     }
 63 
 64     @OpFactory.OpDeclaration(ConstantOp.NAME)
 65     public static class ConstantOp extends ArithMathOp implements Op.Pure {
 66         public static final String NAME = "arith.constant";
 67         public static final String ATTRIBUTE_CONSTANT_VALUE = NAME + ".value";
 68 
 69         final Object value;
 70 
 71         public static ConstantOp create(ExternalizedOp def) {
 72             if (!def.operands().isEmpty()) {
 73                 throw new IllegalArgumentException("Operation must have zero operands");
 74             }
 75 
 76             Object value = def.extractAttributeValue(ATTRIBUTE_CONSTANT_VALUE,true,
 77                     v -> processConstantValue(def.resultType(), v));
 78             return new ConstantOp(def, value);
 79         }
 80 
 81         static Object processConstantValue(TypeElement t, Object value) {
 82             if (t.equals(JavaType.BOOLEAN)) {
 83                 if (value instanceof String s) {
 84                     return Boolean.valueOf(s);
 85                 } else if (value instanceof Boolean) {
 86                     return value;
 87                 }
 88             } else if (t.equals(JavaType.BYTE)) {
 89                 if (value instanceof String s) {
 90                     return Byte.valueOf(s);
 91                 } else if (value instanceof Number n) {
 92                     return n.byteValue();
 93                 }
 94             } else if (t.equals(JavaType.SHORT)) {
 95                 if (value instanceof String s) {
 96                     return Short.valueOf(s);
 97                 } else if (value instanceof Number n) {
 98                     return n.shortValue();
 99                 }
100             } else if (t.equals(JavaType.CHAR)) {
101                 if (value instanceof String s) {
102                     return s.charAt(0);
103                 } else if (value instanceof Character) {
104                     return value;
105                 }
106             } else if (t.equals(JavaType.INT)) {
107                 if (value instanceof String s) {
108                     return Integer.valueOf(s);
109                 } else if (value instanceof Number n) {
110                     return n.intValue();
111                 }
112             } else if (t.equals(JavaType.LONG)) {
113                 if (value instanceof String s) {
114                     return Long.valueOf(s);
115                 } else if (value instanceof Number n) {
116                     return n.longValue();
117                 }
118             } else if (t.equals(JavaType.FLOAT)) {
119                 if (value instanceof String s) {
120                     return Float.valueOf(s);
121                 } else if (value instanceof Number n) {
122                     return n.floatValue();
123                 }
124             } else if (t.equals(JavaType.DOUBLE)) {
125                 if (value instanceof String s) {
126                     return Double.valueOf(s);
127                 } else if (value instanceof Number n) {
128                     return n.doubleValue();
129                 }
130             } else if (t instanceof TensorType tt) {
131                 return processConstantValue(tt.eType(), value);
132             }
133 
134             throw new UnsupportedOperationException("Unsupported constant type and value: " + t + " " + value);
135         }
136 
137         ConstantOp(ExternalizedOp def, Object value) {
138             super(def);
139 
140             this.value = value;
141         }
142 
143         ConstantOp(ConstantOp that, CopyContext cc) {
144             super(that, cc);
145 
146             this.value = that.value;
147         }
148 
149         @Override
150         public ConstantOp transform(CopyContext cc, OpTransformer ot) {
151             return new ConstantOp(this, cc);
152         }
153 
154         ConstantOp(TypeElement type, Object value) {
155             super(NAME, type, List.of());
156 
157             this.value = value;
158         }
159 
160         @Override
161         public Map<String, Object> attributes() {
162             HashMap<String, Object> attrs = new HashMap<>(super.attributes());
163             attrs.put("", value);
164             return attrs;
165         }
166 
167         public Object value() {
168             return value;
169         }
170     }
171 
172     @OpFactory.OpDeclaration(AddOp.NAME)
173     public static class AddOp extends ArithMathOp implements Op.Pure {
174         public static final String NAME = "arith.add";
175 
176         public AddOp(ExternalizedOp def) {
177             super(def);
178         }
179 
180         AddOp(AddOp that, CopyContext cc) {
181             super(that, cc);
182         }
183 
184         @Override
185         public AddOp transform(CopyContext cc, OpTransformer ot) {
186             return new AddOp(this, cc);
187         }
188 
189         AddOp(Value a, Value b) {
190             super(NAME + nameSuffixFromType(a.type(), false), a.type(), List.of(a, b));
191         }
192     }
193 
194     @OpFactory.OpDeclaration(SubOp.NAME)
195     public static class SubOp extends ArithMathOp implements Op.Pure {
196         public static final String NAME = "arith.sub";
197 
198         public SubOp(ExternalizedOp def) {
199             super(def);
200         }
201 
202         SubOp(SubOp that, CopyContext cc) {
203             super(that, cc);
204         }
205 
206         @Override
207         public SubOp transform(CopyContext cc, OpTransformer ot) {
208             return new SubOp(this, cc);
209         }
210 
211         SubOp(Value a, Value b) {
212             super(NAME + nameSuffixFromType(a.type(), false), a.type(), List.of(a, b));
213         }
214     }
215 
216     @OpFactory.OpDeclaration(MulOp.NAME)
217     public static class MulOp extends ArithMathOp implements Op.Pure {
218         public static final String NAME = "arith.mul";
219 
220         public MulOp(ExternalizedOp def) {
221             super(def);
222         }
223 
224         MulOp(MulOp that, CopyContext cc) {
225             super(that, cc);
226         }
227 
228         @Override
229         public MulOp transform(CopyContext cc, OpTransformer ot) {
230             return new MulOp(this, cc);
231         }
232 
233         MulOp(Value a, Value b) {
234             super(NAME + nameSuffixFromType(a.type(), false), a.type(), List.of(a, b));
235         }
236     }
237 
238     @OpFactory.OpDeclaration(DivOp.NAME)
239     public static class DivOp extends ArithMathOp implements Op.Pure {
240         public static final String NAME = "arith.div";
241 
242         public DivOp(ExternalizedOp def) {
243             super(def);
244         }
245 
246         DivOp(DivOp that, CopyContext cc) {
247             super(that, cc);
248         }
249 
250         @Override
251         public DivOp transform(CopyContext cc, OpTransformer ot) {
252             return new DivOp(this, cc);
253         }
254 
255         DivOp(Value a, Value b) {
256             super(NAME + nameSuffixFromType(a.type(), true), a.type(), List.of(a, b));
257         }
258     }
259 
260     @OpFactory.OpDeclaration(RemOp.NAME)
261     public static class RemOp extends ArithMathOp implements Op.Pure {
262         public static final String NAME = "arith.rem";
263 
264         public RemOp(ExternalizedOp def) {
265             super(def);
266         }
267 
268         RemOp(RemOp that, CopyContext cc) {
269             super(that, cc);
270         }
271 
272         @Override
273         public RemOp transform(CopyContext cc, OpTransformer ot) {
274             return new RemOp(this, cc);
275         }
276 
277         RemOp(Value a, Value b) {
278             super(NAME + nameSuffixFromType(a.type(), true), a.type(), List.of(a, b));
279         }
280     }
281 
282     @OpFactory.OpDeclaration(AndOp.NAME)
283     public static class AndOp extends ArithMathOp implements Op.Pure {
284         public static final String NAME = "arith.andi";
285 
286         public AndOp(ExternalizedOp def) {
287             super(def);
288         }
289 
290         AndOp(AndOp that, CopyContext cc) {
291             super(that, cc);
292         }
293 
294         @Override
295         public AndOp transform(CopyContext cc, OpTransformer ot) {
296             return new AndOp(this, cc);
297         }
298 
299         AndOp(Value a, Value b) {
300             super(NAME, a.type(), List.of(a, b));
301         }
302     }
303 
304     @OpFactory.OpDeclaration(MaxOp.NAME)
305     public static class MaxOp extends ArithMathOp implements Op.Pure {
306         public static final String NAME = "arith.max";
307 
308         public MaxOp(ExternalizedOp def) {
309             super(def);
310         }
311 
312         MaxOp(MaxOp that, CopyContext cc) {
313             super(that, cc);
314         }
315 
316         @Override
317         public MaxOp transform(CopyContext cc, OpTransformer ot) {
318             return new MaxOp(this, cc);
319         }
320 
321         MaxOp(Value a, Value b) {
322             super(NAME + maxMinSuffixFromType(a.type()) + nameSuffixFromType(a.type(), true),
323                     a.type(), List.of(a, b));
324         }
325     }
326 
327     @OpFactory.OpDeclaration(MinOp.NAME)
328     public static class MinOp extends ArithMathOp implements Op.Pure {
329         public static final String NAME = "arith.min";
330 
331         public MinOp(ExternalizedOp def) {
332             super(def);
333         }
334 
335         MinOp(MinOp that, CopyContext cc) {
336             super(that, cc);
337         }
338 
339         @Override
340         public MinOp transform(CopyContext cc, OpTransformer ot) {
341             return new MinOp(this, cc);
342         }
343 
344         MinOp(Value a, Value b) {
345             super(NAME + maxMinSuffixFromType(a.type()) + nameSuffixFromType(a.type(), true),
346                     a.type(), List.of(a, b));
347         }
348     }
349 
350     @OpFactory.OpDeclaration(ExpOp.NAME)
351     public static class TruncOp extends ArithMathOp implements Op.Pure {
352         public static final String NAME = "arith.trunc";
353 
354         public TruncOp(ExternalizedOp def) {
355             super(def);
356         }
357 
358         TruncOp(TruncOp that, CopyContext cc) {
359             super(that, cc);
360         }
361 
362         @Override
363         public TruncOp transform(CopyContext cc, OpTransformer ot) {
364             return new TruncOp(this, cc);
365         }
366 
367         TruncOp(TypeElement t, Value a) {
368             super(NAME + nameSuffixFromType(a.type(), false),
369                     t, List.of(a));
370         }
371     }
372 
373     @OpFactory.OpDeclaration(ExpOp.NAME)
374     public static class ExpOp extends ArithMathOp implements Op.Pure {
375         public static final String NAME = "math.exp";
376 
377         public ExpOp(ExternalizedOp def) {
378             super(def);
379         }
380 
381         ExpOp(ExpOp that, CopyContext cc) {
382             super(that, cc);
383         }
384 
385         @Override
386         public ExpOp transform(CopyContext cc, OpTransformer ot) {
387             return new ExpOp(this, cc);
388         }
389 
390         ExpOp(Value a) {
391             super(NAME, a.type(), List.of(a));
392         }
393     }
394 
395     @OpFactory.OpDeclaration(CompareOp.NAME)
396     public static class CompareOp extends ArithMathOp implements Op.Pure {
397         public static final String NAME = "arith.cmp";
398         public static final String ATTRIBUTE_CONSTANT_VALUE = NAME + ".compare";
399 
400         public enum CompareKind {
401             slt
402         }
403 
404         final CompareKind ck;
405 
406         public static CompareOp create(ExternalizedOp def) {
407             CompareKind ck = def.extractAttributeValue(ATTRIBUTE_CONSTANT_VALUE, true,
408                     v -> switch (v) {
409                         case String s -> CompareKind.valueOf(s);
410                         case CompareKind k -> k;
411                         default -> throw new UnsupportedOperationException("Unsupported start value:" + v);
412                     });
413             return new CompareOp(def, ck);
414         }
415 
416         CompareOp(ExternalizedOp def, CompareKind ck) {
417             super(def);
418 
419             this.ck = ck;
420         }
421 
422         CompareOp(CompareOp that, CopyContext cc) {
423             super(that, cc);
424 
425             this.ck = that.ck;
426         }
427 
428         @Override
429         public CompareOp transform(CopyContext cc, OpTransformer ot) {
430             return new CompareOp(this, cc);
431         }
432 
433         CompareOp(CompareKind ck, Value a, Value b) {
434             super(NAME + nameSuffixFromType(a.type(), false), a.type(), List.of(a, b));
435 
436             this.ck = ck;
437         }
438 
439         @Override
440         public Map<String, Object> attributes() {
441             HashMap<String, Object> attrs = new HashMap<>(super.attributes());
442             attrs.put("", ck);
443             return attrs;
444         }
445 
446         public CompareKind kind() {
447             return ck;
448         }
449     }
450 
451     static String maxMinSuffixFromType(TypeElement t) {
452         if (t instanceof TensorType tt) {
453             return maxMinSuffixFromType(tt.eType());
454         } else if (t instanceof PtrType pt) {
455             return maxMinSuffixFromType(pt.rType());
456         } else if (t.equals(JavaType.INT)) {
457             return "";
458         } else if (t.equals(JavaType.FLOAT)) {
459             return "imum";
460         } else {
461             throw new UnsupportedOperationException("Unsupported type: " + t);
462         }
463     }
464 
465     static String nameSuffixFromType(TypeElement t, boolean signed) {
466         if (t instanceof TensorType tt) {
467             return nameSuffixFromType(tt.eType(), signed);
468         } else if (t instanceof PtrType pt) {
469             return nameSuffixFromType(pt.rType(), signed);
470         } else if (t.equals(JavaType.INT) || t.equals(JavaType.LONG)) {
471             return (signed ? "s" : "") + "i";
472         } else if (t.equals(JavaType.FLOAT) || t.equals(JavaType.DOUBLE) ||
473                 Float16.FLOAT_16_TYPE.equals(t)) {
474             return "f";
475         } else {
476             throw new UnsupportedOperationException("Unsupported type: " + t);
477         }
478     }
479 
480     public static final OpFactory FACTORY = def -> {
481         return switch (def.name()) {
482             case ConstantOp.NAME -> ConstantOp.create(def);
483             case ExpOp.NAME -> new ExpOp(def);
484             case AddOp.NAME + "i", AddOp.NAME + "f" -> new AddOp(def);
485             case SubOp.NAME + "i", SubOp.NAME + "f" -> new SubOp(def);
486             case MulOp.NAME + "i", MulOp.NAME + "f" -> new MulOp(def);
487             case DivOp.NAME + "si", DivOp.NAME + "f" -> new DivOp(def);
488             case RemOp.NAME + "si", RemOp.NAME + "f" -> new DivOp(def);
489             case AndOp.NAME -> new AndOp(def);
490             case MaxOp.NAME + "si", MaxOp.NAME + "imumf" -> new MaxOp(def);
491             case MinOp.NAME + "si", MinOp.NAME + "imumf" -> new MinOp(def);
492             case TruncOp.NAME + "i", TruncOp.NAME + "f" -> new TruncOp(def);
493             case CompareOp.NAME + "i", CompareOp.NAME + "f" -> CompareOp.create(def);
494             default -> null;
495         };
496     };
497 
498     // Arith
499 
500     public static ConstantOp constant(TypeElement type, Object value) {
501         return new ConstantOp(type, value);
502     }
503 
504     public static MulOp mul(Value a, Value b) {
505         return new MulOp(a, b);
506     }
507 
508     public static AddOp add(Value a, Value b) {
509         return new AddOp(a, b);
510     }
511 
512     public static SubOp sub(Value a, Value b) {
513         return new SubOp(a, b);
514     }
515 
516     public static DivOp div(Value a, Value b) {
517         return new DivOp(a, b);
518     }
519 
520     public static RemOp rem(Value a, Value b) {
521         return new RemOp(a, b);
522     }
523 
524     public static AndOp and(Value a, Value b) {
525         return new AndOp(a, b);
526     }
527 
528     public static MaxOp maximum(Value a, Value b) {
529         return new MaxOp(a, b);
530     }
531 
532     public static MinOp minimum(Value a, Value b) {
533         return new MinOp(a, b);
534     }
535 
536     public static CompareOp cmp(CompareOp.CompareKind ck, Value a, Value b) {
537         return new CompareOp(ck, a, b);
538     }
539 
540     public static TruncOp trunc(TypeElement type, Value a) {
541         return new TruncOp(type, a);
542     }
543 
544     // Math
545 
546     public static ExpOp exp(Value a) {
547         return new ExpOp(a);
548     }
549 }