1 /*
2 * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation. Oracle designates this
8 * particular file as subject to the "Classpath" exception as provided
9 * by Oracle in the LICENSE file that accompanied this code.
10 *
11 * This code is distributed in the hope that it will be useful, but WITHOUT
12 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
13 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
14 * version 2 for more details (a copy is included in the LICENSE file that
15 * accompanied this code).
16 *
17 * You should have received a copy of the GNU General Public License version
18 * 2 along with this work; if not, write to the Free Software Foundation,
19 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
20 *
21 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
22 * or visit www.oracle.com if you need additional information or have any
23 * questions.
24 */
25
26 package oracle.code.triton;
27
28 import java.lang.constant.ClassDesc;
29 import java.lang.invoke.MethodHandle;
30 import java.lang.invoke.MethodHandles;
31 import java.lang.reflect.Field;
32 import java.lang.reflect.Method;
33 import jdk.incubator.code.*;
34 import jdk.incubator.code.analysis.SSA;
35 import jdk.incubator.code.dialect.core.CoreOp;
36 import jdk.incubator.code.dialect.java.JavaOp;
37 import jdk.incubator.code.dialect.java.JavaType;
38 import jdk.incubator.code.dialect.core.VarType;
39 import java.util.*;
40 import java.util.concurrent.atomic.AtomicInteger;
41 import java.util.stream.Stream;
42
43 import static jdk.incubator.code.dialect.core.CoreOp.*;
44 import static jdk.incubator.code.dialect.core.CoreType.functionType;
45 import static jdk.incubator.code.dialect.java.JavaOp.*;
46
47 public final class TritonTransformer {
48 private TritonTransformer() {}
49
50 static final JavaType TYPE_Triton = JavaType.type(Triton.class);
51
52 static final JavaType TYPE_Triton_Test = JavaType.type(ClassDesc.of("oracle.code.triton.TritonTest"));
53
54 static final JavaType TYPE_Tensor = JavaType.type(Tensor.class);
55
56 static final JavaType TYPE_J_L_MATH = JavaType.type(Math.class);
57
58 public static <O extends Op & Op.Invokable>
59 TritonOps.ModuleOp tritonModule(O kernel,
60 TypeElement rType,
61 List<? extends TypeElement> argTypes) {
62 Map<String, TritonOps.FuncOp> fsymTable = new LinkedHashMap<>();
63 tritonFunction(kernel, rType, argTypes, fsymTable);
64 return TritonOps.module(fsymTable.values().stream().toList());
65 }
66
67 public static <O extends Op & Op.Invokable>
68 TritonOps.FuncOp tritonFunction(O javaKernel,
69 TypeElement rType,
70 List<? extends TypeElement> argTypes,
71 Map<String, TritonOps.FuncOp> fsymTable) {
72 String name = (javaKernel instanceof FuncOp f) ? f.funcName() : "kernel";
73 String signature = signature(name, rType, argTypes);
74 if (fsymTable.containsKey(signature)) {
75 return fsymTable.get(signature);
76 }
77
78 System.out.println(javaKernel.toText());
79
80 Map<Value, TypeElement> valueTypeMap = new HashMap<>();
81 Map<Op, Object> opData = new HashMap<>();
82 TritonTransformer.typeCheckKernel(javaKernel, argTypes, valueTypeMap, opData);
83 TritonTransformer.printTypeMap(javaKernel, valueTypeMap);
84
85 return TritonTransformer.transformToTritonFunction(javaKernel, signature,
86 rType, valueTypeMap, opData,
87 fsymTable);
88 }
89
90 static String signature(String name, TypeElement rType, List<? extends TypeElement> argTypes) {
91 StringBuilder sb = new StringBuilder(name);
92
93 for (TypeElement argType : argTypes) {
94 sb.append("_");
95 if (argType instanceof ConstantType ct) {
96 sb.append(ct.value());
97 } else {
98 sb.append(argType);
99 }
100 }
101 sb.append("_");
102 sb.append(rType);
103 return sb.toString();
104 }
105
106 public static <O extends Op & Op.Invokable> void typeCheckKernel(
107 O kernel, List<? extends TypeElement> argTypes,
108 Map<Value, TypeElement> valueTypeMap, Map<Op, Object> opData) {
109 kernel.traverse(null, CodeElement.opVisitor((o, op) -> {
110 switch (op) {
111 case Op.Invokable fop -> {
112 List<Block.Parameter> parameters = fop.body().entryBlock().parameters();
113 for (int i = 0; i < parameters.size(); i++) {
114 valueTypeMap.put(parameters.get(i), argTypes.get(i));
115 }
116 }
117 case VarOp _, VarAccessOp.VarLoadOp _ -> {
118 Value init = op.operands().get(0);
119 valueTypeMap.put(op.result(), valueTypeMap.get(init));
120 }
121 case VarAccessOp.VarStoreOp _ -> {
122 Value var = op.operands().get(0);
123 TypeElement varType = valueTypeMap.get(var);
124 Value v = op.operands().get(1);
125 TypeElement vType = valueTypeMap.get(v);
126 if (!varType.equals(vType)) {
127 throw new IllegalStateException("Storing to variable with different type: "
128 + varType + " <- " + vType);
129 }
130
131 valueTypeMap.put(op.result(), valueTypeMap.get(var));
132 }
133 case ConstantOp cop -> {
134 valueTypeMap.put(op.result(), new ConstantType(op.result().type(), cop.value()));
135 }
136 case ArithmeticOperation _ -> {
137 TypeElement t = checkWithTypeInterpreter(op, op.externalizeOpName(), valueTypeMap);
138 valueTypeMap.put(op.result(), t);
139 }
140 case FieldAccessOp.FieldLoadOp flop -> {
141 if (!flop.operands().isEmpty()) {
142 throw new IllegalStateException("Unsupported field load: " + flop.fieldDescriptor());
143 }
144
145 Field f;
146 try {
147 f = flop.fieldDescriptor().resolveToField(MethodHandles.lookup());
148 } catch (ReflectiveOperationException e) {
149 throw new IllegalStateException("Unsupported field load: " + flop.fieldDescriptor(), e);
150 }
151 Object value;
152 try {
153 value = f.get(null);
154 } catch (IllegalAccessException e) {
155 throw new IllegalStateException("Unsupported field load: " + f, e);
156 }
157 valueTypeMap.put(op.result(), new ConstantType(JavaType.type(f.getType()), value));
158 }
159 case InvokeOp iop when iop.invokeDescriptor().refType().equals(JavaType.J_L_INTEGER) -> {
160 // Box
161 if (iop.invokeDescriptor().name().equals("valueOf")) {
162 Value a = op.operands().get(0);
163 valueTypeMap.put(op.result(), valueTypeMap.get(a));
164 } else {
165 throw new UnsupportedOperationException("Unsupported invocation on Integer: " + iop.invokeDescriptor());
166 }
167 }
168 case InvokeOp iop when iop.invokeDescriptor().refType().equals(TYPE_J_L_MATH) -> {
169 String name = iop.invokeDescriptor().name();
170 if (name.equals("max") || name.equals("min")) {
171 Value a = op.operands().get(0);
172 valueTypeMap.put(op.result(), valueTypeMap.get(a));
173 } else {
174 throw new UnsupportedOperationException("Unsupported invocation on Math: " + iop.invokeDescriptor());
175 }
176 }
177 case InvokeOp iop when iop.invokeDescriptor().refType().equals(TYPE_Tensor) -> {
178 if (iop.invokeDescriptor().name().equals("type")) {
179 Value a = op.operands().get(0);
180 valueTypeMap.put(op.result(), valueTypeMap.get(a));
181 } else {
182 throw new UnsupportedOperationException("Unsupported invocation on Tensor: " + iop.invokeDescriptor());
183 }
184 }
185 case InvokeOp iop when iop.invokeDescriptor().refType().equals(TYPE_Triton) -> {
186 TypeElement t = checkWithTypeInterpreter(op, iop.invokeDescriptor().name(), valueTypeMap);
187 valueTypeMap.put(op.result(), t);
188 }
189 case InvokeOp iop when iop.invokeDescriptor().refType().equals(TYPE_Triton_Test) -> {
190 TypeElement t = checkWithTypeInterpreter(op, iop.invokeDescriptor().name(), valueTypeMap);
191 valueTypeMap.put(op.result(), t);
192 }
193 case JavaOp.ForOp fop -> {
194 SimpleCountedForLoopInfo li = new SimpleCountedForLoopInfo(fop);
195 opData.put(fop, li);
196
197 TypeElement type = fop.init().yieldType();
198 if (type instanceof VarType vt && vt.valueType().equals(JavaType.INT)) {
199 for (Body b : List.of(fop.cond(), fop.update(), fop.loopBody())) {
200 valueTypeMap.put(b.entryBlock().parameters().get(0), JavaType.INT);
201 }
202 } else {
203 throw new IllegalStateException();
204 }
205 }
206 case TestOperation _ -> {
207 }
208 case JavaOp.ContinueOp _ -> {
209 }
210 case CoreOp.YieldOp _ -> {
211 }
212 case ReturnOp _ -> {
213 }
214 default -> throw new UnsupportedOperationException("Unsupported operation: " + op);
215 }
216
217 return null;
218 }));
219 }
220
221 static TypeElement checkWithTypeInterpreter(Op op, String name, Map<Value, TypeElement> valueTypeMap) {
222 // Obtain associated type-based method
223 MethodHandle mh;
224 try {
225 Optional<Method> om = Stream.of(TritonTypeInterpreter.class.getDeclaredMethods())
226 .filter(m -> m.getName().equals(name))
227 .filter(m -> m.isVarArgs() ? m.getParameterCount() <= op.operands().size() : m.getParameterCount() == op.operands().size())
228 .findFirst();
229 mh = MethodHandles.lookup().unreflect(
230 om.orElseThrow(() -> new NoSuchMethodException(name)));
231 } catch (ReflectiveOperationException e) {
232 throw new IllegalStateException(name, e);
233 }
234
235 // Invoke with the values' types
236 List<TypeElement> operandTypes = op.operands().stream().map(valueTypeMap::get).toList();
237 try {
238 return (TypeElement) mh.invokeWithArguments(operandTypes.toArray(Object[]::new));
239 } catch (Throwable e) {
240 throw new IllegalStateException(mh.toString(), e);
241 }
242 }
243
244 // @@@ type check tensor shapes
245 static class TritonTypeInterpreter {
246 private TritonTypeInterpreter() {
247 }
248
249 // int programId(@Constant int axis) {
250 public static JavaType programId(ConstantType axis) {
251 assert axis.cType().equals(JavaType.INT);
252 int axisValue = (int) axis.value();
253 if (axisValue < 0 || axisValue > 3) {
254 throw new IllegalStateException();
255 }
256
257 return JavaType.INT;
258 }
259
260 // Tensor arange(@Constant int start, @Constant int end)
261 public static TensorType arange(ConstantType start, ConstantType end) {
262 assert start.cType().equals(JavaType.INT);
263 assert end.cType().equals(JavaType.INT);
264
265 int startValue = (int) start.value();
266 int endValue = (int) end.value();
267
268 return new TensorType(JavaType.INT, List.of(endValue - startValue));
269 }
270
271 // Tensor expand(Tensor a, int axis) {
272 public static TensorType expand(TensorType a, ConstantType axis) {
273 assert axis.cType().equals(JavaType.INT);
274 int axisValue = (int) axis.value();
275
276 List<Integer> s = new ArrayList<>(a.shape());
277 if (axisValue < s.size()) {
278 s.add(axisValue, 1);
279 } else {
280 for (int i = 0; i <= (axisValue - s.size()); i++) {
281 s.add(1);
282 }
283 }
284 return new TensorType(a.eType(), s);
285 }
286
287 // Tensor load(Tensor ptr, Tensor mask)
288 public static TensorType load(TensorType ptr, TensorType mask) {
289 checkTensorShape(ptr, mask);
290 if (ptr.eType() instanceof PtrType eptr) {
291 return new TensorType(eptr.rType(), ptr.shape());
292 }
293
294 throw new IllegalStateException();
295 }
296
297 // Tensor load(Tensor ptr, Tensor mask, ConstantType other) {
298 public static TensorType load(TensorType ptr, TensorType mask, ConstantType other) {
299 checkTensorShape(ptr, mask);
300 if (ptr.eType() instanceof PtrType eptr) {
301 return new TensorType(eptr.rType(), ptr.shape());
302 }
303
304 throw new IllegalStateException();
305 }
306
307 // void store(Tensor ptr, Tensor value, Tensor mask)
308 public static void store(TensorType ptr, TensorType value, TensorType mask) {
309 if (!(ptr.eType() instanceof PtrType)) {
310 throw new IllegalStateException();
311 }
312 }
313
314 // Tensor zeros(TensorType type)
315 public static TensorType zeros(ConstantType eType, ConstantType... cShape) {
316 List<Integer> shape = Stream.of(cShape).map(s -> (int) s.value()).toList();
317 return new TensorType((TypeElement) eType.value(), shape);
318 }
319
320 // Tensor broadcast(Object o, TensorType type)
321 public static TensorType broadcast(TypeElement o, TensorType type) {
322 if (o instanceof TensorType ot) {
323 // @@@
324 if (ot.shape().size() != type.shape().size()) {
325 throw new IllegalStateException();
326 }
327 o = ot.eType();
328 } if (o instanceof ConstantType oc) {
329 o = oc.cType();
330 }
331 return new TensorType(o, type.shape());
332 }
333
334 public static TensorType joinShape(TensorType a, TensorType b) {
335 return checkTensorTypes(a, b);
336 }
337
338 // Tensor add(Number a, Number b)
339 // Ptr add(Ptr a, int offset)
340 public static TypeElement add(TypeElement a, TypeElement b) {
341 // @@@ Pass additional argument for checking ptr
342 return binary(a, b);
343 }
344
345 public static TypeElement sub(TypeElement a, TypeElement b) {
346 return binary(a, b);
347 }
348
349 public static TypeElement mul(TypeElement a, TypeElement b) {
350 return binary(a, b);
351 }
352
353 public static TypeElement div(TypeElement a, TypeElement b) {
354 return binary(a, b);
355 }
356
357 public static TypeElement mod(TypeElement a, TypeElement b) {
358 return binary(a, b);
359 }
360
361 public static TypeElement and(TypeElement a, TypeElement b) {
362 return binary(a, b);
363 }
364
365 public static TypeElement cdiv(TypeElement a, TypeElement b) {
366 a = reduceScalarType(a);
367 b = reduceScalarType(b);
368 if (!a.equals(JavaType.INT) && !b.equals(JavaType.INT)) {
369 throw new IllegalStateException();
370 }
371 return a;
372 }
373
374 // Number conv(Type t, Number a) {
375 public static TypeElement conv(ConstantType eType, TypeElement a) {
376 return convTypes(eType, a);
377 }
378
379 public static TypeElement convTypes(ConstantType eType, TypeElement a) {
380 if (a instanceof TensorType tb) {
381 TypeElement e = convScalarTypes(eType, tb.eType());
382 return new TensorType(e, tb.shape());
383 } else {
384 return convScalarTypes(eType, a);
385 }
386 }
387
388 public static TypeElement convScalarTypes(ConstantType eType, TypeElement a) {
389 TypeElement t = (TypeElement) eType.value();
390 if (t.equals(Float16.FLOAT_16_TYPE) && a.equals(JavaType.FLOAT)) {
391 return Float16.FLOAT_16_TYPE;
392 } else if (t.equals(a)) {
393 return t;
394 } else {
395 // @@@ Conversions;
396 throw new IllegalStateException();
397 }
398 }
399
400 // Tensor exp(Tensor a)
401 public static TypeElement exp(TypeElement a) {
402 return unary(a);
403 }
404
405 static TypeElement unary(TypeElement a) {
406 return a;
407 }
408
409 // Tensor compare(Number a, Number b, @Constant CompareKind ck) {
410 public static TypeElement compare(TypeElement a, TypeElement b, ConstantType kind) {
411 assert kind.cType().equals(JavaType.type(Triton.CompareKind.class));
412
413 TypeElement t = binary(a, b);
414 if (t instanceof TensorType tt) {
415 return new TensorType(JavaType.BOOLEAN, tt.shape());
416 } else {
417 return t;
418 }
419 }
420
421 // Tensor dot(Tensor a, Tensor b)
422 public static TensorType dot(TensorType a, TensorType b) {
423 if (a.shape().size() != 2 || b.shape().size() != 2) {
424 throw new IllegalStateException();
425 }
426
427 if (!a.shape().get(1).equals(b.shape().get(0))) {
428 throw new IllegalStateException();
429 }
430
431 if (a.eType() != b.eType()) {
432 // @@@ Conversion, type checking
433 throw new IllegalStateException();
434 }
435
436 // Computed result is tensor of floats, regardless of inputs
437 return new TensorType(JavaType.FLOAT, List.of(a.shape().get(0), b.shape().get(1)));
438 }
439
440
441 // Tensor max(Tensor a, @Constant int axis) {
442 public static TypeElement max(TensorType a, ConstantType axis) {
443 return reduce(a, axis);
444 }
445
446 // Tensor sum(Tensor a, @Constant int axis) {
447 public static TypeElement sum(TensorType a, ConstantType axis) {
448 return reduce(a, axis);
449 }
450
451 static TypeElement reduce(TensorType a, ConstantType axis) {
452 assert axis.cType().equals(JavaType.INT);
453 int axisValue = (int) axis.value();
454 if (axisValue < 0 || axisValue > 3) {
455 throw new IllegalStateException();
456 }
457
458 List<Integer> reduceShape = new ArrayList<>();
459 for (int i = 0; i < a.shape().size(); i++) {
460 if (i != axisValue) {
461 reduceShape.add(a.shape().get(i));
462 } else {
463 reduceShape.add(1);
464 }
465 }
466
467 if (reduceShape.size() == 1 && reduceShape.getFirst() == 1) {
468 return a.eType();
469 } else {
470 return new TensorType(a.eType(), reduceShape);
471 }
472 }
473
474 // @@@ Test
475 public static void consume(TypeElement a) {
476 }
477
478
479 static TypeElement binary(TypeElement a, TypeElement b) {
480 if (a instanceof TensorType ta && b instanceof TensorType tb) {
481 return checkTensorTypes(ta, tb);
482 } else if (a instanceof TensorType ta) {
483 return new TensorType(checkScalarTypes(ta.eType(), b), ta.shape());
484 } else if (b instanceof TensorType tb) {
485 return new TensorType(checkScalarTypes(a, tb.eType()), tb.shape());
486 } else {
487 return checkScalarTypes(a, b);
488 }
489 }
490
491 static TensorType checkTensorTypes(TensorType a, TensorType b) {
492 List<Integer> s = checkTensorShape(a, b);
493 TypeElement e = checkScalarTypes(a.eType(), b.eType());
494 return new TensorType(e, s);
495 }
496
497 static List<Integer> checkTensorShape(TensorType a, TensorType b) {
498 if (a.shape().size() != b.shape().size()) {
499 // Shape mismatch
500 throw new IllegalStateException();
501 }
502
503 List<Integer> s = new ArrayList<>();
504 for (int i = 0; i < a.shape().size(); i++) {
505 int ad = a.shape().get(i);
506 int bd = b.shape().get(i);
507
508 // Expand dimensions
509 int d;
510 if (ad == bd) {
511 d = ad;
512 } else {
513 if (ad != 1 && bd == 1) {
514 d = ad;
515 } else if (ad == 1) {
516 d = bd;
517 } else {
518 // Shape mismatch
519 throw new IllegalStateException();
520 }
521 }
522
523 s.add(d);
524 }
525
526 return s;
527 }
528
529 static TypeElement checkScalarTypes(TypeElement a, TypeElement b) {
530 // @@@ Optional ptr checking
531 if (a instanceof PtrType) {
532 if (!b.equals(JavaType.INT)) {
533 throw new IllegalStateException();
534 }
535 } else if (b instanceof PtrType) {
536 // Pointer must be first argument
537 throw new IllegalStateException();
538 } else if (a instanceof ConstantType || b instanceof ConstantType) {
539 return checkScalarTypes(reduceScalarType(a), reduceScalarType(b));
540 } else if (!a.equals(b)) {
541 // @@@ Conversion
542 throw new IllegalStateException();
543 }
544 return a;
545 }
546
547 static TypeElement reduceScalarType(TypeElement a) {
548 return a instanceof ConstantType ct ? ct.cType() : a;
549 }
550 }
551
552 public static <O extends Op & Op.Invokable> TritonOps.FuncOp transformToTritonFunction(
553 O kernel,
554 String signature,
555 TypeElement rType,
556 Map<Value, TypeElement> valueTypeMap, Map<Op, Object> opData,
557 Map<String, TritonOps.FuncOp> fsymTable) {
558 TritonOps.FuncOp ttKernel = TritonOps.func(signature, functionType(rType))
559 .body(fblock -> {
560 // Process kernel parameters
561 List<Value> args = new ArrayList<>();
562 for (Block.Parameter kp : kernel.body().entryBlock().parameters()) {
563 TypeElement type = valueTypeMap.get(kp);
564 if (type instanceof ConstantType ct) {
565 // Constant
566 Op.Result cr = fblock.op(ArithMathOps.constant(
567 ct.cType(), ct.value()));
568 args.add(cr);
569 } else {
570 args.add(fblock.parameter(type));
571 }
572 }
573
574 // Transform kernel body
575 fblock.body(kernel.body(), args, (kblock, op) -> {
576 return transformToTritonOperation(kblock, op, valueTypeMap, opData, fsymTable);
577 });
578 });
579
580 ttKernel = cleanup(ttKernel);
581 fsymTable.put(ttKernel.funcName(), ttKernel);
582 return ttKernel;
583 }
584
585 static Block.Builder transformToTritonOperation(Block.Builder kblock, Op op,
586 Map<Value, TypeElement> valueTypeMap, Map<Op, Object> opData,
587 Map<String, TritonOps.FuncOp> fsymTable) {
588 // @@@ Avoid constructing for each operation -- block builder passed as argument or a scoped value
589 TritonBuilderInterpreter tbi = new TritonBuilderInterpreter(fsymTable, kblock);
590 CopyContext cc = kblock.context();
591 switch (op) {
592 case VarOp varOp -> {
593 // @@@ Cannot copy op because the result type
594 // is derived from init type
595 Value init = cc.getValue(op.operands().get(0));
596 Op.Result r = kblock.op(var(varOp.varName(), init));
597 cc.mapValue(op.result(), r);
598 }
599 case ConstantOp cop -> {
600 TypeElement t = valueTypeMap.get(cop.result());
601 if (t instanceof ConstantType ct) {
602 Op.Result r = kblock.op(ArithMathOps.constant(
603 ct.cType(), ct.value()));
604 cc.mapValue(op.result(), r);
605 } else {
606 kblock.op(op);
607 }
608 }
609 case ArithmeticOperation _ -> {
610 Value result = tbi.build(op, op.externalizeOpName(), valueTypeMap);
611 if (result != null) {
612 cc.mapValue(op.result(), result);
613 }
614 }
615 case InvokeOp iop when iop.invokeDescriptor().refType().equals(JavaType.J_L_INTEGER) -> {
616 // Replace box with its value
617 Value a = cc.getValue(op.operands().get(0));
618 cc.mapValue(op.result(), a);
619 }
620 case InvokeOp iop when iop.invokeDescriptor().refType().equals(TYPE_J_L_MATH) -> {
621 String name = iop.invokeDescriptor().name();
622 if (name.equals("max")) {
623 Value a = cc.getValue(op.operands().get(0));
624 Value b = cc.getValue(op.operands().get(1));
625
626 Op.Result result = kblock.op(ArithMathOps.maximum(a, b));
627 cc.mapValue(op.result(), result);
628 } else if (name.equals("min")) {
629 Value a = cc.getValue(op.operands().get(0));
630 Value b = cc.getValue(op.operands().get(1));
631
632 Op.Result result = kblock.op(ArithMathOps.minimum(a, b));
633 cc.mapValue(op.result(), result);
634 }
635 }
636 case InvokeOp iop when iop.invokeDescriptor().refType().equals(TYPE_Tensor) -> {
637 if (iop.invokeDescriptor().name().equals("type")) {
638 // Replace with constant operation to produce tensor type.
639 // Result may be used, but transitively it will be removed due to no uses
640 // contributing to the computation
641 Value a = op.operands().get(0);
642 TensorType aType = (TensorType) valueTypeMap.get(a);
643 Op.Result result = kblock.op(CoreOp.constant(iop.resultType(), aType));
644 cc.mapValue(op.result(), result);
645 valueTypeMap.put(result, aType);
646 }
647 // Remove
648 }
649 case InvokeOp iop when iop.invokeDescriptor().refType().equals(TYPE_Triton) -> {
650 Value result = tbi.build(op, iop.invokeDescriptor().name(), valueTypeMap);
651 if (result != null) {
652 cc.mapValue(op.result(), result);
653 }
654 }
655 case InvokeOp iop when iop.invokeDescriptor().refType().equals(TYPE_Triton_Test) -> {
656 Value result = tbi.build(op, iop.invokeDescriptor().name(), valueTypeMap);
657 if (result != null) {
658 cc.mapValue(op.result(), result);
659 }
660 }
661 case JavaOp.ForOp fop -> {
662 transformToSCFFor(cc, kblock, fop, valueTypeMap, opData, fsymTable);
663 }
664 case ReturnOp rop -> {
665 if (rop.operands().isEmpty()) {
666 kblock.op(TritonOps.return_());
667 } else {
668 kblock.op(TritonOps.return_(
669 cc.getValue(rop.returnValue())));
670 }
671 }
672 default -> kblock.op(op);
673 }
674 return kblock;
675 }
676
677 static void transformToSCFFor(CopyContext cc, Block.Builder kblock, JavaOp.ForOp fop,
678 Map<Value, TypeElement> valueTypeMap, Map<Op, Object> opData,
679 Map<String, TritonOps.FuncOp> fsymTable) {
680 Body body = fop.loopBody();
681
682 // Hoist expressions for start, end, and step
683 SimpleCountedForLoopInfo li = (SimpleCountedForLoopInfo) opData.get(fop);
684 Value start = null;
685 for (Op o : li.startExpression()) {
686 transformToTritonOperation(kblock, o, valueTypeMap, opData, fsymTable);
687 start = cc.getValue(o.result());
688 }
689 Value end = null;
690 for (Op o : li.endExpression()) {
691 transformToTritonOperation(kblock, o, valueTypeMap, opData, fsymTable);
692 end = cc.getValue(o.result());
693 }
694 Value step = null;
695 for (Op o : li.stepExpression()) {
696 transformToTritonOperation(kblock, o, valueTypeMap, opData, fsymTable);
697 step = cc.getValue(o.result());
698 }
699
700 // Obtain captured vars
701 // true == stores
702 // false == loads only
703 Map<Boolean, Set<Value>> capturedVars = capturedVars(body);
704 Set<Value> capturedAndStoredVars = capturedVars.get(true);
705
706 // Get load values
707 // Loaded values are hoisted out of the loop body
708 Map<Value, Value> loadValues = new HashMap<>();
709 for (Value v : capturedVars.get(false)) {
710 Value load = kblock.op(varLoad(cc.getValue(v)));
711 valueTypeMap.put(load, valueTypeMap.get(v));
712 loadValues.put(v, load);
713 }
714
715 // Get iteration values -- represented by captured vars that are stored to in the loop
716 // The SCF for operation returns the iteration values of the last loop iteration, which
717 // are then to be stored to the iteration variables
718 List<Value> iterValues = new ArrayList<>();
719 for (Value v : capturedAndStoredVars) {
720 iterValues.add(kblock.op(varLoad(cc.getValue(v))));
721 }
722
723 // @@@ Build in java code model, then transform?
724 SCFOps.ForOp scffor = SCFOps.for_(kblock.parentBody(), start, end, step, iterValues)
725 // Ensure existing context is used
726 .body(CopyContext.create(cc), builder -> {
727 // Create index var initialized from entry block parameter
728 Value index = builder.parameters().get(0);
729 valueTypeMap.put(index, JavaType.INT);
730 Value varIndex = builder.op(var("index", index));
731 valueTypeMap.put(varIndex, JavaType.INT);
732 builder.context().mapValue(body.entryBlock().parameters().get(0), varIndex);
733
734 // Create iter vars initialized from entry block parameters
735 int pi = 1;
736 for (Value v : capturedAndStoredVars) {
737 TypeElement type = valueTypeMap.get(v);
738 Value iter = builder.parameters().get(pi++);
739 valueTypeMap.put(iter, type);
740 Value varIter = builder.op(var(Integer.toString(pi), iter));
741 valueTypeMap.put(varIter, type);
742 builder.context().mapValue(v, varIter);
743 }
744
745 // Transform the Java for body into the SCF for body
746 builder.body(body, List.of(), (block, op) -> {
747 // Yield iter values
748 if (op instanceof JavaOp.ContinueOp) {
749 // Replace with yield of loaded vars
750 List<Value> yieldValues = new ArrayList<>();
751 for (Value value : capturedAndStoredVars) {
752 Value varIter = block.context().getValue(value);
753 Value v = block.op(varLoad(varIter));
754 yieldValues.add(v);
755 }
756 block.op(SCFOps.yield_(yieldValues));
757 } else if (op instanceof VarAccessOp.VarLoadOp) {
758 // Replace with value loaded immediately before loop
759 Value v = op.operands().get(0);
760 if (capturedVars.get(false).contains(v)) {
761 block.context().mapValue(op.result(), loadValues.get(v));
762 } else {
763 block.op(op);
764 }
765 } else {
766 block = transformToTritonOperation(block, op, valueTypeMap, opData, fsymTable);
767 }
768 return block;
769 });
770 });
771 Op.Result forResult = kblock.op(scffor);
772
773 // Assign back result to iter vars
774 if (capturedAndStoredVars.size() == 1) {
775 for (Value v : capturedAndStoredVars) {
776 kblock.op(varStore(cc.getValue(v), forResult));
777 }
778 } else {
779 int i = 0;
780 for (Value v : capturedAndStoredVars) {
781 kblock.op(varStore(cc.getValue(v),
782 kblock.op(tupleLoad(forResult, i++))));
783 }
784 }
785 }
786
787 static Map<Boolean, Set<Value>> capturedVars(Body body) {
788 Map<Boolean, Set<Value>> capturedValues = new HashMap<>();
789 capturedValues.put(false, new LinkedHashSet<>());
790 capturedValues.put(true, new LinkedHashSet<>());
791
792 capturedVars(capturedValues, new ArrayDeque<>(), body);
793 return capturedValues;
794 }
795
796 static void capturedVars(Map<Boolean, Set<Value>> capturedVars, Deque<Body> bodyStack, Body body) {
797 bodyStack.push(body);
798
799 for (Block b : body.blocks()) {
800 for (Op op : b.ops()) {
801 // @@@ Nested bodies
802 if (!op.bodies().isEmpty()) {
803 throw new IllegalStateException();
804 }
805 // for (Body childBody : op.bodies()) {
806 // capturedAndUpdatedVars(capturedValues, bodyStack, childBody);
807 // }
808
809 if (op instanceof VarAccessOp) {
810 Value v = op.operands().get(0);
811 if (!bodyStack.contains(v.declaringBlock().ancestorBody())) {
812 if (op instanceof VarAccessOp.VarStoreOp) {
813 capturedVars.get(true).add(v);
814 capturedVars.get(false).remove(v);
815 } else if (!capturedVars.get(true).contains(v)) {
816 capturedVars.get(false).add(v);
817 }
818 }
819 }
820 }
821 }
822
823 bodyStack.pop();
824 }
825
826 public static final ScopedValue<Boolean> SV_SSA = ScopedValue.newInstance();
827
828 static TritonOps.FuncOp cleanup(TritonOps.FuncOp f) {
829 // Remove var ops
830 boolean doSSA = SV_SSA.isBound() ? SV_SSA.get() : true;
831 if (doSSA) {
832 f = SSA.transform(f);
833 }
834 // Remove unused ops
835 f = f.transform((fblock, op) -> {
836 if (op instanceof Op.Pure && op.result().uses().isEmpty()) {
837 return fblock;
838 } else if (op instanceof VarAccessOp.VarLoadOp && op.result().uses().isEmpty()) {
839 return fblock;
840 }
841
842 fblock.op(op);
843 return fblock;
844 });
845 return f;
846 }
847
848 static class TritonBuilderInterpreter {
849 final Map<String, TritonOps.FuncOp> fsymTable;
850 final Block.Builder block;
851
852 TritonBuilderInterpreter(Map<String, TritonOps.FuncOp> fsymTable, Block.Builder block) {
853 this.fsymTable = fsymTable;
854 this.block = block;
855 }
856
857 Value build(Op op, String name, Map<Value, TypeElement> valueTypeMap) {
858 // Obtain associated type-based method
859 MethodHandle mh;
860 try {
861 Optional<Method> om = Stream.of(TritonBuilderInterpreter.class.getDeclaredMethods())
862 .filter(m -> m.getName().equals(name))
863 .filter(m -> m.isVarArgs()
864 ? m.getParameterCount() / 2 - 1 <= op.operands().size()
865 : m.getParameterCount() / 2 - 1 == op.operands().size())
866 .findFirst();
867 mh = MethodHandles.lookup().unreflect(
868 om.orElseThrow(() -> new NoSuchMethodException(name)));
869 } catch (ReflectiveOperationException e) {
870 throw new IllegalStateException(e);
871 }
872
873 List<Object> iArgs = new ArrayList<>();
874 iArgs.add(this);
875 iArgs.add(valueTypeMap.get(op.result()));
876 iArgs.add(op.result());
877 for (Value o : op.operands()) {
878 iArgs.add(valueTypeMap.get(o));
879 iArgs.add(o);
880 }
881 try {
882 return (Value) mh.invokeWithArguments(iArgs.toArray(Object[]::new));
883 } catch (Throwable e) {
884 throw new IllegalStateException(e);
885 }
886 }
887
888
889 public Value programId(TypeElement rType, Op.Result r,
890 ConstantType axisType, Value axis) {
891 return block.op(TritonOps.getProgramId(
892 (int) axisType.value()));
893 }
894
895 public Value arange(TensorType rType, Op.Result r,
896 ConstantType startType, Value start,
897 ConstantType endType, Value end) {
898 return block.op(TritonOps.makeRange(
899 (int) startType.value(),
900 (int) endType.value()));
901 }
902
903 public Value expand(TensorType rType, Op.Result r,
904 TensorType aType, Value a,
905 ConstantType axisType, Value axis) {
906 return block.op(TritonOps.expand(
907 (int) axisType.value(),
908 rType,
909 block.context().getValue(a)));
910 }
911
912 public Value zeros(TensorType rType, Op.Result r,
913 ConstantType aType, Value a,
914 Object... constantsAndValues) {
915 Object zero;
916 try {
917 JavaType zeroType = (JavaType) aType.value();
918 zero = MethodHandles.zero((Class<?>) zeroType.resolve(MethodHandles.lookup())).invoke();
919 } catch (Throwable e) {
920 throw new RuntimeException(e);
921 }
922 return block.op(ArithMathOps.constant(rType, zero));
923 }
924
925 public Value load(TensorType rType, Op.Result r,
926 TensorType ptrType, Value ptr,
927 TensorType maskType, Value mask) {
928 broadcastConversionRight(ptrType, maskType, mask);
929 return block.op(TritonOps.load(
930 rType,
931 block.context().getValue(ptr),
932 block.context().getValue(mask)));
933 }
934
935 public Value load(TensorType rType, Op.Result r,
936 TensorType ptrType, Value ptr,
937 TensorType maskType, Value mask,
938 ConstantType otherType, Value other) {
939 broadcastConversionRight(ptrType, maskType, mask);
940 Value mb = block.op(ArithMathOps.constant(rType, (float)otherType.value()));
941 block.context().mapValue(other, mb);
942 return block.op(TritonOps.load(
943 rType,
944 block.context().getValue(ptr),
945 block.context().getValue(mask),
946 block.context().getValue(other)));
947 }
948
949 public Value store(TensorType rType, Op.Result r,
950 TensorType ptrType, Value ptr,
951 TensorType valueType, Value value,
952 TensorType maskType, Value mask) {
953 broadcastConversionRight(ptrType, valueType, value);
954 broadcastConversionRight(ptrType, maskType, mask);
955 return block.op(TritonOps.store(
956 block.context().getValue(ptr),
957 block.context().getValue(value),
958 block.context().getValue(mask)));
959 }
960
961 public Value broadcast(TensorType rType, Op.Result r,
962 TypeElement oType, Value o,
963 TensorType tensorTypeType, Value tensorType) {
964 // @@@ tt.splat with scalar operand, tt.broadcast with tensor operand
965 if (oType instanceof TensorType) {
966 return block.op(TritonOps.broadcast(
967 rType,
968 block.context().getValue(o)));
969 } else {
970 return block.op(TritonOps.splat(
971 rType,
972 block.context().getValue(o)));
973 }
974 }
975
976 public Value joinShape(TensorType rType, Op.Result r,
977 TensorType aType, Value a,
978 TensorType bType, Value b) {
979 // Replace with constant operation to produce tensor type.
980 // Result may be used, but transitively it will be removed due to no uses
981 // contributing to the computation
982 return block.op(CoreOp.constant(JavaType.type(TensorType.class), r.type()));
983 }
984
985
986 public Value add(TypeElement rType, Op.Result r,
987 TypeElement aType, Value a,
988 TypeElement bType, Value b) {
989 broadcastConversion(rType, aType, a, bType, b);
990 a = block.context().getValue(a);
991 b = block.context().getValue(b);
992
993 if (rType instanceof PtrType ||
994 rType instanceof TensorType t && t.eType() instanceof PtrType) {
995 return block.op(TritonOps.addptr(a, b));
996 } else {
997 return block.op(ArithMathOps.add(a, b));
998 }
999 }
1000
1001 public Value sub(TypeElement rType, Op.Result r,
1002 TypeElement aType, Value a,
1003 TypeElement bType, Value b) {
1004 broadcastConversion(rType, aType, a, bType, b);
1005 a = block.context().getValue(a);
1006 b = block.context().getValue(b);
1007
1008 return block.op(ArithMathOps.sub(a, b));
1009 }
1010
1011 public Value mul(TypeElement rType, Op.Result r,
1012 TypeElement aType, Value a,
1013 TypeElement bType, Value b) {
1014 broadcastConversion(rType, aType, a, bType, b);
1015 a = block.context().getValue(a);
1016 b = block.context().getValue(b);
1017
1018 return block.op(ArithMathOps.mul(a, b));
1019 }
1020
1021 public Value div(TypeElement rType, Op.Result r,
1022 TypeElement aType, Value a,
1023 TypeElement bType, Value b) {
1024 broadcastConversion(rType, aType, a, bType, b);
1025 a = block.context().getValue(a);
1026 b = block.context().getValue(b);
1027
1028 return block.op(ArithMathOps.div(a, b));
1029 }
1030
1031 public Value mod(TypeElement rType, Op.Result r,
1032 TypeElement aType, Value a,
1033 TypeElement bType, Value b) {
1034 broadcastConversion(rType, aType, a, bType, b);
1035 a = block.context().getValue(a);
1036 b = block.context().getValue(b);
1037
1038 return block.op(ArithMathOps.rem(a, b));
1039 }
1040
1041 public Value and(TypeElement rType, Op.Result r,
1042 TypeElement aType, Value a,
1043 TypeElement bType, Value b) {
1044 broadcastConversion(rType, aType, a, bType, b);
1045 a = block.context().getValue(a);
1046 b = block.context().getValue(b);
1047
1048 return block.op(ArithMathOps.and(a, b));
1049 }
1050
1051 public Value dot(TensorType rType, Op.Result r,
1052 TypeElement aType, Value a,
1053 TypeElement bType, Value b) {
1054 a = block.context().getValue(a);
1055 b = block.context().getValue(b);
1056 // Computed result is tensor of floats, regardless of inputs
1057 Object zero = 0.0f;
1058 var c = block.op(ArithMathOps.constant(rType, zero));
1059 return block.op(TritonOps.dot(rType, a, b, c));
1060 }
1061
1062 public Value cdiv(TypeElement rType, Op.Result r,
1063 TypeElement aType, Value a,
1064 TypeElement bType, Value b) {
1065 a = block.context().getValue(a);
1066 b = block.context().getValue(b);
1067
1068 TritonOps.FuncOp cdiv = tritonFunction(Functions.getJavaCodeModel("cdiv"),
1069 rType, List.of(aType, bType),
1070 fsymTable);
1071 // @@@ Generalize
1072 List<Value> args = new ArrayList<>();
1073 if (!(aType instanceof ConstantType)) {
1074 args.add(a);
1075 }
1076 if (!(bType instanceof ConstantType)) {
1077 args.add(b);
1078 }
1079 return block.op(TritonOps.call(cdiv, args));
1080 }
1081
1082 public Value conv(TypeElement rType, Op.Result r,
1083 ConstantType tType, Value t,
1084 TypeElement aType, Value a) {
1085 a = block.context().getValue(a);
1086
1087 TypeElement rScalarType;
1088 TypeElement aScalarType;
1089 if (rType instanceof TensorType rTensorType && aType instanceof TensorType aTensorType) {
1090 rScalarType = rTensorType.eType();
1091 aScalarType = aTensorType.eType();
1092 } else {
1093 rScalarType = rType;
1094 aScalarType = aType;
1095 }
1096
1097 if (rScalarType.equals(Float16.FLOAT_16_TYPE) && aScalarType.equals(JavaType.FLOAT)) {
1098 return block.op(ArithMathOps.trunc(rType, a));
1099 } else if (rType.equals(aType)) {
1100 return a;
1101 } else {
1102 throw new IllegalStateException();
1103 }
1104 }
1105
1106 public Value exp(TritonType rType, Op.Result r,
1107 TritonType aType, Value a) {
1108 return block.op(ArithMathOps.exp(
1109 block.context().getValue(a)));
1110 }
1111
1112 public Value compare(TensorType rType, Op.Result r,
1113 TypeElement aType, Value a,
1114 TypeElement bType, Value b,
1115 ConstantType compareType, Value compare) {
1116 Triton.CompareKind ck = (Triton.CompareKind) compareType.value();
1117
1118 ArithMathOps.CompareOp.CompareKind ack = switch (ck) {
1119 case LessThan -> ArithMathOps.CompareOp.CompareKind.slt;
1120 default -> throw new UnsupportedOperationException("Unsupported comparison: " + ck);
1121 };
1122
1123 broadcastConversionRight(aType, bType, b);
1124 a = block.context().getValue(a);
1125 b = block.context().getValue(b);
1126
1127 return block.op(ArithMathOps.cmp(ack, a, b));
1128 }
1129
1130
1131 public Value max(TypeElement rType, Op.Result r,
1132 TensorType xType, Value x,
1133 ConstantType axisType, Value axis) {
1134 TritonOps.FuncOp f = tritonFunction(Functions.getJavaCodeModel("max"),
1135 rType, List.of(rType, rType), fsymTable);
1136 return reduce(rType, r, xType, x, axisType, axis, f);
1137 }
1138
1139 public Value sum(TypeElement rType, Op.Result r,
1140 TensorType xType, Value x,
1141 ConstantType axisType, Value axis) {
1142 TritonOps.FuncOp f = tritonFunction(Functions.getJavaCodeModel("sum"),
1143 rType, List.of(rType, rType), fsymTable);
1144 return reduce(rType, r, xType, x, axisType, axis, f);
1145 }
1146
1147 Value reduce(TypeElement rType, Op.Result r,
1148 TensorType xType, Value x,
1149 ConstantType axisType, Value axis,
1150 TritonOps.FuncOp f) {
1151 int axisConstant = (int) axisType.value();
1152
1153 String signature = "reduce_" + f.funcName() + "_" + axisConstant;
1154 TritonOps.FuncOp rf = fsymTable.computeIfAbsent(signature,
1155 s -> reduce(rType, xType, axisConstant, s, f));
1156
1157 return block.op(TritonOps.call(rf, block.context().getValue(x)));
1158 }
1159
1160 static TritonOps.FuncOp reduce(TypeElement elementType,
1161 TensorType tensorType,
1162 int axisConstant,
1163 String name, TritonOps.FuncOp scalarFunc) {
1164 return TritonOps.func(name,
1165 functionType(elementType, tensorType))
1166 .body(fblock -> {
1167 TritonOps.ReduceOp reduceOp = TritonOps.reduce(fblock.parentBody(),
1168 axisConstant, fblock.parameters().get(0),
1169 functionType(elementType, elementType, elementType))
1170 .body(rblock -> {
1171 Block.Parameter a = rblock.parameters().get(0);
1172 Block.Parameter b = rblock.parameters().get(1);
1173 Op.Result _r = rblock.op(TritonOps.call(scalarFunc, a, b));
1174 rblock.op(TritonOps.reduceReturn(_r));
1175 });
1176
1177 Op.Result opr = fblock.op(reduceOp);
1178 fblock.op(TritonOps.return_(opr));
1179 });
1180 }
1181
1182 // @@@ Test
1183 public Value consume(TypeElement rType, Op.Result r,
1184 TypeElement aType, Value a) {
1185 return block.op(TritonTestOps.consume(block.context().getValue(a)));
1186 }
1187
1188 void broadcastConversion(TypeElement rType,
1189 TypeElement aType, Value a,
1190 TypeElement bType, Value b) {
1191 Value ma = block.context().getValue(a);
1192 Value mb = block.context().getValue(b);
1193 if (aType instanceof TensorType at && bType instanceof TensorType bTensorType) {
1194 TensorType rTensorType = (TensorType) rType;
1195 if (!at.shape().equals(rTensorType.shape())) {
1196 ma = block.op(TritonOps.broadcast(rTensorType, ma));
1197 }
1198 if (!bTensorType.shape().equals(rTensorType.shape())) {
1199 if (rTensorType.eType() instanceof PtrType) {
1200 bTensorType = new TensorType(bType, rTensorType.shape());
1201 mb = block.op(TritonOps.broadcast(bTensorType, mb));
1202 } else {
1203 mb = block.op(TritonOps.broadcast(rTensorType, mb));
1204 }
1205 }
1206 } else if (aType instanceof TensorType) {
1207 TensorType rTensorType = (TensorType) rType;
1208 if (rTensorType.eType() instanceof PtrType) {
1209 TensorType bTensorType = new TensorType(bType, rTensorType.shape());
1210 mb = block.op(TritonOps.splat(bTensorType, mb));
1211 } else {
1212 mb = block.op(TritonOps.splat(rTensorType, mb));
1213 }
1214 } else if (bType instanceof TensorType) {
1215 TensorType rTensorType = (TensorType) rType;
1216 ma = block.op(TritonOps.splat(rTensorType, ma));
1217 }
1218 block.context().mapValue(a, ma);
1219 block.context().mapValue(b, mb);
1220 }
1221
1222 void broadcastConversionRight(TypeElement aType,
1223 TypeElement bType, Value b) {
1224 Value mb = block.context().getValue(b);
1225 if (aType instanceof TensorType aTensorType && bType instanceof TensorType bTensorType) {
1226 if (!bTensorType.shape().equals(aTensorType.shape())) {
1227 if (aTensorType.eType() instanceof PtrType) {
1228 bTensorType = new TensorType(bTensorType.eType(), aTensorType.shape());
1229 mb = block.op(TritonOps.broadcast(bTensorType, mb));
1230 } else {
1231 mb = block.op(TritonOps.broadcast(aTensorType, mb));
1232 }
1233 }
1234 } else if (aType instanceof TensorType rTensorType) {
1235 if (rTensorType.eType() instanceof PtrType) {
1236 TensorType bTensorType = new TensorType(bType, rTensorType.shape());
1237 mb = block.op(TritonOps.splat(bTensorType, mb));
1238 } else {
1239 mb = block.op(TritonOps.splat(rTensorType, mb));
1240 }
1241 }
1242 block.context().mapValue(b, mb);
1243 }
1244 }
1245
1246 public static <O extends Op & Op.Invokable> void printTypeMap(
1247 O kernel, Map<Value, TypeElement> valueTypeMap) {
1248 AtomicInteger valueId = new AtomicInteger();
1249 Map<Value, Integer> valueIdMap = new LinkedHashMap<>();
1250 kernel.traverse(null, (o, codeElement) -> {
1251 switch (codeElement) {
1252 case FuncOp _ -> {
1253 // Ignore
1254 }
1255 case Op op when !op.result().type().equals(JavaType.VOID) -> {
1256 valueIdMap.put(op.result(), valueId.getAndIncrement());
1257 }
1258 case Block block -> {
1259 for (Block.Parameter parameter : block.parameters()) {
1260 valueIdMap.put(parameter, valueId.getAndIncrement());
1261 }
1262 }
1263 default -> {
1264 }
1265 }
1266 return null;
1267 });
1268
1269 valueIdMap.forEach((value, id) -> {
1270 TypeElement type = valueTypeMap.get(value);
1271 if (type != null) {
1272 System.out.println("%" + id + " : " + value.type() + " -> " + type);
1273 }
1274 });
1275 }
1276 }