1 /* 2 * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved. 3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. 4 * 5 * This code is free software; you can redistribute it and/or modify it 6 * under the terms of the GNU General Public License version 2 only, as 7 * published by the Free Software Foundation. Oracle designates this 8 * particular file as subject to the "Classpath" exception as provided 9 * by Oracle in the LICENSE file that accompanied this code. 10 * 11 * This code is distributed in the hope that it will be useful, but WITHOUT 12 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or 13 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License 14 * version 2 for more details (a copy is included in the LICENSE file that 15 * accompanied this code). 16 * 17 * You should have received a copy of the GNU General Public License version 18 * 2 along with this work; if not, write to the Free Software Foundation, 19 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. 20 * 21 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA 22 * or visit www.oracle.com if you need additional information or have any 23 * questions. 24 */ 25 26 package oracle.code.triton; 27 28 import jdk.incubator.code.*; 29 import jdk.incubator.code.extern.ExternalizableOp; 30 import jdk.incubator.code.extern.OpFactory; 31 import jdk.incubator.code.dialect.java.JavaType; 32 33 import java.util.HashMap; 34 import java.util.List; 35 import java.util.Map; 36 37 public class ArithMathOps { 38 39 static abstract class ArithMathOp extends ExternalizableOp { 40 final TypeElement resultType; 41 42 public ArithMathOp(ExternalizedOp def) { 43 super(def); 44 45 this.resultType = def.resultType(); 46 } 47 48 ArithMathOp(ArithMathOp that, CopyContext cc) { 49 super(that, cc); 50 51 this.resultType = that.resultType; 52 } 53 54 ArithMathOp(String name, TypeElement resultType, List<? extends Value> operands) { 55 super(name, operands); 56 57 this.resultType = resultType; 58 } 59 60 @Override 61 public TypeElement resultType() { 62 return resultType; 63 } 64 } 65 66 @OpFactory.OpDeclaration(ConstantOp.NAME) 67 public static class ConstantOp extends ArithMathOp implements Op.Pure { 68 public static final String NAME = "arith.constant"; 69 public static final String ATTRIBUTE_CONSTANT_VALUE = "value"; 70 71 final Object value; 72 73 public static ConstantOp create(ExternalizedOp def) { 74 if (!def.operands().isEmpty()) { 75 throw new IllegalArgumentException("Operation must have zero operands"); 76 } 77 78 Object value = def.extractAttributeValue(ATTRIBUTE_CONSTANT_VALUE,true, 79 v -> processConstantValue(def.resultType(), v)); 80 return new ConstantOp(def, value); 81 } 82 83 static Object processConstantValue(TypeElement t, Object value) { 84 if (t.equals(JavaType.BOOLEAN) && value instanceof Boolean) { 85 return value; 86 } else if (t.equals(JavaType.BYTE) && value instanceof Number n) { 87 return n.byteValue(); 88 } else if (t.equals(JavaType.SHORT) && value instanceof Number n) { 89 return n.shortValue(); 90 } else if (t.equals(JavaType.CHAR) && value instanceof Character) { 91 return value; 92 } else if (t.equals(JavaType.INT) && value instanceof Number n) { 93 return n.intValue(); 94 } else if (t.equals(JavaType.LONG) && value instanceof Number n) { 95 return n.longValue(); 96 } else if (t.equals(JavaType.FLOAT) && value instanceof Number n) { 97 return n.floatValue(); 98 } else if (t.equals(Float16.FLOAT_16_TYPE) && value instanceof Number n) { 99 // represent as a float for now 100 return n.floatValue(); 101 } else if (t.equals(JavaType.DOUBLE) && value instanceof Number n) { 102 return n.doubleValue(); 103 } else if (t instanceof TensorType tt) { 104 return processConstantValue(tt.eType(), value); 105 } 106 107 throw new UnsupportedOperationException("Unsupported constant type and value: " + t + " " + value); 108 } 109 110 ConstantOp(ExternalizedOp def, Object value) { 111 super(def); 112 113 this.value = value; 114 } 115 116 ConstantOp(ConstantOp that, CopyContext cc) { 117 super(that, cc); 118 119 this.value = that.value; 120 } 121 122 @Override 123 public ConstantOp transform(CopyContext cc, OpTransformer ot) { 124 return new ConstantOp(this, cc); 125 } 126 127 ConstantOp(TypeElement type, Object value) { 128 super(NAME, type, List.of()); 129 130 this.value = value; 131 } 132 133 @Override 134 public Map<String, Object> attributes() { 135 HashMap<String, Object> attrs = new HashMap<>(super.attributes()); 136 attrs.put(ATTRIBUTE_CONSTANT_VALUE, value); 137 return attrs; 138 } 139 140 public Object value() { 141 return value; 142 } 143 } 144 145 @OpFactory.OpDeclaration(AddOp.NAME) 146 public static class AddOp extends ArithMathOp implements Op.Pure { 147 public static final String NAME = "arith.add"; 148 149 public AddOp(ExternalizedOp def) { 150 super(def); 151 } 152 153 AddOp(AddOp that, CopyContext cc) { 154 super(that, cc); 155 } 156 157 @Override 158 public AddOp transform(CopyContext cc, OpTransformer ot) { 159 return new AddOp(this, cc); 160 } 161 162 AddOp(Value a, Value b) { 163 super(NAME + nameSuffixFromType(a.type(), false), a.type(), List.of(a, b)); 164 } 165 } 166 167 @OpFactory.OpDeclaration(SubOp.NAME) 168 public static class SubOp extends ArithMathOp implements Op.Pure { 169 public static final String NAME = "arith.sub"; 170 171 public SubOp(ExternalizedOp def) { 172 super(def); 173 } 174 175 SubOp(SubOp that, CopyContext cc) { 176 super(that, cc); 177 } 178 179 @Override 180 public SubOp transform(CopyContext cc, OpTransformer ot) { 181 return new SubOp(this, cc); 182 } 183 184 SubOp(Value a, Value b) { 185 super(NAME + nameSuffixFromType(a.type(), false), a.type(), List.of(a, b)); 186 } 187 } 188 189 @OpFactory.OpDeclaration(MulOp.NAME) 190 public static class MulOp extends ArithMathOp implements Op.Pure { 191 public static final String NAME = "arith.mul"; 192 193 public MulOp(ExternalizedOp def) { 194 super(def); 195 } 196 197 MulOp(MulOp that, CopyContext cc) { 198 super(that, cc); 199 } 200 201 @Override 202 public MulOp transform(CopyContext cc, OpTransformer ot) { 203 return new MulOp(this, cc); 204 } 205 206 MulOp(Value a, Value b) { 207 super(NAME + nameSuffixFromType(a.type(), false), a.type(), List.of(a, b)); 208 } 209 } 210 211 @OpFactory.OpDeclaration(DivOp.NAME) 212 public static class DivOp extends ArithMathOp implements Op.Pure { 213 public static final String NAME = "arith.div"; 214 215 public DivOp(ExternalizedOp def) { 216 super(def); 217 } 218 219 DivOp(DivOp that, CopyContext cc) { 220 super(that, cc); 221 } 222 223 @Override 224 public DivOp transform(CopyContext cc, OpTransformer ot) { 225 return new DivOp(this, cc); 226 } 227 228 DivOp(Value a, Value b) { 229 super(NAME + nameSuffixFromType(a.type(), true), a.type(), List.of(a, b)); 230 } 231 } 232 233 @OpFactory.OpDeclaration(RemOp.NAME) 234 public static class RemOp extends ArithMathOp implements Op.Pure { 235 public static final String NAME = "arith.rem"; 236 237 public RemOp(ExternalizedOp def) { 238 super(def); 239 } 240 241 RemOp(RemOp that, CopyContext cc) { 242 super(that, cc); 243 } 244 245 @Override 246 public RemOp transform(CopyContext cc, OpTransformer ot) { 247 return new RemOp(this, cc); 248 } 249 250 RemOp(Value a, Value b) { 251 super(NAME + nameSuffixFromType(a.type(), true), a.type(), List.of(a, b)); 252 } 253 } 254 255 @OpFactory.OpDeclaration(AndOp.NAME) 256 public static class AndOp extends ArithMathOp implements Op.Pure { 257 public static final String NAME = "arith.andi"; 258 259 public AndOp(ExternalizedOp def) { 260 super(def); 261 } 262 263 AndOp(AndOp that, CopyContext cc) { 264 super(that, cc); 265 } 266 267 @Override 268 public AndOp transform(CopyContext cc, OpTransformer ot) { 269 return new AndOp(this, cc); 270 } 271 272 AndOp(Value a, Value b) { 273 super(NAME, a.type(), List.of(a, b)); 274 } 275 } 276 277 @OpFactory.OpDeclaration(MaxOp.NAME) 278 public static class MaxOp extends ArithMathOp implements Op.Pure { 279 public static final String NAME = "arith.max"; 280 281 public MaxOp(ExternalizedOp def) { 282 super(def); 283 } 284 285 MaxOp(MaxOp that, CopyContext cc) { 286 super(that, cc); 287 } 288 289 @Override 290 public MaxOp transform(CopyContext cc, OpTransformer ot) { 291 return new MaxOp(this, cc); 292 } 293 294 MaxOp(Value a, Value b) { 295 super(NAME + maxMinSuffixFromType(a.type()) + nameSuffixFromType(a.type(), true), 296 a.type(), List.of(a, b)); 297 } 298 } 299 300 @OpFactory.OpDeclaration(MinOp.NAME) 301 public static class MinOp extends ArithMathOp implements Op.Pure { 302 public static final String NAME = "arith.min"; 303 304 public MinOp(ExternalizedOp def) { 305 super(def); 306 } 307 308 MinOp(MinOp that, CopyContext cc) { 309 super(that, cc); 310 } 311 312 @Override 313 public MinOp transform(CopyContext cc, OpTransformer ot) { 314 return new MinOp(this, cc); 315 } 316 317 MinOp(Value a, Value b) { 318 super(NAME + maxMinSuffixFromType(a.type()) + nameSuffixFromType(a.type(), true), 319 a.type(), List.of(a, b)); 320 } 321 } 322 323 @OpFactory.OpDeclaration(ExpOp.NAME) 324 public static class TruncOp extends ArithMathOp implements Op.Pure { 325 public static final String NAME = "arith.trunc"; 326 327 public TruncOp(ExternalizedOp def) { 328 super(def); 329 } 330 331 TruncOp(TruncOp that, CopyContext cc) { 332 super(that, cc); 333 } 334 335 @Override 336 public TruncOp transform(CopyContext cc, OpTransformer ot) { 337 return new TruncOp(this, cc); 338 } 339 340 TruncOp(TypeElement t, Value a) { 341 super(NAME + nameSuffixFromType(a.type(), false), 342 t, List.of(a)); 343 } 344 } 345 346 @OpFactory.OpDeclaration(ExpOp.NAME) 347 public static class ExpOp extends ArithMathOp implements Op.Pure { 348 public static final String NAME = "math.exp"; 349 350 public ExpOp(ExternalizedOp def) { 351 super(def); 352 } 353 354 ExpOp(ExpOp that, CopyContext cc) { 355 super(that, cc); 356 } 357 358 @Override 359 public ExpOp transform(CopyContext cc, OpTransformer ot) { 360 return new ExpOp(this, cc); 361 } 362 363 ExpOp(Value a) { 364 super(NAME, a.type(), List.of(a)); 365 } 366 } 367 368 @OpFactory.OpDeclaration(CompareOp.NAME) 369 public static class CompareOp extends ArithMathOp implements Op.Pure { 370 public static final String NAME = "arith.cmp"; 371 public static final String ATTRIBUTE_PREDICATE = "predicate"; 372 373 // https://mlir.llvm.org/docs/Dialects/ArithOps/#cmpipredicate 374 // The ordinal values correspond to the MLIR symbol's values 375 // Need to refine when considering comparisons of floating point numbers which is in a different namespace 376 public enum CompareKind { 377 eq, 378 ne, 379 slt, 380 sle, 381 sgt, 382 sge, 383 ult, 384 ule, 385 ugt, 386 uge 387 } 388 389 final CompareKind ck; 390 391 public static CompareOp create(ExternalizedOp def) { 392 CompareKind ck = def.extractAttributeValue(ATTRIBUTE_PREDICATE, true, 393 v -> switch (v) { 394 case String s -> CompareKind.valueOf(s); 395 case CompareKind k -> k; 396 case null, default -> throw new UnsupportedOperationException("Unsupported start value:" + v); 397 }); 398 return new CompareOp(def, ck); 399 } 400 401 CompareOp(ExternalizedOp def, CompareKind ck) { 402 super(def); 403 404 this.ck = ck; 405 } 406 407 CompareOp(CompareOp that, CopyContext cc) { 408 super(that, cc); 409 410 this.ck = that.ck; 411 } 412 413 @Override 414 public CompareOp transform(CopyContext cc, OpTransformer ot) { 415 return new CompareOp(this, cc); 416 } 417 418 CompareOp(CompareKind ck, Value a, Value b) { 419 TypeElement t; 420 if (a.type() instanceof TensorType ot) { 421 t = new TensorType(JavaType.BOOLEAN, ot.shape()); 422 } 423 else { 424 t = JavaType.BOOLEAN; 425 } 426 super(NAME + nameSuffixFromType(a.type(), false), t, List.of(a, b)); 427 428 this.ck = ck; 429 } 430 431 @Override 432 public Map<String, Object> attributes() { 433 HashMap<String, Object> attrs = new HashMap<>(super.attributes()); 434 attrs.put(ATTRIBUTE_PREDICATE, Long.valueOf(ck.ordinal())); 435 return attrs; 436 } 437 438 public CompareKind kind() { 439 return ck; 440 } 441 } 442 443 static String maxMinSuffixFromType(TypeElement t) { 444 if (t instanceof TensorType tt) { 445 return maxMinSuffixFromType(tt.eType()); 446 } else if (t instanceof PtrType pt) { 447 return maxMinSuffixFromType(pt.rType()); 448 } else if (t.equals(JavaType.INT)) { 449 return ""; 450 } else if (t.equals(JavaType.FLOAT)) { 451 return "imum"; 452 } else { 453 throw new UnsupportedOperationException("Unsupported type: " + t); 454 } 455 } 456 457 static String nameSuffixFromType(TypeElement t, boolean signed) { 458 if (t instanceof TensorType tt) { 459 return nameSuffixFromType(tt.eType(), signed); 460 } else if (t instanceof PtrType pt) { 461 return nameSuffixFromType(pt.rType(), signed); 462 } else if (t.equals(JavaType.INT) || t.equals(JavaType.LONG)) { 463 return (signed ? "s" : "") + "i"; 464 } else if (t.equals(JavaType.FLOAT) || t.equals(JavaType.DOUBLE) || 465 Float16.FLOAT_16_TYPE.equals(t)) { 466 return "f"; 467 } else { 468 throw new UnsupportedOperationException("Unsupported type: " + t); 469 } 470 } 471 472 public static final OpFactory OP_FACTORY = def -> { 473 return switch (def.name()) { 474 case ConstantOp.NAME -> ConstantOp.create(def); 475 case ExpOp.NAME -> new ExpOp(def); 476 case AddOp.NAME + "i", AddOp.NAME + "f" -> new AddOp(def); 477 case SubOp.NAME + "i", SubOp.NAME + "f" -> new SubOp(def); 478 case MulOp.NAME + "i", MulOp.NAME + "f" -> new MulOp(def); 479 case DivOp.NAME + "si", DivOp.NAME + "f" -> new DivOp(def); 480 case RemOp.NAME + "si", RemOp.NAME + "f" -> new DivOp(def); 481 case AndOp.NAME -> new AndOp(def); 482 case MaxOp.NAME + "si", MaxOp.NAME + "imumf" -> new MaxOp(def); 483 case MinOp.NAME + "si", MinOp.NAME + "imumf" -> new MinOp(def); 484 case TruncOp.NAME + "i", TruncOp.NAME + "f" -> new TruncOp(def); 485 case CompareOp.NAME + "i", CompareOp.NAME + "f" -> CompareOp.create(def); 486 default -> null; 487 }; 488 }; 489 490 // Arith 491 492 public static ConstantOp constant(TypeElement type, Object value) { 493 return new ConstantOp(type, value); 494 } 495 496 public static MulOp mul(Value a, Value b) { 497 return new MulOp(a, b); 498 } 499 500 public static AddOp add(Value a, Value b) { 501 return new AddOp(a, b); 502 } 503 504 public static SubOp sub(Value a, Value b) { 505 return new SubOp(a, b); 506 } 507 508 public static DivOp div(Value a, Value b) { 509 return new DivOp(a, b); 510 } 511 512 public static RemOp rem(Value a, Value b) { 513 return new RemOp(a, b); 514 } 515 516 public static AndOp and(Value a, Value b) { 517 return new AndOp(a, b); 518 } 519 520 public static MaxOp maximum(Value a, Value b) { 521 return new MaxOp(a, b); 522 } 523 524 public static MinOp minimum(Value a, Value b) { 525 return new MinOp(a, b); 526 } 527 528 public static CompareOp cmp(CompareOp.CompareKind ck, Value a, Value b) { 529 return new CompareOp(ck, a, b); 530 } 531 532 public static TruncOp trunc(TypeElement type, Value a) { 533 return new TruncOp(type, a); 534 } 535 536 // Math 537 538 public static ExpOp exp(Value a) { 539 return new ExpOp(a); 540 } 541 }