1 /*
   2  * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
   3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
   4  *
   5  * This code is free software; you can redistribute it and/or modify it
   6  * under the terms of the GNU General Public License version 2 only, as
   7  * published by the Free Software Foundation.  Oracle designates this
   8  * particular file as subject to the "Classpath" exception as provided
   9  * by Oracle in the LICENSE file that accompanied this code.
  10  *
  11  * This code is distributed in the hope that it will be useful, but WITHOUT
  12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
  13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
  14  * version 2 for more details (a copy is included in the LICENSE file that
  15  * accompanied this code).
  16  *
  17  * You should have received a copy of the GNU General Public License version
  18  * 2 along with this work; if not, write to the Free Software Foundation,
  19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
  20  *
  21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
  22  * or visit www.oracle.com if you need additional information or have any
  23  * questions.
  24  */
  25 
  26 package oracle.code.triton;
  27 
  28 import java.lang.constant.ClassDesc;
  29 import java.lang.invoke.MethodHandle;
  30 import java.lang.invoke.MethodHandles;
  31 import java.lang.reflect.Field;
  32 import java.lang.reflect.Method;
  33 import jdk.incubator.code.*;
  34 import jdk.incubator.code.dialect.core.CoreOp;
  35 import jdk.incubator.code.dialect.core.SSA;
  36 import jdk.incubator.code.dialect.core.VarType;
  37 import jdk.incubator.code.dialect.java.JavaOp;
  38 import jdk.incubator.code.dialect.java.JavaType;
  39 import java.util.*;
  40 import java.util.concurrent.atomic.AtomicInteger;
  41 import java.util.stream.Stream;
  42 
  43 import static jdk.incubator.code.dialect.core.CoreOp.*;
  44 import static jdk.incubator.code.dialect.core.CoreType.functionType;
  45 import static jdk.incubator.code.dialect.java.JavaOp.*;
  46 
  47 public final class TritonTransformer {
  48     private TritonTransformer() {}
  49 
  50     static final JavaType TYPE_Triton = JavaType.type(Triton.class);
  51 
  52     static final JavaType TYPE_Triton_Test = JavaType.type(ClassDesc.of("oracle.code.triton.TritonTest"));
  53 
  54     static final JavaType TYPE_Tensor = JavaType.type(Tensor.class);
  55 
  56     static final JavaType TYPE_J_L_MATH = JavaType.type(Math.class);
  57 
  58     public static <O extends Op & Op.Invokable>
  59     TritonOps.ModuleOp tritonModule(O kernel,
  60                                     TypeElement rType,
  61                                     List<? extends TypeElement> argTypes) {
  62         Map<String, TritonOps.FuncOp> fsymTable = new LinkedHashMap<>();
  63         tritonFunction(kernel, rType, argTypes, fsymTable);
  64         return TritonOps.module(fsymTable.values().stream().toList());
  65     }
  66 
  67     public static <O extends Op & Op.Invokable>
  68     TritonOps.FuncOp tritonFunction(O javaKernel,
  69                                     TypeElement rType,
  70                                     List<? extends TypeElement> argTypes,
  71                                     Map<String, TritonOps.FuncOp> fsymTable) {
  72         String name = (javaKernel instanceof FuncOp f) ? f.funcName() : "kernel";
  73         String signature = signature(name, rType, argTypes);
  74         if (fsymTable.containsKey(signature)) {
  75             return fsymTable.get(signature);
  76         }
  77 
  78         System.out.println(javaKernel.toText());
  79 
  80         Map<Value, TypeElement> valueTypeMap = new HashMap<>();
  81         Map<Op, Object> opData = new HashMap<>();
  82         TritonTransformer.typeCheckKernel(javaKernel, argTypes, valueTypeMap, opData);
  83         TritonTransformer.printTypeMap(javaKernel, valueTypeMap);
  84 
  85         return TritonTransformer.transformToTritonFunction(javaKernel, signature,
  86                 rType, valueTypeMap, opData,
  87                 fsymTable);
  88     }
  89 
  90     static String signature(String name, TypeElement rType, List<? extends TypeElement> argTypes) {
  91         StringBuilder sb = new StringBuilder(name);
  92 
  93         for (TypeElement argType : argTypes) {
  94             sb.append("_");
  95             if (argType instanceof ConstantType ct) {
  96                 sb.append(ct.value());
  97             } else {
  98                 sb.append(argType);
  99             }
 100         }
 101         sb.append("_");
 102         sb.append(rType);
 103         return sb.toString();
 104     }
 105 
 106     public static <O extends Op & Op.Invokable> void typeCheckKernel(
 107             O kernel, List<? extends TypeElement> argTypes,
 108             Map<Value, TypeElement> valueTypeMap, Map<Op, Object> opData) {
 109         kernel.elements().forEach(e -> {
 110             if (!(e instanceof Op op)) {
 111                 return;
 112             }
 113             switch (op) {
 114                 case Op.Invokable fop -> {
 115                     List<Block.Parameter> parameters = fop.body().entryBlock().parameters();
 116                     for (int i = 0; i < parameters.size(); i++) {
 117                         valueTypeMap.put(parameters.get(i), argTypes.get(i));
 118                     }
 119                 }
 120                 case VarOp _, VarAccessOp.VarLoadOp _ -> {
 121                     Value init = op.operands().get(0);
 122                     valueTypeMap.put(op.result(), valueTypeMap.get(init));
 123                 }
 124                 case VarAccessOp.VarStoreOp _ -> {
 125                     Value var = op.operands().get(0);
 126                     TypeElement varType = valueTypeMap.get(var);
 127                     Value v = op.operands().get(1);
 128                     TypeElement vType = valueTypeMap.get(v);
 129                     if (!varType.equals(vType)) {
 130                         throw new IllegalStateException("Storing to variable with different type: "
 131                                 + varType + " <- " + vType);
 132                     }
 133 
 134                     valueTypeMap.put(op.result(), valueTypeMap.get(var));
 135                 }
 136                 case ConstantOp cop -> {
 137                     valueTypeMap.put(op.result(), new ConstantType(op.result().type(), cop.value()));
 138                 }
 139                 case ArithmeticOperation _ -> {
 140                     TypeElement t = checkWithTypeInterpreter(op, op.externalizeOpName(), valueTypeMap);
 141                     valueTypeMap.put(op.result(), t);
 142                 }
 143                 case FieldAccessOp.FieldLoadOp flop -> {
 144                     if (!flop.operands().isEmpty()) {
 145                         throw new IllegalStateException("Unsupported field load: " + flop.fieldDescriptor());
 146                     }
 147 
 148                     Field f;
 149                     try {
 150                         f = flop.fieldDescriptor().resolveToField(MethodHandles.lookup());
 151                     } catch (ReflectiveOperationException ex) {
 152                         throw new IllegalStateException("Unsupported field load: " + flop.fieldDescriptor(), ex);
 153                     }
 154                     Object value;
 155                     try {
 156                         value = f.get(null);
 157                     } catch (IllegalAccessException ex) {
 158                         throw new IllegalStateException("Unsupported field load: " + f, ex);
 159                     }
 160                     valueTypeMap.put(op.result(), new ConstantType(JavaType.type(f.getType()), value));
 161                 }
 162                 case InvokeOp iop when iop.invokeDescriptor().refType().equals(JavaType.J_L_INTEGER) -> {
 163                     // Box
 164                     if (iop.invokeDescriptor().name().equals("valueOf")) {
 165                         Value a = op.operands().get(0);
 166                         valueTypeMap.put(op.result(), valueTypeMap.get(a));
 167                     } else {
 168                         throw new UnsupportedOperationException("Unsupported invocation on Integer: " + iop.invokeDescriptor());
 169                     }
 170                 }
 171                 case InvokeOp iop when iop.invokeDescriptor().refType().equals(TYPE_J_L_MATH) -> {
 172                     String name = iop.invokeDescriptor().name();
 173                     if (name.equals("max") || name.equals("min")) {
 174                         Value a = op.operands().get(0);
 175                         valueTypeMap.put(op.result(), valueTypeMap.get(a));
 176                     } else {
 177                         throw new UnsupportedOperationException("Unsupported invocation on Math: " + iop.invokeDescriptor());
 178                     }
 179                 }
 180                 case InvokeOp iop when iop.invokeDescriptor().refType().equals(TYPE_Tensor) -> {
 181                     if (iop.invokeDescriptor().name().equals("type")) {
 182                         Value a = op.operands().get(0);
 183                         valueTypeMap.put(op.result(), valueTypeMap.get(a));
 184                     } else {
 185                         throw new UnsupportedOperationException("Unsupported invocation on Tensor: " + iop.invokeDescriptor());
 186                     }
 187                 }
 188                 case InvokeOp iop when iop.invokeDescriptor().refType().equals(TYPE_Triton) -> {
 189                     TypeElement t = checkWithTypeInterpreter(op, iop.invokeDescriptor().name(), valueTypeMap);
 190                     valueTypeMap.put(op.result(), t);
 191                 }
 192                 case InvokeOp iop when iop.invokeDescriptor().refType().equals(TYPE_Triton_Test) -> {
 193                     TypeElement t = checkWithTypeInterpreter(op, iop.invokeDescriptor().name(), valueTypeMap);
 194                     valueTypeMap.put(op.result(), t);
 195                 }
 196                 case JavaOp.ForOp fop -> {
 197                     SimpleCountedForLoopInfo li = new SimpleCountedForLoopInfo(fop);
 198                     opData.put(fop, li);
 199 
 200                     TypeElement type = fop.init().yieldType();
 201                     if (type instanceof VarType vt && vt.valueType().equals(JavaType.INT)) {
 202                         for (Body b : List.of(fop.cond(), fop.update(), fop.loopBody())) {
 203                             valueTypeMap.put(b.entryBlock().parameters().get(0), JavaType.INT);
 204                         }
 205                     } else {
 206                         throw new IllegalStateException();
 207                     }
 208                 }
 209                 case TestOperation _ -> {
 210                 }
 211                 case JavaOp.ContinueOp _ -> {
 212                 }
 213                 case CoreOp.YieldOp _ -> {
 214                 }
 215                 case ReturnOp _ -> {
 216                 }
 217                 default -> throw new UnsupportedOperationException("Unsupported operation: " + op);
 218             }
 219         });
 220     }
 221 
 222     static TypeElement checkWithTypeInterpreter(Op op, String name, Map<Value, TypeElement> valueTypeMap) {
 223         // Obtain associated type-based method
 224         MethodHandle mh;
 225         try {
 226             Optional<Method> om = Stream.of(TritonTypeInterpreter.class.getDeclaredMethods())
 227                     .filter(m -> m.getName().equals(name))
 228                     .filter(m -> m.isVarArgs() ? m.getParameterCount() <= op.operands().size() : m.getParameterCount() == op.operands().size())
 229                     .findFirst();
 230             mh = MethodHandles.lookup().unreflect(
 231                     om.orElseThrow(() -> new NoSuchMethodException(name)));
 232         } catch (ReflectiveOperationException e) {
 233             throw new IllegalStateException(name, e);
 234         }
 235 
 236         // Invoke with the values' types
 237         List<TypeElement> operandTypes = op.operands().stream().map(valueTypeMap::get).toList();
 238         try {
 239             return (TypeElement) mh.invokeWithArguments(operandTypes.toArray(Object[]::new));
 240         } catch (Throwable e) {
 241             throw new IllegalStateException(mh.toString(), e);
 242         }
 243     }
 244 
 245     // @@@ type check tensor shapes
 246     static class TritonTypeInterpreter {
 247         private TritonTypeInterpreter() {
 248         }
 249 
 250         //                 int programId(@Constant int axis) {
 251         public static JavaType programId(ConstantType axis) {
 252             assert axis.cType().equals(JavaType.INT);
 253             int axisValue = (int) axis.value();
 254             if (axisValue < 0 || axisValue > 3) {
 255                 throw new IllegalStateException();
 256             }
 257 
 258             return JavaType.INT;
 259         }
 260 
 261         //                Tensor arange(@Constant int start, @Constant int end)
 262         public static TensorType arange(ConstantType start, ConstantType end) {
 263             assert start.cType().equals(JavaType.INT);
 264             assert end.cType().equals(JavaType.INT);
 265 
 266             int startValue = (int) start.value();
 267             int endValue = (int) end.value();
 268 
 269             return new TensorType(JavaType.INT, List.of(endValue - startValue));
 270         }
 271 
 272         //                Tensor expand(Tensor a, int axis) {
 273         public static TensorType expand(TensorType a, ConstantType axis) {
 274             assert axis.cType().equals(JavaType.INT);
 275             int axisValue = (int) axis.value();
 276 
 277             List<Integer> s = new ArrayList<>(a.shape());
 278             if (axisValue < s.size()) {
 279                 s.add(axisValue, 1);
 280             } else {
 281                 for (int i = 0; i <= (axisValue - s.size()); i++) {
 282                     s.add(1);
 283                 }
 284             }
 285             return new TensorType(a.eType(), s);
 286         }
 287 
 288         //                Tensor load(Tensor ptr, Tensor mask)
 289         public static TensorType load(TensorType ptr, TensorType mask) {
 290             checkTensorShape(ptr, mask);
 291             if (ptr.eType() instanceof PtrType eptr) {
 292                 return new TensorType(eptr.rType(), ptr.shape());
 293             }
 294 
 295             throw new IllegalStateException();
 296         }
 297 
 298         //                Tensor load(Tensor ptr, Tensor mask, ConstantType other) {
 299         public static TensorType load(TensorType ptr, TensorType mask, ConstantType other) {
 300             checkTensorShape(ptr, mask);
 301             if (ptr.eType() instanceof PtrType eptr) {
 302                 return new TensorType(eptr.rType(), ptr.shape());
 303             }
 304 
 305             throw new IllegalStateException();
 306         }
 307 
 308         //            void store(Tensor ptr, Tensor value, Tensor mask)
 309         public static void store(TensorType ptr, TensorType value, TensorType mask) {
 310             if (!(ptr.eType() instanceof PtrType)) {
 311                 throw new IllegalStateException();
 312             }
 313         }
 314 
 315         //                Tensor zeros(TensorType type)
 316         public static TensorType zeros(ConstantType eType, ConstantType... cShape) {
 317             List<Integer> shape = Stream.of(cShape).map(s -> (int) s.value()).toList();
 318             return new TensorType((TypeElement) eType.value(), shape);
 319         }
 320 
 321         //                Tensor broadcast(Object o, TensorType type)
 322         public static TensorType broadcast(TypeElement o, TensorType type) {
 323             if (o instanceof TensorType ot) {
 324                 // @@@
 325                 if (ot.shape().size() != type.shape().size()) {
 326                     throw new IllegalStateException();
 327                 }
 328                 o = ot.eType();
 329             } if (o instanceof ConstantType oc) {
 330                 o = oc.cType();
 331             }
 332             return new TensorType(o, type.shape());
 333         }
 334 
 335         public static TensorType joinShape(TensorType a, TensorType b) {
 336             return checkTensorTypes(a, b);
 337         }
 338 
 339         //          Tensor add(Number a, Number b)
 340         //             Ptr add(Ptr a, int offset)
 341         public static TypeElement add(TypeElement a, TypeElement b) {
 342             // @@@ Pass additional argument for checking ptr
 343             return binary(a, b);
 344         }
 345 
 346         public static TypeElement sub(TypeElement a, TypeElement b) {
 347             return binary(a, b);
 348         }
 349 
 350         public static TypeElement mul(TypeElement a, TypeElement b) {
 351             return binary(a, b);
 352         }
 353 
 354         public static TypeElement div(TypeElement a, TypeElement b) {
 355             return binary(a, b);
 356         }
 357 
 358         public static TypeElement mod(TypeElement a, TypeElement b) {
 359             return binary(a, b);
 360         }
 361 
 362         public static TypeElement and(TypeElement a, TypeElement b) {
 363             return binary(a, b);
 364         }
 365 
 366         public static TypeElement cdiv(TypeElement a, TypeElement b) {
 367             a = reduceScalarType(a);
 368             b = reduceScalarType(b);
 369             if (!a.equals(JavaType.INT) && !b.equals(JavaType.INT)) {
 370                 throw new IllegalStateException();
 371             }
 372             return a;
 373         }
 374 
 375         //          Number conv(Type t, Number a) {
 376         public static TypeElement conv(ConstantType eType, TypeElement a) {
 377             return convTypes(eType, a);
 378         }
 379 
 380         public static TypeElement convTypes(ConstantType eType, TypeElement a) {
 381             if (a instanceof TensorType tb) {
 382                 TypeElement e = convScalarTypes(eType, tb.eType());
 383                 return new TensorType(e, tb.shape());
 384             } else {
 385                 return convScalarTypes(eType, a);
 386             }
 387         }
 388 
 389         public static TypeElement convScalarTypes(ConstantType eType, TypeElement a) {
 390             TypeElement t = (TypeElement) eType.value();
 391             if (t.equals(Float16.FLOAT_16_TYPE) && a.equals(JavaType.FLOAT)) {
 392                 return Float16.FLOAT_16_TYPE;
 393             } else if (t.equals(a)) {
 394                 return t;
 395             } else {
 396                 // @@@ Conversions;
 397                 throw new IllegalStateException();
 398             }
 399         }
 400 
 401         //          Tensor exp(Tensor a)
 402         public static TypeElement exp(TypeElement a) {
 403             return unary(a);
 404         }
 405 
 406         static TypeElement unary(TypeElement a) {
 407             return a;
 408         }
 409 
 410         //                Tensor compare(Number a, Number b, @Constant CompareKind ck) {
 411         public static TypeElement compare(TypeElement a, TypeElement b, ConstantType kind) {
 412             assert kind.cType().equals(JavaType.type(Triton.CompareKind.class));
 413 
 414             TypeElement t = binary(a, b);
 415             if (t instanceof TensorType tt) {
 416                 return new TensorType(JavaType.BOOLEAN, tt.shape());
 417             } else {
 418                 return t;
 419             }
 420         }
 421 
 422         //                Tensor dot(Tensor a, Tensor b)
 423         public static TensorType dot(TensorType a, TensorType b) {
 424             if (a.shape().size() != 2 || b.shape().size() != 2) {
 425                 throw new IllegalStateException();
 426             }
 427 
 428             if (!a.shape().get(1).equals(b.shape().get(0))) {
 429                 throw new IllegalStateException();
 430             }
 431 
 432             if (a.eType() != b.eType()) {
 433                 // @@@ Conversion, type checking
 434                 throw new IllegalStateException();
 435             }
 436 
 437             // Computed result is tensor of floats, regardless of inputs
 438             return new TensorType(JavaType.FLOAT, List.of(a.shape().get(0), b.shape().get(1)));
 439         }
 440 
 441 
 442         //                Tensor max(Tensor a, @Constant int axis) {
 443         public static TypeElement max(TensorType a, ConstantType axis) {
 444             return reduce(a, axis);
 445         }
 446 
 447         //                Tensor sum(Tensor a, @Constant int axis) {
 448         public static TypeElement sum(TensorType a, ConstantType axis) {
 449             return reduce(a, axis);
 450         }
 451 
 452         static TypeElement reduce(TensorType a, ConstantType axis) {
 453             assert axis.cType().equals(JavaType.INT);
 454             int axisValue = (int) axis.value();
 455             if (axisValue < 0 || axisValue > 3) {
 456                 throw new IllegalStateException();
 457             }
 458 
 459             List<Integer> reduceShape = new ArrayList<>();
 460             for (int i = 0; i < a.shape().size(); i++) {
 461                 if (i != axisValue) {
 462                     reduceShape.add(a.shape().get(i));
 463                 } else {
 464                     reduceShape.add(1);
 465                 }
 466             }
 467 
 468             if (reduceShape.size() == 1 && reduceShape.getFirst() == 1) {
 469                 return a.eType();
 470             } else {
 471                 return new TensorType(a.eType(), reduceShape);
 472             }
 473         }
 474 
 475         // @@@ Test
 476         public static void consume(TypeElement a) {
 477         }
 478 
 479 
 480         static TypeElement binary(TypeElement a, TypeElement b) {
 481             if (a instanceof TensorType ta && b instanceof TensorType tb) {
 482                 return checkTensorTypes(ta, tb);
 483             } else if (a instanceof TensorType ta) {
 484                 return new TensorType(checkScalarTypes(ta.eType(), b), ta.shape());
 485             } else if (b instanceof TensorType tb) {
 486                 return new TensorType(checkScalarTypes(a, tb.eType()), tb.shape());
 487             } else {
 488                 return checkScalarTypes(a, b);
 489             }
 490         }
 491 
 492         static TensorType checkTensorTypes(TensorType a, TensorType b) {
 493             List<Integer> s = checkTensorShape(a, b);
 494             TypeElement e = checkScalarTypes(a.eType(), b.eType());
 495             return new TensorType(e, s);
 496         }
 497 
 498         static List<Integer> checkTensorShape(TensorType a, TensorType b) {
 499             if (a.shape().size() != b.shape().size()) {
 500                 // Shape mismatch
 501                 throw new IllegalStateException();
 502             }
 503 
 504             List<Integer> s = new ArrayList<>();
 505             for (int i = 0; i < a.shape().size(); i++) {
 506                 int ad = a.shape().get(i);
 507                 int bd = b.shape().get(i);
 508 
 509                 // Expand dimensions
 510                 int d;
 511                 if (ad == bd) {
 512                     d = ad;
 513                 } else {
 514                     if (ad != 1 && bd == 1) {
 515                         d = ad;
 516                     } else if (ad == 1) {
 517                         d = bd;
 518                     } else {
 519                         // Shape mismatch
 520                         throw new IllegalStateException();
 521                     }
 522                 }
 523 
 524                 s.add(d);
 525             }
 526 
 527             return s;
 528         }
 529 
 530         static TypeElement checkScalarTypes(TypeElement a, TypeElement b) {
 531             // @@@ Optional ptr checking
 532             if (a instanceof PtrType) {
 533                 if (!b.equals(JavaType.INT)) {
 534                     throw new IllegalStateException();
 535                 }
 536             } else if (b instanceof PtrType) {
 537                 // Pointer must be first argument
 538                 throw new IllegalStateException();
 539             } else if (a instanceof ConstantType || b instanceof ConstantType) {
 540                 return checkScalarTypes(reduceScalarType(a), reduceScalarType(b));
 541             } else if (!a.equals(b)) {
 542                 // @@@ Conversion
 543                 throw new IllegalStateException();
 544             }
 545             return a;
 546         }
 547 
 548         static TypeElement reduceScalarType(TypeElement a) {
 549             return a instanceof ConstantType ct ? ct.cType() : a;
 550         }
 551     }
 552 
 553     public static <O extends Op & Op.Invokable> TritonOps.FuncOp transformToTritonFunction(
 554             O kernel,
 555             String signature,
 556             TypeElement rType,
 557             Map<Value, TypeElement> valueTypeMap, Map<Op, Object> opData,
 558             Map<String, TritonOps.FuncOp> fsymTable) {
 559         TritonOps.FuncOp ttKernel = TritonOps.func(signature, functionType(rType))
 560                 .body(fblock -> {
 561                     // Process kernel parameters
 562                     List<Value> args = new ArrayList<>();
 563                     for (Block.Parameter kp : kernel.body().entryBlock().parameters()) {
 564                         TypeElement type = valueTypeMap.get(kp);
 565                         if (type instanceof ConstantType ct) {
 566                             // Constant
 567                             Op.Result cr = fblock.op(ArithMathOps.constant(
 568                                     ct.cType(), ct.value()));
 569                             args.add(cr);
 570                         } else {
 571                             args.add(fblock.parameter(type));
 572                         }
 573                     }
 574 
 575                     // Transform kernel body
 576                     fblock.body(kernel.body(), args, (kblock, op) -> {
 577                         return transformToTritonOperation(kblock, op, valueTypeMap, opData, fsymTable);
 578                     });
 579                 });
 580 
 581         ttKernel = cleanup(ttKernel);
 582         fsymTable.put(ttKernel.funcName(), ttKernel);
 583         return ttKernel;
 584     }
 585 
 586     static Block.Builder transformToTritonOperation(Block.Builder kblock, Op op,
 587                                                     Map<Value, TypeElement> valueTypeMap, Map<Op, Object> opData,
 588                                                     Map<String, TritonOps.FuncOp> fsymTable) {
 589         // @@@ Avoid constructing for each operation -- block builder passed as argument or a scoped value
 590         TritonBuilderInterpreter tbi = new TritonBuilderInterpreter(fsymTable, kblock);
 591         CodeContext cc = kblock.context();
 592         switch (op) {
 593             case VarOp varOp -> {
 594                 // @@@ Cannot copy op because the result type
 595                 //     is derived from init type
 596                 Value init = cc.getValue(op.operands().get(0));
 597                 Op.Result r = kblock.op(var(varOp.varName(), init));
 598                 cc.mapValue(op.result(), r);
 599             }
 600             case ConstantOp cop -> {
 601                 TypeElement t = valueTypeMap.get(cop.result());
 602                 if (t instanceof ConstantType ct) {
 603                     Op.Result r = kblock.op(ArithMathOps.constant(
 604                             ct.cType(), ct.value()));
 605                     cc.mapValue(op.result(), r);
 606                 } else {
 607                     kblock.op(op);
 608                 }
 609             }
 610             case ArithmeticOperation _ -> {
 611                 Value result = tbi.build(op, op.externalizeOpName(), valueTypeMap);
 612                 if (result != null) {
 613                     cc.mapValue(op.result(), result);
 614                 }
 615             }
 616             case InvokeOp iop when iop.invokeDescriptor().refType().equals(JavaType.J_L_INTEGER) -> {
 617                 // Replace box with its value
 618                 Value a = cc.getValue(op.operands().get(0));
 619                 cc.mapValue(op.result(), a);
 620             }
 621             case InvokeOp iop when iop.invokeDescriptor().refType().equals(TYPE_J_L_MATH) -> {
 622                 String name = iop.invokeDescriptor().name();
 623                 if (name.equals("max")) {
 624                     Value a = cc.getValue(op.operands().get(0));
 625                     Value b = cc.getValue(op.operands().get(1));
 626 
 627                     Op.Result result = kblock.op(ArithMathOps.maximum(a, b));
 628                     cc.mapValue(op.result(), result);
 629                 } else if (name.equals("min")) {
 630                     Value a = cc.getValue(op.operands().get(0));
 631                     Value b = cc.getValue(op.operands().get(1));
 632 
 633                     Op.Result result = kblock.op(ArithMathOps.minimum(a, b));
 634                     cc.mapValue(op.result(), result);
 635                 }
 636             }
 637             case InvokeOp iop when iop.invokeDescriptor().refType().equals(TYPE_Tensor) -> {
 638                 if (iop.invokeDescriptor().name().equals("type")) {
 639                     // Replace with constant operation to produce tensor type.
 640                     // Result may be used, but transitively it will be removed due to no uses
 641                     // contributing to the computation
 642                     Value a = op.operands().get(0);
 643                     TensorType aType = (TensorType) valueTypeMap.get(a);
 644                     Op.Result result = kblock.op(CoreOp.constant(iop.resultType(), aType));
 645                     cc.mapValue(op.result(), result);
 646                     valueTypeMap.put(result, aType);
 647                 }
 648                 // Remove
 649             }
 650             case InvokeOp iop when iop.invokeDescriptor().refType().equals(TYPE_Triton) -> {
 651                 Value result = tbi.build(op, iop.invokeDescriptor().name(), valueTypeMap);
 652                 if (result != null) {
 653                     cc.mapValue(op.result(), result);
 654                 }
 655             }
 656             case InvokeOp iop when iop.invokeDescriptor().refType().equals(TYPE_Triton_Test) -> {
 657                 Value result = tbi.build(op, iop.invokeDescriptor().name(), valueTypeMap);
 658                 if (result != null) {
 659                     cc.mapValue(op.result(), result);
 660                 }
 661             }
 662             case JavaOp.ForOp fop -> {
 663                 transformToSCFFor(cc, kblock, fop, valueTypeMap, opData, fsymTable);
 664             }
 665             case ReturnOp rop -> {
 666                 if (rop.operands().isEmpty()) {
 667                     kblock.op(TritonOps.return_());
 668                 } else {
 669                     kblock.op(TritonOps.return_(
 670                             cc.getValue(rop.returnValue())));
 671                 }
 672             }
 673             default -> kblock.op(op);
 674         }
 675         return kblock;
 676     }
 677 
 678     static void transformToSCFFor(CodeContext cc, Block.Builder kblock, JavaOp.ForOp fop,
 679                                   Map<Value, TypeElement> valueTypeMap, Map<Op, Object> opData,
 680                                   Map<String, TritonOps.FuncOp> fsymTable) {
 681         Body body = fop.loopBody();
 682 
 683         // Hoist expressions for start, end, and step
 684         SimpleCountedForLoopInfo li = (SimpleCountedForLoopInfo) opData.get(fop);
 685         Value start = null;
 686         for (Op o : li.startExpression()) {
 687             transformToTritonOperation(kblock, o, valueTypeMap, opData, fsymTable);
 688             start = cc.getValue(o.result());
 689         }
 690         Value end = null;
 691         for (Op o : li.endExpression()) {
 692             transformToTritonOperation(kblock, o, valueTypeMap, opData, fsymTable);
 693             end = cc.getValue(o.result());
 694         }
 695         Value step = null;
 696         for (Op o : li.stepExpression()) {
 697             transformToTritonOperation(kblock, o, valueTypeMap, opData, fsymTable);
 698             step = cc.getValue(o.result());
 699         }
 700 
 701         // Obtain captured vars
 702         // true == stores
 703         // false == loads only
 704         Map<Boolean, Set<Value>> capturedVars = capturedVars(body);
 705         Set<Value> capturedAndStoredVars = capturedVars.get(true);
 706 
 707         // Get load values
 708         // Loaded values are hoisted out of the loop body
 709         Map<Value, Value> loadValues = new HashMap<>();
 710         for (Value v : capturedVars.get(false)) {
 711             Value load = kblock.op(varLoad(cc.getValue(v)));
 712             valueTypeMap.put(load, valueTypeMap.get(v));
 713             loadValues.put(v, load);
 714         }
 715 
 716         // Get iteration values -- represented by captured vars that are stored to in the loop
 717         // The SCF for operation returns the iteration values of the last loop iteration, which
 718         // are then to be stored to the iteration variables
 719         List<Value> iterValues = new ArrayList<>();
 720         for (Value v : capturedAndStoredVars) {
 721             iterValues.add(kblock.op(varLoad(cc.getValue(v))));
 722         }
 723 
 724         // @@@ Build in java code model, then transform?
 725         SCFOps.ForOp scffor = SCFOps.for_(kblock.parentBody(), start, end, step, iterValues)
 726                 // Ensure existing context is used
 727                 .body(CodeContext.create(cc), builder -> {
 728                     // Create index var initialized from entry block parameter
 729                     Value index = builder.parameters().get(0);
 730                     valueTypeMap.put(index, JavaType.INT);
 731                     Value varIndex = builder.op(var("index", index));
 732                     valueTypeMap.put(varIndex, JavaType.INT);
 733                     builder.context().mapValue(body.entryBlock().parameters().get(0), varIndex);
 734 
 735                     // Create iter vars initialized from entry block parameters
 736                     int pi = 1;
 737                     for (Value v : capturedAndStoredVars) {
 738                         TypeElement type = valueTypeMap.get(v);
 739                         Value iter = builder.parameters().get(pi++);
 740                         valueTypeMap.put(iter, type);
 741                         Value varIter = builder.op(var(Integer.toString(pi), iter));
 742                         valueTypeMap.put(varIter, type);
 743                         builder.context().mapValue(v, varIter);
 744                     }
 745 
 746                     // Transform the Java for body into the SCF for body
 747                     builder.body(body, List.of(), (block, op) -> {
 748                         // Yield iter values
 749                         if (op instanceof JavaOp.ContinueOp) {
 750                             // Replace with yield of loaded vars
 751                             List<Value> yieldValues = new ArrayList<>();
 752                             for (Value value : capturedAndStoredVars) {
 753                                 Value varIter = block.context().getValue(value);
 754                                 Value v = block.op(varLoad(varIter));
 755                                 yieldValues.add(v);
 756                             }
 757                             block.op(SCFOps.yield_(yieldValues));
 758                         } else if (op instanceof VarAccessOp.VarLoadOp) {
 759                             // Replace with value loaded immediately before loop
 760                             Value v = op.operands().get(0);
 761                             if (capturedVars.get(false).contains(v)) {
 762                                 block.context().mapValue(op.result(), loadValues.get(v));
 763                             } else {
 764                                 block.op(op);
 765                             }
 766                         } else {
 767                             block = transformToTritonOperation(block, op, valueTypeMap, opData, fsymTable);
 768                         }
 769                         return block;
 770                     });
 771                 });
 772         Op.Result forResult = kblock.op(scffor);
 773 
 774         // Assign back result to iter vars
 775         if (capturedAndStoredVars.size() == 1) {
 776             for (Value v : capturedAndStoredVars) {
 777                 kblock.op(varStore(cc.getValue(v), forResult));
 778             }
 779         } else {
 780             int i = 0;
 781             for (Value v : capturedAndStoredVars) {
 782                 kblock.op(varStore(cc.getValue(v),
 783                         kblock.op(tupleLoad(forResult, i++))));
 784             }
 785         }
 786     }
 787 
 788     static Map<Boolean, Set<Value>> capturedVars(Body body) {
 789         Map<Boolean, Set<Value>> capturedValues = new HashMap<>();
 790         capturedValues.put(false, new LinkedHashSet<>());
 791         capturedValues.put(true, new LinkedHashSet<>());
 792 
 793         capturedVars(capturedValues, new ArrayDeque<>(), body);
 794         return capturedValues;
 795     }
 796 
 797     static void capturedVars(Map<Boolean, Set<Value>> capturedVars, Deque<Body> bodyStack, Body body) {
 798         bodyStack.push(body);
 799 
 800         for (Block b : body.blocks()) {
 801             for (Op op : b.ops()) {
 802                 // @@@ Nested bodies
 803                 if (!op.bodies().isEmpty()) {
 804                     throw new IllegalStateException();
 805                 }
 806 //                for (Body childBody : op.bodies()) {
 807 //                    capturedAndUpdatedVars(capturedValues, bodyStack, childBody);
 808 //                }
 809 
 810                 if (op instanceof VarAccessOp) {
 811                     Value v = op.operands().get(0);
 812                     if (!bodyStack.contains(v.declaringBlock().ancestorBody())) {
 813                         if (op instanceof VarAccessOp.VarStoreOp) {
 814                             capturedVars.get(true).add(v);
 815                             capturedVars.get(false).remove(v);
 816                         } else if (!capturedVars.get(true).contains(v)) {
 817                             capturedVars.get(false).add(v);
 818                         }
 819                     }
 820                 }
 821             }
 822         }
 823 
 824         bodyStack.pop();
 825     }
 826 
 827     public static final ScopedValue<Boolean> SV_SSA = ScopedValue.newInstance();
 828 
 829     static TritonOps.FuncOp cleanup(TritonOps.FuncOp f) {
 830         // Remove var ops
 831         boolean doSSA = SV_SSA.isBound() ? SV_SSA.get() : true;
 832         if (doSSA) {
 833             f = SSA.transform(f);
 834         }
 835         // Remove unused ops
 836         f = f.transform((fblock, op) -> {
 837             if (op instanceof Op.Pure && op.result().uses().isEmpty()) {
 838                 return fblock;
 839             } else if (op instanceof VarAccessOp.VarLoadOp && op.result().uses().isEmpty()) {
 840                 return fblock;
 841             }
 842 
 843             fblock.op(op);
 844             return fblock;
 845         });
 846         return f;
 847     }
 848 
 849     static class TritonBuilderInterpreter {
 850         final Map<String, TritonOps.FuncOp> fsymTable;
 851         final Block.Builder block;
 852 
 853         TritonBuilderInterpreter(Map<String, TritonOps.FuncOp> fsymTable, Block.Builder block) {
 854             this.fsymTable = fsymTable;
 855             this.block = block;
 856         }
 857 
 858         Value build(Op op, String name, Map<Value, TypeElement> valueTypeMap) {
 859             // Obtain associated type-based method
 860             MethodHandle mh;
 861             try {
 862                 Optional<Method> om = Stream.of(TritonBuilderInterpreter.class.getDeclaredMethods())
 863                         .filter(m -> m.getName().equals(name))
 864                         .filter(m -> m.isVarArgs()
 865                         ? m.getParameterCount() / 2 - 1 <= op.operands().size()
 866                         : m.getParameterCount() / 2 - 1 == op.operands().size())
 867                         .findFirst();
 868                 mh = MethodHandles.lookup().unreflect(
 869                         om.orElseThrow(() -> new NoSuchMethodException(name)));
 870             } catch (ReflectiveOperationException e) {
 871                 throw new IllegalStateException(e);
 872             }
 873 
 874             List<Object> iArgs = new ArrayList<>();
 875             iArgs.add(this);
 876             iArgs.add(valueTypeMap.get(op.result()));
 877             iArgs.add(op.result());
 878             for (Value o : op.operands()) {
 879                 iArgs.add(valueTypeMap.get(o));
 880                 iArgs.add(o);
 881             }
 882             try {
 883                 return (Value) mh.invokeWithArguments(iArgs.toArray(Object[]::new));
 884             } catch (Throwable e) {
 885                 throw new IllegalStateException(e);
 886             }
 887         }
 888 
 889 
 890         public Value programId(TypeElement rType, Op.Result r,
 891                                ConstantType axisType, Value axis) {
 892             return block.op(TritonOps.getProgramId(
 893                     (int) axisType.value()));
 894         }
 895 
 896         public Value arange(TensorType rType, Op.Result r,
 897                             ConstantType startType, Value start,
 898                             ConstantType endType, Value end) {
 899             return block.op(TritonOps.makeRange(
 900                     (int) startType.value(),
 901                     (int) endType.value()));
 902         }
 903 
 904         public Value expand(TensorType rType, Op.Result r,
 905                             TensorType aType, Value a,
 906                             ConstantType axisType, Value axis) {
 907             return block.op(TritonOps.expand(
 908                     (int) axisType.value(),
 909                     rType,
 910                     block.context().getValue(a)));
 911         }
 912 
 913         public Value zeros(TensorType rType, Op.Result r,
 914                            ConstantType aType, Value a,
 915                            Object... constantsAndValues) {
 916             Object zero;
 917             try {
 918                 JavaType zeroType = (JavaType) aType.value();
 919                 zero = MethodHandles.zero((Class<?>) zeroType.resolve(MethodHandles.lookup())).invoke();
 920             } catch (Throwable e) {
 921                 throw new RuntimeException(e);
 922             }
 923             return block.op(ArithMathOps.constant(rType, zero));
 924         }
 925 
 926         public Value load(TensorType rType, Op.Result r,
 927                           TensorType ptrType, Value ptr,
 928                           TensorType maskType, Value mask) {
 929             broadcastConversionRight(ptrType, maskType, mask);
 930             return block.op(TritonOps.load(
 931                     rType,
 932                     block.context().getValue(ptr),
 933                     block.context().getValue(mask)));
 934         }
 935 
 936         public Value load(TensorType rType, Op.Result r,
 937                           TensorType ptrType, Value ptr,
 938                           TensorType maskType, Value mask,
 939                           ConstantType otherType, Value other) {
 940             broadcastConversionRight(ptrType, maskType, mask);
 941             Value mb = block.op(ArithMathOps.constant(rType, (float)otherType.value()));
 942             block.context().mapValue(other, mb);
 943             return block.op(TritonOps.load(
 944                     rType,
 945                     block.context().getValue(ptr),
 946                     block.context().getValue(mask),
 947                     block.context().getValue(other)));
 948         }
 949 
 950         public Value store(TensorType rType, Op.Result r,
 951                            TensorType ptrType, Value ptr,
 952                            TensorType valueType, Value value,
 953                            TensorType maskType, Value mask) {
 954             broadcastConversionRight(ptrType, valueType, value);
 955             broadcastConversionRight(ptrType, maskType, mask);
 956             return block.op(TritonOps.store(
 957                     block.context().getValue(ptr),
 958                     block.context().getValue(value),
 959                     block.context().getValue(mask)));
 960         }
 961 
 962         public Value broadcast(TensorType rType, Op.Result r,
 963                                TypeElement oType, Value o,
 964                                TensorType tensorTypeType, Value tensorType) {
 965             // @@@ tt.splat with scalar operand, tt.broadcast with tensor operand
 966             if (oType instanceof TensorType) {
 967                 return block.op(TritonOps.broadcast(
 968                         rType,
 969                         block.context().getValue(o)));
 970             } else {
 971                 return block.op(TritonOps.splat(
 972                         rType,
 973                         block.context().getValue(o)));
 974             }
 975         }
 976 
 977         public Value joinShape(TensorType rType, Op.Result r,
 978                                TensorType aType, Value a,
 979                                TensorType bType, Value b) {
 980             // Replace with constant operation to produce tensor type.
 981             // Result may be used, but transitively it will be removed due to no uses
 982             // contributing to the computation
 983             return block.op(CoreOp.constant(JavaType.type(TensorType.class), r.type()));
 984         }
 985 
 986 
 987         public Value add(TypeElement rType, Op.Result r,
 988                          TypeElement aType, Value a,
 989                          TypeElement bType, Value b) {
 990             broadcastConversion(rType, aType, a, bType, b);
 991             a = block.context().getValue(a);
 992             b = block.context().getValue(b);
 993 
 994             if (rType instanceof PtrType ||
 995                     rType instanceof TensorType t && t.eType() instanceof PtrType) {
 996                 return block.op(TritonOps.addptr(a, b));
 997             } else {
 998                 return block.op(ArithMathOps.add(a, b));
 999             }
1000         }
1001 
1002         public Value sub(TypeElement rType, Op.Result r,
1003                          TypeElement aType, Value a,
1004                          TypeElement bType, Value b) {
1005             broadcastConversion(rType, aType, a, bType, b);
1006             a = block.context().getValue(a);
1007             b = block.context().getValue(b);
1008 
1009             return block.op(ArithMathOps.sub(a, b));
1010         }
1011 
1012         public Value mul(TypeElement rType, Op.Result r,
1013                          TypeElement aType, Value a,
1014                          TypeElement bType, Value b) {
1015             broadcastConversion(rType, aType, a, bType, b);
1016             a = block.context().getValue(a);
1017             b = block.context().getValue(b);
1018 
1019             return block.op(ArithMathOps.mul(a, b));
1020         }
1021 
1022         public Value div(TypeElement rType, Op.Result r,
1023                          TypeElement aType, Value a,
1024                          TypeElement bType, Value b) {
1025             broadcastConversion(rType, aType, a, bType, b);
1026             a = block.context().getValue(a);
1027             b = block.context().getValue(b);
1028 
1029             return block.op(ArithMathOps.div(a, b));
1030         }
1031 
1032         public Value mod(TypeElement rType, Op.Result r,
1033                          TypeElement aType, Value a,
1034                          TypeElement bType, Value b) {
1035             broadcastConversion(rType, aType, a, bType, b);
1036             a = block.context().getValue(a);
1037             b = block.context().getValue(b);
1038 
1039             return block.op(ArithMathOps.rem(a, b));
1040         }
1041 
1042         public Value and(TypeElement rType, Op.Result r,
1043                          TypeElement aType, Value a,
1044                          TypeElement bType, Value b) {
1045             broadcastConversion(rType, aType, a, bType, b);
1046             a = block.context().getValue(a);
1047             b = block.context().getValue(b);
1048 
1049             return block.op(ArithMathOps.and(a, b));
1050         }
1051 
1052         public Value dot(TensorType rType, Op.Result r,
1053                          TypeElement aType, Value a,
1054                          TypeElement bType, Value b) {
1055             a = block.context().getValue(a);
1056             b = block.context().getValue(b);
1057             // Computed result is tensor of floats, regardless of inputs
1058             Object zero = 0.0f;
1059             var c = block.op(ArithMathOps.constant(rType, zero));
1060             return block.op(TritonOps.dot(rType, a, b, c));
1061         }
1062 
1063         public Value cdiv(TypeElement rType, Op.Result r,
1064                           TypeElement aType, Value a,
1065                           TypeElement bType, Value b) {
1066             a = block.context().getValue(a);
1067             b = block.context().getValue(b);
1068 
1069             TritonOps.FuncOp cdiv = tritonFunction(Functions.getJavaCodeModel("cdiv"),
1070                     rType, List.of(aType, bType),
1071                     fsymTable);
1072             // @@@ Generalize
1073             List<Value> args = new ArrayList<>();
1074             if (!(aType instanceof ConstantType)) {
1075                 args.add(a);
1076             }
1077             if (!(bType instanceof ConstantType)) {
1078                 args.add(b);
1079             }
1080             return block.op(TritonOps.call(cdiv, args));
1081         }
1082 
1083         public Value conv(TypeElement rType, Op.Result r,
1084                           ConstantType tType, Value t,
1085                           TypeElement aType, Value a) {
1086             a = block.context().getValue(a);
1087 
1088             TypeElement rScalarType;
1089             TypeElement aScalarType;
1090             if (rType instanceof TensorType rTensorType && aType instanceof TensorType aTensorType) {
1091                 rScalarType = rTensorType.eType();
1092                 aScalarType = aTensorType.eType();
1093             } else {
1094                 rScalarType = rType;
1095                 aScalarType = aType;
1096             }
1097 
1098             if (rScalarType.equals(Float16.FLOAT_16_TYPE) && aScalarType.equals(JavaType.FLOAT)) {
1099                 return block.op(ArithMathOps.trunc(rType, a));
1100             } else if (rType.equals(aType)) {
1101                 return a;
1102             } else {
1103                 throw new IllegalStateException();
1104             }
1105         }
1106 
1107         public Value exp(TritonType rType, Op.Result r,
1108                          TritonType aType, Value a) {
1109             return block.op(ArithMathOps.exp(
1110                     block.context().getValue(a)));
1111         }
1112 
1113         public Value compare(TensorType rType, Op.Result r,
1114                              TypeElement aType, Value a,
1115                              TypeElement bType, Value b,
1116                              ConstantType compareType, Value compare) {
1117             Triton.CompareKind ck = (Triton.CompareKind) compareType.value();
1118 
1119             ArithMathOps.CompareOp.CompareKind ack = switch (ck) {
1120                 case LessThan -> ArithMathOps.CompareOp.CompareKind.slt;
1121                 default -> throw new UnsupportedOperationException("Unsupported comparison: " + ck);
1122             };
1123 
1124             broadcastConversionRight(aType, bType, b);
1125             a = block.context().getValue(a);
1126             b = block.context().getValue(b);
1127 
1128             return block.op(ArithMathOps.cmp(ack, a, b));
1129         }
1130 
1131 
1132         public Value max(TypeElement rType, Op.Result r,
1133                          TensorType xType, Value x,
1134                          ConstantType axisType, Value axis) {
1135             TritonOps.FuncOp f = tritonFunction(Functions.getJavaCodeModel("max"),
1136                     rType, List.of(rType, rType), fsymTable);
1137             return reduce(rType, r, xType, x, axisType, axis, f);
1138         }
1139 
1140         public Value sum(TypeElement rType, Op.Result r,
1141                          TensorType xType, Value x,
1142                          ConstantType axisType, Value axis) {
1143             TritonOps.FuncOp f = tritonFunction(Functions.getJavaCodeModel("sum"),
1144                     rType, List.of(rType, rType), fsymTable);
1145             return reduce(rType, r, xType, x, axisType, axis, f);
1146         }
1147 
1148         Value reduce(TypeElement rType, Op.Result r,
1149                      TensorType xType, Value x,
1150                      ConstantType axisType, Value axis,
1151                      TritonOps.FuncOp f) {
1152             int axisConstant = (int) axisType.value();
1153 
1154             String signature = "reduce_" + f.funcName() + "_" + axisConstant;
1155             TritonOps.FuncOp rf = fsymTable.computeIfAbsent(signature,
1156                     s -> reduce(rType, xType, axisConstant, s, f));
1157 
1158             return block.op(TritonOps.call(rf, block.context().getValue(x)));
1159         }
1160 
1161         static TritonOps.FuncOp reduce(TypeElement elementType,
1162                                        TensorType tensorType,
1163                                        int axisConstant,
1164                                        String name, TritonOps.FuncOp scalarFunc) {
1165             return TritonOps.func(name,
1166                             functionType(elementType, tensorType))
1167                     .body(fblock -> {
1168                         TritonOps.ReduceOp reduceOp = TritonOps.reduce(fblock.parentBody(),
1169                                         axisConstant, fblock.parameters().get(0),
1170                                         functionType(elementType, elementType, elementType))
1171                                 .body(rblock -> {
1172                                     Block.Parameter a = rblock.parameters().get(0);
1173                                     Block.Parameter b = rblock.parameters().get(1);
1174                                     Op.Result _r = rblock.op(TritonOps.call(scalarFunc, a, b));
1175                                     rblock.op(TritonOps.reduceReturn(_r));
1176                                 });
1177 
1178                         Op.Result opr = fblock.op(reduceOp);
1179                         fblock.op(TritonOps.return_(opr));
1180                     });
1181         }
1182 
1183         // @@@ Test
1184         public Value consume(TypeElement rType, Op.Result r,
1185                              TypeElement aType, Value a) {
1186             return block.op(TritonTestOps.consume(block.context().getValue(a)));
1187         }
1188 
1189         void broadcastConversion(TypeElement rType,
1190                                  TypeElement aType, Value a,
1191                                  TypeElement bType, Value b) {
1192             Value ma = block.context().getValue(a);
1193             Value mb = block.context().getValue(b);
1194             if (aType instanceof TensorType at && bType instanceof TensorType bTensorType) {
1195                 TensorType rTensorType = (TensorType) rType;
1196                 if (!at.shape().equals(rTensorType.shape())) {
1197                     ma = block.op(TritonOps.broadcast(rTensorType, ma));
1198                 }
1199                 if (!bTensorType.shape().equals(rTensorType.shape())) {
1200                     if (rTensorType.eType() instanceof PtrType) {
1201                         bTensorType = new TensorType(bType, rTensorType.shape());
1202                         mb = block.op(TritonOps.broadcast(bTensorType, mb));
1203                     } else {
1204                         mb = block.op(TritonOps.broadcast(rTensorType, mb));
1205                     }
1206                 }
1207             } else if (aType instanceof TensorType) {
1208                 TensorType rTensorType = (TensorType) rType;
1209                 if (rTensorType.eType() instanceof PtrType) {
1210                     TensorType bTensorType = new TensorType(bType, rTensorType.shape());
1211                     mb = block.op(TritonOps.splat(bTensorType, mb));
1212                 } else {
1213                     mb = block.op(TritonOps.splat(rTensorType, mb));
1214                 }
1215             } else if (bType instanceof TensorType) {
1216                 TensorType rTensorType = (TensorType) rType;
1217                 ma = block.op(TritonOps.splat(rTensorType, ma));
1218             }
1219             block.context().mapValue(a, ma);
1220             block.context().mapValue(b, mb);
1221         }
1222 
1223         void broadcastConversionRight(TypeElement aType,
1224                                       TypeElement bType, Value b) {
1225             Value mb = block.context().getValue(b);
1226             if (aType instanceof TensorType aTensorType && bType instanceof TensorType bTensorType) {
1227                 if (!bTensorType.shape().equals(aTensorType.shape())) {
1228                     if (aTensorType.eType() instanceof PtrType) {
1229                         bTensorType = new TensorType(bTensorType.eType(), aTensorType.shape());
1230                         mb = block.op(TritonOps.broadcast(bTensorType, mb));
1231                     } else {
1232                         mb = block.op(TritonOps.broadcast(aTensorType, mb));
1233                     }
1234                 }
1235             } else if (aType instanceof TensorType rTensorType) {
1236                 if (rTensorType.eType() instanceof PtrType) {
1237                     TensorType bTensorType = new TensorType(bType, rTensorType.shape());
1238                     mb = block.op(TritonOps.splat(bTensorType, mb));
1239                 } else {
1240                     mb = block.op(TritonOps.splat(rTensorType, mb));
1241                 }
1242             }
1243             block.context().mapValue(b, mb);
1244         }
1245     }
1246 
1247     public static <O extends Op & Op.Invokable> void printTypeMap(
1248             O kernel, Map<Value, TypeElement> valueTypeMap) {
1249         AtomicInteger valueId = new AtomicInteger();
1250         Map<Value, Integer> valueIdMap = new LinkedHashMap<>();
1251         kernel.elements().forEach(codeElement -> {
1252             switch (codeElement) {
1253                 case FuncOp _ -> {
1254                     // Ignore
1255                 }
1256                 case Op op when !op.result().type().equals(JavaType.VOID) -> {
1257                     valueIdMap.put(op.result(), valueId.getAndIncrement());
1258                 }
1259                 case Block block -> {
1260                     for (Block.Parameter parameter : block.parameters()) {
1261                         valueIdMap.put(parameter, valueId.getAndIncrement());
1262                     }
1263                 }
1264                 default -> {
1265                 }
1266             }
1267         });
1268 
1269         valueIdMap.forEach((value, id) -> {
1270             TypeElement type = valueTypeMap.get(value);
1271             if (type != null) {
1272                 System.out.println("%" + id + " : " + value.type() + " -> " + type);
1273             }
1274         });
1275     }
1276 }