1 /* 2 * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved. 3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. 4 * 5 * This code is free software; you can redistribute it and/or modify it 6 * under the terms of the GNU General Public License version 2 only, as 7 * published by the Free Software Foundation. Oracle designates this 8 * particular file as subject to the "Classpath" exception as provided 9 * by Oracle in the LICENSE file that accompanied this code. 10 * 11 * This code is distributed in the hope that it will be useful, but WITHOUT 12 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or 13 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License 14 * version 2 for more details (a copy is included in the LICENSE file that 15 * accompanied this code). 16 * 17 * You should have received a copy of the GNU General Public License version 18 * 2 along with this work; if not, write to the Free Software Foundation, 19 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. 20 * 21 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA 22 * or visit www.oracle.com if you need additional information or have any 23 * questions. 24 */ 25 26 package oracle.code.triton; 27 28 import jdk.incubator.code.*; 29 import jdk.incubator.code.extern.ExternalizedOp; 30 import jdk.incubator.code.extern.OpFactory; 31 import jdk.incubator.code.dialect.java.JavaType; 32 33 import java.util.HashMap; 34 import java.util.List; 35 import java.util.Map; 36 37 public class ArithMathOps { 38 39 static abstract class ArithMathOp extends Op { 40 final TypeElement resultType; 41 42 public ArithMathOp(ExternalizedOp def) { 43 super(def.name(), def.operands());; 44 45 this.resultType = def.resultType(); 46 } 47 48 ArithMathOp(ArithMathOp that, CopyContext cc) { 49 super(that, cc); 50 51 this.resultType = that.resultType; 52 } 53 54 ArithMathOp(String name, TypeElement resultType, List<? extends Value> operands) { 55 super(name, operands); 56 57 this.resultType = resultType; 58 } 59 60 @Override 61 public TypeElement resultType() { 62 return resultType; 63 } 64 } 65 66 @OpFactoryHelper.OpDeclaration(ConstantOp.NAME) 67 public static class ConstantOp extends ArithMathOp implements Op.Pure { 68 public static final String NAME = "arith.constant"; 69 public static final String ATTRIBUTE_CONSTANT_VALUE = "value"; 70 71 final Object value; 72 73 public static ConstantOp create(ExternalizedOp def) { 74 if (!def.operands().isEmpty()) { 75 throw new IllegalArgumentException("Operation must have zero operands"); 76 } 77 78 Object value = def.extractAttributeValue(ATTRIBUTE_CONSTANT_VALUE,true, 79 v -> processConstantValue(def.resultType(), v)); 80 return new ConstantOp(def, value); 81 } 82 83 static Object processConstantValue(TypeElement t, Object value) { 84 if (t.equals(JavaType.BOOLEAN) && value instanceof Boolean) { 85 return value; 86 } else if (t.equals(JavaType.BYTE) && value instanceof Number n) { 87 return n.byteValue(); 88 } else if (t.equals(JavaType.SHORT) && value instanceof Number n) { 89 return n.shortValue(); 90 } else if (t.equals(JavaType.CHAR) && value instanceof Character) { 91 return value; 92 } else if (t.equals(JavaType.INT) && value instanceof Number n) { 93 return n.intValue(); 94 } else if (t.equals(JavaType.LONG) && value instanceof Number n) { 95 return n.longValue(); 96 } else if (t.equals(JavaType.FLOAT) && value instanceof Number n) { 97 return n.floatValue(); 98 } else if (t.equals(Float16.FLOAT_16_TYPE) && value instanceof Number n) { 99 // represent as a float for now 100 return n.floatValue(); 101 } else if (t.equals(JavaType.DOUBLE) && value instanceof Number n) { 102 return n.doubleValue(); 103 } else if (t instanceof TensorType tt) { 104 return processConstantValue(tt.eType(), value); 105 } 106 107 throw new UnsupportedOperationException("Unsupported constant type and value: " + t + " " + value); 108 } 109 110 ConstantOp(ExternalizedOp def, Object value) { 111 super(def); 112 113 this.value = value; 114 } 115 116 ConstantOp(ConstantOp that, CopyContext cc) { 117 super(that, cc); 118 119 this.value = that.value; 120 } 121 122 @Override 123 public ConstantOp transform(CopyContext cc, OpTransformer ot) { 124 return new ConstantOp(this, cc); 125 } 126 127 ConstantOp(TypeElement type, Object value) { 128 super(NAME, type, List.of()); 129 130 this.value = value; 131 } 132 133 @Override 134 public Map<String, Object> externalize() { 135 return Map.of(ATTRIBUTE_CONSTANT_VALUE, value); 136 } 137 138 public Object value() { 139 return value; 140 } 141 } 142 143 @OpFactoryHelper.OpDeclaration(AddOp.NAME) 144 public static class AddOp extends ArithMathOp implements Op.Pure { 145 public static final String NAME = "arith.add"; 146 147 public AddOp(ExternalizedOp def) { 148 super(def); 149 } 150 151 AddOp(AddOp that, CopyContext cc) { 152 super(that, cc); 153 } 154 155 @Override 156 public AddOp transform(CopyContext cc, OpTransformer ot) { 157 return new AddOp(this, cc); 158 } 159 160 AddOp(Value a, Value b) { 161 super(NAME + nameSuffixFromType(a.type(), false), a.type(), List.of(a, b)); 162 } 163 } 164 165 @OpFactoryHelper.OpDeclaration(SubOp.NAME) 166 public static class SubOp extends ArithMathOp implements Op.Pure { 167 public static final String NAME = "arith.sub"; 168 169 public SubOp(ExternalizedOp def) { 170 super(def); 171 } 172 173 SubOp(SubOp that, CopyContext cc) { 174 super(that, cc); 175 } 176 177 @Override 178 public SubOp transform(CopyContext cc, OpTransformer ot) { 179 return new SubOp(this, cc); 180 } 181 182 SubOp(Value a, Value b) { 183 super(NAME + nameSuffixFromType(a.type(), false), a.type(), List.of(a, b)); 184 } 185 } 186 187 @OpFactoryHelper.OpDeclaration(MulOp.NAME) 188 public static class MulOp extends ArithMathOp implements Op.Pure { 189 public static final String NAME = "arith.mul"; 190 191 public MulOp(ExternalizedOp def) { 192 super(def); 193 } 194 195 MulOp(MulOp that, CopyContext cc) { 196 super(that, cc); 197 } 198 199 @Override 200 public MulOp transform(CopyContext cc, OpTransformer ot) { 201 return new MulOp(this, cc); 202 } 203 204 MulOp(Value a, Value b) { 205 super(NAME + nameSuffixFromType(a.type(), false), a.type(), List.of(a, b)); 206 } 207 } 208 209 @OpFactoryHelper.OpDeclaration(DivOp.NAME) 210 public static class DivOp extends ArithMathOp implements Op.Pure { 211 public static final String NAME = "arith.div"; 212 213 public DivOp(ExternalizedOp def) { 214 super(def); 215 } 216 217 DivOp(DivOp that, CopyContext cc) { 218 super(that, cc); 219 } 220 221 @Override 222 public DivOp transform(CopyContext cc, OpTransformer ot) { 223 return new DivOp(this, cc); 224 } 225 226 DivOp(Value a, Value b) { 227 super(NAME + nameSuffixFromType(a.type(), true), a.type(), List.of(a, b)); 228 } 229 } 230 231 @OpFactoryHelper.OpDeclaration(RemOp.NAME) 232 public static class RemOp extends ArithMathOp implements Op.Pure { 233 public static final String NAME = "arith.rem"; 234 235 public RemOp(ExternalizedOp def) { 236 super(def); 237 } 238 239 RemOp(RemOp that, CopyContext cc) { 240 super(that, cc); 241 } 242 243 @Override 244 public RemOp transform(CopyContext cc, OpTransformer ot) { 245 return new RemOp(this, cc); 246 } 247 248 RemOp(Value a, Value b) { 249 super(NAME + nameSuffixFromType(a.type(), true), a.type(), List.of(a, b)); 250 } 251 } 252 253 @OpFactoryHelper.OpDeclaration(AndOp.NAME) 254 public static class AndOp extends ArithMathOp implements Op.Pure { 255 public static final String NAME = "arith.andi"; 256 257 public AndOp(ExternalizedOp def) { 258 super(def); 259 } 260 261 AndOp(AndOp that, CopyContext cc) { 262 super(that, cc); 263 } 264 265 @Override 266 public AndOp transform(CopyContext cc, OpTransformer ot) { 267 return new AndOp(this, cc); 268 } 269 270 AndOp(Value a, Value b) { 271 super(NAME, a.type(), List.of(a, b)); 272 } 273 } 274 275 @OpFactoryHelper.OpDeclaration(MaxOp.NAME) 276 public static class MaxOp extends ArithMathOp implements Op.Pure { 277 public static final String NAME = "arith.max"; 278 279 public MaxOp(ExternalizedOp def) { 280 super(def); 281 } 282 283 MaxOp(MaxOp that, CopyContext cc) { 284 super(that, cc); 285 } 286 287 @Override 288 public MaxOp transform(CopyContext cc, OpTransformer ot) { 289 return new MaxOp(this, cc); 290 } 291 292 MaxOp(Value a, Value b) { 293 super(NAME + maxMinSuffixFromType(a.type()) + nameSuffixFromType(a.type(), true), 294 a.type(), List.of(a, b)); 295 } 296 } 297 298 @OpFactoryHelper.OpDeclaration(MinOp.NAME) 299 public static class MinOp extends ArithMathOp implements Op.Pure { 300 public static final String NAME = "arith.min"; 301 302 public MinOp(ExternalizedOp def) { 303 super(def); 304 } 305 306 MinOp(MinOp that, CopyContext cc) { 307 super(that, cc); 308 } 309 310 @Override 311 public MinOp transform(CopyContext cc, OpTransformer ot) { 312 return new MinOp(this, cc); 313 } 314 315 MinOp(Value a, Value b) { 316 super(NAME + maxMinSuffixFromType(a.type()) + nameSuffixFromType(a.type(), true), 317 a.type(), List.of(a, b)); 318 } 319 } 320 321 @OpFactoryHelper.OpDeclaration(ExpOp.NAME) 322 public static class TruncOp extends ArithMathOp implements Op.Pure { 323 public static final String NAME = "arith.trunc"; 324 325 public TruncOp(ExternalizedOp def) { 326 super(def); 327 } 328 329 TruncOp(TruncOp that, CopyContext cc) { 330 super(that, cc); 331 } 332 333 @Override 334 public TruncOp transform(CopyContext cc, OpTransformer ot) { 335 return new TruncOp(this, cc); 336 } 337 338 TruncOp(TypeElement t, Value a) { 339 super(NAME + nameSuffixFromType(a.type(), false), 340 t, List.of(a)); 341 } 342 } 343 344 @OpFactoryHelper.OpDeclaration(ExpOp.NAME) 345 public static class ExpOp extends ArithMathOp implements Op.Pure { 346 public static final String NAME = "math.exp"; 347 348 public ExpOp(ExternalizedOp def) { 349 super(def); 350 } 351 352 ExpOp(ExpOp that, CopyContext cc) { 353 super(that, cc); 354 } 355 356 @Override 357 public ExpOp transform(CopyContext cc, OpTransformer ot) { 358 return new ExpOp(this, cc); 359 } 360 361 ExpOp(Value a) { 362 super(NAME, a.type(), List.of(a)); 363 } 364 } 365 366 @OpFactoryHelper.OpDeclaration(CompareOp.NAME) 367 public static class CompareOp extends ArithMathOp implements Op.Pure { 368 public static final String NAME = "arith.cmp"; 369 public static final String ATTRIBUTE_PREDICATE = "predicate"; 370 371 // https://mlir.llvm.org/docs/Dialects/ArithOps/#cmpipredicate 372 // The ordinal values correspond to the MLIR symbol's values 373 // Need to refine when considering comparisons of floating point numbers which is in a different namespace 374 public enum CompareKind { 375 eq, 376 ne, 377 slt, 378 sle, 379 sgt, 380 sge, 381 ult, 382 ule, 383 ugt, 384 uge 385 } 386 387 final CompareKind ck; 388 389 public static CompareOp create(ExternalizedOp def) { 390 CompareKind ck = def.extractAttributeValue(ATTRIBUTE_PREDICATE, true, 391 v -> switch (v) { 392 case String s -> CompareKind.valueOf(s); 393 case CompareKind k -> k; 394 case null, default -> throw new UnsupportedOperationException("Unsupported start value:" + v); 395 }); 396 return new CompareOp(def, ck); 397 } 398 399 CompareOp(ExternalizedOp def, CompareKind ck) { 400 super(def); 401 402 this.ck = ck; 403 } 404 405 CompareOp(CompareOp that, CopyContext cc) { 406 super(that, cc); 407 408 this.ck = that.ck; 409 } 410 411 @Override 412 public CompareOp transform(CopyContext cc, OpTransformer ot) { 413 return new CompareOp(this, cc); 414 } 415 416 CompareOp(CompareKind ck, Value a, Value b) { 417 TypeElement t; 418 if (a.type() instanceof TensorType ot) { 419 t = new TensorType(JavaType.BOOLEAN, ot.shape()); 420 } 421 else { 422 t = JavaType.BOOLEAN; 423 } 424 super(NAME + nameSuffixFromType(a.type(), false), t, List.of(a, b)); 425 426 this.ck = ck; 427 } 428 429 @Override 430 public Map<String, Object> externalize() { 431 return Map.of(ATTRIBUTE_PREDICATE, Long.valueOf(ck.ordinal())); 432 } 433 434 public CompareKind kind() { 435 return ck; 436 } 437 } 438 439 static String maxMinSuffixFromType(TypeElement t) { 440 if (t instanceof TensorType tt) { 441 return maxMinSuffixFromType(tt.eType()); 442 } else if (t instanceof PtrType pt) { 443 return maxMinSuffixFromType(pt.rType()); 444 } else if (t.equals(JavaType.INT)) { 445 return ""; 446 } else if (t.equals(JavaType.FLOAT)) { 447 return "imum"; 448 } else { 449 throw new UnsupportedOperationException("Unsupported type: " + t); 450 } 451 } 452 453 static String nameSuffixFromType(TypeElement t, boolean signed) { 454 if (t instanceof TensorType tt) { 455 return nameSuffixFromType(tt.eType(), signed); 456 } else if (t instanceof PtrType pt) { 457 return nameSuffixFromType(pt.rType(), signed); 458 } else if (t.equals(JavaType.INT) || t.equals(JavaType.LONG)) { 459 return (signed ? "s" : "") + "i"; 460 } else if (t.equals(JavaType.FLOAT) || t.equals(JavaType.DOUBLE) || 461 Float16.FLOAT_16_TYPE.equals(t)) { 462 return "f"; 463 } else { 464 throw new UnsupportedOperationException("Unsupported type: " + t); 465 } 466 } 467 468 public static final OpFactory OP_FACTORY = def -> { 469 return switch (def.name()) { 470 case ConstantOp.NAME -> ConstantOp.create(def); 471 case ExpOp.NAME -> new ExpOp(def); 472 case AddOp.NAME + "i", AddOp.NAME + "f" -> new AddOp(def); 473 case SubOp.NAME + "i", SubOp.NAME + "f" -> new SubOp(def); 474 case MulOp.NAME + "i", MulOp.NAME + "f" -> new MulOp(def); 475 case DivOp.NAME + "si", DivOp.NAME + "f" -> new DivOp(def); 476 case RemOp.NAME + "si", RemOp.NAME + "f" -> new DivOp(def); 477 case AndOp.NAME -> new AndOp(def); 478 case MaxOp.NAME + "si", MaxOp.NAME + "imumf" -> new MaxOp(def); 479 case MinOp.NAME + "si", MinOp.NAME + "imumf" -> new MinOp(def); 480 case TruncOp.NAME + "i", TruncOp.NAME + "f" -> new TruncOp(def); 481 case CompareOp.NAME + "i", CompareOp.NAME + "f" -> CompareOp.create(def); 482 default -> null; 483 }; 484 }; 485 486 // Arith 487 488 public static ConstantOp constant(TypeElement type, Object value) { 489 return new ConstantOp(type, value); 490 } 491 492 public static MulOp mul(Value a, Value b) { 493 return new MulOp(a, b); 494 } 495 496 public static AddOp add(Value a, Value b) { 497 return new AddOp(a, b); 498 } 499 500 public static SubOp sub(Value a, Value b) { 501 return new SubOp(a, b); 502 } 503 504 public static DivOp div(Value a, Value b) { 505 return new DivOp(a, b); 506 } 507 508 public static RemOp rem(Value a, Value b) { 509 return new RemOp(a, b); 510 } 511 512 public static AndOp and(Value a, Value b) { 513 return new AndOp(a, b); 514 } 515 516 public static MaxOp maximum(Value a, Value b) { 517 return new MaxOp(a, b); 518 } 519 520 public static MinOp minimum(Value a, Value b) { 521 return new MinOp(a, b); 522 } 523 524 public static CompareOp cmp(CompareOp.CompareKind ck, Value a, Value b) { 525 return new CompareOp(ck, a, b); 526 } 527 528 public static TruncOp trunc(TypeElement type, Value a) { 529 return new TruncOp(type, a); 530 } 531 532 // Math 533 534 public static ExpOp exp(Value a) { 535 return new ExpOp(a); 536 } 537 }