1 /*
2 * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation. Oracle designates this
8 * particular file as subject to the "Classpath" exception as provided
9 * by Oracle in the LICENSE file that accompanied this code.
10 *
11 * This code is distributed in the hope that it will be useful, but WITHOUT
12 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
13 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
14 * version 2 for more details (a copy is included in the LICENSE file that
15 * accompanied this code).
16 *
17 * You should have received a copy of the GNU General Public License version
18 * 2 along with this work; if not, write to the Free Software Foundation,
19 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
20 *
21 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
22 * or visit www.oracle.com if you need additional information or have any
23 * questions.
24 */
25
26 package oracle.code.triton;
27
28 import java.lang.constant.ClassDesc;
29 import java.lang.invoke.MethodHandle;
30 import java.lang.invoke.MethodHandles;
31 import java.lang.reflect.Field;
32 import java.lang.reflect.Method;
33 import jdk.incubator.code.*;
34 import jdk.incubator.code.dialect.core.CoreOp;
35 import jdk.incubator.code.dialect.core.SSA;
36 import jdk.incubator.code.dialect.core.VarType;
37 import jdk.incubator.code.dialect.java.JavaOp;
38 import jdk.incubator.code.dialect.java.JavaType;
39 import java.util.*;
40 import java.util.concurrent.atomic.AtomicInteger;
41 import java.util.stream.Stream;
42
43 import static jdk.incubator.code.dialect.core.CoreOp.*;
44 import static jdk.incubator.code.dialect.core.CoreType.functionType;
45 import static jdk.incubator.code.dialect.java.JavaOp.*;
46
47 public final class TritonTransformer {
48 private TritonTransformer() {}
49
50 static final JavaType TYPE_Triton = JavaType.type(Triton.class);
51
52 static final JavaType TYPE_Triton_Test = JavaType.type(ClassDesc.of("oracle.code.triton.TritonTest"));
53
54 static final JavaType TYPE_Tensor = JavaType.type(Tensor.class);
55
56 static final JavaType TYPE_J_L_MATH = JavaType.type(Math.class);
57
58 public static <O extends Op & Op.Invokable>
59 TritonOps.ModuleOp tritonModule(O kernel,
60 TypeElement rType,
61 List<? extends TypeElement> argTypes) {
62 Map<String, TritonOps.FuncOp> fsymTable = new LinkedHashMap<>();
63 tritonFunction(kernel, rType, argTypes, fsymTable);
64 return TritonOps.module(fsymTable.values().stream().toList());
65 }
66
67 public static <O extends Op & Op.Invokable>
68 TritonOps.FuncOp tritonFunction(O javaKernel,
69 TypeElement rType,
70 List<? extends TypeElement> argTypes,
71 Map<String, TritonOps.FuncOp> fsymTable) {
72 String name = (javaKernel instanceof FuncOp f) ? f.funcName() : "kernel";
73 String signature = signature(name, rType, argTypes);
74 if (fsymTable.containsKey(signature)) {
75 return fsymTable.get(signature);
76 }
77
78 System.out.println(javaKernel.toText());
79
80 Map<Value, TypeElement> valueTypeMap = new HashMap<>();
81 Map<Op, Object> opData = new HashMap<>();
82 TritonTransformer.typeCheckKernel(javaKernel, argTypes, valueTypeMap, opData);
83 TritonTransformer.printTypeMap(javaKernel, valueTypeMap);
84
85 return TritonTransformer.transformToTritonFunction(javaKernel, signature,
86 rType, valueTypeMap, opData,
87 fsymTable);
88 }
89
90 static String signature(String name, TypeElement rType, List<? extends TypeElement> argTypes) {
91 StringBuilder sb = new StringBuilder(name);
92
93 for (TypeElement argType : argTypes) {
94 sb.append("_");
95 if (argType instanceof ConstantType ct) {
96 sb.append(ct.value());
97 } else {
98 sb.append(argType);
99 }
100 }
101 sb.append("_");
102 sb.append(rType);
103 return sb.toString();
104 }
105
106 public static <O extends Op & Op.Invokable> void typeCheckKernel(
107 O kernel, List<? extends TypeElement> argTypes,
108 Map<Value, TypeElement> valueTypeMap, Map<Op, Object> opData) {
109 kernel.elements().forEach(e -> {
110 if (!(e instanceof Op op)) {
111 return;
112 }
113 switch (op) {
114 case Op.Invokable fop -> {
115 List<Block.Parameter> parameters = fop.body().entryBlock().parameters();
116 for (int i = 0; i < parameters.size(); i++) {
117 valueTypeMap.put(parameters.get(i), argTypes.get(i));
118 }
119 }
120 case VarOp _, VarAccessOp.VarLoadOp _ -> {
121 Value init = op.operands().get(0);
122 valueTypeMap.put(op.result(), valueTypeMap.get(init));
123 }
124 case VarAccessOp.VarStoreOp _ -> {
125 Value var = op.operands().get(0);
126 TypeElement varType = valueTypeMap.get(var);
127 Value v = op.operands().get(1);
128 TypeElement vType = valueTypeMap.get(v);
129 if (!varType.equals(vType)) {
130 throw new IllegalStateException("Storing to variable with different type: "
131 + varType + " <- " + vType);
132 }
133
134 valueTypeMap.put(op.result(), valueTypeMap.get(var));
135 }
136 case ConstantOp cop -> {
137 valueTypeMap.put(op.result(), new ConstantType(op.result().type(), cop.value()));
138 }
139 case ArithmeticOperation _ -> {
140 TypeElement t = checkWithTypeInterpreter(op, op.externalizeOpName(), valueTypeMap);
141 valueTypeMap.put(op.result(), t);
142 }
143 case FieldAccessOp.FieldLoadOp flop -> {
144 if (!flop.operands().isEmpty()) {
145 throw new IllegalStateException("Unsupported field load: " + flop.fieldDescriptor());
146 }
147
148 Field f;
149 try {
150 f = flop.fieldDescriptor().resolveToField(MethodHandles.lookup());
151 } catch (ReflectiveOperationException ex) {
152 throw new IllegalStateException("Unsupported field load: " + flop.fieldDescriptor(), ex);
153 }
154 Object value;
155 try {
156 value = f.get(null);
157 } catch (IllegalAccessException ex) {
158 throw new IllegalStateException("Unsupported field load: " + f, ex);
159 }
160 valueTypeMap.put(op.result(), new ConstantType(JavaType.type(f.getType()), value));
161 }
162 case InvokeOp iop when iop.invokeDescriptor().refType().equals(JavaType.J_L_INTEGER) -> {
163 // Box
164 if (iop.invokeDescriptor().name().equals("valueOf")) {
165 Value a = op.operands().get(0);
166 valueTypeMap.put(op.result(), valueTypeMap.get(a));
167 } else {
168 throw new UnsupportedOperationException("Unsupported invocation on Integer: " + iop.invokeDescriptor());
169 }
170 }
171 case InvokeOp iop when iop.invokeDescriptor().refType().equals(TYPE_J_L_MATH) -> {
172 String name = iop.invokeDescriptor().name();
173 if (name.equals("max") || name.equals("min")) {
174 Value a = op.operands().get(0);
175 valueTypeMap.put(op.result(), valueTypeMap.get(a));
176 } else {
177 throw new UnsupportedOperationException("Unsupported invocation on Math: " + iop.invokeDescriptor());
178 }
179 }
180 case InvokeOp iop when iop.invokeDescriptor().refType().equals(TYPE_Tensor) -> {
181 if (iop.invokeDescriptor().name().equals("type")) {
182 Value a = op.operands().get(0);
183 valueTypeMap.put(op.result(), valueTypeMap.get(a));
184 } else {
185 throw new UnsupportedOperationException("Unsupported invocation on Tensor: " + iop.invokeDescriptor());
186 }
187 }
188 case InvokeOp iop when iop.invokeDescriptor().refType().equals(TYPE_Triton) -> {
189 TypeElement t = checkWithTypeInterpreter(op, iop.invokeDescriptor().name(), valueTypeMap);
190 valueTypeMap.put(op.result(), t);
191 }
192 case InvokeOp iop when iop.invokeDescriptor().refType().equals(TYPE_Triton_Test) -> {
193 TypeElement t = checkWithTypeInterpreter(op, iop.invokeDescriptor().name(), valueTypeMap);
194 valueTypeMap.put(op.result(), t);
195 }
196 case JavaOp.ForOp fop -> {
197 SimpleCountedForLoopInfo li = new SimpleCountedForLoopInfo(fop);
198 opData.put(fop, li);
199
200 TypeElement type = fop.init().yieldType();
201 if (type instanceof VarType vt && vt.valueType().equals(JavaType.INT)) {
202 for (Body b : List.of(fop.cond(), fop.update(), fop.loopBody())) {
203 valueTypeMap.put(b.entryBlock().parameters().get(0), JavaType.INT);
204 }
205 } else {
206 throw new IllegalStateException();
207 }
208 }
209 case TestOperation _ -> {
210 }
211 case JavaOp.ContinueOp _ -> {
212 }
213 case CoreOp.YieldOp _ -> {
214 }
215 case ReturnOp _ -> {
216 }
217 default -> throw new UnsupportedOperationException("Unsupported operation: " + op);
218 }
219 });
220 }
221
222 static TypeElement checkWithTypeInterpreter(Op op, String name, Map<Value, TypeElement> valueTypeMap) {
223 // Obtain associated type-based method
224 MethodHandle mh;
225 try {
226 Optional<Method> om = Stream.of(TritonTypeInterpreter.class.getDeclaredMethods())
227 .filter(m -> m.getName().equals(name))
228 .filter(m -> m.isVarArgs() ? m.getParameterCount() <= op.operands().size() : m.getParameterCount() == op.operands().size())
229 .findFirst();
230 mh = MethodHandles.lookup().unreflect(
231 om.orElseThrow(() -> new NoSuchMethodException(name)));
232 } catch (ReflectiveOperationException e) {
233 throw new IllegalStateException(name, e);
234 }
235
236 // Invoke with the values' types
237 List<TypeElement> operandTypes = op.operands().stream().map(valueTypeMap::get).toList();
238 try {
239 return (TypeElement) mh.invokeWithArguments(operandTypes.toArray(Object[]::new));
240 } catch (Throwable e) {
241 throw new IllegalStateException(mh.toString(), e);
242 }
243 }
244
245 // @@@ type check tensor shapes
246 static class TritonTypeInterpreter {
247 private TritonTypeInterpreter() {
248 }
249
250 // int programId(@Constant int axis) {
251 public static JavaType programId(ConstantType axis) {
252 assert axis.cType().equals(JavaType.INT);
253 int axisValue = (int) axis.value();
254 if (axisValue < 0 || axisValue > 3) {
255 throw new IllegalStateException();
256 }
257
258 return JavaType.INT;
259 }
260
261 // Tensor arange(@Constant int start, @Constant int end)
262 public static TensorType arange(ConstantType start, ConstantType end) {
263 assert start.cType().equals(JavaType.INT);
264 assert end.cType().equals(JavaType.INT);
265
266 int startValue = (int) start.value();
267 int endValue = (int) end.value();
268
269 return new TensorType(JavaType.INT, List.of(endValue - startValue));
270 }
271
272 // Tensor expand(Tensor a, int axis) {
273 public static TensorType expand(TensorType a, ConstantType axis) {
274 assert axis.cType().equals(JavaType.INT);
275 int axisValue = (int) axis.value();
276
277 List<Integer> s = new ArrayList<>(a.shape());
278 if (axisValue < s.size()) {
279 s.add(axisValue, 1);
280 } else {
281 for (int i = 0; i <= (axisValue - s.size()); i++) {
282 s.add(1);
283 }
284 }
285 return new TensorType(a.eType(), s);
286 }
287
288 // Tensor load(Tensor ptr, Tensor mask)
289 public static TensorType load(TensorType ptr, TensorType mask) {
290 checkTensorShape(ptr, mask);
291 if (ptr.eType() instanceof PtrType eptr) {
292 return new TensorType(eptr.rType(), ptr.shape());
293 }
294
295 throw new IllegalStateException();
296 }
297
298 // Tensor load(Tensor ptr, Tensor mask, ConstantType other) {
299 public static TensorType load(TensorType ptr, TensorType mask, ConstantType other) {
300 checkTensorShape(ptr, mask);
301 if (ptr.eType() instanceof PtrType eptr) {
302 return new TensorType(eptr.rType(), ptr.shape());
303 }
304
305 throw new IllegalStateException();
306 }
307
308 // void store(Tensor ptr, Tensor value, Tensor mask)
309 public static void store(TensorType ptr, TensorType value, TensorType mask) {
310 if (!(ptr.eType() instanceof PtrType)) {
311 throw new IllegalStateException();
312 }
313 }
314
315 // Tensor zeros(TensorType type)
316 public static TensorType zeros(ConstantType eType, ConstantType... cShape) {
317 List<Integer> shape = Stream.of(cShape).map(s -> (int) s.value()).toList();
318 return new TensorType((TypeElement) eType.value(), shape);
319 }
320
321 // Tensor broadcast(Object o, TensorType type)
322 public static TensorType broadcast(TypeElement o, TensorType type) {
323 if (o instanceof TensorType ot) {
324 // @@@
325 if (ot.shape().size() != type.shape().size()) {
326 throw new IllegalStateException();
327 }
328 o = ot.eType();
329 } if (o instanceof ConstantType oc) {
330 o = oc.cType();
331 }
332 return new TensorType(o, type.shape());
333 }
334
335 public static TensorType joinShape(TensorType a, TensorType b) {
336 return checkTensorTypes(a, b);
337 }
338
339 // Tensor add(Number a, Number b)
340 // Ptr add(Ptr a, int offset)
341 public static TypeElement add(TypeElement a, TypeElement b) {
342 // @@@ Pass additional argument for checking ptr
343 return binary(a, b);
344 }
345
346 public static TypeElement sub(TypeElement a, TypeElement b) {
347 return binary(a, b);
348 }
349
350 public static TypeElement mul(TypeElement a, TypeElement b) {
351 return binary(a, b);
352 }
353
354 public static TypeElement div(TypeElement a, TypeElement b) {
355 return binary(a, b);
356 }
357
358 public static TypeElement mod(TypeElement a, TypeElement b) {
359 return binary(a, b);
360 }
361
362 public static TypeElement and(TypeElement a, TypeElement b) {
363 return binary(a, b);
364 }
365
366 public static TypeElement cdiv(TypeElement a, TypeElement b) {
367 a = reduceScalarType(a);
368 b = reduceScalarType(b);
369 if (!a.equals(JavaType.INT) && !b.equals(JavaType.INT)) {
370 throw new IllegalStateException();
371 }
372 return a;
373 }
374
375 // Number conv(Type t, Number a) {
376 public static TypeElement conv(ConstantType eType, TypeElement a) {
377 return convTypes(eType, a);
378 }
379
380 public static TypeElement convTypes(ConstantType eType, TypeElement a) {
381 if (a instanceof TensorType tb) {
382 TypeElement e = convScalarTypes(eType, tb.eType());
383 return new TensorType(e, tb.shape());
384 } else {
385 return convScalarTypes(eType, a);
386 }
387 }
388
389 public static TypeElement convScalarTypes(ConstantType eType, TypeElement a) {
390 TypeElement t = (TypeElement) eType.value();
391 if (t.equals(Float16.FLOAT_16_TYPE) && a.equals(JavaType.FLOAT)) {
392 return Float16.FLOAT_16_TYPE;
393 } else if (t.equals(a)) {
394 return t;
395 } else {
396 // @@@ Conversions;
397 throw new IllegalStateException();
398 }
399 }
400
401 // Tensor exp(Tensor a)
402 public static TypeElement exp(TypeElement a) {
403 return unary(a);
404 }
405
406 static TypeElement unary(TypeElement a) {
407 return a;
408 }
409
410 // Tensor compare(Number a, Number b, @Constant CompareKind ck) {
411 public static TypeElement compare(TypeElement a, TypeElement b, ConstantType kind) {
412 assert kind.cType().equals(JavaType.type(Triton.CompareKind.class));
413
414 TypeElement t = binary(a, b);
415 if (t instanceof TensorType tt) {
416 return new TensorType(JavaType.BOOLEAN, tt.shape());
417 } else {
418 return t;
419 }
420 }
421
422 // Tensor dot(Tensor a, Tensor b)
423 public static TensorType dot(TensorType a, TensorType b) {
424 if (a.shape().size() != 2 || b.shape().size() != 2) {
425 throw new IllegalStateException();
426 }
427
428 if (!a.shape().get(1).equals(b.shape().get(0))) {
429 throw new IllegalStateException();
430 }
431
432 if (a.eType() != b.eType()) {
433 // @@@ Conversion, type checking
434 throw new IllegalStateException();
435 }
436
437 // Computed result is tensor of floats, regardless of inputs
438 return new TensorType(JavaType.FLOAT, List.of(a.shape().get(0), b.shape().get(1)));
439 }
440
441
442 // Tensor max(Tensor a, @Constant int axis) {
443 public static TypeElement max(TensorType a, ConstantType axis) {
444 return reduce(a, axis);
445 }
446
447 // Tensor sum(Tensor a, @Constant int axis) {
448 public static TypeElement sum(TensorType a, ConstantType axis) {
449 return reduce(a, axis);
450 }
451
452 static TypeElement reduce(TensorType a, ConstantType axis) {
453 assert axis.cType().equals(JavaType.INT);
454 int axisValue = (int) axis.value();
455 if (axisValue < 0 || axisValue > 3) {
456 throw new IllegalStateException();
457 }
458
459 List<Integer> reduceShape = new ArrayList<>();
460 for (int i = 0; i < a.shape().size(); i++) {
461 if (i != axisValue) {
462 reduceShape.add(a.shape().get(i));
463 } else {
464 reduceShape.add(1);
465 }
466 }
467
468 if (reduceShape.size() == 1 && reduceShape.getFirst() == 1) {
469 return a.eType();
470 } else {
471 return new TensorType(a.eType(), reduceShape);
472 }
473 }
474
475 // @@@ Test
476 public static void consume(TypeElement a) {
477 }
478
479
480 static TypeElement binary(TypeElement a, TypeElement b) {
481 if (a instanceof TensorType ta && b instanceof TensorType tb) {
482 return checkTensorTypes(ta, tb);
483 } else if (a instanceof TensorType ta) {
484 return new TensorType(checkScalarTypes(ta.eType(), b), ta.shape());
485 } else if (b instanceof TensorType tb) {
486 return new TensorType(checkScalarTypes(a, tb.eType()), tb.shape());
487 } else {
488 return checkScalarTypes(a, b);
489 }
490 }
491
492 static TensorType checkTensorTypes(TensorType a, TensorType b) {
493 List<Integer> s = checkTensorShape(a, b);
494 TypeElement e = checkScalarTypes(a.eType(), b.eType());
495 return new TensorType(e, s);
496 }
497
498 static List<Integer> checkTensorShape(TensorType a, TensorType b) {
499 if (a.shape().size() != b.shape().size()) {
500 // Shape mismatch
501 throw new IllegalStateException();
502 }
503
504 List<Integer> s = new ArrayList<>();
505 for (int i = 0; i < a.shape().size(); i++) {
506 int ad = a.shape().get(i);
507 int bd = b.shape().get(i);
508
509 // Expand dimensions
510 int d;
511 if (ad == bd) {
512 d = ad;
513 } else {
514 if (ad != 1 && bd == 1) {
515 d = ad;
516 } else if (ad == 1) {
517 d = bd;
518 } else {
519 // Shape mismatch
520 throw new IllegalStateException();
521 }
522 }
523
524 s.add(d);
525 }
526
527 return s;
528 }
529
530 static TypeElement checkScalarTypes(TypeElement a, TypeElement b) {
531 // @@@ Optional ptr checking
532 if (a instanceof PtrType) {
533 if (!b.equals(JavaType.INT)) {
534 throw new IllegalStateException();
535 }
536 } else if (b instanceof PtrType) {
537 // Pointer must be first argument
538 throw new IllegalStateException();
539 } else if (a instanceof ConstantType || b instanceof ConstantType) {
540 return checkScalarTypes(reduceScalarType(a), reduceScalarType(b));
541 } else if (!a.equals(b)) {
542 // @@@ Conversion
543 throw new IllegalStateException();
544 }
545 return a;
546 }
547
548 static TypeElement reduceScalarType(TypeElement a) {
549 return a instanceof ConstantType ct ? ct.cType() : a;
550 }
551 }
552
553 public static <O extends Op & Op.Invokable> TritonOps.FuncOp transformToTritonFunction(
554 O kernel,
555 String signature,
556 TypeElement rType,
557 Map<Value, TypeElement> valueTypeMap, Map<Op, Object> opData,
558 Map<String, TritonOps.FuncOp> fsymTable) {
559 TritonOps.FuncOp ttKernel = TritonOps.func(signature, functionType(rType))
560 .body(fblock -> {
561 // Process kernel parameters
562 List<Value> args = new ArrayList<>();
563 for (Block.Parameter kp : kernel.body().entryBlock().parameters()) {
564 TypeElement type = valueTypeMap.get(kp);
565 if (type instanceof ConstantType ct) {
566 // Constant
567 Op.Result cr = fblock.op(ArithMathOps.constant(
568 ct.cType(), ct.value()));
569 args.add(cr);
570 } else {
571 args.add(fblock.parameter(type));
572 }
573 }
574
575 // Transform kernel body
576 fblock.body(kernel.body(), args, (kblock, op) -> {
577 return transformToTritonOperation(kblock, op, valueTypeMap, opData, fsymTable);
578 });
579 });
580
581 ttKernel = cleanup(ttKernel);
582 fsymTable.put(ttKernel.funcName(), ttKernel);
583 return ttKernel;
584 }
585
586 static Block.Builder transformToTritonOperation(Block.Builder kblock, Op op,
587 Map<Value, TypeElement> valueTypeMap, Map<Op, Object> opData,
588 Map<String, TritonOps.FuncOp> fsymTable) {
589 // @@@ Avoid constructing for each operation -- block builder passed as argument or a scoped value
590 TritonBuilderInterpreter tbi = new TritonBuilderInterpreter(fsymTable, kblock);
591 CodeContext cc = kblock.context();
592 switch (op) {
593 case VarOp varOp -> {
594 // @@@ Cannot copy op because the result type
595 // is derived from init type
596 Value init = cc.getValue(op.operands().get(0));
597 Op.Result r = kblock.op(var(varOp.varName(), init));
598 cc.mapValue(op.result(), r);
599 }
600 case ConstantOp cop -> {
601 TypeElement t = valueTypeMap.get(cop.result());
602 if (t instanceof ConstantType ct) {
603 Op.Result r = kblock.op(ArithMathOps.constant(
604 ct.cType(), ct.value()));
605 cc.mapValue(op.result(), r);
606 } else {
607 kblock.op(op);
608 }
609 }
610 case ArithmeticOperation _ -> {
611 Value result = tbi.build(op, op.externalizeOpName(), valueTypeMap);
612 if (result != null) {
613 cc.mapValue(op.result(), result);
614 }
615 }
616 case InvokeOp iop when iop.invokeDescriptor().refType().equals(JavaType.J_L_INTEGER) -> {
617 // Replace box with its value
618 Value a = cc.getValue(op.operands().get(0));
619 cc.mapValue(op.result(), a);
620 }
621 case InvokeOp iop when iop.invokeDescriptor().refType().equals(TYPE_J_L_MATH) -> {
622 String name = iop.invokeDescriptor().name();
623 if (name.equals("max")) {
624 Value a = cc.getValue(op.operands().get(0));
625 Value b = cc.getValue(op.operands().get(1));
626
627 Op.Result result = kblock.op(ArithMathOps.maximum(a, b));
628 cc.mapValue(op.result(), result);
629 } else if (name.equals("min")) {
630 Value a = cc.getValue(op.operands().get(0));
631 Value b = cc.getValue(op.operands().get(1));
632
633 Op.Result result = kblock.op(ArithMathOps.minimum(a, b));
634 cc.mapValue(op.result(), result);
635 }
636 }
637 case InvokeOp iop when iop.invokeDescriptor().refType().equals(TYPE_Tensor) -> {
638 if (iop.invokeDescriptor().name().equals("type")) {
639 // Replace with constant operation to produce tensor type.
640 // Result may be used, but transitively it will be removed due to no uses
641 // contributing to the computation
642 Value a = op.operands().get(0);
643 TensorType aType = (TensorType) valueTypeMap.get(a);
644 Op.Result result = kblock.op(CoreOp.constant(iop.resultType(), aType));
645 cc.mapValue(op.result(), result);
646 valueTypeMap.put(result, aType);
647 }
648 // Remove
649 }
650 case InvokeOp iop when iop.invokeDescriptor().refType().equals(TYPE_Triton) -> {
651 Value result = tbi.build(op, iop.invokeDescriptor().name(), valueTypeMap);
652 if (result != null) {
653 cc.mapValue(op.result(), result);
654 }
655 }
656 case InvokeOp iop when iop.invokeDescriptor().refType().equals(TYPE_Triton_Test) -> {
657 Value result = tbi.build(op, iop.invokeDescriptor().name(), valueTypeMap);
658 if (result != null) {
659 cc.mapValue(op.result(), result);
660 }
661 }
662 case JavaOp.ForOp fop -> {
663 transformToSCFFor(cc, kblock, fop, valueTypeMap, opData, fsymTable);
664 }
665 case ReturnOp rop -> {
666 if (rop.operands().isEmpty()) {
667 kblock.op(TritonOps.return_());
668 } else {
669 kblock.op(TritonOps.return_(
670 cc.getValue(rop.returnValue())));
671 }
672 }
673 default -> kblock.op(op);
674 }
675 return kblock;
676 }
677
678 static void transformToSCFFor(CodeContext cc, Block.Builder kblock, JavaOp.ForOp fop,
679 Map<Value, TypeElement> valueTypeMap, Map<Op, Object> opData,
680 Map<String, TritonOps.FuncOp> fsymTable) {
681 Body body = fop.loopBody();
682
683 // Hoist expressions for start, end, and step
684 SimpleCountedForLoopInfo li = (SimpleCountedForLoopInfo) opData.get(fop);
685 Value start = null;
686 for (Op o : li.startExpression()) {
687 transformToTritonOperation(kblock, o, valueTypeMap, opData, fsymTable);
688 start = cc.getValue(o.result());
689 }
690 Value end = null;
691 for (Op o : li.endExpression()) {
692 transformToTritonOperation(kblock, o, valueTypeMap, opData, fsymTable);
693 end = cc.getValue(o.result());
694 }
695 Value step = null;
696 for (Op o : li.stepExpression()) {
697 transformToTritonOperation(kblock, o, valueTypeMap, opData, fsymTable);
698 step = cc.getValue(o.result());
699 }
700
701 // Obtain captured vars
702 // true == stores
703 // false == loads only
704 Map<Boolean, Set<Value>> capturedVars = capturedVars(body);
705 Set<Value> capturedAndStoredVars = capturedVars.get(true);
706
707 // Get load values
708 // Loaded values are hoisted out of the loop body
709 Map<Value, Value> loadValues = new HashMap<>();
710 for (Value v : capturedVars.get(false)) {
711 Value load = kblock.op(varLoad(cc.getValue(v)));
712 valueTypeMap.put(load, valueTypeMap.get(v));
713 loadValues.put(v, load);
714 }
715
716 // Get iteration values -- represented by captured vars that are stored to in the loop
717 // The SCF for operation returns the iteration values of the last loop iteration, which
718 // are then to be stored to the iteration variables
719 List<Value> iterValues = new ArrayList<>();
720 for (Value v : capturedAndStoredVars) {
721 iterValues.add(kblock.op(varLoad(cc.getValue(v))));
722 }
723
724 // @@@ Build in java code model, then transform?
725 SCFOps.ForOp scffor = SCFOps.for_(kblock.parentBody(), start, end, step, iterValues)
726 // Ensure existing context is used
727 .body(CodeContext.create(cc), builder -> {
728 // Create index var initialized from entry block parameter
729 Value index = builder.parameters().get(0);
730 valueTypeMap.put(index, JavaType.INT);
731 Value varIndex = builder.op(var("index", index));
732 valueTypeMap.put(varIndex, JavaType.INT);
733 builder.context().mapValue(body.entryBlock().parameters().get(0), varIndex);
734
735 // Create iter vars initialized from entry block parameters
736 int pi = 1;
737 for (Value v : capturedAndStoredVars) {
738 TypeElement type = valueTypeMap.get(v);
739 Value iter = builder.parameters().get(pi++);
740 valueTypeMap.put(iter, type);
741 Value varIter = builder.op(var(Integer.toString(pi), iter));
742 valueTypeMap.put(varIter, type);
743 builder.context().mapValue(v, varIter);
744 }
745
746 // Transform the Java for body into the SCF for body
747 builder.body(body, List.of(), (block, op) -> {
748 // Yield iter values
749 if (op instanceof JavaOp.ContinueOp) {
750 // Replace with yield of loaded vars
751 List<Value> yieldValues = new ArrayList<>();
752 for (Value value : capturedAndStoredVars) {
753 Value varIter = block.context().getValue(value);
754 Value v = block.op(varLoad(varIter));
755 yieldValues.add(v);
756 }
757 block.op(SCFOps.yield_(yieldValues));
758 } else if (op instanceof VarAccessOp.VarLoadOp) {
759 // Replace with value loaded immediately before loop
760 Value v = op.operands().get(0);
761 if (capturedVars.get(false).contains(v)) {
762 block.context().mapValue(op.result(), loadValues.get(v));
763 } else {
764 block.op(op);
765 }
766 } else {
767 block = transformToTritonOperation(block, op, valueTypeMap, opData, fsymTable);
768 }
769 return block;
770 });
771 });
772 Op.Result forResult = kblock.op(scffor);
773
774 // Assign back result to iter vars
775 if (capturedAndStoredVars.size() == 1) {
776 for (Value v : capturedAndStoredVars) {
777 kblock.op(varStore(cc.getValue(v), forResult));
778 }
779 } else {
780 int i = 0;
781 for (Value v : capturedAndStoredVars) {
782 kblock.op(varStore(cc.getValue(v),
783 kblock.op(tupleLoad(forResult, i++))));
784 }
785 }
786 }
787
788 static Map<Boolean, Set<Value>> capturedVars(Body body) {
789 Map<Boolean, Set<Value>> capturedValues = new HashMap<>();
790 capturedValues.put(false, new LinkedHashSet<>());
791 capturedValues.put(true, new LinkedHashSet<>());
792
793 capturedVars(capturedValues, new ArrayDeque<>(), body);
794 return capturedValues;
795 }
796
797 static void capturedVars(Map<Boolean, Set<Value>> capturedVars, Deque<Body> bodyStack, Body body) {
798 bodyStack.push(body);
799
800 for (Block b : body.blocks()) {
801 for (Op op : b.ops()) {
802 // @@@ Nested bodies
803 if (!op.bodies().isEmpty()) {
804 throw new IllegalStateException();
805 }
806 // for (Body childBody : op.bodies()) {
807 // capturedAndUpdatedVars(capturedValues, bodyStack, childBody);
808 // }
809
810 if (op instanceof VarAccessOp) {
811 Value v = op.operands().get(0);
812 if (!bodyStack.contains(v.declaringBlock().ancestorBody())) {
813 if (op instanceof VarAccessOp.VarStoreOp) {
814 capturedVars.get(true).add(v);
815 capturedVars.get(false).remove(v);
816 } else if (!capturedVars.get(true).contains(v)) {
817 capturedVars.get(false).add(v);
818 }
819 }
820 }
821 }
822 }
823
824 bodyStack.pop();
825 }
826
827 public static final ScopedValue<Boolean> SV_SSA = ScopedValue.newInstance();
828
829 static TritonOps.FuncOp cleanup(TritonOps.FuncOp f) {
830 // Remove var ops
831 boolean doSSA = SV_SSA.isBound() ? SV_SSA.get() : true;
832 if (doSSA) {
833 f = SSA.transform(f);
834 }
835 // Remove unused ops
836 f = f.transform((fblock, op) -> {
837 if (op instanceof Op.Pure && op.result().uses().isEmpty()) {
838 return fblock;
839 } else if (op instanceof VarAccessOp.VarLoadOp && op.result().uses().isEmpty()) {
840 return fblock;
841 }
842
843 fblock.op(op);
844 return fblock;
845 });
846 return f;
847 }
848
849 static class TritonBuilderInterpreter {
850 final Map<String, TritonOps.FuncOp> fsymTable;
851 final Block.Builder block;
852
853 TritonBuilderInterpreter(Map<String, TritonOps.FuncOp> fsymTable, Block.Builder block) {
854 this.fsymTable = fsymTable;
855 this.block = block;
856 }
857
858 Value build(Op op, String name, Map<Value, TypeElement> valueTypeMap) {
859 // Obtain associated type-based method
860 MethodHandle mh;
861 try {
862 Optional<Method> om = Stream.of(TritonBuilderInterpreter.class.getDeclaredMethods())
863 .filter(m -> m.getName().equals(name))
864 .filter(m -> m.isVarArgs()
865 ? m.getParameterCount() / 2 - 1 <= op.operands().size()
866 : m.getParameterCount() / 2 - 1 == op.operands().size())
867 .findFirst();
868 mh = MethodHandles.lookup().unreflect(
869 om.orElseThrow(() -> new NoSuchMethodException(name)));
870 } catch (ReflectiveOperationException e) {
871 throw new IllegalStateException(e);
872 }
873
874 List<Object> iArgs = new ArrayList<>();
875 iArgs.add(this);
876 iArgs.add(valueTypeMap.get(op.result()));
877 iArgs.add(op.result());
878 for (Value o : op.operands()) {
879 iArgs.add(valueTypeMap.get(o));
880 iArgs.add(o);
881 }
882 try {
883 return (Value) mh.invokeWithArguments(iArgs.toArray(Object[]::new));
884 } catch (Throwable e) {
885 throw new IllegalStateException(e);
886 }
887 }
888
889
890 public Value programId(TypeElement rType, Op.Result r,
891 ConstantType axisType, Value axis) {
892 return block.op(TritonOps.getProgramId(
893 (int) axisType.value()));
894 }
895
896 public Value arange(TensorType rType, Op.Result r,
897 ConstantType startType, Value start,
898 ConstantType endType, Value end) {
899 return block.op(TritonOps.makeRange(
900 (int) startType.value(),
901 (int) endType.value()));
902 }
903
904 public Value expand(TensorType rType, Op.Result r,
905 TensorType aType, Value a,
906 ConstantType axisType, Value axis) {
907 return block.op(TritonOps.expand(
908 (int) axisType.value(),
909 rType,
910 block.context().getValue(a)));
911 }
912
913 public Value zeros(TensorType rType, Op.Result r,
914 ConstantType aType, Value a,
915 Object... constantsAndValues) {
916 Object zero;
917 try {
918 JavaType zeroType = (JavaType) aType.value();
919 zero = MethodHandles.zero((Class<?>) zeroType.resolve(MethodHandles.lookup())).invoke();
920 } catch (Throwable e) {
921 throw new RuntimeException(e);
922 }
923 return block.op(ArithMathOps.constant(rType, zero));
924 }
925
926 public Value load(TensorType rType, Op.Result r,
927 TensorType ptrType, Value ptr,
928 TensorType maskType, Value mask) {
929 broadcastConversionRight(ptrType, maskType, mask);
930 return block.op(TritonOps.load(
931 rType,
932 block.context().getValue(ptr),
933 block.context().getValue(mask)));
934 }
935
936 public Value load(TensorType rType, Op.Result r,
937 TensorType ptrType, Value ptr,
938 TensorType maskType, Value mask,
939 ConstantType otherType, Value other) {
940 broadcastConversionRight(ptrType, maskType, mask);
941 Value mb = block.op(ArithMathOps.constant(rType, (float)otherType.value()));
942 block.context().mapValue(other, mb);
943 return block.op(TritonOps.load(
944 rType,
945 block.context().getValue(ptr),
946 block.context().getValue(mask),
947 block.context().getValue(other)));
948 }
949
950 public Value store(TensorType rType, Op.Result r,
951 TensorType ptrType, Value ptr,
952 TensorType valueType, Value value,
953 TensorType maskType, Value mask) {
954 broadcastConversionRight(ptrType, valueType, value);
955 broadcastConversionRight(ptrType, maskType, mask);
956 return block.op(TritonOps.store(
957 block.context().getValue(ptr),
958 block.context().getValue(value),
959 block.context().getValue(mask)));
960 }
961
962 public Value broadcast(TensorType rType, Op.Result r,
963 TypeElement oType, Value o,
964 TensorType tensorTypeType, Value tensorType) {
965 // @@@ tt.splat with scalar operand, tt.broadcast with tensor operand
966 if (oType instanceof TensorType) {
967 return block.op(TritonOps.broadcast(
968 rType,
969 block.context().getValue(o)));
970 } else {
971 return block.op(TritonOps.splat(
972 rType,
973 block.context().getValue(o)));
974 }
975 }
976
977 public Value joinShape(TensorType rType, Op.Result r,
978 TensorType aType, Value a,
979 TensorType bType, Value b) {
980 // Replace with constant operation to produce tensor type.
981 // Result may be used, but transitively it will be removed due to no uses
982 // contributing to the computation
983 return block.op(CoreOp.constant(JavaType.type(TensorType.class), r.type()));
984 }
985
986
987 public Value add(TypeElement rType, Op.Result r,
988 TypeElement aType, Value a,
989 TypeElement bType, Value b) {
990 broadcastConversion(rType, aType, a, bType, b);
991 a = block.context().getValue(a);
992 b = block.context().getValue(b);
993
994 if (rType instanceof PtrType ||
995 rType instanceof TensorType t && t.eType() instanceof PtrType) {
996 return block.op(TritonOps.addptr(a, b));
997 } else {
998 return block.op(ArithMathOps.add(a, b));
999 }
1000 }
1001
1002 public Value sub(TypeElement rType, Op.Result r,
1003 TypeElement aType, Value a,
1004 TypeElement bType, Value b) {
1005 broadcastConversion(rType, aType, a, bType, b);
1006 a = block.context().getValue(a);
1007 b = block.context().getValue(b);
1008
1009 return block.op(ArithMathOps.sub(a, b));
1010 }
1011
1012 public Value mul(TypeElement rType, Op.Result r,
1013 TypeElement aType, Value a,
1014 TypeElement bType, Value b) {
1015 broadcastConversion(rType, aType, a, bType, b);
1016 a = block.context().getValue(a);
1017 b = block.context().getValue(b);
1018
1019 return block.op(ArithMathOps.mul(a, b));
1020 }
1021
1022 public Value div(TypeElement rType, Op.Result r,
1023 TypeElement aType, Value a,
1024 TypeElement bType, Value b) {
1025 broadcastConversion(rType, aType, a, bType, b);
1026 a = block.context().getValue(a);
1027 b = block.context().getValue(b);
1028
1029 return block.op(ArithMathOps.div(a, b));
1030 }
1031
1032 public Value mod(TypeElement rType, Op.Result r,
1033 TypeElement aType, Value a,
1034 TypeElement bType, Value b) {
1035 broadcastConversion(rType, aType, a, bType, b);
1036 a = block.context().getValue(a);
1037 b = block.context().getValue(b);
1038
1039 return block.op(ArithMathOps.rem(a, b));
1040 }
1041
1042 public Value and(TypeElement rType, Op.Result r,
1043 TypeElement aType, Value a,
1044 TypeElement bType, Value b) {
1045 broadcastConversion(rType, aType, a, bType, b);
1046 a = block.context().getValue(a);
1047 b = block.context().getValue(b);
1048
1049 return block.op(ArithMathOps.and(a, b));
1050 }
1051
1052 public Value dot(TensorType rType, Op.Result r,
1053 TypeElement aType, Value a,
1054 TypeElement bType, Value b) {
1055 a = block.context().getValue(a);
1056 b = block.context().getValue(b);
1057 // Computed result is tensor of floats, regardless of inputs
1058 Object zero = 0.0f;
1059 var c = block.op(ArithMathOps.constant(rType, zero));
1060 return block.op(TritonOps.dot(rType, a, b, c));
1061 }
1062
1063 public Value cdiv(TypeElement rType, Op.Result r,
1064 TypeElement aType, Value a,
1065 TypeElement bType, Value b) {
1066 a = block.context().getValue(a);
1067 b = block.context().getValue(b);
1068
1069 TritonOps.FuncOp cdiv = tritonFunction(Functions.getJavaCodeModel("cdiv"),
1070 rType, List.of(aType, bType),
1071 fsymTable);
1072 // @@@ Generalize
1073 List<Value> args = new ArrayList<>();
1074 if (!(aType instanceof ConstantType)) {
1075 args.add(a);
1076 }
1077 if (!(bType instanceof ConstantType)) {
1078 args.add(b);
1079 }
1080 return block.op(TritonOps.call(cdiv, args));
1081 }
1082
1083 public Value conv(TypeElement rType, Op.Result r,
1084 ConstantType tType, Value t,
1085 TypeElement aType, Value a) {
1086 a = block.context().getValue(a);
1087
1088 TypeElement rScalarType;
1089 TypeElement aScalarType;
1090 if (rType instanceof TensorType rTensorType && aType instanceof TensorType aTensorType) {
1091 rScalarType = rTensorType.eType();
1092 aScalarType = aTensorType.eType();
1093 } else {
1094 rScalarType = rType;
1095 aScalarType = aType;
1096 }
1097
1098 if (rScalarType.equals(Float16.FLOAT_16_TYPE) && aScalarType.equals(JavaType.FLOAT)) {
1099 return block.op(ArithMathOps.trunc(rType, a));
1100 } else if (rType.equals(aType)) {
1101 return a;
1102 } else {
1103 throw new IllegalStateException();
1104 }
1105 }
1106
1107 public Value exp(TritonType rType, Op.Result r,
1108 TritonType aType, Value a) {
1109 return block.op(ArithMathOps.exp(
1110 block.context().getValue(a)));
1111 }
1112
1113 public Value compare(TensorType rType, Op.Result r,
1114 TypeElement aType, Value a,
1115 TypeElement bType, Value b,
1116 ConstantType compareType, Value compare) {
1117 Triton.CompareKind ck = (Triton.CompareKind) compareType.value();
1118
1119 ArithMathOps.CompareOp.CompareKind ack = switch (ck) {
1120 case LessThan -> ArithMathOps.CompareOp.CompareKind.slt;
1121 default -> throw new UnsupportedOperationException("Unsupported comparison: " + ck);
1122 };
1123
1124 broadcastConversionRight(aType, bType, b);
1125 a = block.context().getValue(a);
1126 b = block.context().getValue(b);
1127
1128 return block.op(ArithMathOps.cmp(ack, a, b));
1129 }
1130
1131
1132 public Value max(TypeElement rType, Op.Result r,
1133 TensorType xType, Value x,
1134 ConstantType axisType, Value axis) {
1135 TritonOps.FuncOp f = tritonFunction(Functions.getJavaCodeModel("max"),
1136 rType, List.of(rType, rType), fsymTable);
1137 return reduce(rType, r, xType, x, axisType, axis, f);
1138 }
1139
1140 public Value sum(TypeElement rType, Op.Result r,
1141 TensorType xType, Value x,
1142 ConstantType axisType, Value axis) {
1143 TritonOps.FuncOp f = tritonFunction(Functions.getJavaCodeModel("sum"),
1144 rType, List.of(rType, rType), fsymTable);
1145 return reduce(rType, r, xType, x, axisType, axis, f);
1146 }
1147
1148 Value reduce(TypeElement rType, Op.Result r,
1149 TensorType xType, Value x,
1150 ConstantType axisType, Value axis,
1151 TritonOps.FuncOp f) {
1152 int axisConstant = (int) axisType.value();
1153
1154 String signature = "reduce_" + f.funcName() + "_" + axisConstant;
1155 TritonOps.FuncOp rf = fsymTable.computeIfAbsent(signature,
1156 s -> reduce(rType, xType, axisConstant, s, f));
1157
1158 return block.op(TritonOps.call(rf, block.context().getValue(x)));
1159 }
1160
1161 static TritonOps.FuncOp reduce(TypeElement elementType,
1162 TensorType tensorType,
1163 int axisConstant,
1164 String name, TritonOps.FuncOp scalarFunc) {
1165 return TritonOps.func(name,
1166 functionType(elementType, tensorType))
1167 .body(fblock -> {
1168 TritonOps.ReduceOp reduceOp = TritonOps.reduce(fblock.parentBody(),
1169 axisConstant, fblock.parameters().get(0),
1170 functionType(elementType, elementType, elementType))
1171 .body(rblock -> {
1172 Block.Parameter a = rblock.parameters().get(0);
1173 Block.Parameter b = rblock.parameters().get(1);
1174 Op.Result _r = rblock.op(TritonOps.call(scalarFunc, a, b));
1175 rblock.op(TritonOps.reduceReturn(_r));
1176 });
1177
1178 Op.Result opr = fblock.op(reduceOp);
1179 fblock.op(TritonOps.return_(opr));
1180 });
1181 }
1182
1183 // @@@ Test
1184 public Value consume(TypeElement rType, Op.Result r,
1185 TypeElement aType, Value a) {
1186 return block.op(TritonTestOps.consume(block.context().getValue(a)));
1187 }
1188
1189 void broadcastConversion(TypeElement rType,
1190 TypeElement aType, Value a,
1191 TypeElement bType, Value b) {
1192 Value ma = block.context().getValue(a);
1193 Value mb = block.context().getValue(b);
1194 if (aType instanceof TensorType at && bType instanceof TensorType bTensorType) {
1195 TensorType rTensorType = (TensorType) rType;
1196 if (!at.shape().equals(rTensorType.shape())) {
1197 ma = block.op(TritonOps.broadcast(rTensorType, ma));
1198 }
1199 if (!bTensorType.shape().equals(rTensorType.shape())) {
1200 if (rTensorType.eType() instanceof PtrType) {
1201 bTensorType = new TensorType(bType, rTensorType.shape());
1202 mb = block.op(TritonOps.broadcast(bTensorType, mb));
1203 } else {
1204 mb = block.op(TritonOps.broadcast(rTensorType, mb));
1205 }
1206 }
1207 } else if (aType instanceof TensorType) {
1208 TensorType rTensorType = (TensorType) rType;
1209 if (rTensorType.eType() instanceof PtrType) {
1210 TensorType bTensorType = new TensorType(bType, rTensorType.shape());
1211 mb = block.op(TritonOps.splat(bTensorType, mb));
1212 } else {
1213 mb = block.op(TritonOps.splat(rTensorType, mb));
1214 }
1215 } else if (bType instanceof TensorType) {
1216 TensorType rTensorType = (TensorType) rType;
1217 ma = block.op(TritonOps.splat(rTensorType, ma));
1218 }
1219 block.context().mapValue(a, ma);
1220 block.context().mapValue(b, mb);
1221 }
1222
1223 void broadcastConversionRight(TypeElement aType,
1224 TypeElement bType, Value b) {
1225 Value mb = block.context().getValue(b);
1226 if (aType instanceof TensorType aTensorType && bType instanceof TensorType bTensorType) {
1227 if (!bTensorType.shape().equals(aTensorType.shape())) {
1228 if (aTensorType.eType() instanceof PtrType) {
1229 bTensorType = new TensorType(bTensorType.eType(), aTensorType.shape());
1230 mb = block.op(TritonOps.broadcast(bTensorType, mb));
1231 } else {
1232 mb = block.op(TritonOps.broadcast(aTensorType, mb));
1233 }
1234 }
1235 } else if (aType instanceof TensorType rTensorType) {
1236 if (rTensorType.eType() instanceof PtrType) {
1237 TensorType bTensorType = new TensorType(bType, rTensorType.shape());
1238 mb = block.op(TritonOps.splat(bTensorType, mb));
1239 } else {
1240 mb = block.op(TritonOps.splat(rTensorType, mb));
1241 }
1242 }
1243 block.context().mapValue(b, mb);
1244 }
1245 }
1246
1247 public static <O extends Op & Op.Invokable> void printTypeMap(
1248 O kernel, Map<Value, TypeElement> valueTypeMap) {
1249 AtomicInteger valueId = new AtomicInteger();
1250 Map<Value, Integer> valueIdMap = new LinkedHashMap<>();
1251 kernel.elements().forEach(codeElement -> {
1252 switch (codeElement) {
1253 case FuncOp _ -> {
1254 // Ignore
1255 }
1256 case Op op when !op.result().type().equals(JavaType.VOID) -> {
1257 valueIdMap.put(op.result(), valueId.getAndIncrement());
1258 }
1259 case Block block -> {
1260 for (Block.Parameter parameter : block.parameters()) {
1261 valueIdMap.put(parameter, valueId.getAndIncrement());
1262 }
1263 }
1264 default -> {
1265 }
1266 }
1267 });
1268
1269 valueIdMap.forEach((value, id) -> {
1270 TypeElement type = valueTypeMap.get(value);
1271 if (type != null) {
1272 System.out.println("%" + id + " : " + value.type() + " -> " + type);
1273 }
1274 });
1275 }
1276 }