1 /* 2 * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved. 3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. 4 * 5 * This code is free software; you can redistribute it and/or modify it 6 * under the terms of the GNU General Public License version 2 only, as 7 * published by the Free Software Foundation. Oracle designates this 8 * particular file as subject to the "Classpath" exception as provided 9 * by Oracle in the LICENSE file that accompanied this code. 10 * 11 * This code is distributed in the hope that it will be useful, but WITHOUT 12 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or 13 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License 14 * version 2 for more details (a copy is included in the LICENSE file that 15 * accompanied this code). 16 * 17 * You should have received a copy of the GNU General Public License version 18 * 2 along with this work; if not, write to the Free Software Foundation, 19 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. 20 * 21 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA 22 * or visit www.oracle.com if you need additional information or have any 23 * questions. 24 */ 25 26 package oracle.code.triton; 27 28 import jdk.incubator.code.*; 29 import jdk.incubator.code.op.*; 30 import jdk.incubator.code.type.JavaType; 31 import java.util.HashMap; 32 import java.util.List; 33 import java.util.Map; 34 35 public class ArithMathOps { 36 37 static abstract class ArithMathOp extends ExternalizableOp { 38 final TypeElement resultType; 39 40 public ArithMathOp(ExternalizedOp def) { 41 super(def); 42 43 this.resultType = def.resultType(); 44 } 45 46 ArithMathOp(ArithMathOp that, CopyContext cc) { 47 super(that, cc); 48 49 this.resultType = that.resultType; 50 } 51 52 ArithMathOp(String name, TypeElement resultType, List<? extends Value> operands) { 53 super(name, operands); 54 55 this.resultType = resultType; 56 } 57 58 @Override 59 public TypeElement resultType() { 60 return resultType; 61 } 62 } 63 64 @OpFactory.OpDeclaration(ConstantOp.NAME) 65 public static class ConstantOp extends ArithMathOp implements Op.Pure { 66 public static final String NAME = "arith.constant"; 67 public static final String ATTRIBUTE_CONSTANT_VALUE = "value"; 68 69 final Object value; 70 71 public static ConstantOp create(ExternalizedOp def) { 72 if (!def.operands().isEmpty()) { 73 throw new IllegalArgumentException("Operation must have zero operands"); 74 } 75 76 Object value = def.extractAttributeValue(ATTRIBUTE_CONSTANT_VALUE,true, 77 v -> processConstantValue(def.resultType(), v)); 78 return new ConstantOp(def, value); 79 } 80 81 static Object processConstantValue(TypeElement t, Object value) { 82 if (t.equals(JavaType.BOOLEAN)) { 83 if (value instanceof String s) { 84 return Boolean.valueOf(s); 85 } else if (value instanceof Boolean) { 86 return value; 87 } 88 } else if (t.equals(JavaType.BYTE)) { 89 if (value instanceof String s) { 90 return Byte.valueOf(s); 91 } else if (value instanceof Number n) { 92 return n.byteValue(); 93 } 94 } else if (t.equals(JavaType.SHORT)) { 95 if (value instanceof String s) { 96 return Short.valueOf(s); 97 } else if (value instanceof Number n) { 98 return n.shortValue(); 99 } 100 } else if (t.equals(JavaType.CHAR)) { 101 if (value instanceof String s) { 102 return s.charAt(0); 103 } else if (value instanceof Character) { 104 return value; 105 } 106 } else if (t.equals(JavaType.INT)) { 107 if (value instanceof String s) { 108 return Integer.valueOf(s); 109 } else if (value instanceof Number n) { 110 return n.intValue(); 111 } 112 } else if (t.equals(JavaType.LONG)) { 113 if (value instanceof String s) { 114 return Long.valueOf(s); 115 } else if (value instanceof Number n) { 116 return n.longValue(); 117 } 118 } else if (t.equals(JavaType.FLOAT)) { 119 if (value instanceof String s) { 120 return Float.valueOf(s); 121 } else if (value instanceof Number n) { 122 return n.floatValue(); 123 } 124 } else if (t.equals(Float16.FLOAT_16_TYPE)) { 125 // represent as a float for now 126 if (value instanceof String s) { 127 return Float.valueOf(s); 128 } else if (value instanceof Number n) { 129 return n.floatValue(); 130 } 131 } else if (t.equals(JavaType.DOUBLE)) { 132 if (value instanceof String s) { 133 return Double.valueOf(s); 134 } else if (value instanceof Number n) { 135 return n.doubleValue(); 136 } 137 } else if (t instanceof TensorType tt) { 138 return processConstantValue(tt.eType(), value); 139 } 140 141 throw new UnsupportedOperationException("Unsupported constant type and value: " + t + " " + value); 142 } 143 144 ConstantOp(ExternalizedOp def, Object value) { 145 super(def); 146 147 this.value = value; 148 } 149 150 ConstantOp(ConstantOp that, CopyContext cc) { 151 super(that, cc); 152 153 this.value = that.value; 154 } 155 156 @Override 157 public ConstantOp transform(CopyContext cc, OpTransformer ot) { 158 return new ConstantOp(this, cc); 159 } 160 161 ConstantOp(TypeElement type, Object value) { 162 super(NAME, type, List.of()); 163 164 this.value = value; 165 } 166 167 @Override 168 public Map<String, Object> attributes() { 169 HashMap<String, Object> attrs = new HashMap<>(super.attributes()); 170 attrs.put(ATTRIBUTE_CONSTANT_VALUE, value); 171 return attrs; 172 } 173 174 public Object value() { 175 return value; 176 } 177 } 178 179 @OpFactory.OpDeclaration(AddOp.NAME) 180 public static class AddOp extends ArithMathOp implements Op.Pure { 181 public static final String NAME = "arith.add"; 182 183 public AddOp(ExternalizedOp def) { 184 super(def); 185 } 186 187 AddOp(AddOp that, CopyContext cc) { 188 super(that, cc); 189 } 190 191 @Override 192 public AddOp transform(CopyContext cc, OpTransformer ot) { 193 return new AddOp(this, cc); 194 } 195 196 AddOp(Value a, Value b) { 197 super(NAME + nameSuffixFromType(a.type(), false), a.type(), List.of(a, b)); 198 } 199 } 200 201 @OpFactory.OpDeclaration(SubOp.NAME) 202 public static class SubOp extends ArithMathOp implements Op.Pure { 203 public static final String NAME = "arith.sub"; 204 205 public SubOp(ExternalizedOp def) { 206 super(def); 207 } 208 209 SubOp(SubOp that, CopyContext cc) { 210 super(that, cc); 211 } 212 213 @Override 214 public SubOp transform(CopyContext cc, OpTransformer ot) { 215 return new SubOp(this, cc); 216 } 217 218 SubOp(Value a, Value b) { 219 super(NAME + nameSuffixFromType(a.type(), false), a.type(), List.of(a, b)); 220 } 221 } 222 223 @OpFactory.OpDeclaration(MulOp.NAME) 224 public static class MulOp extends ArithMathOp implements Op.Pure { 225 public static final String NAME = "arith.mul"; 226 227 public MulOp(ExternalizedOp def) { 228 super(def); 229 } 230 231 MulOp(MulOp that, CopyContext cc) { 232 super(that, cc); 233 } 234 235 @Override 236 public MulOp transform(CopyContext cc, OpTransformer ot) { 237 return new MulOp(this, cc); 238 } 239 240 MulOp(Value a, Value b) { 241 super(NAME + nameSuffixFromType(a.type(), false), a.type(), List.of(a, b)); 242 } 243 } 244 245 @OpFactory.OpDeclaration(DivOp.NAME) 246 public static class DivOp extends ArithMathOp implements Op.Pure { 247 public static final String NAME = "arith.div"; 248 249 public DivOp(ExternalizedOp def) { 250 super(def); 251 } 252 253 DivOp(DivOp that, CopyContext cc) { 254 super(that, cc); 255 } 256 257 @Override 258 public DivOp transform(CopyContext cc, OpTransformer ot) { 259 return new DivOp(this, cc); 260 } 261 262 DivOp(Value a, Value b) { 263 super(NAME + nameSuffixFromType(a.type(), true), a.type(), List.of(a, b)); 264 } 265 } 266 267 @OpFactory.OpDeclaration(RemOp.NAME) 268 public static class RemOp extends ArithMathOp implements Op.Pure { 269 public static final String NAME = "arith.rem"; 270 271 public RemOp(ExternalizedOp def) { 272 super(def); 273 } 274 275 RemOp(RemOp that, CopyContext cc) { 276 super(that, cc); 277 } 278 279 @Override 280 public RemOp transform(CopyContext cc, OpTransformer ot) { 281 return new RemOp(this, cc); 282 } 283 284 RemOp(Value a, Value b) { 285 super(NAME + nameSuffixFromType(a.type(), true), a.type(), List.of(a, b)); 286 } 287 } 288 289 @OpFactory.OpDeclaration(AndOp.NAME) 290 public static class AndOp extends ArithMathOp implements Op.Pure { 291 public static final String NAME = "arith.andi"; 292 293 public AndOp(ExternalizedOp def) { 294 super(def); 295 } 296 297 AndOp(AndOp that, CopyContext cc) { 298 super(that, cc); 299 } 300 301 @Override 302 public AndOp transform(CopyContext cc, OpTransformer ot) { 303 return new AndOp(this, cc); 304 } 305 306 AndOp(Value a, Value b) { 307 super(NAME, a.type(), List.of(a, b)); 308 } 309 } 310 311 @OpFactory.OpDeclaration(MaxOp.NAME) 312 public static class MaxOp extends ArithMathOp implements Op.Pure { 313 public static final String NAME = "arith.max"; 314 315 public MaxOp(ExternalizedOp def) { 316 super(def); 317 } 318 319 MaxOp(MaxOp that, CopyContext cc) { 320 super(that, cc); 321 } 322 323 @Override 324 public MaxOp transform(CopyContext cc, OpTransformer ot) { 325 return new MaxOp(this, cc); 326 } 327 328 MaxOp(Value a, Value b) { 329 super(NAME + maxMinSuffixFromType(a.type()) + nameSuffixFromType(a.type(), true), 330 a.type(), List.of(a, b)); 331 } 332 } 333 334 @OpFactory.OpDeclaration(MinOp.NAME) 335 public static class MinOp extends ArithMathOp implements Op.Pure { 336 public static final String NAME = "arith.min"; 337 338 public MinOp(ExternalizedOp def) { 339 super(def); 340 } 341 342 MinOp(MinOp that, CopyContext cc) { 343 super(that, cc); 344 } 345 346 @Override 347 public MinOp transform(CopyContext cc, OpTransformer ot) { 348 return new MinOp(this, cc); 349 } 350 351 MinOp(Value a, Value b) { 352 super(NAME + maxMinSuffixFromType(a.type()) + nameSuffixFromType(a.type(), true), 353 a.type(), List.of(a, b)); 354 } 355 } 356 357 @OpFactory.OpDeclaration(ExpOp.NAME) 358 public static class TruncOp extends ArithMathOp implements Op.Pure { 359 public static final String NAME = "arith.trunc"; 360 361 public TruncOp(ExternalizedOp def) { 362 super(def); 363 } 364 365 TruncOp(TruncOp that, CopyContext cc) { 366 super(that, cc); 367 } 368 369 @Override 370 public TruncOp transform(CopyContext cc, OpTransformer ot) { 371 return new TruncOp(this, cc); 372 } 373 374 TruncOp(TypeElement t, Value a) { 375 super(NAME + nameSuffixFromType(a.type(), false), 376 t, List.of(a)); 377 } 378 } 379 380 @OpFactory.OpDeclaration(ExpOp.NAME) 381 public static class ExpOp extends ArithMathOp implements Op.Pure { 382 public static final String NAME = "math.exp"; 383 384 public ExpOp(ExternalizedOp def) { 385 super(def); 386 } 387 388 ExpOp(ExpOp that, CopyContext cc) { 389 super(that, cc); 390 } 391 392 @Override 393 public ExpOp transform(CopyContext cc, OpTransformer ot) { 394 return new ExpOp(this, cc); 395 } 396 397 ExpOp(Value a) { 398 super(NAME, a.type(), List.of(a)); 399 } 400 } 401 402 @OpFactory.OpDeclaration(CompareOp.NAME) 403 public static class CompareOp extends ArithMathOp implements Op.Pure { 404 public static final String NAME = "arith.cmp"; 405 public static final String ATTRIBUTE_PREDICATE = "predicate"; 406 407 // https://mlir.llvm.org/docs/Dialects/ArithOps/#cmpipredicate 408 // The ordinal values correspond to the MLIR symbol's values 409 // Need to refine when considering comparisons of floating point numbers which is in a different namespace 410 public enum CompareKind { 411 eq, 412 ne, 413 slt, 414 sle, 415 sgt, 416 sge, 417 ult, 418 ule, 419 ugt, 420 uge 421 } 422 423 final CompareKind ck; 424 425 public static CompareOp create(ExternalizedOp def) { 426 CompareKind ck = def.extractAttributeValue(ATTRIBUTE_PREDICATE, true, 427 v -> switch (v) { 428 case String s -> CompareKind.valueOf(s); 429 case CompareKind k -> k; 430 case null, default -> throw new UnsupportedOperationException("Unsupported start value:" + v); 431 }); 432 return new CompareOp(def, ck); 433 } 434 435 CompareOp(ExternalizedOp def, CompareKind ck) { 436 super(def); 437 438 this.ck = ck; 439 } 440 441 CompareOp(CompareOp that, CopyContext cc) { 442 super(that, cc); 443 444 this.ck = that.ck; 445 } 446 447 @Override 448 public CompareOp transform(CopyContext cc, OpTransformer ot) { 449 return new CompareOp(this, cc); 450 } 451 452 CompareOp(CompareKind ck, Value a, Value b) { 453 TypeElement t; 454 if (a.type() instanceof TensorType ot) { 455 t = new TensorType(JavaType.BOOLEAN, ot.shape()); 456 } 457 else { 458 t = JavaType.BOOLEAN; 459 } 460 super(NAME + nameSuffixFromType(a.type(), false), t, List.of(a, b)); 461 462 this.ck = ck; 463 } 464 465 @Override 466 public Map<String, Object> attributes() { 467 HashMap<String, Object> attrs = new HashMap<>(super.attributes()); 468 attrs.put(ATTRIBUTE_PREDICATE, Long.valueOf(ck.ordinal())); 469 return attrs; 470 } 471 472 public CompareKind kind() { 473 return ck; 474 } 475 } 476 477 static String maxMinSuffixFromType(TypeElement t) { 478 if (t instanceof TensorType tt) { 479 return maxMinSuffixFromType(tt.eType()); 480 } else if (t instanceof PtrType pt) { 481 return maxMinSuffixFromType(pt.rType()); 482 } else if (t.equals(JavaType.INT)) { 483 return ""; 484 } else if (t.equals(JavaType.FLOAT)) { 485 return "imum"; 486 } else { 487 throw new UnsupportedOperationException("Unsupported type: " + t); 488 } 489 } 490 491 static String nameSuffixFromType(TypeElement t, boolean signed) { 492 if (t instanceof TensorType tt) { 493 return nameSuffixFromType(tt.eType(), signed); 494 } else if (t instanceof PtrType pt) { 495 return nameSuffixFromType(pt.rType(), signed); 496 } else if (t.equals(JavaType.INT) || t.equals(JavaType.LONG)) { 497 return (signed ? "s" : "") + "i"; 498 } else if (t.equals(JavaType.FLOAT) || t.equals(JavaType.DOUBLE) || 499 Float16.FLOAT_16_TYPE.equals(t)) { 500 return "f"; 501 } else { 502 throw new UnsupportedOperationException("Unsupported type: " + t); 503 } 504 } 505 506 public static final OpFactory FACTORY = def -> { 507 return switch (def.name()) { 508 case ConstantOp.NAME -> ConstantOp.create(def); 509 case ExpOp.NAME -> new ExpOp(def); 510 case AddOp.NAME + "i", AddOp.NAME + "f" -> new AddOp(def); 511 case SubOp.NAME + "i", SubOp.NAME + "f" -> new SubOp(def); 512 case MulOp.NAME + "i", MulOp.NAME + "f" -> new MulOp(def); 513 case DivOp.NAME + "si", DivOp.NAME + "f" -> new DivOp(def); 514 case RemOp.NAME + "si", RemOp.NAME + "f" -> new DivOp(def); 515 case AndOp.NAME -> new AndOp(def); 516 case MaxOp.NAME + "si", MaxOp.NAME + "imumf" -> new MaxOp(def); 517 case MinOp.NAME + "si", MinOp.NAME + "imumf" -> new MinOp(def); 518 case TruncOp.NAME + "i", TruncOp.NAME + "f" -> new TruncOp(def); 519 case CompareOp.NAME + "i", CompareOp.NAME + "f" -> CompareOp.create(def); 520 default -> null; 521 }; 522 }; 523 524 // Arith 525 526 public static ConstantOp constant(TypeElement type, Object value) { 527 return new ConstantOp(type, value); 528 } 529 530 public static MulOp mul(Value a, Value b) { 531 return new MulOp(a, b); 532 } 533 534 public static AddOp add(Value a, Value b) { 535 return new AddOp(a, b); 536 } 537 538 public static SubOp sub(Value a, Value b) { 539 return new SubOp(a, b); 540 } 541 542 public static DivOp div(Value a, Value b) { 543 return new DivOp(a, b); 544 } 545 546 public static RemOp rem(Value a, Value b) { 547 return new RemOp(a, b); 548 } 549 550 public static AndOp and(Value a, Value b) { 551 return new AndOp(a, b); 552 } 553 554 public static MaxOp maximum(Value a, Value b) { 555 return new MaxOp(a, b); 556 } 557 558 public static MinOp minimum(Value a, Value b) { 559 return new MinOp(a, b); 560 } 561 562 public static CompareOp cmp(CompareOp.CompareKind ck, Value a, Value b) { 563 return new CompareOp(ck, a, b); 564 } 565 566 public static TruncOp trunc(TypeElement type, Value a) { 567 return new TruncOp(type, a); 568 } 569 570 // Math 571 572 public static ExpOp exp(Value a) { 573 return new ExpOp(a); 574 } 575 }