1 /*
  2  * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.
  8  *
  9  * This code is distributed in the hope that it will be useful, but WITHOUT
 10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 12  * version 2 for more details (a copy is included in the LICENSE file that
 13  * accompanied this code).
 14  *
 15  * You should have received a copy of the GNU General Public License version
 16  * 2 along with this work; if not, write to the Free Software Foundation,
 17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 18  *
 19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 20  * or visit www.oracle.com if you need additional information or have any
 21  * questions.
 22  */
 23 
 24 package oracle.code.triton;
 25 
 26 import org.junit.jupiter.api.Test;
 27 import org.junit.jupiter.api.extension.ExtendWith;
 28 
 29 import jdk.incubator.code.TypeElement;
 30 import jdk.incubator.code.dialect.java.JavaType;
 31 import jdk.incubator.code.CodeReflection;
 32 import java.util.List;
 33 
 34 import static oracle.code.triton.Triton.*;
 35 import static oracle.code.triton.Triton.CompareKind.*;
 36 
 37 @ExtendWith(TritonTestExtension.class)
 38 public class TestMatrixFp16 {
 39 
 40     @TritonCodeModel("""
 41             module ()java.type:"void" -> {
 42                 tt.func @"cdiv_int_32_int" (%0 : java.type:"int")java.type:"int" -> {
 43                     %1 : java.type:"int" = arith.constant @32;
 44                     %2 : java.type:"int" = arith.addi %0 %1;
 45                     %3 : java.type:"int" = arith.constant @1;
 46                     %4 : java.type:"int" = arith.subi %2 %3;
 47                     %5 : java.type:"int" = arith.divsi %4 %1;
 48                     tt.return %5;
 49                 };
 50                 tt.func @"cdiv_int_64_int" (%6 : java.type:"int")java.type:"int" -> {
 51                     %7 : java.type:"int" = arith.constant @64;
 52                     %8 : java.type:"int" = arith.addi %6 %7;
 53                     %9 : java.type:"int" = arith.constant @1;
 54                     %10 : java.type:"int" = arith.subi %8 %9;
 55                     %11 : java.type:"int" = arith.divsi %10 %7;
 56                     tt.return %11;
 57                 };
 58                 tt.func @sym_name="matmul_kernel_fp16_ptr<java.type.class<oracle.code.triton.Float16, java.type.primitive<void>>>_ptr<java.type.class<oracle.code.triton.Float16, java.type.primitive<void>>>_ptr<java.type.class<oracle.code.triton.Float16, java.type.primitive<void>>>_int_int_int_int_1_int_1_int_1_32_64_32_8_false_void" (%12 : ptr<java.type:"oracle.code.triton.Float16">, %13 : ptr<java.type:"oracle.code.triton.Float16">, %14 : ptr<java.type:"oracle.code.triton.Float16">, %15 : java.type:"int", %16 : java.type:"int", %17 : java.type:"int", %18 : java.type:"int", %19 : java.type:"int", %20 : java.type:"int")java.type:"void" -> {
 59                     %21 : java.type:"int" = arith.constant @value=1;
 60                     %22 : java.type:"int" = arith.constant @value=1;
 61                     %23 : java.type:"int" = arith.constant @value=1;
 62                     %24 : java.type:"int" = arith.constant @value=32;
 63                     %25 : java.type:"int" = arith.constant @value=64;
 64                     %26 : java.type:"int" = arith.constant @value=32;
 65                     %27 : java.type:"int" = arith.constant @value=8;
 66                     %28 : java.type:"int" = tt.get_program_id @axis=0;
 67                     %29 : java.type:"int" = tt.call %15 @callee="cdiv_int_32_int";
 68                     %30 : java.type:"int" = tt.call %16 @callee="cdiv_int_64_int";
 69                     %31 : java.type:"int" = arith.muli %27 %30;
 70                     %32 : java.type:"int" = arith.divsi %28 %31;
 71                     %33 : java.type:"int" = arith.muli %32 %27;
 72                     %34 : java.type:"int" = arith.subi %29 %33;
 73                     %35 : java.type:"int" = arith.minsi %34 %27;
 74                     %36 : java.type:"int" = arith.remsi %28 %35;
 75                     %37 : java.type:"int" = arith.addi %33 %36;
 76                     %38 : java.type:"int" = arith.remsi %28 %31;
 77                     %39 : java.type:"int" = arith.divsi %38 %35;
 78                     %40 : tensor<x32, java.type:"int"> = tt.make_range @start=0 @end=32;
 79                     %41 : java.type:"int" = arith.muli %37 %24;
 80                     %42 : tensor<x32, java.type:"int"> = tt.splat %41;
 81                     %43 : tensor<x32, java.type:"int"> = arith.addi %42 %40;
 82                     %44 : tensor<x32, java.type:"int"> = tt.splat %15;
 83                     %45 : tensor<x32, java.type:"int"> = arith.remsi %43 %44;
 84                     %46 : tensor<x64, java.type:"int"> = tt.make_range @start=0 @end=64;
 85                     %47 : java.type:"int" = arith.muli %39 %25;
 86                     %48 : tensor<x64, java.type:"int"> = tt.splat %47;
 87                     %49 : tensor<x64, java.type:"int"> = arith.addi %48 %46;
 88                     %50 : tensor<x64, java.type:"int"> = tt.splat %16;
 89                     %51 : tensor<x64, java.type:"int"> = arith.remsi %49 %50;
 90                     %52 : tensor<x32, java.type:"int"> = tt.make_range @start=0 @end=32;
 91                     %53 : tensor<x32, x1, java.type:"int"> = tt.expand_dims %45 @axis=1;
 92                     %54 : tensor<x32, x1, java.type:"int"> = tt.splat %18;
 93                     %55 : tensor<x32, x1, java.type:"int"> = arith.muli %53 %54;
 94                     %56 : tensor<x1, x32, java.type:"int"> = tt.expand_dims %52 @axis=0;
 95                     %57 : tensor<x1, x32, java.type:"int"> = tt.splat %21;
 96                     %58 : tensor<x1, x32, java.type:"int"> = arith.muli %56 %57;
 97                     %59 : tensor<x32, x32, java.type:"int"> = tt.broadcast %55;
 98                     %60 : tensor<x32, x32, java.type:"int"> = tt.broadcast %58;
 99                     %61 : tensor<x32, x32, java.type:"int"> = arith.addi %59 %60;
100                     %62 : tensor<x32, x32, ptr<java.type:"oracle.code.triton.Float16">> = tt.splat %12;
101                     %63 : tensor<x32, x32, ptr<java.type:"oracle.code.triton.Float16">> = tt.addptr %62 %61;
102                     %64 : tensor<x32, x1, java.type:"int"> = tt.expand_dims %52 @axis=1;
103                     %65 : tensor<x32, x1, java.type:"int"> = tt.splat %19;
104                     %66 : tensor<x32, x1, java.type:"int"> = arith.muli %64 %65;
105                     %67 : tensor<x1, x64, java.type:"int"> = tt.expand_dims %51 @axis=0;
106                     %68 : tensor<x1, x64, java.type:"int"> = tt.splat %22;
107                     %69 : tensor<x1, x64, java.type:"int"> = arith.muli %67 %68;
108                     %70 : tensor<x32, x64, java.type:"int"> = tt.broadcast %66;
109                     %71 : tensor<x32, x64, java.type:"int"> = tt.broadcast %69;
110                     %72 : tensor<x32, x64, java.type:"int"> = arith.addi %70 %71;
111                     %73 : tensor<x32, x64, ptr<java.type:"oracle.code.triton.Float16">> = tt.splat %13;
112                     %74 : tensor<x32, x64, ptr<java.type:"oracle.code.triton.Float16">> = tt.addptr %73 %72;
113                     %75 : tensor<x32, x64, java.type:"float"> = arith.constant @value=0.0f;
114                     %76 : java.type:"int" = arith.constant @value=0;
115                     %77 : java.type:"int" = tt.call %17 @callee="cdiv_int_32_int";
116                     %78 : java.type:"int" = arith.constant @value=1;
117                     %79 : Tuple<tensor<x32, x64, java.type:"float">, tensor<x32, x32, ptr<java.type:"oracle.code.triton.Float16">>, tensor<x32, x64, ptr<java.type:"oracle.code.triton.Float16">>> = scf.for %76 %77 %78 %75 %63 %74 (%80 : java.type:"int", %81 : tensor<x32, x64, java.type:"float">, %82 : tensor<x32, x32, ptr<java.type:"oracle.code.triton.Float16">>, %83 : tensor<x32, x64, ptr<java.type:"oracle.code.triton.Float16">>)Tuple<tensor<x32, x64, java.type:"float">, tensor<x32, x32, ptr<java.type:"oracle.code.triton.Float16">>, tensor<x32, x64, ptr<java.type:"oracle.code.triton.Float16">>> -> {
118                         %84 : tensor<x1, x32, java.type:"int"> = tt.expand_dims %52 @axis=0;
119                         %85 : java.type:"int" = arith.muli %80 %26;
120                         %86 : java.type:"int" = arith.subi %17 %85;
121                         %87 : tensor<x1, x32, java.type:"int"> = tt.splat %86;
122                         %88 : tensor<x1, x32, java.type:"boolean"> = arith.cmpi %84 %87 @predicate="slt";
123                         %89 : tensor<x32, x32, java.type:"boolean"> = tt.broadcast %88;
124                         %90 : tensor<x32, x32, java.type:"oracle.code.triton.Float16"> = arith.constant @value=0.0f;
125                         %91 : tensor<x32, x32, java.type:"oracle.code.triton.Float16"> = tt.load %82 %89 %90;
126                         %92 : tensor<x32, x1, java.type:"int"> = tt.expand_dims %52 @axis=1;
127                         %93 : java.type:"int" = arith.muli %80 %26;
128                         %94 : java.type:"int" = arith.subi %17 %93;
129                         %95 : tensor<x32, x1, java.type:"int"> = tt.splat %94;
130                         %96 : tensor<x32, x1, java.type:"boolean"> = arith.cmpi %92 %95 @predicate="slt";
131                         %97 : tensor<x32, x64, java.type:"boolean"> = tt.broadcast %96;
132                         %98 : tensor<x32, x64, java.type:"oracle.code.triton.Float16"> = arith.constant @value=0.0f;
133                         %99 : tensor<x32, x64, java.type:"oracle.code.triton.Float16"> = tt.load %83 %97 %98;
134                         %100 : tensor<x32, x64, java.type:"float"> = arith.constant @value=0.0f;
135                         %101 : tensor<x32, x64, java.type:"float"> = tt.dot %91 %99 %100;
136                         %102 : tensor<x32, x64, java.type:"float"> = arith.addf %81 %101;
137                         %103 : java.type:"int" = arith.muli %26 %21;
138                         %104 : tensor<x32, x32, java.type:"int"> = tt.splat %103;
139                         %105 : tensor<x32, x32, ptr<java.type:"oracle.code.triton.Float16">> = tt.addptr %82 %104;
140                         %106 : java.type:"int" = arith.muli %26 %19;
141                         %107 : tensor<x32, x64, java.type:"int"> = tt.splat %106;
142                         %108 : tensor<x32, x64, ptr<java.type:"oracle.code.triton.Float16">> = tt.addptr %83 %107;
143                         scf.yield %102 %105 %108;
144                     };
145                     %109 : tensor<x32, x64, java.type:"float"> = tuple.load %79 @0;
146                     %110 : tensor<x32, x32, ptr<java.type:"oracle.code.triton.Float16">> = tuple.load %79 @1;
147                     %111 : tensor<x32, x64, ptr<java.type:"oracle.code.triton.Float16">> = tuple.load %79 @2;
148                     %112 : tensor<x32, x64, java.type:"oracle.code.triton.Float16"> = arith.truncf %109;
149                     %113 : java.type:"int" = arith.muli %37 %24;
150                     %114 : tensor<x32, java.type:"int"> = tt.splat %113;
151                     %115 : tensor<x32, java.type:"int"> = arith.addi %114 %40;
152                     %116 : java.type:"int" = arith.muli %39 %25;
153                     %117 : tensor<x64, java.type:"int"> = tt.splat %116;
154                     %118 : tensor<x64, java.type:"int"> = arith.addi %117 %46;
155                     %119 : tensor<x32, x1, java.type:"int"> = tt.expand_dims %115 @axis=1;
156                     %120 : tensor<x32, x1, java.type:"int"> = tt.splat %20;
157                     %121 : tensor<x32, x1, java.type:"int"> = arith.muli %120 %119;
158                     %122 : tensor<x1, x64, java.type:"int"> = tt.expand_dims %118 @axis=0;
159                     %123 : tensor<x1, x64, java.type:"int"> = tt.splat %23;
160                     %124 : tensor<x1, x64, java.type:"int"> = arith.muli %123 %122;
161                     %125 : tensor<x32, x64, java.type:"int"> = tt.broadcast %121;
162                     %126 : tensor<x32, x64, java.type:"int"> = tt.broadcast %124;
163                     %127 : tensor<x32, x64, java.type:"int"> = arith.addi %125 %126;
164                     %128 : tensor<x32, x64, ptr<java.type:"oracle.code.triton.Float16">> = tt.splat %14;
165                     %129 : tensor<x32, x64, ptr<java.type:"oracle.code.triton.Float16">> = tt.addptr %128 %127;
166                     %130 : tensor<x32, x1, java.type:"int"> = tt.expand_dims %115 @axis=1;
167                     %131 : tensor<x32, x1, java.type:"int"> = tt.splat %15;
168                     %132 : tensor<x32, x1, java.type:"boolean"> = arith.cmpi %130 %131 @predicate="slt";
169                     %133 : tensor<x1, x64, java.type:"int"> = tt.expand_dims %118 @axis=0;
170                     %134 : tensor<x1, x64, java.type:"int"> = tt.splat %16;
171                     %135 : tensor<x1, x64, java.type:"boolean"> = arith.cmpi %133 %134 @predicate="slt";
172                     %136 : tensor<x32, x64, java.type:"boolean"> = tt.broadcast %132;
173                     %137 : tensor<x32, x64, java.type:"boolean"> = tt.broadcast %135;
174                     %138 : tensor<x32, x64, java.type:"boolean"> = arith.andi %136 %137;
175                     tt.store %129 %112 %138;
176                     tt.return;
177                 };
178                 unreachable;
179             };
180             """)
181     @CodeReflection
182     static void matmul_kernel_fp16(
183             // Pointers to matrices
184             Ptr a_ptr, Ptr b_ptr, Ptr c_ptr,
185             // Matrix dimensions
186             int M, int N, int K,
187             // The stride variables represent how much to increase the ptr by when moving by 1
188             // element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr`
189             // by to get the element one row down (A has M rows).
190             int stride_am, @Constant int stride_ak,
191             int stride_bk, @Constant int stride_bn,
192             int stride_cm, @Constant int stride_cn,
193             // Meta-parameters
194             @Constant int BLOCK_SIZE_M, @Constant int BLOCK_SIZE_N, @Constant int BLOCK_SIZE_K,
195             @Constant int GROUP_SIZE_M,
196             @Constant boolean ACTIVATION) {
197 
198         // """Kernel for computing the matmul C = A x B.
199         // A has shape (M, K), B has shape (K, N) and C has shape (M, N)
200         // """
201         // -----------------------------------------------------------
202         // Map program ids `pid` to the block of C it should compute.
203         // This is done in a grouped ordering to promote L2 data reuse.
204         // See above `L2 Cache Optimizations` section for details.
205         var pid = programId(0);
206         var num_pid_m = cdiv(M, BLOCK_SIZE_M);
207         var num_pid_n = cdiv(N, BLOCK_SIZE_N);
208         var num_pid_in_group = GROUP_SIZE_M * num_pid_n;
209         var group_id = pid / num_pid_in_group;
210         var first_pid_m = group_id * GROUP_SIZE_M;
211         var group_size_m = Math.min(num_pid_m - first_pid_m, GROUP_SIZE_M);
212         var pid_m = first_pid_m + (pid % group_size_m);
213         var pid_n = (pid % num_pid_in_group) / group_size_m;
214 
215         // ----------------------------------------------------------
216         // Create pointers for the first blocks of A and B.
217         // We will advance this pointer as we move in the K direction
218         // and accumulate
219         // `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
220         // `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
221         // See above `Pointer Arithmetics` section for details
222         var offs_m = arange(0, BLOCK_SIZE_M);
223         var offs_am = mod(add(pid_m * BLOCK_SIZE_M, offs_m), M);
224         var offs_n = arange(0, BLOCK_SIZE_N);
225         var offs_bn = mod(add(pid_n * BLOCK_SIZE_N, offs_n), N);
226         var offs_k = arange(0, BLOCK_SIZE_K);
227         var a_ptrs = add(a_ptr, add(
228                 mul(expand(offs_am, 1), stride_am),
229                 mul(expand(offs_k, 0), stride_ak)));
230         var b_ptrs = add(b_ptr, add(
231                         mul(expand(offs_k, 1), stride_bk),
232                         mul(expand(offs_bn, 0), stride_bn)));
233 
234         // -----------------------------------------------------------
235         // Iterate to compute a block of the C matrix.
236         // We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
237         // of fp32 values for higher accuracy.
238         // `accumulator` will be converted back to fp16 after the loop.
239         var accumulator = zeros(float.class, BLOCK_SIZE_M, BLOCK_SIZE_N);
240         for (int k = 0; k < cdiv(K, BLOCK_SIZE_K); k++) {
241             // Load the next block of A and B, generate a mask by checking the K dimension.
242             // If it is out of bounds, set it to 0.
243             var a = load(a_ptrs,
244                     compare(expand(offs_k, 0), K - k * BLOCK_SIZE_K, LessThan), 0f);
245             var b = load(b_ptrs,
246                     compare(expand(offs_k, 1), K - k * BLOCK_SIZE_K, LessThan), 0f);
247             // We accumulate along the K dimension.
248             accumulator = add(accumulator, dot(a, b));
249             // Advance the ptrs to the next K block.
250             a_ptrs = add(a_ptrs, BLOCK_SIZE_K * stride_ak);
251             b_ptrs = add(b_ptrs, BLOCK_SIZE_K * stride_bk);
252         }
253 
254         // You can fuse arbitrary activation functions here
255         // while the accumulator is still in FP32!
256 //        if (ACTIVATION) {
257 //            // ...
258 //        }
259         var c = Triton.conv(Float16.class, accumulator);
260 
261         // -----------------------------------------------------------
262         // Write back the block of the output matrix C with masks.
263         var offs_cm = add(pid_m * BLOCK_SIZE_M, offs_m);
264         var offs_cn = add(pid_n * BLOCK_SIZE_N, offs_n);
265         var c_ptrs = add(c_ptr, add(
266                         mul(stride_cm, expand(offs_cm, 1)),
267                         mul(stride_cn, expand(offs_cn, 0))));
268         var c_mask = and(
269                 compare(expand(offs_cm, 1), M, LessThan),
270                 compare(expand(offs_cn, 0), N, LessThan));
271         store(c_ptrs, c, c_mask);
272     }
273 
274     @TritonTestExtension.Kernel("matmul_kernel_fp16")
275     @Test
276     public void test(TritonTestExtension.TritonTestData t) {
277         List<TypeElement> argTypes = List.of(
278                 new PtrType(Float16.FLOAT_16_TYPE),
279                 new PtrType(Float16.FLOAT_16_TYPE),
280                 new PtrType(Float16.FLOAT_16_TYPE),
281                 JavaType.INT, JavaType.INT, JavaType.INT,
282                 JavaType.INT, new ConstantType(JavaType.INT, 1),
283                 JavaType.INT, new ConstantType(JavaType.INT, 1),
284                 JavaType.INT, new ConstantType(JavaType.INT, 1),
285                 new ConstantType(JavaType.INT, 32), new ConstantType(JavaType.INT, 64), new ConstantType(JavaType.INT, 32),
286                 new ConstantType(JavaType.INT, 8),
287                 new ConstantType(JavaType.INT, false));
288 
289         t.test(argTypes);
290     }
291 
292 }
293 
294 /*
295 @triton.jit
296 def matmul_kernel(
297         # Pointers to matrices
298         a_ptr, b_ptr, c_ptr,
299         # Matrix dimensions
300         M, N, K,
301         # The stride variables represent how much to increase the ptr by when moving by 1
302         # element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr`
303         # by to get the element one row down (A has M rows).
304         stride_am, stride_ak,  #
305         stride_bk, stride_bn,  #
306         stride_cm, stride_cn,
307         # Meta-parameters
308         BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,  #
309         GROUP_SIZE_M: tl.constexpr,  #
310         ACTIVATION: tl.constexpr  #
311 ):
312     """Kernel for computing the matmul C = A x B.
313     A has shape (M, K), B has shape (K, N) and C has shape (M, N)
314     """
315     # -----------------------------------------------------------
316     # Map program ids `pid` to the block of C it should compute.
317     # This is done in a grouped ordering to promote L2 data reuse.
318     # See above `L2 Cache Optimizations` section for details.
319     pid = tl.program_id(axis=0)
320     num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
321     num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
322     num_pid_in_group = GROUP_SIZE_M * num_pid_n
323     group_id = pid // num_pid_in_group
324     first_pid_m = group_id * GROUP_SIZE_M
325     group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
326     pid_m = first_pid_m + (pid % group_size_m)
327     pid_n = (pid % num_pid_in_group) // group_size_m
328 
329     # ----------------------------------------------------------
330     # Create pointers for the first blocks of A and B.
331     # We will advance this pointer as we move in the K direction
332     # and accumulate
333     # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
334     # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
335     # See above `Pointer Arithmetics` section for details
336     offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
337     offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
338     offs_k = tl.arange(0, BLOCK_SIZE_K)
339     a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
340     b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
341 
342     # -----------------------------------------------------------
343     # Iterate to compute a block of the C matrix.
344     # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
345     # of fp32 values for higher accuracy.
346     # `accumulator` will be converted back to fp16 after the loop.
347     accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
348     for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
349         # Load the next block of A and B, generate a mask by checking the K dimension.
350         # If it is out of bounds, set it to 0.
351         a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
352         b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
353         # We accumulate along the K dimension.
354         accumulator += tl.dot(a, b)
355         # Advance the ptrs to the next K block.
356         a_ptrs += BLOCK_SIZE_K * stride_ak
357         b_ptrs += BLOCK_SIZE_K * stride_bk
358     # You can fuse arbitrary activation functions here
359     # while the accumulator is still in FP32!
360     if ACTIVATION == "leaky_relu":
361         accumulator = leaky_relu(accumulator)
362     c = accumulator.to(tl.float16)
363 
364     # -----------------------------------------------------------
365     # Write back the block of the output matrix C with masks.
366     offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
367     offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
368     c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
369     c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
370     tl.store(c_ptrs, c, mask=c_mask)
371 
372 
373 # We can fuse `leaky_relu` by providing it as an `ACTIVATION` meta-parameter in `_matmul`.
374 @triton.jit
375 def leaky_relu(x):
376     x = x + 1
377     return tl.where(x >= 0, x, 0.01 * x)
378 */
379 
380 /*
381 
382  triton/python/triton/tools/compile.py \
383     --kernel-name matmul_kernel \
384     --signature "*fp16,*fp16,*fp16,i32,i32,i32,i32,i32,i32,i32,i32,i32,32,64,32,8,0" \
385     --grid=1024,1024,1024 \
386     03-matrix-multiplication.py
387 
388 BLOCK_SIZE_M = 32
389 BLOCK_SIZE_N = 64
390 BLOCK_SIZE_K = 32
391 GROUP_SIZE_M = 8
392 ACTIVATION = 0
393 
394 module {
395   tt.func public @matmul_kernel_01234567891011(
396             %arg0: !tt.ptr<f16, 1>, %arg1: !tt.ptr<f16, 1>, %arg2: !tt.ptr<f16, 1> ,
397             %arg3: i32, %arg4: i32, %arg5: i32 ,
398             %arg6: i32, %arg7: i32, %arg8: i32 ,
399             %%arg9: i32, %arg10: i32, %arg11: i32 ) attributes {noinline = false} {
400     %0 = tt.get_program_id x : i32
401     %1 = tt.call @cdiv__i32__1cconstexpr_32_(%arg3) : (i32) -> i32
402     %2 = tt.call @cdiv__i32__1cconstexpr_64_(%arg4) : (i32) -> i32
403     %c8_i32 = arith.constant 8 : i32
404     %3 = arith.muli %2, %c8_i32 : i32
405     %4 = arith.divsi %0, %3 : i32
406     %c8_i32_0 = arith.constant 8 : i32
407     %5 = arith.muli %4, %c8_i32_0 : i32
408     %6 = arith.subi %1, %5 : i32
409     %7 = tt.call @minimum__i32__1cconstexpr_8_(%6) : (i32) -> i32
410     %8 = arith.remsi %0, %7 : i32
411     %9 = arith.addi %5, %8 : i32
412     %10 = arith.remsi %0, %3 : i32
413     %11 = arith.divsi %10, %7 : i32
414     %c32_i32 = arith.constant 32 : i32
415     %12 = arith.muli %9, %c32_i32 : i32
416     %13 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32>
417     %14 = tt.splat %12 : (i32) -> tensor<32xi32>
418     %15 = arith.addi %14, %13 : tensor<32xi32>
419     %16 = tt.splat %arg3 : (i32) -> tensor<32xi32>
420     %17 = arith.remsi %15, %16 : tensor<32xi32>
421     %c64_i32 = arith.constant 64 : i32
422     %18 = arith.muli %11, %c64_i32 : i32
423     %19 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
424     %20 = tt.splat %18 : (i32) -> tensor<64xi32>
425     %21 = arith.addi %20, %19 : tensor<64xi32>
426     %22 = tt.splat %arg4 : (i32) -> tensor<64xi32>
427     %23 = arith.remsi %21, %22 : tensor<64xi32>
428     %24 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32>
429     %25 = tt.expand_dims %17 {axis = 1 : i32} : (tensor<32xi32>) -> tensor<32x1xi32>
430     %26 = tt.splat %arg6 : (i32) -> tensor<32x1xi32>
431     %27 = arith.muli %25, %26 : tensor<32x1xi32>
432     %28 = tt.expand_dims %24 {axis = 0 : i32} : (tensor<32xi32>) -> tensor<1x32xi32>
433     %29 = tt.splat %arg7 : (i32) -> tensor<1x32xi32>
434     %30 = arith.muli %28, %29 : tensor<1x32xi32>
435     %31 = tt.broadcast %27 : (tensor<32x1xi32>) -> tensor<32x32xi32>
436     %32 = tt.broadcast %30 : (tensor<1x32xi32>) -> tensor<32x32xi32>
437     %33 = arith.addi %31, %32 : tensor<32x32xi32>
438     %34 = tt.splat %arg0 : (!tt.ptr<f16, 1>) -> tensor<32x32x!tt.ptr<f16, 1>>
439     %35 = tt.addptr %34, %33 : tensor<32x32x!tt.ptr<f16, 1>>, tensor<32x32xi32>
440     %36 = tt.expand_dims %24 {axis = 1 : i32} : (tensor<32xi32>) -> tensor<32x1xi32>
441     %37 = tt.splat %arg8 : (i32) -> tensor<32x1xi32>
442     %38 = arith.muli %36, %37 : tensor<32x1xi32>
443     %39 = tt.expand_dims %23 {axis = 0 : i32} : (tensor<64xi32>) -> tensor<1x64xi32>
444     %40 = tt.splat %arg9 : (i32) -> tensor<1x64xi32>
445     %41 = arith.muli %39, %40 : tensor<1x64xi32>
446     %42 = tt.broadcast %38 : (tensor<32x1xi32>) -> tensor<32x64xi32>
447     %43 = tt.broadcast %41 : (tensor<1x64xi32>) -> tensor<32x64xi32>
448     %44 = arith.addi %42, %43 : tensor<32x64xi32>
449     %45 = tt.splat %arg1 : (!tt.ptr<f16, 1>) -> tensor<32x64x!tt.ptr<f16, 1>>
450     %46 = tt.addptr %45, %44 : tensor<32x64x!tt.ptr<f16, 1>>, tensor<32x64xi32>
451     %47 = tt.call @"zeros____0cconstexpr_(constexpr_32_, constexpr_64_)__1cconstexpr_fp32_"() : () -> tensor<32x64xf32>
452     %48 = tt.call @cdiv__i32__1cconstexpr_32_(%arg5) : (i32) -> i32
453     %c0_i32 = arith.constant 0 : i32
454     %c1_i32 = arith.constant 1 : i32
455     %49 = arith.bitcast %c0_i32 : i32 to i32
456     %50 = arith.bitcast %48 : i32 to i32
457     %51 = arith.bitcast %c1_i32 : i32 to i32
458     %52 = llvm.mlir.undef : i32
459     %53:3 = scf.for %arg12 = %49 to %50 step %51 iter_args(%arg13 = %47, %arg14 = %35, %arg15 = %46) -> (tensor<32x64xf32>, tensor<32x32x!tt.ptr<f16, 1>>, tensor<32x64x!tt.ptr<f16, 1>>)  : i32 {
460       %83 = tt.expand_dims %24 {axis = 0 : i32} : (tensor<32xi32>) -> tensor<1x32xi32>
461       %c32_i32_3 = arith.constant 32 : i32
462       %84 = arith.muli %arg12, %c32_i32_3 : i32
463       %85 = arith.subi %arg5, %84 : i32
464       %86 = tt.splat %85 : (i32) -> tensor<1x32xi32>
465       %87 = arith.cmpi slt, %83, %86 : tensor<1x32xi32>
466       %cst = arith.constant 0.000000e+00 : f32
467       %88 = tt.broadcast %87 : (tensor<1x32xi1>) -> tensor<32x32xi1>
468       %cst_4 = arith.constant dense<0.000000e+00> : tensor<32x32xf32>
469       %89 = arith.truncf %cst_4 : tensor<32x32xf32> to tensor<32x32xf16>
470       %90 = tt.load %arg14, %88, %89 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x32xf16>
471       %91 = tt.expand_dims %24 {axis = 1 : i32} : (tensor<32xi32>) -> tensor<32x1xi32>
472       %c32_i32_5 = arith.constant 32 : i32
473       %92 = arith.muli %arg12, %c32_i32_5 : i32
474       %93 = arith.subi %arg5, %92 : i32
475       %94 = tt.splat %93 : (i32) -> tensor<32x1xi32>
476       %95 = arith.cmpi slt, %91, %94 : tensor<32x1xi32>
477       %cst_6 = arith.constant 0.000000e+00 : f32
478       %96 = tt.broadcast %95 : (tensor<32x1xi1>) -> tensor<32x64xi1>
479       %cst_7 = arith.constant dense<0.000000e+00> : tensor<32x64xf32>
480       %97 = arith.truncf %cst_7 : tensor<32x64xf32> to tensor<32x64xf16>
481       %98 = tt.load %arg15, %96, %97 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x64xf16>
482       %cst_8 = arith.constant 0.000000e+00 : f32
483       %cst_9 = arith.constant dense<0.000000e+00> : tensor<32x64xf32>
484       %99 = tt.dot %90, %98, %cst_9 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<32x32xf16> * tensor<32x64xf16> -> tensor<32x64xf32>
485       %100 = arith.addf %arg13, %99 : tensor<32x64xf32>
486       %c32_i32_10 = arith.constant 32 : i32
487       %101 = arith.muli %arg7, %c32_i32_10 : i32
488       %102 = tt.splat %101 : (i32) -> tensor<32x32xi32>
489       %103 = tt.addptr %arg14, %102 : tensor<32x32x!tt.ptr<f16, 1>>, tensor<32x32xi32>
490       %c32_i32_11 = arith.constant 32 : i32
491       %104 = arith.muli %arg8, %c32_i32_11 : i32
492       %105 = tt.splat %104 : (i32) -> tensor<32x64xi32>
493       %106 = tt.addptr %arg15, %105 : tensor<32x64x!tt.ptr<f16, 1>>, tensor<32x64xi32>
494       scf.yield %100, %103, %106 : tensor<32x64xf32>, tensor<32x32x!tt.ptr<f16, 1>>, tensor<32x64x!tt.ptr<f16, 1>>
495     }
496     %54 = arith.truncf %53#0 : tensor<32x64xf32> to tensor<32x64xf16>
497     %c32_i32_1 = arith.constant 32 : i32
498     %55 = arith.muli %9, %c32_i32_1 : i32
499     %56 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32>
500     %57 = tt.splat %55 : (i32) -> tensor<32xi32>
501     %58 = arith.addi %57, %56 : tensor<32xi32>
502     %c64_i32_2 = arith.constant 64 : i32
503     %59 = arith.muli %11, %c64_i32_2 : i32
504     %60 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
505     %61 = tt.splat %59 : (i32) -> tensor<64xi32>
506     %62 = arith.addi %61, %60 : tensor<64xi32>
507     %63 = tt.expand_dims %58 {axis = 1 : i32} : (tensor<32xi32>) -> tensor<32x1xi32>
508     %64 = tt.splat %arg10 : (i32) -> tensor<32x1xi32>
509     %65 = arith.muli %64, %63 : tensor<32x1xi32>
510     %66 = tt.splat %arg2 : (!tt.ptr<f16, 1>) -> tensor<32x1x!tt.ptr<f16, 1>>
511     %67 = tt.addptr %66, %65 : tensor<32x1x!tt.ptr<f16, 1>>, tensor<32x1xi32>
512     %68 = tt.expand_dims %62 {axis = 0 : i32} : (tensor<64xi32>) -> tensor<1x64xi32>
513     %69 = tt.splat %arg11 : (i32) -> tensor<1x64xi32>
514     %70 = arith.muli %69, %68 : tensor<1x64xi32>
515     %71 = tt.broadcast %67 : (tensor<32x1x!tt.ptr<f16, 1>>) -> tensor<32x64x!tt.ptr<f16, 1>>
516     %72 = tt.broadcast %70 : (tensor<1x64xi32>) -> tensor<32x64xi32>
517     %73 = tt.addptr %71, %72 : tensor<32x64x!tt.ptr<f16, 1>>, tensor<32x64xi32>
518     %74 = tt.expand_dims %58 {axis = 1 : i32} : (tensor<32xi32>) -> tensor<32x1xi32>
519     %75 = tt.splat %arg3 : (i32) -> tensor<32x1xi32>
520     %76 = arith.cmpi slt, %74, %75 : tensor<32x1xi32>
521     %77 = tt.expand_dims %62 {axis = 0 : i32} : (tensor<64xi32>) -> tensor<1x64xi32>
522     %78 = tt.splat %arg4 : (i32) -> tensor<1x64xi32>
523     %79 = arith.cmpi slt, %77, %78 : tensor<1x64xi32>
524     %80 = tt.broadcast %76 : (tensor<32x1xi1>) -> tensor<32x64xi1>
525     %81 = tt.broadcast %79 : (tensor<1x64xi1>) -> tensor<32x64xi1>
526     %82 = arith.andi %80, %81 : tensor<32x64xi1>
527     tt.store %73, %54, %82 {cache = 1 : i32, evict = 1 : i32} : tensor<32x64xf16>
528     tt.return
529   }
530   tt.func private @cdiv__i32__1cconstexpr_32_(%arg0: i32 ) -> i32 attributes {noinline = false} {
531     %c32_i32 = arith.constant 32 : i32
532     %0 = arith.addi %arg0, %c32_i32 : i32
533     %c1_i32 = arith.constant 1 : i32
534     %1 = arith.subi %0, %c1_i32 : i32
535     %c32_i32_0 = arith.constant 32 : i32
536     %2 = arith.divsi %1, %c32_i32_0 : i32
537     tt.return %2 : i32
538   }
539   tt.func private @cdiv__i32__1cconstexpr_64_(%arg0: i32 ) -> i32 attributes {noinline = false} {
540     %c64_i32 = arith.constant 64 : i32
541     %0 = arith.addi %arg0, %c64_i32 : i32
542     %c1_i32 = arith.constant 1 : i32
543     %1 = arith.subi %0, %c1_i32 : i32
544     %c64_i32_0 = arith.constant 64 : i32
545     %2 = arith.divsi %1, %c64_i32_0 : i32
546     tt.return %2 : i32
547   }
548   tt.func private @minimum__i32__1cconstexpr_8_(%arg0: i32 ) -> i32 attributes {noinline = false} {
549     %c8_i32 = arith.constant 8 : i32
550     %0 = arith.minsi %arg0, %c8_i32 : i32
551     tt.return %0 : i32
552   }
553   tt.func private @"zeros____0cconstexpr_(constexpr_32_, constexpr_64_)__1cconstexpr_fp32_"() -> tensor<32x64xf32> attributes {noinline = false} {
554     %cst = arith.constant 0.000000e+00 : f32
555     %cst_0 = arith.constant dense<0.000000e+00> : tensor<32x64xf32>
556     tt.return %cst_0 : tensor<32x64xf32>
557   }
558 }
559  */