1 /*
2 * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation.
8 *
9 * This code is distributed in the hope that it will be useful, but WITHOUT
10 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
11 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
12 * version 2 for more details (a copy is included in the LICENSE file that
13 * accompanied this code).
14 *
15 * You should have received a copy of the GNU General Public License version
16 * 2 along with this work; if not, write to the Free Software Foundation,
17 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
18 *
19 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
20 * or visit www.oracle.com if you need additional information or have any
21 * questions.
22 */
23
24 package oracle.code.triton;
25
26 import oracle.code.triton.TritonTestExtension.TritonTestData;
27 import org.junit.jupiter.api.Test;
28 import org.junit.jupiter.api.extension.ExtendWith;
29
30 import jdk.incubator.code.TypeElement;
31 import jdk.incubator.code.dialect.java.JavaType;
32 import jdk.incubator.code.CodeReflection;
33 import java.util.List;
34
35 import static oracle.code.triton.Triton.*;
36 import static oracle.code.triton.TritonTest.consume;
37
38 @ExtendWith(TritonTestExtension.class)
39 public class TestBroadcast {
40
41 @TritonCodeModel("""
42 module ()java.type:"void" -> {
43 tt.func @"test1_ptr<java.type.primitive<int>>_int_64_void" (%0 : ptr<java.type:"int">, %1 : java.type:"int")java.type:"void" -> {
44 %2 : tensor<x64, java.type:"int"> = tt.make_range @start=0 @end=64;
45 %3 : tensor<x64, ptr<java.type:"int">> = tt.splat %0;
46 %4 : tensor<x64, ptr<java.type:"int">> = tt.addptr %3 %2;
47 tt.consume %4;
48 %5 : tensor<x64, java.type:"int"> = tt.splat %1;
49 %6 : tensor<x64, java.type:"int"> = arith.addi %5 %2;
50 tt.consume %6;
51 %7 : tensor<x64, java.type:"int"> = tt.splat %1;
52 %8 : tensor<x64, java.type:"int"> = arith.addi %2 %7;
53 tt.consume %8;
54 %9 : tensor<x64, java.type:"int"> = tt.splat %1;
55 %10 : tensor<x64, java.type:"int"> = arith.addi %9 %2;
56 tt.consume %10;
57 %11 : tensor<x64, java.type:"int"> = tt.splat %1;
58 %12 : tensor<x64, java.type:"int"> = arith.addi %2 %11;
59 tt.consume %12;
60 tt.return;
61 };
62 unreachable;
63 };
64 """)
65 @CodeReflection
66 static void test1(Ptr ptr, int a, @Constant int s) {
67 var t = arange(0, s);
68 consume(add(ptr, t));
69 consume(add(a, t));
70 consume(add(t, a));
71 consume(add(broadcast(a, t.type()), t));
72 consume(add(t, broadcast(a, t.type())));
73 }
74
75 @Test
76 public void test1(TritonTestData t) {
77 List<TypeElement> argTypes = List.of(
78 new PtrType(JavaType.INT),
79 JavaType.INT,
80 new ConstantType(JavaType.INT, 64));
81
82 t.test(argTypes);
83 }
84
85 @TritonCodeModel("""
86 module ()java.type:"void" -> {
87 tt.func @"test2_int_64_32_void" (%1 : java.type:"int")java.type:"void" -> {
88 %2 : tensor<x64, java.type:"int"> = tt.make_range @start=0 @end=64;
89 %3 : tensor<x1, x64, java.type:"int"> = tt.expand_dims %2 @0;
90 %4 : tensor<x32, java.type:"int"> = tt.make_range @start=0 @end=32;
91 %5 : tensor<x32, x1, java.type:"int"> = tt.expand_dims %4 @1;
92 %6 : tensor<x1, x64, java.type:"int"> = tt.splat %1;
93 %7 : tensor<x1, x64, java.type:"int"> = arith.addi %3 %6;
94 tt.consume %7;
95 %8 : tensor<x32, x64, java.type:"int"> = tt.broadcast %3;
96 %9 : tensor<x32, x64, java.type:"int"> = tt.broadcast %5;
97 %10 : tensor<x32, x64, java.type:"int"> = arith.addi %8 %9;
98 tt.consume %10;
99 tt.return;
100 };
101 unreachable;
102 };
103 """)
104 @CodeReflection
105 static void test2(int a, @Constant int M, @Constant int N) {
106 var m = arange(0, M);
107 var me = expand(m, 0);
108
109 var n = arange(0, N);
110 var ne = expand(n, 1);
111
112 var t4 = add(me, a);
113 consume(t4);
114
115 var t3 = add(me, ne);
116 consume(t3);
117 }
118
119 @Test
120 public void test2(TritonTestData t) {
121 List<TypeElement> argTypes = List.of(
122 JavaType.INT,
123 new ConstantType(JavaType.INT, 64),
124 new ConstantType(JavaType.INT, 32)
125 );
126
127 t.test(argTypes);
128 }
129 }