1 /*
   2  * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
   3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
   4  *
   5  * This code is free software; you can redistribute it and/or modify it
   6  * under the terms of the GNU General Public License version 2 only, as
   7  * published by the Free Software Foundation.  Oracle designates this
   8  * particular file as subject to the "Classpath" exception as provided
   9  * by Oracle in the LICENSE file that accompanied this code.
  10  *
  11  * This code is distributed in the hope that it will be useful, but WITHOUT
  12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
  13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
  14  * version 2 for more details (a copy is included in the LICENSE file that
  15  * accompanied this code).
  16  *
  17  * You should have received a copy of the GNU General Public License version
  18  * 2 along with this work; if not, write to the Free Software Foundation,
  19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
  20  *
  21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
  22  * or visit www.oracle.com if you need additional information or have any
  23  * questions.
  24  */
  25 
  26 package oracle.code.triton;
  27 
  28 import java.lang.invoke.MethodHandle;
  29 import java.lang.invoke.MethodHandles;
  30 import java.lang.reflect.Field;
  31 import java.lang.reflect.Method;
  32 import java.lang.reflect.code.*;
  33 import java.lang.reflect.code.analysis.SSA;
  34 import java.lang.reflect.code.op.CoreOp;
  35 import java.lang.reflect.code.op.ExtendedOp;
  36 import java.lang.reflect.code.type.JavaType;
  37 import java.lang.reflect.code.type.VarType;
  38 import java.util.*;
  39 import java.util.concurrent.atomic.AtomicInteger;
  40 import java.util.stream.Stream;
  41 
  42 import static java.lang.reflect.code.op.CoreOp.*;
  43 import static java.lang.reflect.code.type.FunctionType.functionType;
  44 
  45 public final class TritonTransformer {
  46     private TritonTransformer() {}
  47 
  48     static final JavaType TYPE_Triton = JavaType.type(Triton.class);
  49 
  50     static final JavaType TYPE_Triton_Test = JavaType.ofString("oracle.code.triton.TritonTest");
  51 
  52     static final JavaType TYPE_Tensor = JavaType.type(Tensor.class);
  53 
  54     static final JavaType TYPE_J_L_MATH = JavaType.type(Math.class);
  55 
  56     public static <O extends Op & Op.Invokable>
  57     TritonOps.ModuleOp tritonModule(O kernel,
  58                                     TypeElement rType,
  59                                     List<? extends TypeElement> argTypes) {
  60         Map<String, TritonOps.FuncOp> fsymTable = new LinkedHashMap<>();
  61         tritonFunction(kernel, rType, argTypes, fsymTable);
  62         return TritonOps.module(fsymTable.values().stream().toList());
  63     }
  64 
  65     public static <O extends Op & Op.Invokable>
  66     TritonOps.FuncOp tritonFunction(O javaKernel,
  67                                     TypeElement rType,
  68                                     List<? extends TypeElement> argTypes,
  69                                     Map<String, TritonOps.FuncOp> fsymTable) {
  70         String name = (javaKernel instanceof FuncOp f) ? f.funcName() : "kernel";
  71         String signature = signature(name, rType, argTypes);
  72         if (fsymTable.containsKey(signature)) {
  73             return fsymTable.get(signature);
  74         }
  75 
  76         System.out.println(javaKernel.toText());
  77 
  78         Map<Value, TypeElement> valueTypeMap = new HashMap<>();
  79         Map<Op, Object> opData = new HashMap<>();
  80         TritonTransformer.typeCheckKernel(javaKernel, argTypes, valueTypeMap, opData);
  81         TritonTransformer.printTypeMap(javaKernel, valueTypeMap);
  82 
  83         return TritonTransformer.transformToTritonFunction(javaKernel, signature,
  84                 rType, valueTypeMap, opData,
  85                 fsymTable);
  86     }
  87 
  88     static String signature(String name, TypeElement rType, List<? extends TypeElement> argTypes) {
  89         StringBuilder sb = new StringBuilder(name);
  90 
  91         for (TypeElement argType : argTypes) {
  92             sb.append("_");
  93             if (argType instanceof ConstantType ct) {
  94                 sb.append(ct.value());
  95             } else {
  96                 sb.append(argType);
  97             }
  98         }
  99         sb.append("_");
 100         sb.append(rType);
 101         return sb.toString();
 102     }
 103 
 104     public static <O extends Op & Op.Invokable> void typeCheckKernel(
 105             O kernel, List<? extends TypeElement> argTypes,
 106             Map<Value, TypeElement> valueTypeMap, Map<Op, Object> opData) {
 107         kernel.traverse(null, CodeElement.opVisitor((o, op) -> {
 108             switch (op) {
 109                 case Op.Invokable fop -> {
 110                     List<Block.Parameter> parameters = fop.body().entryBlock().parameters();
 111                     for (int i = 0; i < parameters.size(); i++) {
 112                         valueTypeMap.put(parameters.get(i), argTypes.get(i));
 113                     }
 114                 }
 115                 case VarOp _, VarAccessOp.VarLoadOp _ -> {
 116                     Value init = op.operands().get(0);
 117                     valueTypeMap.put(op.result(), valueTypeMap.get(init));
 118                 }
 119                 case VarAccessOp.VarStoreOp _ -> {
 120                     Value var = op.operands().get(0);
 121                     TypeElement varType = valueTypeMap.get(var);
 122                     Value v = op.operands().get(1);
 123                     TypeElement vType = valueTypeMap.get(v);
 124                     if (!varType.equals(vType)) {
 125                         throw new IllegalStateException("Storing to variable with different type: "
 126                                 + varType + " <- " + vType);
 127                     }
 128 
 129                     valueTypeMap.put(op.result(), valueTypeMap.get(var));
 130                 }
 131                 case ConstantOp cop -> {
 132                     valueTypeMap.put(op.result(), new ConstantType(op.result().type(), cop.value()));
 133                 }
 134                 case ArithmeticOperation _ -> {
 135                     TypeElement t = checkWithTypeInterpreter(op, op.opName(), valueTypeMap);
 136                     valueTypeMap.put(op.result(), t);
 137                 }
 138                 case FieldAccessOp.FieldLoadOp flop -> {
 139                     if (!flop.operands().isEmpty()) {
 140                         throw new IllegalStateException("Unsupported field load: " + flop.fieldDescriptor());
 141                     }
 142 
 143                     Field f;
 144                     try {
 145                         f = flop.fieldDescriptor().resolveToMember(MethodHandles.lookup());
 146                     } catch (ReflectiveOperationException e) {
 147                         throw new IllegalStateException("Unsupported field load: " + flop.fieldDescriptor(), e);
 148                     }
 149                     Object value;
 150                     try {
 151                         value = f.get(null);
 152                     } catch (IllegalAccessException e) {
 153                         throw new IllegalStateException("Unsupported field load: " + f, e);
 154                     }
 155                     valueTypeMap.put(op.result(), new ConstantType(JavaType.type(f.getType()), value));
 156                 }
 157                 case InvokeOp iop when iop.invokeDescriptor().refType().equals(JavaType.J_L_INTEGER) -> {
 158                     // Box
 159                     if (iop.invokeDescriptor().name().equals("valueOf")) {
 160                         Value a = op.operands().get(0);
 161                         valueTypeMap.put(op.result(), valueTypeMap.get(a));
 162                     } else {
 163                         throw new UnsupportedOperationException("Unsupported invocation on Integer: " + iop.invokeDescriptor());
 164                     }
 165                 }
 166                 case InvokeOp iop when iop.invokeDescriptor().refType().equals(TYPE_J_L_MATH) -> {
 167                     String name = iop.invokeDescriptor().name();
 168                     if (name.equals("max") || name.equals("min")) {
 169                         Value a = op.operands().get(0);
 170                         valueTypeMap.put(op.result(), valueTypeMap.get(a));
 171                     } else {
 172                         throw new UnsupportedOperationException("Unsupported invocation on Math: " + iop.invokeDescriptor());
 173                     }
 174                 }
 175                 case InvokeOp iop when iop.invokeDescriptor().refType().equals(TYPE_Tensor) -> {
 176                     if (iop.invokeDescriptor().name().equals("type")) {
 177                         Value a = op.operands().get(0);
 178                         valueTypeMap.put(op.result(), valueTypeMap.get(a));
 179                     } else {
 180                         throw new UnsupportedOperationException("Unsupported invocation on Tensor: " + iop.invokeDescriptor());
 181                     }
 182                 }
 183                 case InvokeOp iop when iop.invokeDescriptor().refType().equals(TYPE_Triton) -> {
 184                     TypeElement t = checkWithTypeInterpreter(op, iop.invokeDescriptor().name(), valueTypeMap);
 185                     valueTypeMap.put(op.result(), t);
 186                 }
 187                 case InvokeOp iop when iop.invokeDescriptor().refType().equals(TYPE_Triton_Test) -> {
 188                     TypeElement t = checkWithTypeInterpreter(op, iop.invokeDescriptor().name(), valueTypeMap);
 189                     valueTypeMap.put(op.result(), t);
 190                 }
 191                 case ExtendedOp.JavaForOp fop -> {
 192                     SimpleCountedForLoopInfo li = new SimpleCountedForLoopInfo(fop);
 193                     opData.put(fop, li);
 194 
 195                     TypeElement type = fop.init().yieldType();
 196                     if (type instanceof VarType vt && vt.valueType().equals(JavaType.INT)) {
 197                         for (Body b : List.of(fop.cond(), fop.update(), fop.loopBody())) {
 198                             valueTypeMap.put(b.entryBlock().parameters().get(0), JavaType.INT);
 199                         }
 200                     } else {
 201                         throw new IllegalStateException();
 202                     }
 203                 }
 204                 case TestOperation _ -> {
 205                 }
 206                 case ExtendedOp.JavaContinueOp _ -> {
 207                 }
 208                 case YieldOp _ -> {
 209                 }
 210                 case ReturnOp _ -> {
 211                 }
 212                 default -> throw new UnsupportedOperationException("Unsupported operation: " + op);
 213             }
 214 
 215             return null;
 216         }));
 217     }
 218 
 219     static TypeElement checkWithTypeInterpreter(Op op, String name, Map<Value, TypeElement> valueTypeMap) {
 220         // Obtain associated type-based method
 221         MethodHandle mh;
 222         try {
 223             Optional<Method> om = Stream.of(TritonTypeInterpreter.class.getDeclaredMethods())
 224                     .filter(m -> m.getName().equals(name))
 225                     .findFirst();
 226             mh = MethodHandles.lookup().unreflect(
 227                     om.orElseThrow(() -> new NoSuchMethodException(name)));
 228         } catch (ReflectiveOperationException e) {
 229             throw new IllegalStateException(name, e);
 230         }
 231 
 232         // Invoke with the values' types
 233         List<TypeElement> operandTypes = op.operands().stream().map(valueTypeMap::get).toList();
 234         try {
 235             return (TypeElement) mh.invokeWithArguments(operandTypes.toArray(Object[]::new));
 236         } catch (Throwable e) {
 237             throw new IllegalStateException(mh.toString(), e);
 238         }
 239     }
 240 
 241     // @@@ type check tensor shapes
 242     static class TritonTypeInterpreter {
 243         private TritonTypeInterpreter() {
 244         }
 245 
 246         //                 int programId(@Constant int axis) {
 247         public static JavaType programId(ConstantType axis) {
 248             assert axis.cType().equals(JavaType.INT);
 249             int axisValue = (int) axis.value();
 250             if (axisValue < 0 || axisValue > 3) {
 251                 throw new IllegalStateException();
 252             }
 253 
 254             return JavaType.INT;
 255         }
 256 
 257         //                Tensor arange(@Constant int start, @Constant int end)
 258         public static TensorType arange(ConstantType start, ConstantType end) {
 259             assert start.cType().equals(JavaType.INT);
 260             assert end.cType().equals(JavaType.INT);
 261 
 262             int startValue = (int) start.value();
 263             int endValue = (int) end.value();
 264 
 265             return new TensorType(JavaType.INT, List.of(endValue - startValue));
 266         }
 267 
 268         //                Tensor expand(Tensor a, int axis) {
 269         public static TensorType expand(TensorType a, ConstantType axis) {
 270             assert axis.cType().equals(JavaType.INT);
 271             int axisValue = (int) axis.value();
 272 
 273             List<Integer> s = new ArrayList<>(a.shape());
 274             if (axisValue < s.size()) {
 275                 s.add(axisValue, 1);
 276             } else {
 277                 for (int i = 0; i <= (axisValue - s.size()); i++) {
 278                     s.add(1);
 279                 }
 280             }
 281             return new TensorType(a.eType(), s);
 282         }
 283 
 284         //                Tensor load(Tensor ptr, Tensor mask)
 285         public static TensorType load(TensorType ptr, TensorType mask) {
 286             checkTensorShape(ptr, mask);
 287             if (ptr.eType() instanceof PtrType eptr) {
 288                 return new TensorType(eptr.rType(), ptr.shape());
 289             }
 290 
 291             throw new IllegalStateException();
 292         }
 293 
 294         //            void store(Tensor ptr, Tensor value, Tensor mask)
 295         public static void store(TensorType ptr, TensorType value, TensorType mask) {
 296             if (!(ptr.eType() instanceof PtrType)) {
 297                 throw new IllegalStateException();
 298             }
 299         }
 300 
 301         //                Tensor zeros(TensorType type)
 302         public static TensorType zeros(ConstantType eType, ConstantType... cShape) {
 303             List<Integer> shape = Stream.of(cShape).map(s -> (int) s.value()).toList();
 304             return new TensorType((TypeElement) eType.value(), shape);
 305         }
 306 
 307         //                Tensor broadcast(Object o, TensorType type)
 308         public static TensorType broadcast(TypeElement o, TensorType type) {
 309             if (o instanceof TensorType ot) {
 310                 // @@@
 311                 if (ot.shape().size() != type.shape().size()) {
 312                     throw new IllegalStateException();
 313                 }
 314                 o = ot.eType();
 315             } if (o instanceof ConstantType oc) {
 316                 o = oc.cType();
 317             }
 318             return new TensorType(o, type.shape());
 319         }
 320 
 321         public static TensorType joinShape(TensorType a, TensorType b) {
 322             return checkTensorTypes(a, b);
 323         }
 324 
 325         //          Tensor add(Number a, Number b)
 326         //             Ptr add(Ptr a, int offset)
 327         public static TypeElement add(TypeElement a, TypeElement b) {
 328             // @@@ Pass additional argument for checking ptr
 329             return binary(a, b);
 330         }
 331 
 332         public static TypeElement sub(TypeElement a, TypeElement b) {
 333             return binary(a, b);
 334         }
 335 
 336         public static TypeElement mul(TypeElement a, TypeElement b) {
 337             return binary(a, b);
 338         }
 339 
 340         public static TypeElement div(TypeElement a, TypeElement b) {
 341             return binary(a, b);
 342         }
 343 
 344         public static TypeElement mod(TypeElement a, TypeElement b) {
 345             return binary(a, b);
 346         }
 347 
 348         public static TypeElement and(TypeElement a, TypeElement b) {
 349             return binary(a, b);
 350         }
 351 
 352         public static TypeElement cdiv(TypeElement a, TypeElement b) {
 353             a = reduceScalarType(a);
 354             b = reduceScalarType(b);
 355             if (!a.equals(JavaType.INT) && !b.equals(JavaType.INT)) {
 356                 throw new IllegalStateException();
 357             }
 358             return a;
 359         }
 360 
 361         //          Number conv(Type t, Number a) {
 362         public static TypeElement conv(ConstantType eType, TypeElement a) {
 363             return convTypes(eType, a);
 364         }
 365 
 366         public static TypeElement convTypes(ConstantType eType, TypeElement a) {
 367             if (a instanceof TensorType tb) {
 368                 TypeElement e = convScalarTypes(eType, tb.eType());
 369                 return new TensorType(e, tb.shape());
 370             } else {
 371                 return convScalarTypes(eType, a);
 372             }
 373         }
 374 
 375         public static TypeElement convScalarTypes(ConstantType eType, TypeElement a) {
 376             TypeElement t = (TypeElement) eType.value();
 377             if (t.equals(Float16.FLOAT_16_TYPE) && a.equals(JavaType.FLOAT)) {
 378                 return Float16.FLOAT_16_TYPE;
 379             } else if (t.equals(a)) {
 380                 return t;
 381             } else {
 382                 // @@@ Conversions;
 383                 throw new IllegalStateException();
 384             }
 385         }
 386 
 387         //          Tensor exp(Tensor a)
 388         public static TypeElement exp(TypeElement a) {
 389             return unary(a);
 390         }
 391 
 392         static TypeElement unary(TypeElement a) {
 393             return a;
 394         }
 395 
 396         //                Tensor compare(Number a, Number b, @Constant CompareKind ck) {
 397         public static TypeElement compare(TypeElement a, TypeElement b, ConstantType kind) {
 398             assert kind.cType().equals(JavaType.type(Triton.CompareKind.class));
 399 
 400             return binary(a, b);
 401         }
 402 
 403         //                Tensor dot(Tensor a, Tensor b)
 404         public static TensorType dot(TensorType a, TensorType b) {
 405             if (a.shape().size() != 2 || b.shape().size() != 2) {
 406                 throw new IllegalStateException();
 407             }
 408 
 409             if (!a.shape().get(1).equals(b.shape().get(0))) {
 410                 throw new IllegalStateException();
 411             }
 412 
 413             if (a.eType() != b.eType()) {
 414                 // @@@ Conversion, type checking
 415                 throw new IllegalStateException();
 416             }
 417 
 418             // Computed result is tensor of floats, regardless of inputs
 419             return new TensorType(JavaType.FLOAT, List.of(a.shape().get(0), b.shape().get(1)));
 420         }
 421 
 422 
 423         //                Tensor max(Tensor a, @Constant int axis) {
 424         public static TypeElement max(TensorType a, ConstantType axis) {
 425             return reduce(a, axis);
 426         }
 427 
 428         //                Tensor sum(Tensor a, @Constant int axis) {
 429         public static TypeElement sum(TensorType a, ConstantType axis) {
 430             return reduce(a, axis);
 431         }
 432 
 433         static TypeElement reduce(TensorType a, ConstantType axis) {
 434             assert axis.cType().equals(JavaType.INT);
 435             int axisValue = (int) axis.value();
 436             if (axisValue < 0 || axisValue > 3) {
 437                 throw new IllegalStateException();
 438             }
 439 
 440             List<Integer> reduceShape = new ArrayList<>();
 441             for (int i = 0; i < a.shape().size(); i++) {
 442                 if (i != axisValue) {
 443                     reduceShape.add(a.shape().get(i));
 444                 } else {
 445                     reduceShape.add(1);
 446                 }
 447             }
 448 
 449             if (reduceShape.size() == 1 && reduceShape.getFirst() == 1) {
 450                 return a.eType();
 451             } else {
 452                 return new TensorType(a.eType(), reduceShape);
 453             }
 454         }
 455 
 456         // @@@ Test
 457         public static void consume(TypeElement a) {
 458         }
 459 
 460 
 461         static TypeElement binary(TypeElement a, TypeElement b) {
 462             if (a instanceof TensorType ta && b instanceof TensorType tb) {
 463                 return checkTensorTypes(ta, tb);
 464             } else if (a instanceof TensorType ta) {
 465                 return new TensorType(checkScalarTypes(ta.eType(), b), ta.shape());
 466             } else if (b instanceof TensorType tb) {
 467                 return new TensorType(checkScalarTypes(a, tb.eType()), tb.shape());
 468             } else {
 469                 return checkScalarTypes(a, b);
 470             }
 471         }
 472 
 473         static TensorType checkTensorTypes(TensorType a, TensorType b) {
 474             List<Integer> s = checkTensorShape(a, b);
 475             TypeElement e = checkScalarTypes(a.eType(), b.eType());
 476             return new TensorType(e, s);
 477         }
 478 
 479         static List<Integer> checkTensorShape(TensorType a, TensorType b) {
 480             if (a.shape().size() != b.shape().size()) {
 481                 // Shape mismatch
 482                 throw new IllegalStateException();
 483             }
 484 
 485             List<Integer> s = new ArrayList<>();
 486             for (int i = 0; i < a.shape().size(); i++) {
 487                 int ad = a.shape().get(i);
 488                 int bd = b.shape().get(i);
 489 
 490                 // Expand dimensions
 491                 int d;
 492                 if (ad == bd) {
 493                     d = ad;
 494                 } else {
 495                     if (ad != 1 && bd == 1) {
 496                         d = ad;
 497                     } else if (ad == 1) {
 498                         d = bd;
 499                     } else {
 500                         // Shape mismatch
 501                         throw new IllegalStateException();
 502                     }
 503                 }
 504 
 505                 s.add(d);
 506             }
 507 
 508             return s;
 509         }
 510 
 511         static TypeElement checkScalarTypes(TypeElement a, TypeElement b) {
 512             // @@@ Optional ptr checking
 513             if (a instanceof PtrType) {
 514                 if (!b.equals(JavaType.INT)) {
 515                     throw new IllegalStateException();
 516                 }
 517             } else if (b instanceof PtrType) {
 518                 // Pointer must be first argument
 519                 throw new IllegalStateException();
 520             } else if (a instanceof ConstantType || b instanceof ConstantType) {
 521                 return checkScalarTypes(reduceScalarType(a), reduceScalarType(b));
 522             } else if (!a.equals(b)) {
 523                 // @@@ Conversion
 524                 throw new IllegalStateException();
 525             }
 526             return a;
 527         }
 528 
 529         static TypeElement reduceScalarType(TypeElement a) {
 530             return a instanceof ConstantType ct ? ct.cType() : a;
 531         }
 532     }
 533 
 534     public static <O extends Op & Op.Invokable> TritonOps.FuncOp transformToTritonFunction(
 535             O kernel,
 536             String signature,
 537             TypeElement rType,
 538             Map<Value, TypeElement> valueTypeMap, Map<Op, Object> opData,
 539             Map<String, TritonOps.FuncOp> fsymTable) {
 540         TritonOps.FuncOp ttKernel = TritonOps.func(signature, functionType(rType))
 541                 .body(fblock -> {
 542                     // Process kernel parameters
 543                     List<Value> args = new ArrayList<>();
 544                     for (Block.Parameter kp : kernel.body().entryBlock().parameters()) {
 545                         TypeElement type = valueTypeMap.get(kp);
 546                         if (type instanceof ConstantType ct) {
 547                             // Constant
 548                             Op.Result cr = fblock.op(ArithMathOps.constant(
 549                                     ct.cType(), ct.value()));
 550                             args.add(cr);
 551                         } else {
 552                             args.add(fblock.parameter(type));
 553                         }
 554                     }
 555 
 556                     // Transform kernel body
 557                     fblock.transformBody(kernel.body(), args, (kblock, op) -> {
 558                         return transformToTritonOperation(kblock, op, valueTypeMap, opData, fsymTable);
 559                     });
 560                 });
 561 
 562         ttKernel = cleanup(ttKernel);
 563         fsymTable.put(ttKernel.funcName(), ttKernel);
 564         return ttKernel;
 565     }
 566 
 567     static Block.Builder transformToTritonOperation(Block.Builder kblock, Op op,
 568                                                     Map<Value, TypeElement> valueTypeMap, Map<Op, Object> opData,
 569                                                     Map<String, TritonOps.FuncOp> fsymTable) {
 570         // @@@ Avoid constructing for each operation -- block builder passed as argument or a scoped value
 571         TritonBuilderInterpreter tbi = new TritonBuilderInterpreter(fsymTable, kblock);
 572         CopyContext cc = kblock.context();
 573         switch (op) {
 574             case VarOp varOp -> {
 575                 // @@@ Cannot copy op because the result type
 576                 //     is derived from init type
 577                 Value init = cc.getValue(op.operands().get(0));
 578                 Op.Result r = kblock.op(var(varOp.varName(), init));
 579                 cc.mapValue(op.result(), r);
 580             }
 581             case ConstantOp cop -> {
 582                 TypeElement t = valueTypeMap.get(cop.result());
 583                 if (t instanceof ConstantType ct) {
 584                     Op.Result r = kblock.op(ArithMathOps.constant(
 585                             ct.cType(), ct.value()));
 586                     cc.mapValue(op.result(), r);
 587                 } else {
 588                     kblock.op(op);
 589                 }
 590             }
 591             case ArithmeticOperation _ -> {
 592                 Value result = tbi.build(op, op.opName(), valueTypeMap);
 593                 if (result != null) {
 594                     cc.mapValue(op.result(), result);
 595                 }
 596             }
 597             case InvokeOp iop when iop.invokeDescriptor().refType().equals(JavaType.J_L_INTEGER) -> {
 598                 // Replace box with its value
 599                 Value a = cc.getValue(op.operands().get(0));
 600                 cc.mapValue(op.result(), a);
 601             }
 602             case InvokeOp iop when iop.invokeDescriptor().refType().equals(TYPE_J_L_MATH) -> {
 603                 String name = iop.invokeDescriptor().name();
 604                 if (name.equals("max")) {
 605                     Value a = cc.getValue(op.operands().get(0));
 606                     Value b = cc.getValue(op.operands().get(1));
 607 
 608                     Op.Result result = kblock.op(ArithMathOps.maximum(a, b));
 609                     cc.mapValue(op.result(), result);
 610                 } else if (name.equals("min")) {
 611                     Value a = cc.getValue(op.operands().get(0));
 612                     Value b = cc.getValue(op.operands().get(1));
 613 
 614                     Op.Result result = kblock.op(ArithMathOps.minimum(a, b));
 615                     cc.mapValue(op.result(), result);
 616                 }
 617             }
 618             case InvokeOp iop when iop.invokeDescriptor().refType().equals(TYPE_Tensor) -> {
 619                 if (iop.invokeDescriptor().name().equals("type")) {
 620                     // Replace with constant operation to produce tensor type.
 621                     // Result may be used, but transitively it will be removed due to no uses
 622                     // contributing to the computation
 623                     Value a = op.operands().get(0);
 624                     TensorType aType = (TensorType) valueTypeMap.get(a);
 625                     Op.Result result = kblock.op(CoreOp.constant(iop.resultType(), aType));
 626                     cc.mapValue(op.result(), result);
 627                     valueTypeMap.put(result, aType);
 628                 }
 629                 // Remove
 630             }
 631             case InvokeOp iop when iop.invokeDescriptor().refType().equals(TYPE_Triton) -> {
 632                 Value result = tbi.build(op, iop.invokeDescriptor().name(), valueTypeMap);
 633                 if (result != null) {
 634                     cc.mapValue(op.result(), result);
 635                 }
 636             }
 637             case InvokeOp iop when iop.invokeDescriptor().refType().equals(TYPE_Triton_Test) -> {
 638                 Value result = tbi.build(op, iop.invokeDescriptor().name(), valueTypeMap);
 639                 if (result != null) {
 640                     cc.mapValue(op.result(), result);
 641                 }
 642             }
 643             case ExtendedOp.JavaForOp fop -> {
 644                 transformToSCFFor(cc, kblock, fop, valueTypeMap, opData, fsymTable);
 645             }
 646             case ReturnOp rop -> {
 647                 if (rop.operands().isEmpty()) {
 648                     kblock.op(TritonOps.return_());
 649                 } else {
 650                     kblock.op(TritonOps.return_(
 651                             cc.getValue(rop.returnValue())));
 652                 }
 653             }
 654             default -> kblock.op(op);
 655         }
 656         return kblock;
 657     }
 658 
 659     static void transformToSCFFor(CopyContext cc, Block.Builder kblock, ExtendedOp.JavaForOp fop,
 660                                   Map<Value, TypeElement> valueTypeMap, Map<Op, Object> opData,
 661                                   Map<String, TritonOps.FuncOp> fsymTable) {
 662         Body body = fop.loopBody();
 663 
 664         // Hoist expressions for start, end, and step
 665         SimpleCountedForLoopInfo li = (SimpleCountedForLoopInfo) opData.get(fop);
 666         Value start = null;
 667         for (Op o : li.startExpression()) {
 668             transformToTritonOperation(kblock, o, valueTypeMap, opData, fsymTable);
 669             start = cc.getValue(o.result());
 670         }
 671         Value end = null;
 672         for (Op o : li.endExpression()) {
 673             transformToTritonOperation(kblock, o, valueTypeMap, opData, fsymTable);
 674             end = cc.getValue(o.result());
 675         }
 676         Value step = null;
 677         for (Op o : li.stepExpression()) {
 678             transformToTritonOperation(kblock, o, valueTypeMap, opData, fsymTable);
 679             step = cc.getValue(o.result());
 680         }
 681 
 682         // Obtain captured vars
 683         // true == stores
 684         // false == loads only
 685         Map<Boolean, Set<Value>> capturedVars = capturedVars(body);
 686         Set<Value> capturedAndStoredVars = capturedVars.get(true);
 687 
 688         // Get load values
 689         // Loaded values are hoisted out of the loop body
 690         Map<Value, Value> loadValues = new HashMap<>();
 691         for (Value v : capturedVars.get(false)) {
 692             Value load = kblock.op(varLoad(cc.getValue(v)));
 693             valueTypeMap.put(load, valueTypeMap.get(v));
 694             loadValues.put(v, load);
 695         }
 696 
 697         // Get iteration values -- represented by captured vars that are stored to in the loop
 698         // The SCF for operation returns the iteration values of the last loop iteration, which
 699         // are then to be stored to the iteration variables
 700         List<Value> iterValues = new ArrayList<>();
 701         for (Value v : capturedAndStoredVars) {
 702             iterValues.add(kblock.op(varLoad(cc.getValue(v))));
 703         }
 704 
 705         // @@@ Build in java code model, then transform?
 706         SCFOps.ForOp scffor = SCFOps.for_(kblock.parentBody(), start, end, step, iterValues)
 707                 // Ensure existing context is used
 708                 .body(CopyContext.create(cc), builder -> {
 709                     // Create index var initialized from entry block parameter
 710                     Value index = builder.parameters().get(0);
 711                     valueTypeMap.put(index, JavaType.INT);
 712                     Value varIndex = builder.op(var("index", index));
 713                     valueTypeMap.put(varIndex, JavaType.INT);
 714                     builder.context().mapValue(body.entryBlock().parameters().get(0), varIndex);
 715 
 716                     // Create iter vars initialized from entry block parameters
 717                     int pi = 1;
 718                     for (Value v : capturedAndStoredVars) {
 719                         TypeElement type = valueTypeMap.get(v);
 720                         Value iter = builder.parameters().get(pi++);
 721                         valueTypeMap.put(iter, type);
 722                         Value varIter = builder.op(var(Integer.toString(pi), iter));
 723                         valueTypeMap.put(varIter, type);
 724                         builder.context().mapValue(v, varIter);
 725                     }
 726 
 727                     // Transform the Java for body into the SCF for body
 728                     builder.transformBody(body, List.of(), (block, op) -> {
 729                         // Yield iter values
 730                         if (op instanceof ExtendedOp.JavaContinueOp) {
 731                             // Replace with yield of loaded vars
 732                             List<Value> yieldValues = new ArrayList<>();
 733                             for (Value value : capturedAndStoredVars) {
 734                                 Value varIter = block.context().getValue(value);
 735                                 Value v = block.op(varLoad(varIter));
 736                                 yieldValues.add(v);
 737                             }
 738                             block.op(SCFOps.yield_(yieldValues));
 739                         } else if (op instanceof VarAccessOp.VarLoadOp) {
 740                             // Replace with value loaded immediately before loop
 741                             Value v = op.operands().get(0);
 742                             if (capturedVars.get(false).contains(v)) {
 743                                 block.context().mapValue(op.result(), loadValues.get(v));
 744                             } else {
 745                                 block.op(op);
 746                             }
 747                         } else {
 748                             block = transformToTritonOperation(block, op, valueTypeMap, opData, fsymTable);
 749                         }
 750                         return block;
 751                     });
 752                 });
 753         Op.Result forResult = kblock.op(scffor);
 754 
 755         // Assign back result to iter vars
 756         if (capturedAndStoredVars.size() == 1) {
 757             for (Value v : capturedAndStoredVars) {
 758                 kblock.op(varStore(cc.getValue(v), forResult));
 759             }
 760         } else {
 761             int i = 0;
 762             for (Value v : capturedAndStoredVars) {
 763                 kblock.op(varStore(cc.getValue(v),
 764                         kblock.op(tupleLoad(forResult, i++))));
 765             }
 766         }
 767     }
 768 
 769     static Map<Boolean, Set<Value>> capturedVars(Body body) {
 770         Map<Boolean, Set<Value>> capturedValues = new HashMap<>();
 771         capturedValues.put(false, new LinkedHashSet<>());
 772         capturedValues.put(true, new LinkedHashSet<>());
 773 
 774         capturedVars(capturedValues, new ArrayDeque<>(), body);
 775         return capturedValues;
 776     }
 777 
 778     static void capturedVars(Map<Boolean, Set<Value>> capturedVars, Deque<Body> bodyStack, Body body) {
 779         bodyStack.push(body);
 780 
 781         for (Block b : body.blocks()) {
 782             for (Op op : b.ops()) {
 783                 // @@@ Nested bodies
 784                 if (!op.bodies().isEmpty()) {
 785                     throw new IllegalStateException();
 786                 }
 787 //                for (Body childBody : op.bodies()) {
 788 //                    capturedAndUpdatedVars(capturedValues, bodyStack, childBody);
 789 //                }
 790 
 791                 if (op instanceof VarAccessOp) {
 792                     Value v = op.operands().get(0);
 793                     if (!bodyStack.contains(v.declaringBlock().parentBody())) {
 794                         if (op instanceof VarAccessOp.VarStoreOp) {
 795                             capturedVars.get(true).add(v);
 796                             capturedVars.get(false).remove(v);
 797                         } else if (!capturedVars.get(true).contains(v)) {
 798                             capturedVars.get(false).add(v);
 799                         }
 800                     }
 801                 }
 802             }
 803         }
 804 
 805         bodyStack.pop();
 806     }
 807 
 808     public static final ScopedValue<Boolean> SV_SSA = ScopedValue.newInstance();
 809 
 810     static TritonOps.FuncOp cleanup(TritonOps.FuncOp f) {
 811         // Remove var ops
 812         boolean doSSA = SV_SSA.isBound() ? SV_SSA.get() : true;
 813         if (doSSA) {
 814             f = SSA.transform(f);
 815         }
 816         // Remove unused ops
 817         f = f.transform((fblock, op) -> {
 818             if (op instanceof Op.Pure && op.result().uses().isEmpty()) {
 819                 return fblock;
 820             } else if (op instanceof VarAccessOp.VarLoadOp && op.result().uses().isEmpty()) {
 821                 return fblock;
 822             }
 823 
 824             fblock.op(op);
 825             return fblock;
 826         });
 827         return f;
 828     }
 829 
 830     static class TritonBuilderInterpreter {
 831         final Map<String, TritonOps.FuncOp> fsymTable;
 832         final Block.Builder block;
 833 
 834         TritonBuilderInterpreter(Map<String, TritonOps.FuncOp> fsymTable, Block.Builder block) {
 835             this.fsymTable = fsymTable;
 836             this.block = block;
 837         }
 838 
 839         Value build(Op op, String name, Map<Value, TypeElement> valueTypeMap) {
 840             // Obtain associated type-based method
 841             MethodHandle mh;
 842             try {
 843                 Optional<Method> om = Stream.of(TritonBuilderInterpreter.class.getDeclaredMethods())
 844                         .filter(m -> m.getName().equals(name))
 845                         .findFirst();
 846                 mh = MethodHandles.lookup().unreflect(
 847                         om.orElseThrow(() -> new NoSuchMethodException(name)));
 848             } catch (ReflectiveOperationException e) {
 849                 throw new IllegalStateException(e);
 850             }
 851 
 852             List<Object> iArgs = new ArrayList<>();
 853             iArgs.add(this);
 854             iArgs.add(valueTypeMap.get(op.result()));
 855             iArgs.add(op.result());
 856             for (Value o : op.operands()) {
 857                 iArgs.add(valueTypeMap.get(o));
 858                 iArgs.add(o);
 859             }
 860             try {
 861                 return (Value) mh.invokeWithArguments(iArgs.toArray(Object[]::new));
 862             } catch (Throwable e) {
 863                 throw new IllegalStateException(e);
 864             }
 865         }
 866 
 867 
 868         public Value programId(TypeElement rType, Op.Result r,
 869                                ConstantType axisType, Value axis) {
 870             return block.op(TritonOps.getProgramId(
 871                     (int) axisType.value()));
 872         }
 873 
 874         public Value arange(TensorType rType, Op.Result r,
 875                             ConstantType startType, Value start,
 876                             ConstantType endType, Value end) {
 877             return block.op(TritonOps.makeRange(
 878                     (int) startType.value(),
 879                     (int) endType.value()));
 880         }
 881 
 882         public Value expand(TensorType rType, Op.Result r,
 883                             TensorType aType, Value a,
 884                             ConstantType axisType, Value axis) {
 885             return block.op(TritonOps.expand(
 886                     (int) axisType.value(),
 887                     rType,
 888                     block.context().getValue(a)));
 889         }
 890 
 891         public Value zeros(TensorType rType, Op.Result r,
 892                            ConstantType aType, Value a,
 893                            Object... constantsAndValues) {
 894             Object zero;
 895             try {
 896                 JavaType zeroType = (JavaType) aType.value();
 897                 zero = MethodHandles.zero((Class<?>) zeroType.resolve(MethodHandles.lookup())).invoke();
 898             } catch (Throwable e) {
 899                 throw new RuntimeException(e);
 900             }
 901             return block.op(ArithMathOps.constant(rType, zero));
 902         }
 903 
 904         public Value load(TensorType rType, Op.Result r,
 905                           TensorType ptrType, Value ptr,
 906                           TensorType maskType, Value mask) {
 907             broadcastConversionRight(ptrType, maskType, mask);
 908             return block.op(TritonOps.load(
 909                     rType,
 910                     block.context().getValue(ptr),
 911                     block.context().getValue(mask)));
 912         }
 913 
 914         public Value store(TensorType rType, Op.Result r,
 915                            TensorType ptrType, Value ptr,
 916                            TensorType valueType, Value value,
 917                            TensorType maskType, Value mask) {
 918             broadcastConversionRight(ptrType, valueType, value);
 919             broadcastConversionRight(ptrType, maskType, mask);
 920             return block.op(TritonOps.store(
 921                     block.context().getValue(ptr),
 922                     block.context().getValue(value),
 923                     block.context().getValue(mask)));
 924         }
 925 
 926         public Value broadcast(TensorType rType, Op.Result r,
 927                                TypeElement oType, Value o,
 928                                TensorType tensorTypeType, Value tensorType) {
 929             // @@@ tt.splat with scalar operand, tt.broadcast with tensor operand
 930             if (oType instanceof TensorType) {
 931                 return block.op(TritonOps.broadcast(
 932                         rType,
 933                         block.context().getValue(o)));
 934             } else {
 935                 return block.op(TritonOps.splat(
 936                         rType,
 937                         block.context().getValue(o)));
 938             }
 939         }
 940 
 941         public Value joinShape(TensorType rType, Op.Result r,
 942                                TensorType aType, Value a,
 943                                TensorType bType, Value b) {
 944             // Replace with constant operation to produce tensor type.
 945             // Result may be used, but transitively it will be removed due to no uses
 946             // contributing to the computation
 947             return block.op(CoreOp.constant(JavaType.type(TensorType.class), r.type()));
 948         }
 949 
 950 
 951         public Value add(TypeElement rType, Op.Result r,
 952                          TypeElement aType, Value a,
 953                          TypeElement bType, Value b) {
 954             broadcastConversion(rType, aType, a, bType, b);
 955             a = block.context().getValue(a);
 956             b = block.context().getValue(b);
 957 
 958             if (rType instanceof PtrType ||
 959                     rType instanceof TensorType t && t.eType() instanceof PtrType) {
 960                 return block.op(TritonOps.addptr(a, b));
 961             } else {
 962                 return block.op(ArithMathOps.add(a, b));
 963             }
 964         }
 965 
 966         public Value sub(TypeElement rType, Op.Result r,
 967                          TypeElement aType, Value a,
 968                          TypeElement bType, Value b) {
 969             broadcastConversion(rType, aType, a, bType, b);
 970             a = block.context().getValue(a);
 971             b = block.context().getValue(b);
 972 
 973             return block.op(ArithMathOps.sub(a, b));
 974         }
 975 
 976         public Value mul(TypeElement rType, Op.Result r,
 977                          TypeElement aType, Value a,
 978                          TypeElement bType, Value b) {
 979             broadcastConversion(rType, aType, a, bType, b);
 980             a = block.context().getValue(a);
 981             b = block.context().getValue(b);
 982 
 983             return block.op(ArithMathOps.mul(a, b));
 984         }
 985 
 986         public Value div(TypeElement rType, Op.Result r,
 987                          TypeElement aType, Value a,
 988                          TypeElement bType, Value b) {
 989             broadcastConversion(rType, aType, a, bType, b);
 990             a = block.context().getValue(a);
 991             b = block.context().getValue(b);
 992 
 993             return block.op(ArithMathOps.div(a, b));
 994         }
 995 
 996         public Value mod(TypeElement rType, Op.Result r,
 997                          TypeElement aType, Value a,
 998                          TypeElement bType, Value b) {
 999             broadcastConversion(rType, aType, a, bType, b);
1000             a = block.context().getValue(a);
1001             b = block.context().getValue(b);
1002 
1003             return block.op(ArithMathOps.rem(a, b));
1004         }
1005 
1006         public Value and(TypeElement rType, Op.Result r,
1007                          TypeElement aType, Value a,
1008                          TypeElement bType, Value b) {
1009             broadcastConversion(rType, aType, a, bType, b);
1010             a = block.context().getValue(a);
1011             b = block.context().getValue(b);
1012 
1013             return block.op(ArithMathOps.and(a, b));
1014         }
1015 
1016         public Value dot(TensorType rType, Op.Result r,
1017                          TypeElement aType, Value a,
1018                          TypeElement bType, Value b) {
1019             a = block.context().getValue(a);
1020             b = block.context().getValue(b);
1021 
1022             return block.op(TritonOps.dot(rType, a, b));
1023         }
1024 
1025         public Value cdiv(TypeElement rType, Op.Result r,
1026                           TypeElement aType, Value a,
1027                           TypeElement bType, Value b) {
1028             a = block.context().getValue(a);
1029             b = block.context().getValue(b);
1030 
1031             TritonOps.FuncOp cdiv = tritonFunction(Functions.getJavaCodeModel("cdiv"),
1032                     rType, List.of(aType, bType),
1033                     fsymTable);
1034             // @@@ Generalize
1035             List<Value> args = new ArrayList<>();
1036             if (!(aType instanceof ConstantType)) {
1037                 args.add(a);
1038             }
1039             if (!(bType instanceof ConstantType)) {
1040                 args.add(b);
1041             }
1042             return block.op(TritonOps.call(cdiv, args));
1043         }
1044 
1045         public Value conv(TypeElement rType, Op.Result r,
1046                           ConstantType tType, Value t,
1047                           TypeElement aType, Value a) {
1048             a = block.context().getValue(a);
1049 
1050             TypeElement rScalarType;
1051             TypeElement aScalarType;
1052             if (rType instanceof TensorType rTensorType && aType instanceof TensorType aTensorType) {
1053                 rScalarType = rTensorType.eType();
1054                 aScalarType = aTensorType.eType();
1055             } else {
1056                 rScalarType = rType;
1057                 aScalarType = aType;
1058             }
1059 
1060             if (rScalarType.equals(Float16.FLOAT_16_TYPE) && aScalarType.equals(JavaType.FLOAT)) {
1061                 return block.op(ArithMathOps.trunc(rType, a));
1062             } else if (rType.equals(aType)) {
1063                 return a;
1064             } else {
1065                 throw new IllegalStateException();
1066             }
1067         }
1068 
1069         public Value exp(TritonType rType, Op.Result r,
1070                          TritonType aType, Value a) {
1071             return block.op(ArithMathOps.exp(
1072                     block.context().getValue(a)));
1073         }
1074 
1075         public Value compare(TensorType rType, Op.Result r,
1076                              TypeElement aType, Value a,
1077                              TypeElement bType, Value b,
1078                              ConstantType compareType, Value compare) {
1079             Triton.CompareKind ck = (Triton.CompareKind) compareType.value();
1080 
1081             ArithMathOps.CompareOp.CompareKind ack = switch (ck) {
1082                 case LessThan -> ArithMathOps.CompareOp.CompareKind.slt;
1083                 default -> throw new UnsupportedOperationException("Unsupported comparison: " + ck);
1084             };
1085 
1086             broadcastConversion(rType, aType, a, bType, b);
1087             a = block.context().getValue(a);
1088             b = block.context().getValue(b);
1089 
1090             return block.op(ArithMathOps.cmp(ack, a, b));
1091         }
1092 
1093 
1094         public Value max(TypeElement rType, Op.Result r,
1095                          TensorType xType, Value x,
1096                          ConstantType axisType, Value axis) {
1097             TritonOps.FuncOp f = tritonFunction(Functions.getJavaCodeModel("max"),
1098                     rType, List.of(rType, rType), fsymTable);
1099             return reduce(rType, r, xType, x, axisType, axis, f);
1100         }
1101 
1102         public Value sum(TypeElement rType, Op.Result r,
1103                          TensorType xType, Value x,
1104                          ConstantType axisType, Value axis) {
1105             TritonOps.FuncOp f = tritonFunction(Functions.getJavaCodeModel("sum"),
1106                     rType, List.of(rType, rType), fsymTable);
1107             return reduce(rType, r, xType, x, axisType, axis, f);
1108         }
1109 
1110         Value reduce(TypeElement rType, Op.Result r,
1111                      TensorType xType, Value x,
1112                      ConstantType axisType, Value axis,
1113                      TritonOps.FuncOp f) {
1114             int axisConstant = (int) axisType.value();
1115 
1116             String signature = "reduce_" + f.funcName() + "_" + axisConstant;
1117             TritonOps.FuncOp rf = fsymTable.computeIfAbsent(signature,
1118                     s -> reduce(rType, xType, axisConstant, s, f));
1119 
1120             return block.op(TritonOps.call(rf, block.context().getValue(x)));
1121         }
1122 
1123         static TritonOps.FuncOp reduce(TypeElement elementType,
1124                                        TensorType tensorType,
1125                                        int axisConstant,
1126                                        String name, TritonOps.FuncOp scalarFunc) {
1127             return TritonOps.func(name,
1128                             functionType(elementType, tensorType))
1129                     .body(fblock -> {
1130                         TritonOps.ReduceOp reduceOp = TritonOps.reduce(fblock.parentBody(),
1131                                         axisConstant, fblock.parameters().get(0),
1132                                         functionType(elementType, elementType, elementType))
1133                                 .body(rblock -> {
1134                                     Block.Parameter a = rblock.parameters().get(0);
1135                                     Block.Parameter b = rblock.parameters().get(1);
1136                                     Op.Result _r = rblock.op(TritonOps.call(scalarFunc, a, b));
1137                                     rblock.op(TritonOps.reduceReturn(_r));
1138                                 });
1139 
1140                         Op.Result opr = fblock.op(reduceOp);
1141                         fblock.op(TritonOps.return_(opr));
1142                     });
1143         }
1144 
1145         // @@@ Test
1146         public Value consume(TypeElement rType, Op.Result r,
1147                              TypeElement aType, Value a) {
1148             return block.op(TritonTestOps.consume(block.context().getValue(a)));
1149         }
1150 
1151         void broadcastConversion(TypeElement rType,
1152                                  TypeElement aType, Value a,
1153                                  TypeElement bType, Value b) {
1154             Value ma = block.context().getValue(a);
1155             Value mb = block.context().getValue(b);
1156             if (aType instanceof TensorType at && bType instanceof TensorType bTensorType) {
1157                 TensorType rTensorType = (TensorType) rType;
1158                 if (!at.shape().equals(rTensorType.shape())) {
1159                     ma = block.op(TritonOps.broadcast(rTensorType, ma));
1160                 }
1161                 if (!bTensorType.shape().equals(rTensorType.shape())) {
1162                     if (rTensorType.eType() instanceof PtrType) {
1163                         bTensorType = new TensorType(bType, rTensorType.shape());
1164                         mb = block.op(TritonOps.broadcast(bTensorType, mb));
1165                     } else {
1166                         mb = block.op(TritonOps.broadcast(rTensorType, mb));
1167                     }
1168                 }
1169             } else if (aType instanceof TensorType) {
1170                 TensorType rTensorType = (TensorType) rType;
1171                 if (rTensorType.eType() instanceof PtrType) {
1172                     TensorType bTensorType = new TensorType(bType, rTensorType.shape());
1173                     mb = block.op(TritonOps.splat(bTensorType, mb));
1174                 } else {
1175                     mb = block.op(TritonOps.splat(rTensorType, mb));
1176                 }
1177             } else if (bType instanceof TensorType) {
1178                 TensorType rTensorType = (TensorType) rType;
1179                 ma = block.op(TritonOps.splat(rTensorType, ma));
1180             }
1181             block.context().mapValue(a, ma);
1182             block.context().mapValue(b, mb);
1183         }
1184 
1185         void broadcastConversionRight(TypeElement aType,
1186                                       TypeElement bType, Value b) {
1187             Value mb = block.context().getValue(b);
1188             if (aType instanceof TensorType aTensorType && bType instanceof TensorType bTensorType) {
1189                 if (!bTensorType.shape().equals(aTensorType.shape())) {
1190                     if (aTensorType.eType() instanceof PtrType) {
1191                         bTensorType = new TensorType(bTensorType.eType(), aTensorType.shape());
1192                         mb = block.op(TritonOps.broadcast(bTensorType, mb));
1193                     } else {
1194                         mb = block.op(TritonOps.broadcast(aTensorType, mb));
1195                     }
1196                 }
1197             } else if (aType instanceof TensorType rTensorType) {
1198                 if (rTensorType.eType() instanceof PtrType) {
1199                     TensorType bTensorType = new TensorType(bType, rTensorType.shape());
1200                     mb = block.op(TritonOps.splat(bTensorType, mb));
1201                 } else {
1202                     mb = block.op(TritonOps.splat(rTensorType, mb));
1203                 }
1204             }
1205             block.context().mapValue(b, mb);
1206         }
1207     }
1208 
1209     public static <O extends Op & Op.Invokable> void printTypeMap(
1210             O kernel, Map<Value, TypeElement> valueTypeMap) {
1211         AtomicInteger valueId = new AtomicInteger();
1212         Map<Value, Integer> valueIdMap = new LinkedHashMap<>();
1213         kernel.traverse(null, (o, codeElement) -> {
1214             switch (codeElement) {
1215                 case FuncOp _ -> {
1216                     // Ignore
1217                 }
1218                 case Op op when !op.result().type().equals(JavaType.VOID) -> {
1219                     valueIdMap.put(op.result(), valueId.getAndIncrement());
1220                 }
1221                 case Block block -> {
1222                     for (Block.Parameter parameter : block.parameters()) {
1223                         valueIdMap.put(parameter, valueId.getAndIncrement());
1224                     }
1225                 }
1226                 default -> {
1227                 }
1228             }
1229             return null;
1230         });
1231 
1232         valueIdMap.forEach((value, id) -> {
1233             TypeElement type = valueTypeMap.get(value);
1234             if (type != null) {
1235                 System.out.println("%" + id + " : " + value.type() + " -> " + type);
1236             }
1237         });
1238     }
1239 }