1 /*
2 * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation.
8 *
9 * This code is distributed in the hope that it will be useful, but WITHOUT
10 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
11 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
12 * version 2 for more details (a copy is included in the LICENSE file that
13 * accompanied this code).
14 *
15 * You should have received a copy of the GNU General Public License version
16 * 2 along with this work; if not, write to the Free Software Foundation,
17 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
18 *
19 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
20 * or visit www.oracle.com if you need additional information or have any
21 * questions.
22 */
23
24 package oracle.code.triton;
25
26 import org.junit.jupiter.api.Test;
27 import org.junit.jupiter.api.extension.ExtendWith;
28
29 import jdk.incubator.code.TypeElement;
30 import jdk.incubator.code.dialect.java.JavaType;
31 import jdk.incubator.code.CodeReflection;
32 import java.util.List;
33
34 import static oracle.code.triton.Triton.*;
35 import static oracle.code.triton.Triton.CompareKind.*;
36
37 @ExtendWith(TritonTestExtension.class)
38 public class TestMatrixFp16 {
39
40 @TritonCodeModel("""
41 module ()java.type:"void" -> {
42 tt.func @"cdiv_int_32_int" (%0 : java.type:"int")java.type:"int" -> {
43 %1 : java.type:"int" = arith.constant @32;
44 %2 : java.type:"int" = arith.addi %0 %1;
45 %3 : java.type:"int" = arith.constant @1;
46 %4 : java.type:"int" = arith.subi %2 %3;
47 %5 : java.type:"int" = arith.divsi %4 %1;
48 tt.return %5;
49 };
50 tt.func @"cdiv_int_64_int" (%6 : java.type:"int")java.type:"int" -> {
51 %7 : java.type:"int" = arith.constant @64;
52 %8 : java.type:"int" = arith.addi %6 %7;
53 %9 : java.type:"int" = arith.constant @1;
54 %10 : java.type:"int" = arith.subi %8 %9;
55 %11 : java.type:"int" = arith.divsi %10 %7;
56 tt.return %11;
57 };
58 tt.func @sym_name="matmul_kernel_fp16_ptr<java.type.class<oracle.code.triton.Float16, java.type.primitive<void>>>_ptr<java.type.class<oracle.code.triton.Float16, java.type.primitive<void>>>_ptr<java.type.class<oracle.code.triton.Float16, java.type.primitive<void>>>_int_int_int_int_1_int_1_int_1_32_64_32_8_false_void" (%12 : ptr<java.type:"oracle.code.triton.Float16">, %13 : ptr<java.type:"oracle.code.triton.Float16">, %14 : ptr<java.type:"oracle.code.triton.Float16">, %15 : java.type:"int", %16 : java.type:"int", %17 : java.type:"int", %18 : java.type:"int", %19 : java.type:"int", %20 : java.type:"int")java.type:"void" -> {
59 %21 : java.type:"int" = arith.constant @value=1;
60 %22 : java.type:"int" = arith.constant @value=1;
61 %23 : java.type:"int" = arith.constant @value=1;
62 %24 : java.type:"int" = arith.constant @value=32;
63 %25 : java.type:"int" = arith.constant @value=64;
64 %26 : java.type:"int" = arith.constant @value=32;
65 %27 : java.type:"int" = arith.constant @value=8;
66 %28 : java.type:"int" = tt.get_program_id @axis=0;
67 %29 : java.type:"int" = tt.call %15 @callee="cdiv_int_32_int";
68 %30 : java.type:"int" = tt.call %16 @callee="cdiv_int_64_int";
69 %31 : java.type:"int" = arith.muli %27 %30;
70 %32 : java.type:"int" = arith.divsi %28 %31;
71 %33 : java.type:"int" = arith.muli %32 %27;
72 %34 : java.type:"int" = arith.subi %29 %33;
73 %35 : java.type:"int" = arith.minsi %34 %27;
74 %36 : java.type:"int" = arith.remsi %28 %35;
75 %37 : java.type:"int" = arith.addi %33 %36;
76 %38 : java.type:"int" = arith.remsi %28 %31;
77 %39 : java.type:"int" = arith.divsi %38 %35;
78 %40 : tensor<x32, java.type:"int"> = tt.make_range @start=0 @end=32;
79 %41 : java.type:"int" = arith.muli %37 %24;
80 %42 : tensor<x32, java.type:"int"> = tt.splat %41;
81 %43 : tensor<x32, java.type:"int"> = arith.addi %42 %40;
82 %44 : tensor<x32, java.type:"int"> = tt.splat %15;
83 %45 : tensor<x32, java.type:"int"> = arith.remsi %43 %44;
84 %46 : tensor<x64, java.type:"int"> = tt.make_range @start=0 @end=64;
85 %47 : java.type:"int" = arith.muli %39 %25;
86 %48 : tensor<x64, java.type:"int"> = tt.splat %47;
87 %49 : tensor<x64, java.type:"int"> = arith.addi %48 %46;
88 %50 : tensor<x64, java.type:"int"> = tt.splat %16;
89 %51 : tensor<x64, java.type:"int"> = arith.remsi %49 %50;
90 %52 : tensor<x32, java.type:"int"> = tt.make_range @start=0 @end=32;
91 %53 : tensor<x32, x1, java.type:"int"> = tt.expand_dims %45 @axis=1;
92 %54 : tensor<x32, x1, java.type:"int"> = tt.splat %18;
93 %55 : tensor<x32, x1, java.type:"int"> = arith.muli %53 %54;
94 %56 : tensor<x1, x32, java.type:"int"> = tt.expand_dims %52 @axis=0;
95 %57 : tensor<x1, x32, java.type:"int"> = tt.splat %21;
96 %58 : tensor<x1, x32, java.type:"int"> = arith.muli %56 %57;
97 %59 : tensor<x32, x32, java.type:"int"> = tt.broadcast %55;
98 %60 : tensor<x32, x32, java.type:"int"> = tt.broadcast %58;
99 %61 : tensor<x32, x32, java.type:"int"> = arith.addi %59 %60;
100 %62 : tensor<x32, x32, ptr<java.type:"oracle.code.triton.Float16">> = tt.splat %12;
101 %63 : tensor<x32, x32, ptr<java.type:"oracle.code.triton.Float16">> = tt.addptr %62 %61;
102 %64 : tensor<x32, x1, java.type:"int"> = tt.expand_dims %52 @axis=1;
103 %65 : tensor<x32, x1, java.type:"int"> = tt.splat %19;
104 %66 : tensor<x32, x1, java.type:"int"> = arith.muli %64 %65;
105 %67 : tensor<x1, x64, java.type:"int"> = tt.expand_dims %51 @axis=0;
106 %68 : tensor<x1, x64, java.type:"int"> = tt.splat %22;
107 %69 : tensor<x1, x64, java.type:"int"> = arith.muli %67 %68;
108 %70 : tensor<x32, x64, java.type:"int"> = tt.broadcast %66;
109 %71 : tensor<x32, x64, java.type:"int"> = tt.broadcast %69;
110 %72 : tensor<x32, x64, java.type:"int"> = arith.addi %70 %71;
111 %73 : tensor<x32, x64, ptr<java.type:"oracle.code.triton.Float16">> = tt.splat %13;
112 %74 : tensor<x32, x64, ptr<java.type:"oracle.code.triton.Float16">> = tt.addptr %73 %72;
113 %75 : tensor<x32, x64, java.type:"float"> = arith.constant @value=0.0f;
114 %76 : java.type:"int" = arith.constant @value=0;
115 %77 : java.type:"int" = tt.call %17 @callee="cdiv_int_32_int";
116 %78 : java.type:"int" = arith.constant @value=1;
117 %79 : Tuple<tensor<x32, x64, java.type:"float">, tensor<x32, x32, ptr<java.type:"oracle.code.triton.Float16">>, tensor<x32, x64, ptr<java.type:"oracle.code.triton.Float16">>> = scf.for %76 %77 %78 %75 %63 %74 (%80 : java.type:"int", %81 : tensor<x32, x64, java.type:"float">, %82 : tensor<x32, x32, ptr<java.type:"oracle.code.triton.Float16">>, %83 : tensor<x32, x64, ptr<java.type:"oracle.code.triton.Float16">>)Tuple<tensor<x32, x64, java.type:"float">, tensor<x32, x32, ptr<java.type:"oracle.code.triton.Float16">>, tensor<x32, x64, ptr<java.type:"oracle.code.triton.Float16">>> -> {
118 %84 : tensor<x1, x32, java.type:"int"> = tt.expand_dims %52 @axis=0;
119 %85 : java.type:"int" = arith.muli %80 %26;
120 %86 : java.type:"int" = arith.subi %17 %85;
121 %87 : tensor<x1, x32, java.type:"int"> = tt.splat %86;
122 %88 : tensor<x1, x32, java.type:"boolean"> = arith.cmpi %84 %87 @predicate="slt";
123 %89 : tensor<x32, x32, java.type:"boolean"> = tt.broadcast %88;
124 %90 : tensor<x32, x32, java.type:"oracle.code.triton.Float16"> = arith.constant @value=0.0f;
125 %91 : tensor<x32, x32, java.type:"oracle.code.triton.Float16"> = tt.load %82 %89 %90;
126 %92 : tensor<x32, x1, java.type:"int"> = tt.expand_dims %52 @axis=1;
127 %93 : java.type:"int" = arith.muli %80 %26;
128 %94 : java.type:"int" = arith.subi %17 %93;
129 %95 : tensor<x32, x1, java.type:"int"> = tt.splat %94;
130 %96 : tensor<x32, x1, java.type:"boolean"> = arith.cmpi %92 %95 @predicate="slt";
131 %97 : tensor<x32, x64, java.type:"boolean"> = tt.broadcast %96;
132 %98 : tensor<x32, x64, java.type:"oracle.code.triton.Float16"> = arith.constant @value=0.0f;
133 %99 : tensor<x32, x64, java.type:"oracle.code.triton.Float16"> = tt.load %83 %97 %98;
134 %100 : tensor<x32, x64, java.type:"float"> = arith.constant @value=0.0f;
135 %101 : tensor<x32, x64, java.type:"float"> = tt.dot %91 %99 %100;
136 %102 : tensor<x32, x64, java.type:"float"> = arith.addf %81 %101;
137 %103 : java.type:"int" = arith.muli %26 %21;
138 %104 : tensor<x32, x32, java.type:"int"> = tt.splat %103;
139 %105 : tensor<x32, x32, ptr<java.type:"oracle.code.triton.Float16">> = tt.addptr %82 %104;
140 %106 : java.type:"int" = arith.muli %26 %19;
141 %107 : tensor<x32, x64, java.type:"int"> = tt.splat %106;
142 %108 : tensor<x32, x64, ptr<java.type:"oracle.code.triton.Float16">> = tt.addptr %83 %107;
143 scf.yield %102 %105 %108;
144 };
145 %109 : tensor<x32, x64, java.type:"float"> = tuple.load %79 @0;
146 %110 : tensor<x32, x32, ptr<java.type:"oracle.code.triton.Float16">> = tuple.load %79 @1;
147 %111 : tensor<x32, x64, ptr<java.type:"oracle.code.triton.Float16">> = tuple.load %79 @2;
148 %112 : tensor<x32, x64, java.type:"oracle.code.triton.Float16"> = arith.truncf %109;
149 %113 : java.type:"int" = arith.muli %37 %24;
150 %114 : tensor<x32, java.type:"int"> = tt.splat %113;
151 %115 : tensor<x32, java.type:"int"> = arith.addi %114 %40;
152 %116 : java.type:"int" = arith.muli %39 %25;
153 %117 : tensor<x64, java.type:"int"> = tt.splat %116;
154 %118 : tensor<x64, java.type:"int"> = arith.addi %117 %46;
155 %119 : tensor<x32, x1, java.type:"int"> = tt.expand_dims %115 @axis=1;
156 %120 : tensor<x32, x1, java.type:"int"> = tt.splat %20;
157 %121 : tensor<x32, x1, java.type:"int"> = arith.muli %120 %119;
158 %122 : tensor<x1, x64, java.type:"int"> = tt.expand_dims %118 @axis=0;
159 %123 : tensor<x1, x64, java.type:"int"> = tt.splat %23;
160 %124 : tensor<x1, x64, java.type:"int"> = arith.muli %123 %122;
161 %125 : tensor<x32, x64, java.type:"int"> = tt.broadcast %121;
162 %126 : tensor<x32, x64, java.type:"int"> = tt.broadcast %124;
163 %127 : tensor<x32, x64, java.type:"int"> = arith.addi %125 %126;
164 %128 : tensor<x32, x64, ptr<java.type:"oracle.code.triton.Float16">> = tt.splat %14;
165 %129 : tensor<x32, x64, ptr<java.type:"oracle.code.triton.Float16">> = tt.addptr %128 %127;
166 %130 : tensor<x32, x1, java.type:"int"> = tt.expand_dims %115 @axis=1;
167 %131 : tensor<x32, x1, java.type:"int"> = tt.splat %15;
168 %132 : tensor<x32, x1, java.type:"boolean"> = arith.cmpi %130 %131 @predicate="slt";
169 %133 : tensor<x1, x64, java.type:"int"> = tt.expand_dims %118 @axis=0;
170 %134 : tensor<x1, x64, java.type:"int"> = tt.splat %16;
171 %135 : tensor<x1, x64, java.type:"boolean"> = arith.cmpi %133 %134 @predicate="slt";
172 %136 : tensor<x32, x64, java.type:"boolean"> = tt.broadcast %132;
173 %137 : tensor<x32, x64, java.type:"boolean"> = tt.broadcast %135;
174 %138 : tensor<x32, x64, java.type:"boolean"> = arith.andi %136 %137;
175 tt.store %129 %112 %138;
176 tt.return;
177 };
178 unreachable;
179 };
180 """)
181 @CodeReflection
182 static void matmul_kernel_fp16(
183 // Pointers to matrices
184 Ptr a_ptr, Ptr b_ptr, Ptr c_ptr,
185 // Matrix dimensions
186 int M, int N, int K,
187 // The stride variables represent how much to increase the ptr by when moving by 1
188 // element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr`
189 // by to get the element one row down (A has M rows).
190 int stride_am, @Constant int stride_ak,
191 int stride_bk, @Constant int stride_bn,
192 int stride_cm, @Constant int stride_cn,
193 // Meta-parameters
194 @Constant int BLOCK_SIZE_M, @Constant int BLOCK_SIZE_N, @Constant int BLOCK_SIZE_K,
195 @Constant int GROUP_SIZE_M,
196 @Constant boolean ACTIVATION) {
197
198 // """Kernel for computing the matmul C = A x B.
199 // A has shape (M, K), B has shape (K, N) and C has shape (M, N)
200 // """
201 // -----------------------------------------------------------
202 // Map program ids `pid` to the block of C it should compute.
203 // This is done in a grouped ordering to promote L2 data reuse.
204 // See above `L2 Cache Optimizations` section for details.
205 var pid = programId(0);
206 var num_pid_m = cdiv(M, BLOCK_SIZE_M);
207 var num_pid_n = cdiv(N, BLOCK_SIZE_N);
208 var num_pid_in_group = GROUP_SIZE_M * num_pid_n;
209 var group_id = pid / num_pid_in_group;
210 var first_pid_m = group_id * GROUP_SIZE_M;
211 var group_size_m = Math.min(num_pid_m - first_pid_m, GROUP_SIZE_M);
212 var pid_m = first_pid_m + (pid % group_size_m);
213 var pid_n = (pid % num_pid_in_group) / group_size_m;
214
215 // ----------------------------------------------------------
216 // Create pointers for the first blocks of A and B.
217 // We will advance this pointer as we move in the K direction
218 // and accumulate
219 // `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
220 // `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
221 // See above `Pointer Arithmetics` section for details
222 var offs_m = arange(0, BLOCK_SIZE_M);
223 var offs_am = mod(add(pid_m * BLOCK_SIZE_M, offs_m), M);
224 var offs_n = arange(0, BLOCK_SIZE_N);
225 var offs_bn = mod(add(pid_n * BLOCK_SIZE_N, offs_n), N);
226 var offs_k = arange(0, BLOCK_SIZE_K);
227 var a_ptrs = add(a_ptr, add(
228 mul(expand(offs_am, 1), stride_am),
229 mul(expand(offs_k, 0), stride_ak)));
230 var b_ptrs = add(b_ptr, add(
231 mul(expand(offs_k, 1), stride_bk),
232 mul(expand(offs_bn, 0), stride_bn)));
233
234 // -----------------------------------------------------------
235 // Iterate to compute a block of the C matrix.
236 // We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
237 // of fp32 values for higher accuracy.
238 // `accumulator` will be converted back to fp16 after the loop.
239 var accumulator = zeros(float.class, BLOCK_SIZE_M, BLOCK_SIZE_N);
240 for (int k = 0; k < cdiv(K, BLOCK_SIZE_K); k++) {
241 // Load the next block of A and B, generate a mask by checking the K dimension.
242 // If it is out of bounds, set it to 0.
243 var a = load(a_ptrs,
244 compare(expand(offs_k, 0), K - k * BLOCK_SIZE_K, LessThan), 0f);
245 var b = load(b_ptrs,
246 compare(expand(offs_k, 1), K - k * BLOCK_SIZE_K, LessThan), 0f);
247 // We accumulate along the K dimension.
248 accumulator = add(accumulator, dot(a, b));
249 // Advance the ptrs to the next K block.
250 a_ptrs = add(a_ptrs, BLOCK_SIZE_K * stride_ak);
251 b_ptrs = add(b_ptrs, BLOCK_SIZE_K * stride_bk);
252 }
253
254 // You can fuse arbitrary activation functions here
255 // while the accumulator is still in FP32!
256 // if (ACTIVATION) {
257 // // ...
258 // }
259 var c = Triton.conv(Float16.class, accumulator);
260
261 // -----------------------------------------------------------
262 // Write back the block of the output matrix C with masks.
263 var offs_cm = add(pid_m * BLOCK_SIZE_M, offs_m);
264 var offs_cn = add(pid_n * BLOCK_SIZE_N, offs_n);
265 var c_ptrs = add(c_ptr, add(
266 mul(stride_cm, expand(offs_cm, 1)),
267 mul(stride_cn, expand(offs_cn, 0))));
268 var c_mask = and(
269 compare(expand(offs_cm, 1), M, LessThan),
270 compare(expand(offs_cn, 0), N, LessThan));
271 store(c_ptrs, c, c_mask);
272 }
273
274 @TritonTestExtension.Kernel("matmul_kernel_fp16")
275 @Test
276 public void test(TritonTestExtension.TritonTestData t) {
277 List<TypeElement> argTypes = List.of(
278 new PtrType(Float16.FLOAT_16_TYPE),
279 new PtrType(Float16.FLOAT_16_TYPE),
280 new PtrType(Float16.FLOAT_16_TYPE),
281 JavaType.INT, JavaType.INT, JavaType.INT,
282 JavaType.INT, new ConstantType(JavaType.INT, 1),
283 JavaType.INT, new ConstantType(JavaType.INT, 1),
284 JavaType.INT, new ConstantType(JavaType.INT, 1),
285 new ConstantType(JavaType.INT, 32), new ConstantType(JavaType.INT, 64), new ConstantType(JavaType.INT, 32),
286 new ConstantType(JavaType.INT, 8),
287 new ConstantType(JavaType.INT, false));
288
289 t.test(argTypes);
290 }
291
292 }
293
294 /*
295 @triton.jit
296 def matmul_kernel(
297 # Pointers to matrices
298 a_ptr, b_ptr, c_ptr,
299 # Matrix dimensions
300 M, N, K,
301 # The stride variables represent how much to increase the ptr by when moving by 1
302 # element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr`
303 # by to get the element one row down (A has M rows).
304 stride_am, stride_ak, #
305 stride_bk, stride_bn, #
306 stride_cm, stride_cn,
307 # Meta-parameters
308 BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, #
309 GROUP_SIZE_M: tl.constexpr, #
310 ACTIVATION: tl.constexpr #
311 ):
312 """Kernel for computing the matmul C = A x B.
313 A has shape (M, K), B has shape (K, N) and C has shape (M, N)
314 """
315 # -----------------------------------------------------------
316 # Map program ids `pid` to the block of C it should compute.
317 # This is done in a grouped ordering to promote L2 data reuse.
318 # See above `L2 Cache Optimizations` section for details.
319 pid = tl.program_id(axis=0)
320 num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
321 num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
322 num_pid_in_group = GROUP_SIZE_M * num_pid_n
323 group_id = pid // num_pid_in_group
324 first_pid_m = group_id * GROUP_SIZE_M
325 group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
326 pid_m = first_pid_m + (pid % group_size_m)
327 pid_n = (pid % num_pid_in_group) // group_size_m
328
329 # ----------------------------------------------------------
330 # Create pointers for the first blocks of A and B.
331 # We will advance this pointer as we move in the K direction
332 # and accumulate
333 # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
334 # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
335 # See above `Pointer Arithmetics` section for details
336 offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
337 offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
338 offs_k = tl.arange(0, BLOCK_SIZE_K)
339 a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
340 b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
341
342 # -----------------------------------------------------------
343 # Iterate to compute a block of the C matrix.
344 # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
345 # of fp32 values for higher accuracy.
346 # `accumulator` will be converted back to fp16 after the loop.
347 accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
348 for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
349 # Load the next block of A and B, generate a mask by checking the K dimension.
350 # If it is out of bounds, set it to 0.
351 a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
352 b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
353 # We accumulate along the K dimension.
354 accumulator += tl.dot(a, b)
355 # Advance the ptrs to the next K block.
356 a_ptrs += BLOCK_SIZE_K * stride_ak
357 b_ptrs += BLOCK_SIZE_K * stride_bk
358 # You can fuse arbitrary activation functions here
359 # while the accumulator is still in FP32!
360 if ACTIVATION == "leaky_relu":
361 accumulator = leaky_relu(accumulator)
362 c = accumulator.to(tl.float16)
363
364 # -----------------------------------------------------------
365 # Write back the block of the output matrix C with masks.
366 offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
367 offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
368 c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
369 c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
370 tl.store(c_ptrs, c, mask=c_mask)
371
372
373 # We can fuse `leaky_relu` by providing it as an `ACTIVATION` meta-parameter in `_matmul`.
374 @triton.jit
375 def leaky_relu(x):
376 x = x + 1
377 return tl.where(x >= 0, x, 0.01 * x)
378 */
379
380 /*
381
382 triton/python/triton/tools/compile.py \
383 --kernel-name matmul_kernel \
384 --signature "*fp16,*fp16,*fp16,i32,i32,i32,i32,i32,i32,i32,i32,i32,32,64,32,8,0" \
385 --grid=1024,1024,1024 \
386 03-matrix-multiplication.py
387
388 BLOCK_SIZE_M = 32
389 BLOCK_SIZE_N = 64
390 BLOCK_SIZE_K = 32
391 GROUP_SIZE_M = 8
392 ACTIVATION = 0
393
394 module {
395 tt.func public @matmul_kernel_01234567891011(
396 %arg0: !tt.ptr<f16, 1>, %arg1: !tt.ptr<f16, 1>, %arg2: !tt.ptr<f16, 1> ,
397 %arg3: i32, %arg4: i32, %arg5: i32 ,
398 %arg6: i32, %arg7: i32, %arg8: i32 ,
399 %%arg9: i32, %arg10: i32, %arg11: i32 ) attributes {noinline = false} {
400 %0 = tt.get_program_id x : i32
401 %1 = tt.call @cdiv__i32__1cconstexpr_32_(%arg3) : (i32) -> i32
402 %2 = tt.call @cdiv__i32__1cconstexpr_64_(%arg4) : (i32) -> i32
403 %c8_i32 = arith.constant 8 : i32
404 %3 = arith.muli %2, %c8_i32 : i32
405 %4 = arith.divsi %0, %3 : i32
406 %c8_i32_0 = arith.constant 8 : i32
407 %5 = arith.muli %4, %c8_i32_0 : i32
408 %6 = arith.subi %1, %5 : i32
409 %7 = tt.call @minimum__i32__1cconstexpr_8_(%6) : (i32) -> i32
410 %8 = arith.remsi %0, %7 : i32
411 %9 = arith.addi %5, %8 : i32
412 %10 = arith.remsi %0, %3 : i32
413 %11 = arith.divsi %10, %7 : i32
414 %c32_i32 = arith.constant 32 : i32
415 %12 = arith.muli %9, %c32_i32 : i32
416 %13 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32>
417 %14 = tt.splat %12 : (i32) -> tensor<32xi32>
418 %15 = arith.addi %14, %13 : tensor<32xi32>
419 %16 = tt.splat %arg3 : (i32) -> tensor<32xi32>
420 %17 = arith.remsi %15, %16 : tensor<32xi32>
421 %c64_i32 = arith.constant 64 : i32
422 %18 = arith.muli %11, %c64_i32 : i32
423 %19 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
424 %20 = tt.splat %18 : (i32) -> tensor<64xi32>
425 %21 = arith.addi %20, %19 : tensor<64xi32>
426 %22 = tt.splat %arg4 : (i32) -> tensor<64xi32>
427 %23 = arith.remsi %21, %22 : tensor<64xi32>
428 %24 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32>
429 %25 = tt.expand_dims %17 {axis = 1 : i32} : (tensor<32xi32>) -> tensor<32x1xi32>
430 %26 = tt.splat %arg6 : (i32) -> tensor<32x1xi32>
431 %27 = arith.muli %25, %26 : tensor<32x1xi32>
432 %28 = tt.expand_dims %24 {axis = 0 : i32} : (tensor<32xi32>) -> tensor<1x32xi32>
433 %29 = tt.splat %arg7 : (i32) -> tensor<1x32xi32>
434 %30 = arith.muli %28, %29 : tensor<1x32xi32>
435 %31 = tt.broadcast %27 : (tensor<32x1xi32>) -> tensor<32x32xi32>
436 %32 = tt.broadcast %30 : (tensor<1x32xi32>) -> tensor<32x32xi32>
437 %33 = arith.addi %31, %32 : tensor<32x32xi32>
438 %34 = tt.splat %arg0 : (!tt.ptr<f16, 1>) -> tensor<32x32x!tt.ptr<f16, 1>>
439 %35 = tt.addptr %34, %33 : tensor<32x32x!tt.ptr<f16, 1>>, tensor<32x32xi32>
440 %36 = tt.expand_dims %24 {axis = 1 : i32} : (tensor<32xi32>) -> tensor<32x1xi32>
441 %37 = tt.splat %arg8 : (i32) -> tensor<32x1xi32>
442 %38 = arith.muli %36, %37 : tensor<32x1xi32>
443 %39 = tt.expand_dims %23 {axis = 0 : i32} : (tensor<64xi32>) -> tensor<1x64xi32>
444 %40 = tt.splat %arg9 : (i32) -> tensor<1x64xi32>
445 %41 = arith.muli %39, %40 : tensor<1x64xi32>
446 %42 = tt.broadcast %38 : (tensor<32x1xi32>) -> tensor<32x64xi32>
447 %43 = tt.broadcast %41 : (tensor<1x64xi32>) -> tensor<32x64xi32>
448 %44 = arith.addi %42, %43 : tensor<32x64xi32>
449 %45 = tt.splat %arg1 : (!tt.ptr<f16, 1>) -> tensor<32x64x!tt.ptr<f16, 1>>
450 %46 = tt.addptr %45, %44 : tensor<32x64x!tt.ptr<f16, 1>>, tensor<32x64xi32>
451 %47 = tt.call @"zeros____0cconstexpr_(constexpr_32_, constexpr_64_)__1cconstexpr_fp32_"() : () -> tensor<32x64xf32>
452 %48 = tt.call @cdiv__i32__1cconstexpr_32_(%arg5) : (i32) -> i32
453 %c0_i32 = arith.constant 0 : i32
454 %c1_i32 = arith.constant 1 : i32
455 %49 = arith.bitcast %c0_i32 : i32 to i32
456 %50 = arith.bitcast %48 : i32 to i32
457 %51 = arith.bitcast %c1_i32 : i32 to i32
458 %52 = llvm.mlir.undef : i32
459 %53:3 = scf.for %arg12 = %49 to %50 step %51 iter_args(%arg13 = %47, %arg14 = %35, %arg15 = %46) -> (tensor<32x64xf32>, tensor<32x32x!tt.ptr<f16, 1>>, tensor<32x64x!tt.ptr<f16, 1>>) : i32 {
460 %83 = tt.expand_dims %24 {axis = 0 : i32} : (tensor<32xi32>) -> tensor<1x32xi32>
461 %c32_i32_3 = arith.constant 32 : i32
462 %84 = arith.muli %arg12, %c32_i32_3 : i32
463 %85 = arith.subi %arg5, %84 : i32
464 %86 = tt.splat %85 : (i32) -> tensor<1x32xi32>
465 %87 = arith.cmpi slt, %83, %86 : tensor<1x32xi32>
466 %cst = arith.constant 0.000000e+00 : f32
467 %88 = tt.broadcast %87 : (tensor<1x32xi1>) -> tensor<32x32xi1>
468 %cst_4 = arith.constant dense<0.000000e+00> : tensor<32x32xf32>
469 %89 = arith.truncf %cst_4 : tensor<32x32xf32> to tensor<32x32xf16>
470 %90 = tt.load %arg14, %88, %89 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x32xf16>
471 %91 = tt.expand_dims %24 {axis = 1 : i32} : (tensor<32xi32>) -> tensor<32x1xi32>
472 %c32_i32_5 = arith.constant 32 : i32
473 %92 = arith.muli %arg12, %c32_i32_5 : i32
474 %93 = arith.subi %arg5, %92 : i32
475 %94 = tt.splat %93 : (i32) -> tensor<32x1xi32>
476 %95 = arith.cmpi slt, %91, %94 : tensor<32x1xi32>
477 %cst_6 = arith.constant 0.000000e+00 : f32
478 %96 = tt.broadcast %95 : (tensor<32x1xi1>) -> tensor<32x64xi1>
479 %cst_7 = arith.constant dense<0.000000e+00> : tensor<32x64xf32>
480 %97 = arith.truncf %cst_7 : tensor<32x64xf32> to tensor<32x64xf16>
481 %98 = tt.load %arg15, %96, %97 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x64xf16>
482 %cst_8 = arith.constant 0.000000e+00 : f32
483 %cst_9 = arith.constant dense<0.000000e+00> : tensor<32x64xf32>
484 %99 = tt.dot %90, %98, %cst_9 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<32x32xf16> * tensor<32x64xf16> -> tensor<32x64xf32>
485 %100 = arith.addf %arg13, %99 : tensor<32x64xf32>
486 %c32_i32_10 = arith.constant 32 : i32
487 %101 = arith.muli %arg7, %c32_i32_10 : i32
488 %102 = tt.splat %101 : (i32) -> tensor<32x32xi32>
489 %103 = tt.addptr %arg14, %102 : tensor<32x32x!tt.ptr<f16, 1>>, tensor<32x32xi32>
490 %c32_i32_11 = arith.constant 32 : i32
491 %104 = arith.muli %arg8, %c32_i32_11 : i32
492 %105 = tt.splat %104 : (i32) -> tensor<32x64xi32>
493 %106 = tt.addptr %arg15, %105 : tensor<32x64x!tt.ptr<f16, 1>>, tensor<32x64xi32>
494 scf.yield %100, %103, %106 : tensor<32x64xf32>, tensor<32x32x!tt.ptr<f16, 1>>, tensor<32x64x!tt.ptr<f16, 1>>
495 }
496 %54 = arith.truncf %53#0 : tensor<32x64xf32> to tensor<32x64xf16>
497 %c32_i32_1 = arith.constant 32 : i32
498 %55 = arith.muli %9, %c32_i32_1 : i32
499 %56 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32>
500 %57 = tt.splat %55 : (i32) -> tensor<32xi32>
501 %58 = arith.addi %57, %56 : tensor<32xi32>
502 %c64_i32_2 = arith.constant 64 : i32
503 %59 = arith.muli %11, %c64_i32_2 : i32
504 %60 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
505 %61 = tt.splat %59 : (i32) -> tensor<64xi32>
506 %62 = arith.addi %61, %60 : tensor<64xi32>
507 %63 = tt.expand_dims %58 {axis = 1 : i32} : (tensor<32xi32>) -> tensor<32x1xi32>
508 %64 = tt.splat %arg10 : (i32) -> tensor<32x1xi32>
509 %65 = arith.muli %64, %63 : tensor<32x1xi32>
510 %66 = tt.splat %arg2 : (!tt.ptr<f16, 1>) -> tensor<32x1x!tt.ptr<f16, 1>>
511 %67 = tt.addptr %66, %65 : tensor<32x1x!tt.ptr<f16, 1>>, tensor<32x1xi32>
512 %68 = tt.expand_dims %62 {axis = 0 : i32} : (tensor<64xi32>) -> tensor<1x64xi32>
513 %69 = tt.splat %arg11 : (i32) -> tensor<1x64xi32>
514 %70 = arith.muli %69, %68 : tensor<1x64xi32>
515 %71 = tt.broadcast %67 : (tensor<32x1x!tt.ptr<f16, 1>>) -> tensor<32x64x!tt.ptr<f16, 1>>
516 %72 = tt.broadcast %70 : (tensor<1x64xi32>) -> tensor<32x64xi32>
517 %73 = tt.addptr %71, %72 : tensor<32x64x!tt.ptr<f16, 1>>, tensor<32x64xi32>
518 %74 = tt.expand_dims %58 {axis = 1 : i32} : (tensor<32xi32>) -> tensor<32x1xi32>
519 %75 = tt.splat %arg3 : (i32) -> tensor<32x1xi32>
520 %76 = arith.cmpi slt, %74, %75 : tensor<32x1xi32>
521 %77 = tt.expand_dims %62 {axis = 0 : i32} : (tensor<64xi32>) -> tensor<1x64xi32>
522 %78 = tt.splat %arg4 : (i32) -> tensor<1x64xi32>
523 %79 = arith.cmpi slt, %77, %78 : tensor<1x64xi32>
524 %80 = tt.broadcast %76 : (tensor<32x1xi1>) -> tensor<32x64xi1>
525 %81 = tt.broadcast %79 : (tensor<1x64xi1>) -> tensor<32x64xi1>
526 %82 = arith.andi %80, %81 : tensor<32x64xi1>
527 tt.store %73, %54, %82 {cache = 1 : i32, evict = 1 : i32} : tensor<32x64xf16>
528 tt.return
529 }
530 tt.func private @cdiv__i32__1cconstexpr_32_(%arg0: i32 ) -> i32 attributes {noinline = false} {
531 %c32_i32 = arith.constant 32 : i32
532 %0 = arith.addi %arg0, %c32_i32 : i32
533 %c1_i32 = arith.constant 1 : i32
534 %1 = arith.subi %0, %c1_i32 : i32
535 %c32_i32_0 = arith.constant 32 : i32
536 %2 = arith.divsi %1, %c32_i32_0 : i32
537 tt.return %2 : i32
538 }
539 tt.func private @cdiv__i32__1cconstexpr_64_(%arg0: i32 ) -> i32 attributes {noinline = false} {
540 %c64_i32 = arith.constant 64 : i32
541 %0 = arith.addi %arg0, %c64_i32 : i32
542 %c1_i32 = arith.constant 1 : i32
543 %1 = arith.subi %0, %c1_i32 : i32
544 %c64_i32_0 = arith.constant 64 : i32
545 %2 = arith.divsi %1, %c64_i32_0 : i32
546 tt.return %2 : i32
547 }
548 tt.func private @minimum__i32__1cconstexpr_8_(%arg0: i32 ) -> i32 attributes {noinline = false} {
549 %c8_i32 = arith.constant 8 : i32
550 %0 = arith.minsi %arg0, %c8_i32 : i32
551 tt.return %0 : i32
552 }
553 tt.func private @"zeros____0cconstexpr_(constexpr_32_, constexpr_64_)__1cconstexpr_fp32_"() -> tensor<32x64xf32> attributes {noinline = false} {
554 %cst = arith.constant 0.000000e+00 : f32
555 %cst_0 = arith.constant dense<0.000000e+00> : tensor<32x64xf32>
556 tt.return %cst_0 : tensor<32x64xf32>
557 }
558 }
559 */