1 /* 2 * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved. 3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. 4 * 5 * This code is free software; you can redistribute it and/or modify it 6 * under the terms of the GNU General Public License version 2 only, as 7 * published by the Free Software Foundation. Oracle designates this 8 * particular file as subject to the "Classpath" exception as provided 9 * by Oracle in the LICENSE file that accompanied this code. 10 * 11 * This code is distributed in the hope that it will be useful, but WITHOUT 12 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or 13 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License 14 * version 2 for more details (a copy is included in the LICENSE file that 15 * accompanied this code). 16 * 17 * You should have received a copy of the GNU General Public License version 18 * 2 along with this work; if not, write to the Free Software Foundation, 19 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. 20 * 21 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA 22 * or visit www.oracle.com if you need additional information or have any 23 * questions. 24 */ 25 26 package oracle.code.triton; 27 28 import java.lang.reflect.code.*; 29 import java.lang.reflect.code.op.*; 30 import java.lang.reflect.code.type.JavaType; 31 import java.util.HashMap; 32 import java.util.List; 33 import java.util.Map; 34 35 public class ArithMathOps { 36 37 static abstract class ArithMathOp extends ExternalizableOp { 38 final TypeElement resultType; 39 40 public ArithMathOp(ExternalizedOp def) { 41 super(def); 42 43 this.resultType = def.resultType(); 44 } 45 46 ArithMathOp(ArithMathOp that, CopyContext cc) { 47 super(that, cc); 48 49 this.resultType = that.resultType; 50 } 51 52 ArithMathOp(String name, TypeElement resultType, List<? extends Value> operands) { 53 super(name, operands); 54 55 this.resultType = resultType; 56 } 57 58 @Override 59 public TypeElement resultType() { 60 return resultType; 61 } 62 } 63 64 @OpFactory.OpDeclaration(ConstantOp.NAME) 65 public static class ConstantOp extends ArithMathOp implements Op.Pure { 66 public static final String NAME = "arith.constant"; 67 public static final String ATTRIBUTE_CONSTANT_VALUE = NAME + ".value"; 68 69 final Object value; 70 71 public static ConstantOp create(ExternalizedOp def) { 72 if (!def.operands().isEmpty()) { 73 throw new IllegalArgumentException("Operation must have zero operands"); 74 } 75 76 Object value = def.extractAttributeValue(ATTRIBUTE_CONSTANT_VALUE,true, 77 v -> processConstantValue(def.resultType(), v)); 78 return new ConstantOp(def, value); 79 } 80 81 static Object processConstantValue(TypeElement t, Object value) { 82 if (t.equals(JavaType.BOOLEAN)) { 83 if (value instanceof String s) { 84 return Boolean.valueOf(s); 85 } else if (value instanceof Boolean) { 86 return value; 87 } 88 } else if (t.equals(JavaType.BYTE)) { 89 if (value instanceof String s) { 90 return Byte.valueOf(s); 91 } else if (value instanceof Number n) { 92 return n.byteValue(); 93 } 94 } else if (t.equals(JavaType.SHORT)) { 95 if (value instanceof String s) { 96 return Short.valueOf(s); 97 } else if (value instanceof Number n) { 98 return n.shortValue(); 99 } 100 } else if (t.equals(JavaType.CHAR)) { 101 if (value instanceof String s) { 102 return s.charAt(0); 103 } else if (value instanceof Character) { 104 return value; 105 } 106 } else if (t.equals(JavaType.INT)) { 107 if (value instanceof String s) { 108 return Integer.valueOf(s); 109 } else if (value instanceof Number n) { 110 return n.intValue(); 111 } 112 } else if (t.equals(JavaType.LONG)) { 113 if (value instanceof String s) { 114 return Long.valueOf(s); 115 } else if (value instanceof Number n) { 116 return n.longValue(); 117 } 118 } else if (t.equals(JavaType.FLOAT)) { 119 if (value instanceof String s) { 120 return Float.valueOf(s); 121 } else if (value instanceof Number n) { 122 return n.floatValue(); 123 } 124 } else if (t.equals(JavaType.DOUBLE)) { 125 if (value instanceof String s) { 126 return Double.valueOf(s); 127 } else if (value instanceof Number n) { 128 return n.doubleValue(); 129 } 130 } else if (t instanceof TensorType tt) { 131 return processConstantValue(tt.eType(), value); 132 } 133 134 throw new UnsupportedOperationException("Unsupported constant type and value: " + t + " " + value); 135 } 136 137 ConstantOp(ExternalizedOp def, Object value) { 138 super(def); 139 140 this.value = value; 141 } 142 143 ConstantOp(ConstantOp that, CopyContext cc) { 144 super(that, cc); 145 146 this.value = that.value; 147 } 148 149 @Override 150 public ConstantOp transform(CopyContext cc, OpTransformer ot) { 151 return new ConstantOp(this, cc); 152 } 153 154 ConstantOp(TypeElement type, Object value) { 155 super(NAME, type, List.of()); 156 157 this.value = value; 158 } 159 160 @Override 161 public Map<String, Object> attributes() { 162 HashMap<String, Object> attrs = new HashMap<>(super.attributes()); 163 attrs.put("", value); 164 return attrs; 165 } 166 167 public Object value() { 168 return value; 169 } 170 } 171 172 @OpFactory.OpDeclaration(AddOp.NAME) 173 public static class AddOp extends ArithMathOp implements Op.Pure { 174 public static final String NAME = "arith.add"; 175 176 public AddOp(ExternalizedOp def) { 177 super(def); 178 } 179 180 AddOp(AddOp that, CopyContext cc) { 181 super(that, cc); 182 } 183 184 @Override 185 public AddOp transform(CopyContext cc, OpTransformer ot) { 186 return new AddOp(this, cc); 187 } 188 189 AddOp(Value a, Value b) { 190 super(NAME + nameSuffixFromType(a.type(), false), a.type(), List.of(a, b)); 191 } 192 } 193 194 @OpFactory.OpDeclaration(SubOp.NAME) 195 public static class SubOp extends ArithMathOp implements Op.Pure { 196 public static final String NAME = "arith.sub"; 197 198 public SubOp(ExternalizedOp def) { 199 super(def); 200 } 201 202 SubOp(SubOp that, CopyContext cc) { 203 super(that, cc); 204 } 205 206 @Override 207 public SubOp transform(CopyContext cc, OpTransformer ot) { 208 return new SubOp(this, cc); 209 } 210 211 SubOp(Value a, Value b) { 212 super(NAME + nameSuffixFromType(a.type(), false), a.type(), List.of(a, b)); 213 } 214 } 215 216 @OpFactory.OpDeclaration(MulOp.NAME) 217 public static class MulOp extends ArithMathOp implements Op.Pure { 218 public static final String NAME = "arith.mul"; 219 220 public MulOp(ExternalizedOp def) { 221 super(def); 222 } 223 224 MulOp(MulOp that, CopyContext cc) { 225 super(that, cc); 226 } 227 228 @Override 229 public MulOp transform(CopyContext cc, OpTransformer ot) { 230 return new MulOp(this, cc); 231 } 232 233 MulOp(Value a, Value b) { 234 super(NAME + nameSuffixFromType(a.type(), false), a.type(), List.of(a, b)); 235 } 236 } 237 238 @OpFactory.OpDeclaration(DivOp.NAME) 239 public static class DivOp extends ArithMathOp implements Op.Pure { 240 public static final String NAME = "arith.div"; 241 242 public DivOp(ExternalizedOp def) { 243 super(def); 244 } 245 246 DivOp(DivOp that, CopyContext cc) { 247 super(that, cc); 248 } 249 250 @Override 251 public DivOp transform(CopyContext cc, OpTransformer ot) { 252 return new DivOp(this, cc); 253 } 254 255 DivOp(Value a, Value b) { 256 super(NAME + nameSuffixFromType(a.type(), true), a.type(), List.of(a, b)); 257 } 258 } 259 260 @OpFactory.OpDeclaration(RemOp.NAME) 261 public static class RemOp extends ArithMathOp implements Op.Pure { 262 public static final String NAME = "arith.rem"; 263 264 public RemOp(ExternalizedOp def) { 265 super(def); 266 } 267 268 RemOp(RemOp that, CopyContext cc) { 269 super(that, cc); 270 } 271 272 @Override 273 public RemOp transform(CopyContext cc, OpTransformer ot) { 274 return new RemOp(this, cc); 275 } 276 277 RemOp(Value a, Value b) { 278 super(NAME + nameSuffixFromType(a.type(), true), a.type(), List.of(a, b)); 279 } 280 } 281 282 @OpFactory.OpDeclaration(AndOp.NAME) 283 public static class AndOp extends ArithMathOp implements Op.Pure { 284 public static final String NAME = "arith.andi"; 285 286 public AndOp(ExternalizedOp def) { 287 super(def); 288 } 289 290 AndOp(AndOp that, CopyContext cc) { 291 super(that, cc); 292 } 293 294 @Override 295 public AndOp transform(CopyContext cc, OpTransformer ot) { 296 return new AndOp(this, cc); 297 } 298 299 AndOp(Value a, Value b) { 300 super(NAME, a.type(), List.of(a, b)); 301 } 302 } 303 304 @OpFactory.OpDeclaration(MaxOp.NAME) 305 public static class MaxOp extends ArithMathOp implements Op.Pure { 306 public static final String NAME = "arith.max"; 307 308 public MaxOp(ExternalizedOp def) { 309 super(def); 310 } 311 312 MaxOp(MaxOp that, CopyContext cc) { 313 super(that, cc); 314 } 315 316 @Override 317 public MaxOp transform(CopyContext cc, OpTransformer ot) { 318 return new MaxOp(this, cc); 319 } 320 321 MaxOp(Value a, Value b) { 322 super(NAME + maxMinSuffixFromType(a.type()) + nameSuffixFromType(a.type(), true), 323 a.type(), List.of(a, b)); 324 } 325 } 326 327 @OpFactory.OpDeclaration(MinOp.NAME) 328 public static class MinOp extends ArithMathOp implements Op.Pure { 329 public static final String NAME = "arith.min"; 330 331 public MinOp(ExternalizedOp def) { 332 super(def); 333 } 334 335 MinOp(MinOp that, CopyContext cc) { 336 super(that, cc); 337 } 338 339 @Override 340 public MinOp transform(CopyContext cc, OpTransformer ot) { 341 return new MinOp(this, cc); 342 } 343 344 MinOp(Value a, Value b) { 345 super(NAME + maxMinSuffixFromType(a.type()) + nameSuffixFromType(a.type(), true), 346 a.type(), List.of(a, b)); 347 } 348 } 349 350 @OpFactory.OpDeclaration(ExpOp.NAME) 351 public static class TruncOp extends ArithMathOp implements Op.Pure { 352 public static final String NAME = "arith.trunc"; 353 354 public TruncOp(ExternalizedOp def) { 355 super(def); 356 } 357 358 TruncOp(TruncOp that, CopyContext cc) { 359 super(that, cc); 360 } 361 362 @Override 363 public TruncOp transform(CopyContext cc, OpTransformer ot) { 364 return new TruncOp(this, cc); 365 } 366 367 TruncOp(TypeElement t, Value a) { 368 super(NAME + nameSuffixFromType(a.type(), false), 369 t, List.of(a)); 370 } 371 } 372 373 @OpFactory.OpDeclaration(ExpOp.NAME) 374 public static class ExpOp extends ArithMathOp implements Op.Pure { 375 public static final String NAME = "math.exp"; 376 377 public ExpOp(ExternalizedOp def) { 378 super(def); 379 } 380 381 ExpOp(ExpOp that, CopyContext cc) { 382 super(that, cc); 383 } 384 385 @Override 386 public ExpOp transform(CopyContext cc, OpTransformer ot) { 387 return new ExpOp(this, cc); 388 } 389 390 ExpOp(Value a) { 391 super(NAME, a.type(), List.of(a)); 392 } 393 } 394 395 @OpFactory.OpDeclaration(CompareOp.NAME) 396 public static class CompareOp extends ArithMathOp implements Op.Pure { 397 public static final String NAME = "arith.cmp"; 398 public static final String ATTRIBUTE_CONSTANT_VALUE = NAME + ".compare"; 399 400 public enum CompareKind { 401 slt 402 } 403 404 final CompareKind ck; 405 406 public static CompareOp create(ExternalizedOp def) { 407 CompareKind ck = def.extractAttributeValue(ATTRIBUTE_CONSTANT_VALUE, true, 408 v -> switch (v) { 409 case String s -> CompareKind.valueOf(s); 410 case CompareKind k -> k; 411 default -> throw new UnsupportedOperationException("Unsupported start value:" + v); 412 }); 413 return new CompareOp(def, ck); 414 } 415 416 CompareOp(ExternalizedOp def, CompareKind ck) { 417 super(def); 418 419 this.ck = ck; 420 } 421 422 CompareOp(CompareOp that, CopyContext cc) { 423 super(that, cc); 424 425 this.ck = that.ck; 426 } 427 428 @Override 429 public CompareOp transform(CopyContext cc, OpTransformer ot) { 430 return new CompareOp(this, cc); 431 } 432 433 CompareOp(CompareKind ck, Value a, Value b) { 434 super(NAME + nameSuffixFromType(a.type(), false), a.type(), List.of(a, b)); 435 436 this.ck = ck; 437 } 438 439 @Override 440 public Map<String, Object> attributes() { 441 HashMap<String, Object> attrs = new HashMap<>(super.attributes()); 442 attrs.put("", ck); 443 return attrs; 444 } 445 446 public CompareKind kind() { 447 return ck; 448 } 449 } 450 451 static String maxMinSuffixFromType(TypeElement t) { 452 if (t instanceof TensorType tt) { 453 return maxMinSuffixFromType(tt.eType()); 454 } else if (t instanceof PtrType pt) { 455 return maxMinSuffixFromType(pt.rType()); 456 } else if (t.equals(JavaType.INT)) { 457 return ""; 458 } else if (t.equals(JavaType.FLOAT)) { 459 return "imum"; 460 } else { 461 throw new UnsupportedOperationException("Unsupported type: " + t); 462 } 463 } 464 465 static String nameSuffixFromType(TypeElement t, boolean signed) { 466 if (t instanceof TensorType tt) { 467 return nameSuffixFromType(tt.eType(), signed); 468 } else if (t instanceof PtrType pt) { 469 return nameSuffixFromType(pt.rType(), signed); 470 } else if (t.equals(JavaType.INT) || t.equals(JavaType.LONG)) { 471 return (signed ? "s" : "") + "i"; 472 } else if (t.equals(JavaType.FLOAT) || t.equals(JavaType.DOUBLE) || 473 Float16.FLOAT_16_TYPE.equals(t)) { 474 return "f"; 475 } else { 476 throw new UnsupportedOperationException("Unsupported type: " + t); 477 } 478 } 479 480 public static final OpFactory FACTORY = def -> { 481 return switch (def.name()) { 482 case ConstantOp.NAME -> ConstantOp.create(def); 483 case ExpOp.NAME -> new ExpOp(def); 484 case AddOp.NAME + "i", AddOp.NAME + "f" -> new AddOp(def); 485 case SubOp.NAME + "i", SubOp.NAME + "f" -> new SubOp(def); 486 case MulOp.NAME + "i", MulOp.NAME + "f" -> new MulOp(def); 487 case DivOp.NAME + "si", DivOp.NAME + "f" -> new DivOp(def); 488 case RemOp.NAME + "si", RemOp.NAME + "f" -> new DivOp(def); 489 case AndOp.NAME -> new AndOp(def); 490 case MaxOp.NAME + "si", MaxOp.NAME + "imumf" -> new MaxOp(def); 491 case MinOp.NAME + "si", MinOp.NAME + "imumf" -> new MinOp(def); 492 case TruncOp.NAME + "i", TruncOp.NAME + "f" -> new TruncOp(def); 493 case CompareOp.NAME + "i", CompareOp.NAME + "f" -> CompareOp.create(def); 494 default -> null; 495 }; 496 }; 497 498 // Arith 499 500 public static ConstantOp constant(TypeElement type, Object value) { 501 return new ConstantOp(type, value); 502 } 503 504 public static MulOp mul(Value a, Value b) { 505 return new MulOp(a, b); 506 } 507 508 public static AddOp add(Value a, Value b) { 509 return new AddOp(a, b); 510 } 511 512 public static SubOp sub(Value a, Value b) { 513 return new SubOp(a, b); 514 } 515 516 public static DivOp div(Value a, Value b) { 517 return new DivOp(a, b); 518 } 519 520 public static RemOp rem(Value a, Value b) { 521 return new RemOp(a, b); 522 } 523 524 public static AndOp and(Value a, Value b) { 525 return new AndOp(a, b); 526 } 527 528 public static MaxOp maximum(Value a, Value b) { 529 return new MaxOp(a, b); 530 } 531 532 public static MinOp minimum(Value a, Value b) { 533 return new MinOp(a, b); 534 } 535 536 public static CompareOp cmp(CompareOp.CompareKind ck, Value a, Value b) { 537 return new CompareOp(ck, a, b); 538 } 539 540 public static TruncOp trunc(TypeElement type, Value a) { 541 return new TruncOp(type, a); 542 } 543 544 // Math 545 546 public static ExpOp exp(Value a) { 547 return new ExpOp(a); 548 } 549 }