1 /*
   2  * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
   3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
   4  *
   5  * This code is free software; you can redistribute it and/or modify it
   6  * under the terms of the GNU General Public License version 2 only, as
   7  * published by the Free Software Foundation.  Oracle designates this
   8  * particular file as subject to the "Classpath" exception as provided
   9  * by Oracle in the LICENSE file that accompanied this code.
  10  *
  11  * This code is distributed in the hope that it will be useful, but WITHOUT
  12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
  13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
  14  * version 2 for more details (a copy is included in the LICENSE file that
  15  * accompanied this code).
  16  *
  17  * You should have received a copy of the GNU General Public License version
  18  * 2 along with this work; if not, write to the Free Software Foundation,
  19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
  20  *
  21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
  22  * or visit www.oracle.com if you need additional information or have any
  23  * questions.
  24  */
  25 
  26 package oracle.code.triton;
  27 
  28 import java.lang.constant.ClassDesc;
  29 import java.lang.invoke.MethodHandle;
  30 import java.lang.invoke.MethodHandles;
  31 import java.lang.reflect.Field;
  32 import java.lang.reflect.Method;
  33 import jdk.incubator.code.*;
  34 import jdk.incubator.code.analysis.SSA;
  35 import jdk.incubator.code.dialect.core.CoreOp;
  36 import jdk.incubator.code.dialect.java.JavaOp;
  37 import jdk.incubator.code.dialect.java.JavaType;
  38 import jdk.incubator.code.dialect.core.VarType;
  39 import java.util.*;
  40 import java.util.concurrent.atomic.AtomicInteger;
  41 import java.util.stream.Stream;
  42 
  43 import static jdk.incubator.code.dialect.core.CoreOp.*;
  44 import static jdk.incubator.code.dialect.core.CoreType.functionType;
  45 import static jdk.incubator.code.dialect.java.JavaOp.*;
  46 
  47 public final class TritonTransformer {
  48     private TritonTransformer() {}
  49 
  50     static final JavaType TYPE_Triton = JavaType.type(Triton.class);
  51 
  52     static final JavaType TYPE_Triton_Test = JavaType.type(ClassDesc.of("oracle.code.triton.TritonTest"));
  53 
  54     static final JavaType TYPE_Tensor = JavaType.type(Tensor.class);
  55 
  56     static final JavaType TYPE_J_L_MATH = JavaType.type(Math.class);
  57 
  58     public static <O extends Op & Op.Invokable>
  59     TritonOps.ModuleOp tritonModule(O kernel,
  60                                     TypeElement rType,
  61                                     List<? extends TypeElement> argTypes) {
  62         Map<String, TritonOps.FuncOp> fsymTable = new LinkedHashMap<>();
  63         tritonFunction(kernel, rType, argTypes, fsymTable);
  64         return TritonOps.module(fsymTable.values().stream().toList());
  65     }
  66 
  67     public static <O extends Op & Op.Invokable>
  68     TritonOps.FuncOp tritonFunction(O javaKernel,
  69                                     TypeElement rType,
  70                                     List<? extends TypeElement> argTypes,
  71                                     Map<String, TritonOps.FuncOp> fsymTable) {
  72         String name = (javaKernel instanceof FuncOp f) ? f.funcName() : "kernel";
  73         String signature = signature(name, rType, argTypes);
  74         if (fsymTable.containsKey(signature)) {
  75             return fsymTable.get(signature);
  76         }
  77 
  78         System.out.println(javaKernel.toText());
  79 
  80         Map<Value, TypeElement> valueTypeMap = new HashMap<>();
  81         Map<Op, Object> opData = new HashMap<>();
  82         TritonTransformer.typeCheckKernel(javaKernel, argTypes, valueTypeMap, opData);
  83         TritonTransformer.printTypeMap(javaKernel, valueTypeMap);
  84 
  85         return TritonTransformer.transformToTritonFunction(javaKernel, signature,
  86                 rType, valueTypeMap, opData,
  87                 fsymTable);
  88     }
  89 
  90     static String signature(String name, TypeElement rType, List<? extends TypeElement> argTypes) {
  91         StringBuilder sb = new StringBuilder(name);
  92 
  93         for (TypeElement argType : argTypes) {
  94             sb.append("_");
  95             if (argType instanceof ConstantType ct) {
  96                 sb.append(ct.value());
  97             } else {
  98                 sb.append(argType);
  99             }
 100         }
 101         sb.append("_");
 102         sb.append(rType);
 103         return sb.toString();
 104     }
 105 
 106     public static <O extends Op & Op.Invokable> void typeCheckKernel(
 107             O kernel, List<? extends TypeElement> argTypes,
 108             Map<Value, TypeElement> valueTypeMap, Map<Op, Object> opData) {
 109         kernel.traverse(null, CodeElement.opVisitor((o, op) -> {
 110             switch (op) {
 111                 case Op.Invokable fop -> {
 112                     List<Block.Parameter> parameters = fop.body().entryBlock().parameters();
 113                     for (int i = 0; i < parameters.size(); i++) {
 114                         valueTypeMap.put(parameters.get(i), argTypes.get(i));
 115                     }
 116                 }
 117                 case VarOp _, VarAccessOp.VarLoadOp _ -> {
 118                     Value init = op.operands().get(0);
 119                     valueTypeMap.put(op.result(), valueTypeMap.get(init));
 120                 }
 121                 case VarAccessOp.VarStoreOp _ -> {
 122                     Value var = op.operands().get(0);
 123                     TypeElement varType = valueTypeMap.get(var);
 124                     Value v = op.operands().get(1);
 125                     TypeElement vType = valueTypeMap.get(v);
 126                     if (!varType.equals(vType)) {
 127                         throw new IllegalStateException("Storing to variable with different type: "
 128                                 + varType + " <- " + vType);
 129                     }
 130 
 131                     valueTypeMap.put(op.result(), valueTypeMap.get(var));
 132                 }
 133                 case ConstantOp cop -> {
 134                     valueTypeMap.put(op.result(), new ConstantType(op.result().type(), cop.value()));
 135                 }
 136                 case ArithmeticOperation _ -> {
 137                     TypeElement t = checkWithTypeInterpreter(op, op.externalizeOpName(), valueTypeMap);
 138                     valueTypeMap.put(op.result(), t);
 139                 }
 140                 case FieldAccessOp.FieldLoadOp flop -> {
 141                     if (!flop.operands().isEmpty()) {
 142                         throw new IllegalStateException("Unsupported field load: " + flop.fieldDescriptor());
 143                     }
 144 
 145                     Field f;
 146                     try {
 147                         f = flop.fieldDescriptor().resolveToField(MethodHandles.lookup());
 148                     } catch (ReflectiveOperationException e) {
 149                         throw new IllegalStateException("Unsupported field load: " + flop.fieldDescriptor(), e);
 150                     }
 151                     Object value;
 152                     try {
 153                         value = f.get(null);
 154                     } catch (IllegalAccessException e) {
 155                         throw new IllegalStateException("Unsupported field load: " + f, e);
 156                     }
 157                     valueTypeMap.put(op.result(), new ConstantType(JavaType.type(f.getType()), value));
 158                 }
 159                 case InvokeOp iop when iop.invokeDescriptor().refType().equals(JavaType.J_L_INTEGER) -> {
 160                     // Box
 161                     if (iop.invokeDescriptor().name().equals("valueOf")) {
 162                         Value a = op.operands().get(0);
 163                         valueTypeMap.put(op.result(), valueTypeMap.get(a));
 164                     } else {
 165                         throw new UnsupportedOperationException("Unsupported invocation on Integer: " + iop.invokeDescriptor());
 166                     }
 167                 }
 168                 case InvokeOp iop when iop.invokeDescriptor().refType().equals(TYPE_J_L_MATH) -> {
 169                     String name = iop.invokeDescriptor().name();
 170                     if (name.equals("max") || name.equals("min")) {
 171                         Value a = op.operands().get(0);
 172                         valueTypeMap.put(op.result(), valueTypeMap.get(a));
 173                     } else {
 174                         throw new UnsupportedOperationException("Unsupported invocation on Math: " + iop.invokeDescriptor());
 175                     }
 176                 }
 177                 case InvokeOp iop when iop.invokeDescriptor().refType().equals(TYPE_Tensor) -> {
 178                     if (iop.invokeDescriptor().name().equals("type")) {
 179                         Value a = op.operands().get(0);
 180                         valueTypeMap.put(op.result(), valueTypeMap.get(a));
 181                     } else {
 182                         throw new UnsupportedOperationException("Unsupported invocation on Tensor: " + iop.invokeDescriptor());
 183                     }
 184                 }
 185                 case InvokeOp iop when iop.invokeDescriptor().refType().equals(TYPE_Triton) -> {
 186                     TypeElement t = checkWithTypeInterpreter(op, iop.invokeDescriptor().name(), valueTypeMap);
 187                     valueTypeMap.put(op.result(), t);
 188                 }
 189                 case InvokeOp iop when iop.invokeDescriptor().refType().equals(TYPE_Triton_Test) -> {
 190                     TypeElement t = checkWithTypeInterpreter(op, iop.invokeDescriptor().name(), valueTypeMap);
 191                     valueTypeMap.put(op.result(), t);
 192                 }
 193                 case JavaOp.ForOp fop -> {
 194                     SimpleCountedForLoopInfo li = new SimpleCountedForLoopInfo(fop);
 195                     opData.put(fop, li);
 196 
 197                     TypeElement type = fop.init().yieldType();
 198                     if (type instanceof VarType vt && vt.valueType().equals(JavaType.INT)) {
 199                         for (Body b : List.of(fop.cond(), fop.update(), fop.loopBody())) {
 200                             valueTypeMap.put(b.entryBlock().parameters().get(0), JavaType.INT);
 201                         }
 202                     } else {
 203                         throw new IllegalStateException();
 204                     }
 205                 }
 206                 case TestOperation _ -> {
 207                 }
 208                 case JavaOp.ContinueOp _ -> {
 209                 }
 210                 case CoreOp.YieldOp _ -> {
 211                 }
 212                 case ReturnOp _ -> {
 213                 }
 214                 default -> throw new UnsupportedOperationException("Unsupported operation: " + op);
 215             }
 216 
 217             return null;
 218         }));
 219     }
 220 
 221     static TypeElement checkWithTypeInterpreter(Op op, String name, Map<Value, TypeElement> valueTypeMap) {
 222         // Obtain associated type-based method
 223         MethodHandle mh;
 224         try {
 225             Optional<Method> om = Stream.of(TritonTypeInterpreter.class.getDeclaredMethods())
 226                     .filter(m -> m.getName().equals(name))
 227                     .filter(m -> m.isVarArgs() ? m.getParameterCount() <= op.operands().size() : m.getParameterCount() == op.operands().size())
 228                     .findFirst();
 229             mh = MethodHandles.lookup().unreflect(
 230                     om.orElseThrow(() -> new NoSuchMethodException(name)));
 231         } catch (ReflectiveOperationException e) {
 232             throw new IllegalStateException(name, e);
 233         }
 234 
 235         // Invoke with the values' types
 236         List<TypeElement> operandTypes = op.operands().stream().map(valueTypeMap::get).toList();
 237         try {
 238             return (TypeElement) mh.invokeWithArguments(operandTypes.toArray(Object[]::new));
 239         } catch (Throwable e) {
 240             throw new IllegalStateException(mh.toString(), e);
 241         }
 242     }
 243 
 244     // @@@ type check tensor shapes
 245     static class TritonTypeInterpreter {
 246         private TritonTypeInterpreter() {
 247         }
 248 
 249         //                 int programId(@Constant int axis) {
 250         public static JavaType programId(ConstantType axis) {
 251             assert axis.cType().equals(JavaType.INT);
 252             int axisValue = (int) axis.value();
 253             if (axisValue < 0 || axisValue > 3) {
 254                 throw new IllegalStateException();
 255             }
 256 
 257             return JavaType.INT;
 258         }
 259 
 260         //                Tensor arange(@Constant int start, @Constant int end)
 261         public static TensorType arange(ConstantType start, ConstantType end) {
 262             assert start.cType().equals(JavaType.INT);
 263             assert end.cType().equals(JavaType.INT);
 264 
 265             int startValue = (int) start.value();
 266             int endValue = (int) end.value();
 267 
 268             return new TensorType(JavaType.INT, List.of(endValue - startValue));
 269         }
 270 
 271         //                Tensor expand(Tensor a, int axis) {
 272         public static TensorType expand(TensorType a, ConstantType axis) {
 273             assert axis.cType().equals(JavaType.INT);
 274             int axisValue = (int) axis.value();
 275 
 276             List<Integer> s = new ArrayList<>(a.shape());
 277             if (axisValue < s.size()) {
 278                 s.add(axisValue, 1);
 279             } else {
 280                 for (int i = 0; i <= (axisValue - s.size()); i++) {
 281                     s.add(1);
 282                 }
 283             }
 284             return new TensorType(a.eType(), s);
 285         }
 286 
 287         //                Tensor load(Tensor ptr, Tensor mask)
 288         public static TensorType load(TensorType ptr, TensorType mask) {
 289             checkTensorShape(ptr, mask);
 290             if (ptr.eType() instanceof PtrType eptr) {
 291                 return new TensorType(eptr.rType(), ptr.shape());
 292             }
 293 
 294             throw new IllegalStateException();
 295         }
 296 
 297         //                Tensor load(Tensor ptr, Tensor mask, ConstantType other) {
 298         public static TensorType load(TensorType ptr, TensorType mask, ConstantType other) {
 299             checkTensorShape(ptr, mask);
 300             if (ptr.eType() instanceof PtrType eptr) {
 301                 return new TensorType(eptr.rType(), ptr.shape());
 302             }
 303 
 304             throw new IllegalStateException();
 305         }
 306 
 307         //            void store(Tensor ptr, Tensor value, Tensor mask)
 308         public static void store(TensorType ptr, TensorType value, TensorType mask) {
 309             if (!(ptr.eType() instanceof PtrType)) {
 310                 throw new IllegalStateException();
 311             }
 312         }
 313 
 314         //                Tensor zeros(TensorType type)
 315         public static TensorType zeros(ConstantType eType, ConstantType... cShape) {
 316             List<Integer> shape = Stream.of(cShape).map(s -> (int) s.value()).toList();
 317             return new TensorType((TypeElement) eType.value(), shape);
 318         }
 319 
 320         //                Tensor broadcast(Object o, TensorType type)
 321         public static TensorType broadcast(TypeElement o, TensorType type) {
 322             if (o instanceof TensorType ot) {
 323                 // @@@
 324                 if (ot.shape().size() != type.shape().size()) {
 325                     throw new IllegalStateException();
 326                 }
 327                 o = ot.eType();
 328             } if (o instanceof ConstantType oc) {
 329                 o = oc.cType();
 330             }
 331             return new TensorType(o, type.shape());
 332         }
 333 
 334         public static TensorType joinShape(TensorType a, TensorType b) {
 335             return checkTensorTypes(a, b);
 336         }
 337 
 338         //          Tensor add(Number a, Number b)
 339         //             Ptr add(Ptr a, int offset)
 340         public static TypeElement add(TypeElement a, TypeElement b) {
 341             // @@@ Pass additional argument for checking ptr
 342             return binary(a, b);
 343         }
 344 
 345         public static TypeElement sub(TypeElement a, TypeElement b) {
 346             return binary(a, b);
 347         }
 348 
 349         public static TypeElement mul(TypeElement a, TypeElement b) {
 350             return binary(a, b);
 351         }
 352 
 353         public static TypeElement div(TypeElement a, TypeElement b) {
 354             return binary(a, b);
 355         }
 356 
 357         public static TypeElement mod(TypeElement a, TypeElement b) {
 358             return binary(a, b);
 359         }
 360 
 361         public static TypeElement and(TypeElement a, TypeElement b) {
 362             return binary(a, b);
 363         }
 364 
 365         public static TypeElement cdiv(TypeElement a, TypeElement b) {
 366             a = reduceScalarType(a);
 367             b = reduceScalarType(b);
 368             if (!a.equals(JavaType.INT) && !b.equals(JavaType.INT)) {
 369                 throw new IllegalStateException();
 370             }
 371             return a;
 372         }
 373 
 374         //          Number conv(Type t, Number a) {
 375         public static TypeElement conv(ConstantType eType, TypeElement a) {
 376             return convTypes(eType, a);
 377         }
 378 
 379         public static TypeElement convTypes(ConstantType eType, TypeElement a) {
 380             if (a instanceof TensorType tb) {
 381                 TypeElement e = convScalarTypes(eType, tb.eType());
 382                 return new TensorType(e, tb.shape());
 383             } else {
 384                 return convScalarTypes(eType, a);
 385             }
 386         }
 387 
 388         public static TypeElement convScalarTypes(ConstantType eType, TypeElement a) {
 389             TypeElement t = (TypeElement) eType.value();
 390             if (t.equals(Float16.FLOAT_16_TYPE) && a.equals(JavaType.FLOAT)) {
 391                 return Float16.FLOAT_16_TYPE;
 392             } else if (t.equals(a)) {
 393                 return t;
 394             } else {
 395                 // @@@ Conversions;
 396                 throw new IllegalStateException();
 397             }
 398         }
 399 
 400         //          Tensor exp(Tensor a)
 401         public static TypeElement exp(TypeElement a) {
 402             return unary(a);
 403         }
 404 
 405         static TypeElement unary(TypeElement a) {
 406             return a;
 407         }
 408 
 409         //                Tensor compare(Number a, Number b, @Constant CompareKind ck) {
 410         public static TypeElement compare(TypeElement a, TypeElement b, ConstantType kind) {
 411             assert kind.cType().equals(JavaType.type(Triton.CompareKind.class));
 412 
 413             TypeElement t = binary(a, b);
 414             if (t instanceof TensorType tt) {
 415                 return new TensorType(JavaType.BOOLEAN, tt.shape());
 416             } else {
 417                 return t;
 418             }
 419         }
 420 
 421         //                Tensor dot(Tensor a, Tensor b)
 422         public static TensorType dot(TensorType a, TensorType b) {
 423             if (a.shape().size() != 2 || b.shape().size() != 2) {
 424                 throw new IllegalStateException();
 425             }
 426 
 427             if (!a.shape().get(1).equals(b.shape().get(0))) {
 428                 throw new IllegalStateException();
 429             }
 430 
 431             if (a.eType() != b.eType()) {
 432                 // @@@ Conversion, type checking
 433                 throw new IllegalStateException();
 434             }
 435 
 436             // Computed result is tensor of floats, regardless of inputs
 437             return new TensorType(JavaType.FLOAT, List.of(a.shape().get(0), b.shape().get(1)));
 438         }
 439 
 440 
 441         //                Tensor max(Tensor a, @Constant int axis) {
 442         public static TypeElement max(TensorType a, ConstantType axis) {
 443             return reduce(a, axis);
 444         }
 445 
 446         //                Tensor sum(Tensor a, @Constant int axis) {
 447         public static TypeElement sum(TensorType a, ConstantType axis) {
 448             return reduce(a, axis);
 449         }
 450 
 451         static TypeElement reduce(TensorType a, ConstantType axis) {
 452             assert axis.cType().equals(JavaType.INT);
 453             int axisValue = (int) axis.value();
 454             if (axisValue < 0 || axisValue > 3) {
 455                 throw new IllegalStateException();
 456             }
 457 
 458             List<Integer> reduceShape = new ArrayList<>();
 459             for (int i = 0; i < a.shape().size(); i++) {
 460                 if (i != axisValue) {
 461                     reduceShape.add(a.shape().get(i));
 462                 } else {
 463                     reduceShape.add(1);
 464                 }
 465             }
 466 
 467             if (reduceShape.size() == 1 && reduceShape.getFirst() == 1) {
 468                 return a.eType();
 469             } else {
 470                 return new TensorType(a.eType(), reduceShape);
 471             }
 472         }
 473 
 474         // @@@ Test
 475         public static void consume(TypeElement a) {
 476         }
 477 
 478 
 479         static TypeElement binary(TypeElement a, TypeElement b) {
 480             if (a instanceof TensorType ta && b instanceof TensorType tb) {
 481                 return checkTensorTypes(ta, tb);
 482             } else if (a instanceof TensorType ta) {
 483                 return new TensorType(checkScalarTypes(ta.eType(), b), ta.shape());
 484             } else if (b instanceof TensorType tb) {
 485                 return new TensorType(checkScalarTypes(a, tb.eType()), tb.shape());
 486             } else {
 487                 return checkScalarTypes(a, b);
 488             }
 489         }
 490 
 491         static TensorType checkTensorTypes(TensorType a, TensorType b) {
 492             List<Integer> s = checkTensorShape(a, b);
 493             TypeElement e = checkScalarTypes(a.eType(), b.eType());
 494             return new TensorType(e, s);
 495         }
 496 
 497         static List<Integer> checkTensorShape(TensorType a, TensorType b) {
 498             if (a.shape().size() != b.shape().size()) {
 499                 // Shape mismatch
 500                 throw new IllegalStateException();
 501             }
 502 
 503             List<Integer> s = new ArrayList<>();
 504             for (int i = 0; i < a.shape().size(); i++) {
 505                 int ad = a.shape().get(i);
 506                 int bd = b.shape().get(i);
 507 
 508                 // Expand dimensions
 509                 int d;
 510                 if (ad == bd) {
 511                     d = ad;
 512                 } else {
 513                     if (ad != 1 && bd == 1) {
 514                         d = ad;
 515                     } else if (ad == 1) {
 516                         d = bd;
 517                     } else {
 518                         // Shape mismatch
 519                         throw new IllegalStateException();
 520                     }
 521                 }
 522 
 523                 s.add(d);
 524             }
 525 
 526             return s;
 527         }
 528 
 529         static TypeElement checkScalarTypes(TypeElement a, TypeElement b) {
 530             // @@@ Optional ptr checking
 531             if (a instanceof PtrType) {
 532                 if (!b.equals(JavaType.INT)) {
 533                     throw new IllegalStateException();
 534                 }
 535             } else if (b instanceof PtrType) {
 536                 // Pointer must be first argument
 537                 throw new IllegalStateException();
 538             } else if (a instanceof ConstantType || b instanceof ConstantType) {
 539                 return checkScalarTypes(reduceScalarType(a), reduceScalarType(b));
 540             } else if (!a.equals(b)) {
 541                 // @@@ Conversion
 542                 throw new IllegalStateException();
 543             }
 544             return a;
 545         }
 546 
 547         static TypeElement reduceScalarType(TypeElement a) {
 548             return a instanceof ConstantType ct ? ct.cType() : a;
 549         }
 550     }
 551 
 552     public static <O extends Op & Op.Invokable> TritonOps.FuncOp transformToTritonFunction(
 553             O kernel,
 554             String signature,
 555             TypeElement rType,
 556             Map<Value, TypeElement> valueTypeMap, Map<Op, Object> opData,
 557             Map<String, TritonOps.FuncOp> fsymTable) {
 558         TritonOps.FuncOp ttKernel = TritonOps.func(signature, functionType(rType))
 559                 .body(fblock -> {
 560                     // Process kernel parameters
 561                     List<Value> args = new ArrayList<>();
 562                     for (Block.Parameter kp : kernel.body().entryBlock().parameters()) {
 563                         TypeElement type = valueTypeMap.get(kp);
 564                         if (type instanceof ConstantType ct) {
 565                             // Constant
 566                             Op.Result cr = fblock.op(ArithMathOps.constant(
 567                                     ct.cType(), ct.value()));
 568                             args.add(cr);
 569                         } else {
 570                             args.add(fblock.parameter(type));
 571                         }
 572                     }
 573 
 574                     // Transform kernel body
 575                     fblock.body(kernel.body(), args, (kblock, op) -> {
 576                         return transformToTritonOperation(kblock, op, valueTypeMap, opData, fsymTable);
 577                     });
 578                 });
 579 
 580         ttKernel = cleanup(ttKernel);
 581         fsymTable.put(ttKernel.funcName(), ttKernel);
 582         return ttKernel;
 583     }
 584 
 585     static Block.Builder transformToTritonOperation(Block.Builder kblock, Op op,
 586                                                     Map<Value, TypeElement> valueTypeMap, Map<Op, Object> opData,
 587                                                     Map<String, TritonOps.FuncOp> fsymTable) {
 588         // @@@ Avoid constructing for each operation -- block builder passed as argument or a scoped value
 589         TritonBuilderInterpreter tbi = new TritonBuilderInterpreter(fsymTable, kblock);
 590         CopyContext cc = kblock.context();
 591         switch (op) {
 592             case VarOp varOp -> {
 593                 // @@@ Cannot copy op because the result type
 594                 //     is derived from init type
 595                 Value init = cc.getValue(op.operands().get(0));
 596                 Op.Result r = kblock.op(var(varOp.varName(), init));
 597                 cc.mapValue(op.result(), r);
 598             }
 599             case ConstantOp cop -> {
 600                 TypeElement t = valueTypeMap.get(cop.result());
 601                 if (t instanceof ConstantType ct) {
 602                     Op.Result r = kblock.op(ArithMathOps.constant(
 603                             ct.cType(), ct.value()));
 604                     cc.mapValue(op.result(), r);
 605                 } else {
 606                     kblock.op(op);
 607                 }
 608             }
 609             case ArithmeticOperation _ -> {
 610                 Value result = tbi.build(op, op.externalizeOpName(), valueTypeMap);
 611                 if (result != null) {
 612                     cc.mapValue(op.result(), result);
 613                 }
 614             }
 615             case InvokeOp iop when iop.invokeDescriptor().refType().equals(JavaType.J_L_INTEGER) -> {
 616                 // Replace box with its value
 617                 Value a = cc.getValue(op.operands().get(0));
 618                 cc.mapValue(op.result(), a);
 619             }
 620             case InvokeOp iop when iop.invokeDescriptor().refType().equals(TYPE_J_L_MATH) -> {
 621                 String name = iop.invokeDescriptor().name();
 622                 if (name.equals("max")) {
 623                     Value a = cc.getValue(op.operands().get(0));
 624                     Value b = cc.getValue(op.operands().get(1));
 625 
 626                     Op.Result result = kblock.op(ArithMathOps.maximum(a, b));
 627                     cc.mapValue(op.result(), result);
 628                 } else if (name.equals("min")) {
 629                     Value a = cc.getValue(op.operands().get(0));
 630                     Value b = cc.getValue(op.operands().get(1));
 631 
 632                     Op.Result result = kblock.op(ArithMathOps.minimum(a, b));
 633                     cc.mapValue(op.result(), result);
 634                 }
 635             }
 636             case InvokeOp iop when iop.invokeDescriptor().refType().equals(TYPE_Tensor) -> {
 637                 if (iop.invokeDescriptor().name().equals("type")) {
 638                     // Replace with constant operation to produce tensor type.
 639                     // Result may be used, but transitively it will be removed due to no uses
 640                     // contributing to the computation
 641                     Value a = op.operands().get(0);
 642                     TensorType aType = (TensorType) valueTypeMap.get(a);
 643                     Op.Result result = kblock.op(CoreOp.constant(iop.resultType(), aType));
 644                     cc.mapValue(op.result(), result);
 645                     valueTypeMap.put(result, aType);
 646                 }
 647                 // Remove
 648             }
 649             case InvokeOp iop when iop.invokeDescriptor().refType().equals(TYPE_Triton) -> {
 650                 Value result = tbi.build(op, iop.invokeDescriptor().name(), valueTypeMap);
 651                 if (result != null) {
 652                     cc.mapValue(op.result(), result);
 653                 }
 654             }
 655             case InvokeOp iop when iop.invokeDescriptor().refType().equals(TYPE_Triton_Test) -> {
 656                 Value result = tbi.build(op, iop.invokeDescriptor().name(), valueTypeMap);
 657                 if (result != null) {
 658                     cc.mapValue(op.result(), result);
 659                 }
 660             }
 661             case JavaOp.ForOp fop -> {
 662                 transformToSCFFor(cc, kblock, fop, valueTypeMap, opData, fsymTable);
 663             }
 664             case ReturnOp rop -> {
 665                 if (rop.operands().isEmpty()) {
 666                     kblock.op(TritonOps.return_());
 667                 } else {
 668                     kblock.op(TritonOps.return_(
 669                             cc.getValue(rop.returnValue())));
 670                 }
 671             }
 672             default -> kblock.op(op);
 673         }
 674         return kblock;
 675     }
 676 
 677     static void transformToSCFFor(CopyContext cc, Block.Builder kblock, JavaOp.ForOp fop,
 678                                   Map<Value, TypeElement> valueTypeMap, Map<Op, Object> opData,
 679                                   Map<String, TritonOps.FuncOp> fsymTable) {
 680         Body body = fop.loopBody();
 681 
 682         // Hoist expressions for start, end, and step
 683         SimpleCountedForLoopInfo li = (SimpleCountedForLoopInfo) opData.get(fop);
 684         Value start = null;
 685         for (Op o : li.startExpression()) {
 686             transformToTritonOperation(kblock, o, valueTypeMap, opData, fsymTable);
 687             start = cc.getValue(o.result());
 688         }
 689         Value end = null;
 690         for (Op o : li.endExpression()) {
 691             transformToTritonOperation(kblock, o, valueTypeMap, opData, fsymTable);
 692             end = cc.getValue(o.result());
 693         }
 694         Value step = null;
 695         for (Op o : li.stepExpression()) {
 696             transformToTritonOperation(kblock, o, valueTypeMap, opData, fsymTable);
 697             step = cc.getValue(o.result());
 698         }
 699 
 700         // Obtain captured vars
 701         // true == stores
 702         // false == loads only
 703         Map<Boolean, Set<Value>> capturedVars = capturedVars(body);
 704         Set<Value> capturedAndStoredVars = capturedVars.get(true);
 705 
 706         // Get load values
 707         // Loaded values are hoisted out of the loop body
 708         Map<Value, Value> loadValues = new HashMap<>();
 709         for (Value v : capturedVars.get(false)) {
 710             Value load = kblock.op(varLoad(cc.getValue(v)));
 711             valueTypeMap.put(load, valueTypeMap.get(v));
 712             loadValues.put(v, load);
 713         }
 714 
 715         // Get iteration values -- represented by captured vars that are stored to in the loop
 716         // The SCF for operation returns the iteration values of the last loop iteration, which
 717         // are then to be stored to the iteration variables
 718         List<Value> iterValues = new ArrayList<>();
 719         for (Value v : capturedAndStoredVars) {
 720             iterValues.add(kblock.op(varLoad(cc.getValue(v))));
 721         }
 722 
 723         // @@@ Build in java code model, then transform?
 724         SCFOps.ForOp scffor = SCFOps.for_(kblock.parentBody(), start, end, step, iterValues)
 725                 // Ensure existing context is used
 726                 .body(CopyContext.create(cc), builder -> {
 727                     // Create index var initialized from entry block parameter
 728                     Value index = builder.parameters().get(0);
 729                     valueTypeMap.put(index, JavaType.INT);
 730                     Value varIndex = builder.op(var("index", index));
 731                     valueTypeMap.put(varIndex, JavaType.INT);
 732                     builder.context().mapValue(body.entryBlock().parameters().get(0), varIndex);
 733 
 734                     // Create iter vars initialized from entry block parameters
 735                     int pi = 1;
 736                     for (Value v : capturedAndStoredVars) {
 737                         TypeElement type = valueTypeMap.get(v);
 738                         Value iter = builder.parameters().get(pi++);
 739                         valueTypeMap.put(iter, type);
 740                         Value varIter = builder.op(var(Integer.toString(pi), iter));
 741                         valueTypeMap.put(varIter, type);
 742                         builder.context().mapValue(v, varIter);
 743                     }
 744 
 745                     // Transform the Java for body into the SCF for body
 746                     builder.body(body, List.of(), (block, op) -> {
 747                         // Yield iter values
 748                         if (op instanceof JavaOp.ContinueOp) {
 749                             // Replace with yield of loaded vars
 750                             List<Value> yieldValues = new ArrayList<>();
 751                             for (Value value : capturedAndStoredVars) {
 752                                 Value varIter = block.context().getValue(value);
 753                                 Value v = block.op(varLoad(varIter));
 754                                 yieldValues.add(v);
 755                             }
 756                             block.op(SCFOps.yield_(yieldValues));
 757                         } else if (op instanceof VarAccessOp.VarLoadOp) {
 758                             // Replace with value loaded immediately before loop
 759                             Value v = op.operands().get(0);
 760                             if (capturedVars.get(false).contains(v)) {
 761                                 block.context().mapValue(op.result(), loadValues.get(v));
 762                             } else {
 763                                 block.op(op);
 764                             }
 765                         } else {
 766                             block = transformToTritonOperation(block, op, valueTypeMap, opData, fsymTable);
 767                         }
 768                         return block;
 769                     });
 770                 });
 771         Op.Result forResult = kblock.op(scffor);
 772 
 773         // Assign back result to iter vars
 774         if (capturedAndStoredVars.size() == 1) {
 775             for (Value v : capturedAndStoredVars) {
 776                 kblock.op(varStore(cc.getValue(v), forResult));
 777             }
 778         } else {
 779             int i = 0;
 780             for (Value v : capturedAndStoredVars) {
 781                 kblock.op(varStore(cc.getValue(v),
 782                         kblock.op(tupleLoad(forResult, i++))));
 783             }
 784         }
 785     }
 786 
 787     static Map<Boolean, Set<Value>> capturedVars(Body body) {
 788         Map<Boolean, Set<Value>> capturedValues = new HashMap<>();
 789         capturedValues.put(false, new LinkedHashSet<>());
 790         capturedValues.put(true, new LinkedHashSet<>());
 791 
 792         capturedVars(capturedValues, new ArrayDeque<>(), body);
 793         return capturedValues;
 794     }
 795 
 796     static void capturedVars(Map<Boolean, Set<Value>> capturedVars, Deque<Body> bodyStack, Body body) {
 797         bodyStack.push(body);
 798 
 799         for (Block b : body.blocks()) {
 800             for (Op op : b.ops()) {
 801                 // @@@ Nested bodies
 802                 if (!op.bodies().isEmpty()) {
 803                     throw new IllegalStateException();
 804                 }
 805 //                for (Body childBody : op.bodies()) {
 806 //                    capturedAndUpdatedVars(capturedValues, bodyStack, childBody);
 807 //                }
 808 
 809                 if (op instanceof VarAccessOp) {
 810                     Value v = op.operands().get(0);
 811                     if (!bodyStack.contains(v.declaringBlock().ancestorBody())) {
 812                         if (op instanceof VarAccessOp.VarStoreOp) {
 813                             capturedVars.get(true).add(v);
 814                             capturedVars.get(false).remove(v);
 815                         } else if (!capturedVars.get(true).contains(v)) {
 816                             capturedVars.get(false).add(v);
 817                         }
 818                     }
 819                 }
 820             }
 821         }
 822 
 823         bodyStack.pop();
 824     }
 825 
 826     public static final ScopedValue<Boolean> SV_SSA = ScopedValue.newInstance();
 827 
 828     static TritonOps.FuncOp cleanup(TritonOps.FuncOp f) {
 829         // Remove var ops
 830         boolean doSSA = SV_SSA.isBound() ? SV_SSA.get() : true;
 831         if (doSSA) {
 832             f = SSA.transform(f);
 833         }
 834         // Remove unused ops
 835         f = f.transform((fblock, op) -> {
 836             if (op instanceof Op.Pure && op.result().uses().isEmpty()) {
 837                 return fblock;
 838             } else if (op instanceof VarAccessOp.VarLoadOp && op.result().uses().isEmpty()) {
 839                 return fblock;
 840             }
 841 
 842             fblock.op(op);
 843             return fblock;
 844         });
 845         return f;
 846     }
 847 
 848     static class TritonBuilderInterpreter {
 849         final Map<String, TritonOps.FuncOp> fsymTable;
 850         final Block.Builder block;
 851 
 852         TritonBuilderInterpreter(Map<String, TritonOps.FuncOp> fsymTable, Block.Builder block) {
 853             this.fsymTable = fsymTable;
 854             this.block = block;
 855         }
 856 
 857         Value build(Op op, String name, Map<Value, TypeElement> valueTypeMap) {
 858             // Obtain associated type-based method
 859             MethodHandle mh;
 860             try {
 861                 Optional<Method> om = Stream.of(TritonBuilderInterpreter.class.getDeclaredMethods())
 862                         .filter(m -> m.getName().equals(name))
 863                         .filter(m -> m.isVarArgs()
 864                         ? m.getParameterCount() / 2 - 1 <= op.operands().size()
 865                         : m.getParameterCount() / 2 - 1 == op.operands().size())
 866                         .findFirst();
 867                 mh = MethodHandles.lookup().unreflect(
 868                         om.orElseThrow(() -> new NoSuchMethodException(name)));
 869             } catch (ReflectiveOperationException e) {
 870                 throw new IllegalStateException(e);
 871             }
 872 
 873             List<Object> iArgs = new ArrayList<>();
 874             iArgs.add(this);
 875             iArgs.add(valueTypeMap.get(op.result()));
 876             iArgs.add(op.result());
 877             for (Value o : op.operands()) {
 878                 iArgs.add(valueTypeMap.get(o));
 879                 iArgs.add(o);
 880             }
 881             try {
 882                 return (Value) mh.invokeWithArguments(iArgs.toArray(Object[]::new));
 883             } catch (Throwable e) {
 884                 throw new IllegalStateException(e);
 885             }
 886         }
 887 
 888 
 889         public Value programId(TypeElement rType, Op.Result r,
 890                                ConstantType axisType, Value axis) {
 891             return block.op(TritonOps.getProgramId(
 892                     (int) axisType.value()));
 893         }
 894 
 895         public Value arange(TensorType rType, Op.Result r,
 896                             ConstantType startType, Value start,
 897                             ConstantType endType, Value end) {
 898             return block.op(TritonOps.makeRange(
 899                     (int) startType.value(),
 900                     (int) endType.value()));
 901         }
 902 
 903         public Value expand(TensorType rType, Op.Result r,
 904                             TensorType aType, Value a,
 905                             ConstantType axisType, Value axis) {
 906             return block.op(TritonOps.expand(
 907                     (int) axisType.value(),
 908                     rType,
 909                     block.context().getValue(a)));
 910         }
 911 
 912         public Value zeros(TensorType rType, Op.Result r,
 913                            ConstantType aType, Value a,
 914                            Object... constantsAndValues) {
 915             Object zero;
 916             try {
 917                 JavaType zeroType = (JavaType) aType.value();
 918                 zero = MethodHandles.zero((Class<?>) zeroType.resolve(MethodHandles.lookup())).invoke();
 919             } catch (Throwable e) {
 920                 throw new RuntimeException(e);
 921             }
 922             return block.op(ArithMathOps.constant(rType, zero));
 923         }
 924 
 925         public Value load(TensorType rType, Op.Result r,
 926                           TensorType ptrType, Value ptr,
 927                           TensorType maskType, Value mask) {
 928             broadcastConversionRight(ptrType, maskType, mask);
 929             return block.op(TritonOps.load(
 930                     rType,
 931                     block.context().getValue(ptr),
 932                     block.context().getValue(mask)));
 933         }
 934 
 935         public Value load(TensorType rType, Op.Result r,
 936                           TensorType ptrType, Value ptr,
 937                           TensorType maskType, Value mask,
 938                           ConstantType otherType, Value other) {
 939             broadcastConversionRight(ptrType, maskType, mask);
 940             Value mb = block.op(ArithMathOps.constant(rType, (float)otherType.value()));
 941             block.context().mapValue(other, mb);
 942             return block.op(TritonOps.load(
 943                     rType,
 944                     block.context().getValue(ptr),
 945                     block.context().getValue(mask),
 946                     block.context().getValue(other)));
 947         }
 948 
 949         public Value store(TensorType rType, Op.Result r,
 950                            TensorType ptrType, Value ptr,
 951                            TensorType valueType, Value value,
 952                            TensorType maskType, Value mask) {
 953             broadcastConversionRight(ptrType, valueType, value);
 954             broadcastConversionRight(ptrType, maskType, mask);
 955             return block.op(TritonOps.store(
 956                     block.context().getValue(ptr),
 957                     block.context().getValue(value),
 958                     block.context().getValue(mask)));
 959         }
 960 
 961         public Value broadcast(TensorType rType, Op.Result r,
 962                                TypeElement oType, Value o,
 963                                TensorType tensorTypeType, Value tensorType) {
 964             // @@@ tt.splat with scalar operand, tt.broadcast with tensor operand
 965             if (oType instanceof TensorType) {
 966                 return block.op(TritonOps.broadcast(
 967                         rType,
 968                         block.context().getValue(o)));
 969             } else {
 970                 return block.op(TritonOps.splat(
 971                         rType,
 972                         block.context().getValue(o)));
 973             }
 974         }
 975 
 976         public Value joinShape(TensorType rType, Op.Result r,
 977                                TensorType aType, Value a,
 978                                TensorType bType, Value b) {
 979             // Replace with constant operation to produce tensor type.
 980             // Result may be used, but transitively it will be removed due to no uses
 981             // contributing to the computation
 982             return block.op(CoreOp.constant(JavaType.type(TensorType.class), r.type()));
 983         }
 984 
 985 
 986         public Value add(TypeElement rType, Op.Result r,
 987                          TypeElement aType, Value a,
 988                          TypeElement bType, Value b) {
 989             broadcastConversion(rType, aType, a, bType, b);
 990             a = block.context().getValue(a);
 991             b = block.context().getValue(b);
 992 
 993             if (rType instanceof PtrType ||
 994                     rType instanceof TensorType t && t.eType() instanceof PtrType) {
 995                 return block.op(TritonOps.addptr(a, b));
 996             } else {
 997                 return block.op(ArithMathOps.add(a, b));
 998             }
 999         }
1000 
1001         public Value sub(TypeElement rType, Op.Result r,
1002                          TypeElement aType, Value a,
1003                          TypeElement bType, Value b) {
1004             broadcastConversion(rType, aType, a, bType, b);
1005             a = block.context().getValue(a);
1006             b = block.context().getValue(b);
1007 
1008             return block.op(ArithMathOps.sub(a, b));
1009         }
1010 
1011         public Value mul(TypeElement rType, Op.Result r,
1012                          TypeElement aType, Value a,
1013                          TypeElement bType, Value b) {
1014             broadcastConversion(rType, aType, a, bType, b);
1015             a = block.context().getValue(a);
1016             b = block.context().getValue(b);
1017 
1018             return block.op(ArithMathOps.mul(a, b));
1019         }
1020 
1021         public Value div(TypeElement rType, Op.Result r,
1022                          TypeElement aType, Value a,
1023                          TypeElement bType, Value b) {
1024             broadcastConversion(rType, aType, a, bType, b);
1025             a = block.context().getValue(a);
1026             b = block.context().getValue(b);
1027 
1028             return block.op(ArithMathOps.div(a, b));
1029         }
1030 
1031         public Value mod(TypeElement rType, Op.Result r,
1032                          TypeElement aType, Value a,
1033                          TypeElement bType, Value b) {
1034             broadcastConversion(rType, aType, a, bType, b);
1035             a = block.context().getValue(a);
1036             b = block.context().getValue(b);
1037 
1038             return block.op(ArithMathOps.rem(a, b));
1039         }
1040 
1041         public Value and(TypeElement rType, Op.Result r,
1042                          TypeElement aType, Value a,
1043                          TypeElement bType, Value b) {
1044             broadcastConversion(rType, aType, a, bType, b);
1045             a = block.context().getValue(a);
1046             b = block.context().getValue(b);
1047 
1048             return block.op(ArithMathOps.and(a, b));
1049         }
1050 
1051         public Value dot(TensorType rType, Op.Result r,
1052                          TypeElement aType, Value a,
1053                          TypeElement bType, Value b) {
1054             a = block.context().getValue(a);
1055             b = block.context().getValue(b);
1056             // Computed result is tensor of floats, regardless of inputs
1057             Object zero = 0.0f;
1058             var c = block.op(ArithMathOps.constant(rType, zero));
1059             return block.op(TritonOps.dot(rType, a, b, c));
1060         }
1061 
1062         public Value cdiv(TypeElement rType, Op.Result r,
1063                           TypeElement aType, Value a,
1064                           TypeElement bType, Value b) {
1065             a = block.context().getValue(a);
1066             b = block.context().getValue(b);
1067 
1068             TritonOps.FuncOp cdiv = tritonFunction(Functions.getJavaCodeModel("cdiv"),
1069                     rType, List.of(aType, bType),
1070                     fsymTable);
1071             // @@@ Generalize
1072             List<Value> args = new ArrayList<>();
1073             if (!(aType instanceof ConstantType)) {
1074                 args.add(a);
1075             }
1076             if (!(bType instanceof ConstantType)) {
1077                 args.add(b);
1078             }
1079             return block.op(TritonOps.call(cdiv, args));
1080         }
1081 
1082         public Value conv(TypeElement rType, Op.Result r,
1083                           ConstantType tType, Value t,
1084                           TypeElement aType, Value a) {
1085             a = block.context().getValue(a);
1086 
1087             TypeElement rScalarType;
1088             TypeElement aScalarType;
1089             if (rType instanceof TensorType rTensorType && aType instanceof TensorType aTensorType) {
1090                 rScalarType = rTensorType.eType();
1091                 aScalarType = aTensorType.eType();
1092             } else {
1093                 rScalarType = rType;
1094                 aScalarType = aType;
1095             }
1096 
1097             if (rScalarType.equals(Float16.FLOAT_16_TYPE) && aScalarType.equals(JavaType.FLOAT)) {
1098                 return block.op(ArithMathOps.trunc(rType, a));
1099             } else if (rType.equals(aType)) {
1100                 return a;
1101             } else {
1102                 throw new IllegalStateException();
1103             }
1104         }
1105 
1106         public Value exp(TritonType rType, Op.Result r,
1107                          TritonType aType, Value a) {
1108             return block.op(ArithMathOps.exp(
1109                     block.context().getValue(a)));
1110         }
1111 
1112         public Value compare(TensorType rType, Op.Result r,
1113                              TypeElement aType, Value a,
1114                              TypeElement bType, Value b,
1115                              ConstantType compareType, Value compare) {
1116             Triton.CompareKind ck = (Triton.CompareKind) compareType.value();
1117 
1118             ArithMathOps.CompareOp.CompareKind ack = switch (ck) {
1119                 case LessThan -> ArithMathOps.CompareOp.CompareKind.slt;
1120                 default -> throw new UnsupportedOperationException("Unsupported comparison: " + ck);
1121             };
1122 
1123             broadcastConversionRight(aType, bType, b);
1124             a = block.context().getValue(a);
1125             b = block.context().getValue(b);
1126 
1127             return block.op(ArithMathOps.cmp(ack, a, b));
1128         }
1129 
1130 
1131         public Value max(TypeElement rType, Op.Result r,
1132                          TensorType xType, Value x,
1133                          ConstantType axisType, Value axis) {
1134             TritonOps.FuncOp f = tritonFunction(Functions.getJavaCodeModel("max"),
1135                     rType, List.of(rType, rType), fsymTable);
1136             return reduce(rType, r, xType, x, axisType, axis, f);
1137         }
1138 
1139         public Value sum(TypeElement rType, Op.Result r,
1140                          TensorType xType, Value x,
1141                          ConstantType axisType, Value axis) {
1142             TritonOps.FuncOp f = tritonFunction(Functions.getJavaCodeModel("sum"),
1143                     rType, List.of(rType, rType), fsymTable);
1144             return reduce(rType, r, xType, x, axisType, axis, f);
1145         }
1146 
1147         Value reduce(TypeElement rType, Op.Result r,
1148                      TensorType xType, Value x,
1149                      ConstantType axisType, Value axis,
1150                      TritonOps.FuncOp f) {
1151             int axisConstant = (int) axisType.value();
1152 
1153             String signature = "reduce_" + f.funcName() + "_" + axisConstant;
1154             TritonOps.FuncOp rf = fsymTable.computeIfAbsent(signature,
1155                     s -> reduce(rType, xType, axisConstant, s, f));
1156 
1157             return block.op(TritonOps.call(rf, block.context().getValue(x)));
1158         }
1159 
1160         static TritonOps.FuncOp reduce(TypeElement elementType,
1161                                        TensorType tensorType,
1162                                        int axisConstant,
1163                                        String name, TritonOps.FuncOp scalarFunc) {
1164             return TritonOps.func(name,
1165                             functionType(elementType, tensorType))
1166                     .body(fblock -> {
1167                         TritonOps.ReduceOp reduceOp = TritonOps.reduce(fblock.parentBody(),
1168                                         axisConstant, fblock.parameters().get(0),
1169                                         functionType(elementType, elementType, elementType))
1170                                 .body(rblock -> {
1171                                     Block.Parameter a = rblock.parameters().get(0);
1172                                     Block.Parameter b = rblock.parameters().get(1);
1173                                     Op.Result _r = rblock.op(TritonOps.call(scalarFunc, a, b));
1174                                     rblock.op(TritonOps.reduceReturn(_r));
1175                                 });
1176 
1177                         Op.Result opr = fblock.op(reduceOp);
1178                         fblock.op(TritonOps.return_(opr));
1179                     });
1180         }
1181 
1182         // @@@ Test
1183         public Value consume(TypeElement rType, Op.Result r,
1184                              TypeElement aType, Value a) {
1185             return block.op(TritonTestOps.consume(block.context().getValue(a)));
1186         }
1187 
1188         void broadcastConversion(TypeElement rType,
1189                                  TypeElement aType, Value a,
1190                                  TypeElement bType, Value b) {
1191             Value ma = block.context().getValue(a);
1192             Value mb = block.context().getValue(b);
1193             if (aType instanceof TensorType at && bType instanceof TensorType bTensorType) {
1194                 TensorType rTensorType = (TensorType) rType;
1195                 if (!at.shape().equals(rTensorType.shape())) {
1196                     ma = block.op(TritonOps.broadcast(rTensorType, ma));
1197                 }
1198                 if (!bTensorType.shape().equals(rTensorType.shape())) {
1199                     if (rTensorType.eType() instanceof PtrType) {
1200                         bTensorType = new TensorType(bType, rTensorType.shape());
1201                         mb = block.op(TritonOps.broadcast(bTensorType, mb));
1202                     } else {
1203                         mb = block.op(TritonOps.broadcast(rTensorType, mb));
1204                     }
1205                 }
1206             } else if (aType instanceof TensorType) {
1207                 TensorType rTensorType = (TensorType) rType;
1208                 if (rTensorType.eType() instanceof PtrType) {
1209                     TensorType bTensorType = new TensorType(bType, rTensorType.shape());
1210                     mb = block.op(TritonOps.splat(bTensorType, mb));
1211                 } else {
1212                     mb = block.op(TritonOps.splat(rTensorType, mb));
1213                 }
1214             } else if (bType instanceof TensorType) {
1215                 TensorType rTensorType = (TensorType) rType;
1216                 ma = block.op(TritonOps.splat(rTensorType, ma));
1217             }
1218             block.context().mapValue(a, ma);
1219             block.context().mapValue(b, mb);
1220         }
1221 
1222         void broadcastConversionRight(TypeElement aType,
1223                                       TypeElement bType, Value b) {
1224             Value mb = block.context().getValue(b);
1225             if (aType instanceof TensorType aTensorType && bType instanceof TensorType bTensorType) {
1226                 if (!bTensorType.shape().equals(aTensorType.shape())) {
1227                     if (aTensorType.eType() instanceof PtrType) {
1228                         bTensorType = new TensorType(bTensorType.eType(), aTensorType.shape());
1229                         mb = block.op(TritonOps.broadcast(bTensorType, mb));
1230                     } else {
1231                         mb = block.op(TritonOps.broadcast(aTensorType, mb));
1232                     }
1233                 }
1234             } else if (aType instanceof TensorType rTensorType) {
1235                 if (rTensorType.eType() instanceof PtrType) {
1236                     TensorType bTensorType = new TensorType(bType, rTensorType.shape());
1237                     mb = block.op(TritonOps.splat(bTensorType, mb));
1238                 } else {
1239                     mb = block.op(TritonOps.splat(rTensorType, mb));
1240                 }
1241             }
1242             block.context().mapValue(b, mb);
1243         }
1244     }
1245 
1246     public static <O extends Op & Op.Invokable> void printTypeMap(
1247             O kernel, Map<Value, TypeElement> valueTypeMap) {
1248         AtomicInteger valueId = new AtomicInteger();
1249         Map<Value, Integer> valueIdMap = new LinkedHashMap<>();
1250         kernel.traverse(null, (o, codeElement) -> {
1251             switch (codeElement) {
1252                 case FuncOp _ -> {
1253                     // Ignore
1254                 }
1255                 case Op op when !op.result().type().equals(JavaType.VOID) -> {
1256                     valueIdMap.put(op.result(), valueId.getAndIncrement());
1257                 }
1258                 case Block block -> {
1259                     for (Block.Parameter parameter : block.parameters()) {
1260                         valueIdMap.put(parameter, valueId.getAndIncrement());
1261                     }
1262                 }
1263                 default -> {
1264                 }
1265             }
1266             return null;
1267         });
1268 
1269         valueIdMap.forEach((value, id) -> {
1270             TypeElement type = valueTypeMap.get(value);
1271             if (type != null) {
1272                 System.out.println("%" + id + " : " + value.type() + " -> " + type);
1273             }
1274         });
1275     }
1276 }