1 /*
  2  * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.
  8  *
  9  * This code is distributed in the hope that it will be useful, but WITHOUT
 10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 12  * version 2 for more details (a copy is included in the LICENSE file that
 13  * accompanied this code).
 14  *
 15  * You should have received a copy of the GNU General Public License version
 16  * 2 along with this work; if not, write to the Free Software Foundation,
 17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 18  *
 19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 20  * or visit www.oracle.com if you need additional information or have any
 21  * questions.
 22  */
 23 
 24 package oracle.code.triton;
 25 
 26 import org.junit.jupiter.api.Test;
 27 import org.junit.jupiter.api.extension.ExtendWith;
 28 
 29 import java.lang.reflect.code.TypeElement;
 30 import java.lang.reflect.code.type.JavaType;
 31 import java.lang.runtime.CodeReflection;
 32 import java.util.List;
 33 
 34 import static oracle.code.triton.Triton.*;
 35 import static oracle.code.triton.Triton.CompareKind.*;
 36 import static oracle.code.triton.Triton.compare;
 37 import static oracle.code.triton.Triton.load;
 38 
 39 @ExtendWith(TritonTestExtension.class)
 40 public class TestMatrix {
 41 
 42     @TritonCodeModel("""
 43             module ()void -> {
 44                 tt.func @"cdiv_int_32_int" (%0 : int)int -> {
 45                     %1 : int = arith.constant @"32";
 46                     %2 : int = arith.addi %0 %1;
 47                     %3 : int = arith.constant @"1";
 48                     %4 : int = arith.subi %2 %3;
 49                     %5 : int = arith.divsi %4 %1;
 50                     tt.return %5;
 51                 };
 52                 tt.func @"cdiv_int_64_int" (%6 : int)int -> {
 53                     %7 : int = arith.constant @"64";
 54                     %8 : int = arith.addi %6 %7;
 55                     %9 : int = arith.constant @"1";
 56                     %10 : int = arith.subi %8 %9;
 57                     %11 : int = arith.divsi %10 %7;
 58                     tt.return %11;
 59                 };
 60                 tt.func @"matmul_kernel_broadcast_ptr<float>_ptr<float>_ptr<float>_int_int_int_int_int_int_int_int_int_32_64_32_8_false_void" (%12 : ptr<float>, %13 : ptr<float>, %14 : ptr<float>, %15 : int, %16 : int, %17 : int, %18 : int, %19 : int, %20 : int, %21 : int, %22 : int, %23 : int)void -> {
 61                     %24 : int = arith.constant @"32";
 62                     %25 : int = arith.constant @"64";
 63                     %26 : int = arith.constant @"32";
 64                     %27 : int = arith.constant @"8";
 65                     %28 : int = tt.get_program_id @"0";
 66                     %29 : int = tt.call %15 @"cdiv_int_32_int";
 67                     %30 : int = tt.call %16 @"cdiv_int_64_int";
 68                     %31 : int = arith.muli %27 %30;
 69                     %32 : int = arith.divsi %28 %31;
 70                     %33 : int = arith.muli %32 %27;
 71                     %34 : int = arith.subi %29 %33;
 72                     %35 : int = arith.minsi %34 %27;
 73                     %36 : int = arith.remsi %28 %35;
 74                     %37 : int = arith.addi %33 %36;
 75                     %38 : int = arith.remsi %28 %31;
 76                     %39 : int = arith.divsi %38 %35;
 77                     %40 : tensor<x32, int> = tt.make_range @start="0" @end="32";
 78                     %41 : int = arith.muli %37 %24;
 79                     %42 : tensor<x32, int> = tt.splat %41;
 80                     %43 : tensor<x32, int> = arith.addi %42 %40;
 81                     %44 : tensor<x32, int> = tt.splat %15;
 82                     %45 : tensor<x32, int> = arith.remsi %43 %44;
 83                     %46 : tensor<x64, int> = tt.make_range @start="0" @end="64";
 84                     %47 : int = arith.muli %39 %25;
 85                     %48 : tensor<x64, int> = tt.splat %47;
 86                     %49 : tensor<x64, int> = arith.addi %48 %46;
 87                     %50 : tensor<x64, int> = tt.splat %16;
 88                     %51 : tensor<x64, int> = arith.remsi %49 %50;
 89                     %52 : tensor<x32, int> = tt.make_range @start="0" @end="32";
 90                     %53 : tensor<x32, x1, int> = tt.expand_dims %45 @"1";
 91                     %54 : tensor<x32, x1, int> = tt.splat %18;
 92                     %55 : tensor<x32, x1, int> = arith.muli %53 %54;
 93                     %56 : tensor<x1, x32, int> = tt.expand_dims %52 @"0";
 94                     %57 : tensor<x1, x32, int> = tt.splat %19;
 95                     %58 : tensor<x1, x32, int> = arith.muli %56 %57;
 96                     %59 : tensor<x32, x32, ptr<float>> = tt.splat %12;
 97                     %60 : tensor<x32, x32, int> = tt.broadcast %55;
 98                     %61 : tensor<x32, x32, int> = tt.broadcast %58;
 99                     %62 : tensor<x32, x32, int> = arith.addi %60 %61;
100                     %63 : tensor<x32, x32, ptr<float>> = tt.addptr %59 %62;
101                     %64 : tensor<x32, x1, int> = tt.expand_dims %52 @"1";
102                     %65 : tensor<x32, x1, int> = tt.splat %20;
103                     %66 : tensor<x32, x1, int> = arith.muli %64 %65;
104                     %67 : tensor<x1, x64, int> = tt.expand_dims %51 @"0";
105                     %68 : tensor<x1, x64, int> = tt.splat %21;
106                     %69 : tensor<x1, x64, int> = arith.muli %67 %68;
107                     %70 : tensor<x32, x64, ptr<float>> = tt.splat %13;
108                     %71 : tensor<x32, x64, int> = tt.broadcast %66;
109                     %72 : tensor<x32, x64, int> = tt.broadcast %69;
110                     %73 : tensor<x32, x64, int> = arith.addi %71 %72;
111                     %74 : tensor<x32, x64, ptr<float>> = tt.addptr %70 %73;
112                     %75 : tensor<x32, x64, float> = arith.constant @"0.0";
113                     %76 : int = arith.constant @"0";
114                     %77 : int = tt.call %17 @"cdiv_int_32_int";
115                     %78 : int = arith.constant @"1";
116                     %79 : Tuple<tensor<x32, x64, float>, tensor<x32, x32, ptr<float>>, tensor<x32, x64, ptr<float>>> = scf.for %76 %77 %78 %75 %63 %74 (%80 : int, %81 : tensor<x32, x64, float>, %82 : tensor<x32, x32, ptr<float>>, %83 : tensor<x32, x64, ptr<float>>)Tuple<tensor<x32, x64, float>, tensor<x32, x32, ptr<float>>, tensor<x32, x64, ptr<float>>> -> {
117                         %84 : tensor<x1, x32, int> = tt.expand_dims %52 @"0";
118                         %85 : int = arith.muli %80 %26;
119                         %86 : int = arith.subi %17 %85;
120                         %87 : tensor<x1, x32, int> = tt.splat %86;
121                         %88 : tensor<x1, x32, int> = arith.cmpi %84 %87 @"slt";
122                         %89 : tensor<x32, x32, int> = tt.broadcast %88;
123                         %90 : tensor<x32, x32, float> = tt.load %82 %89;
124                         %91 : tensor<x32, x1, int> = tt.expand_dims %52 @"1";
125                         %92 : int = arith.muli %80 %26;
126                         %93 : int = arith.subi %17 %92;
127                         %94 : tensor<x32, x1, int> = tt.splat %93;
128                         %95 : tensor<x32, x1, int> = arith.cmpi %91 %94 @"slt";
129                         %96 : tensor<x32, x64, int> = tt.broadcast %95;
130                         %97 : tensor<x32, x64, float> = tt.load %83 %96;
131                         %98 : tensor<x32, x64, float> = tt.dot %90 %97;
132                         %99 : tensor<x32, x64, float> = arith.addf %81 %98;
133                         %100 : int = arith.muli %26 %19;
134                         %101 : tensor<x32, x32, int> = tt.splat %100;
135                         %102 : tensor<x32, x32, ptr<float>> = tt.addptr %82 %101;
136                         %103 : int = arith.muli %26 %20;
137                         %104 : tensor<x32, x64, int> = tt.splat %103;
138                         %105 : tensor<x32, x64, ptr<float>> = tt.addptr %83 %104;
139                         scf.yield %99 %102 %105;
140                     };
141                     %106 : tensor<x32, x64, float> = tuple.load %79 @"0";
142                     %107 : tensor<x32, x32, ptr<float>> = tuple.load %79 @"1";
143                     %108 : tensor<x32, x64, ptr<float>> = tuple.load %79 @"2";
144                     %109 : int = arith.muli %37 %24;
145                     %110 : tensor<x32, int> = tt.splat %109;
146                     %111 : tensor<x32, int> = arith.addi %110 %40;
147                     %112 : int = arith.muli %39 %25;
148                     %113 : tensor<x64, int> = tt.splat %112;
149                     %114 : tensor<x64, int> = arith.addi %113 %46;
150                     %115 : tensor<x32, x1, int> = tt.expand_dims %111 @"1";
151                     %116 : tensor<x32, x1, int> = tt.splat %22;
152                     %117 : tensor<x32, x1, int> = arith.muli %115 %116;
153                     %118 : tensor<x1, x64, int> = tt.expand_dims %114 @"0";
154                     %119 : tensor<x1, x64, int> = tt.splat %23;
155                     %120 : tensor<x1, x64, int> = arith.muli %118 %119;
156                     %121 : tensor<x32, x64, ptr<float>> = tt.splat %14;
157                     %122 : tensor<x32, x64, int> = tt.broadcast %117;
158                     %123 : tensor<x32, x64, int> = tt.broadcast %120;
159                     %124 : tensor<x32, x64, int> = arith.addi %122 %123;
160                     %125 : tensor<x32, x64, ptr<float>> = tt.addptr %121 %124;
161                     %126 : tensor<x32, x1, int> = tt.expand_dims %111 @"1";
162                     %127 : tensor<x32, x1, int> = tt.splat %15;
163                     %128 : tensor<x32, x1, int> = arith.cmpi %126 %127 @"slt";
164                     %129 : tensor<x1, x64, int> = tt.expand_dims %114 @"0";
165                     %130 : tensor<x1, x64, int> = tt.splat %16;
166                     %131 : tensor<x1, x64, int> = arith.cmpi %129 %130 @"slt";
167                     %132 : tensor<x32, x64, int> = tt.broadcast %128;
168                     %133 : tensor<x32, x64, int> = tt.broadcast %131;
169                     %134 : tensor<x32, x64, int> = arith.andi %132 %133;
170                     tt.store %125 %106 %134;
171                     tt.return;
172                 };
173                 unreachable;
174             };
175             """)
176     @CodeReflection
177     static void matmul_kernel_broadcast(
178             // Pointers to matrices
179             Ptr a_ptr, Ptr b_ptr, Ptr c_ptr,
180             // Matrix dimensions
181             int M, int N, int K,
182             // The stride variables represent how much to increase the ptr by when moving by 1
183             // element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr`
184             // by to get the element one row down (A has M rows).
185             int stride_am, int stride_ak,
186             int stride_bk, int stride_bn,
187             int stride_cm, int stride_cn,
188             // Meta-parameters
189             @Constant int BLOCK_SIZE_M, @Constant int BLOCK_SIZE_N, @Constant int BLOCK_SIZE_K,
190             @Constant int GROUP_SIZE_M,
191             @Constant boolean ACTIVATION) {
192 
193         // """Kernel for computing the matmul C = A x B.
194         // A has shape (M, K), B has shape (K, N) and C has shape (M, N)
195         // """
196         // -----------------------------------------------------------
197         // Map program ids `pid` to the block of C it should compute.
198         // This is done in a grouped ordering to promote L2 data reuse.
199         // See above `L2 Cache Optimizations` section for details.
200         var pid = programId(0);
201 
202         var num_pid_m = cdiv(M, BLOCK_SIZE_M);
203         var num_pid_n = cdiv(N, BLOCK_SIZE_N);
204         var num_pid_in_group = GROUP_SIZE_M * num_pid_n;
205         var group_id = pid / num_pid_in_group;
206         var first_pid_m = group_id * GROUP_SIZE_M;
207         var group_size_m = Math.min(num_pid_m - first_pid_m, GROUP_SIZE_M);
208         var pid_m = first_pid_m + (pid % group_size_m);
209         var pid_n = (pid % num_pid_in_group) / group_size_m;
210 
211         // ----------------------------------------------------------
212         // Create pointers for the first blocks of A and B.
213         // We will advance this pointer as we move in the K direction
214         // and accumulate
215         // `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
216         // `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
217         // See above `Pointer Arithmetics` section for details
218         var offs_m = arange(0, BLOCK_SIZE_M);
219         var offs_am = mod(
220                 add(broadcast(pid_m * BLOCK_SIZE_M, offs_m.type()), offs_m),
221                 broadcast(M, offs_m.type()));
222         var offs_n = arange(0, BLOCK_SIZE_N);
223         var offs_bn = mod(
224                 add(broadcast(pid_n * BLOCK_SIZE_N, offs_n.type()), offs_n),
225                 broadcast(N, offs_n.type()));
226         var offs_k = arange(0, BLOCK_SIZE_K);
227 
228         var offs_am_e = expand(offs_am, 1);
229         offs_am_e = mul(offs_am_e, broadcast(stride_am, offs_am_e.type()));
230         var offs_k_e_0 = expand(offs_k, 0);
231         offs_k_e_0 = mul(offs_k_e_0, broadcast(stride_ak, offs_k_e_0.type()));
232         TensorType a_ptrs_t = joinShape(offs_am_e.type(), offs_k_e_0.type());
233         var a_ptrs = add(broadcast(a_ptr, a_ptrs_t),
234                 add(broadcast(offs_am_e, a_ptrs_t), broadcast(offs_k_e_0, a_ptrs_t)));
235 
236         var offs_k_e_1 = expand(offs_k, 1);
237         offs_k_e_1 = mul(offs_k_e_1, broadcast(stride_bk, offs_k_e_1.type()));
238         var offs_bn_e = expand(offs_bn, 0);
239         offs_bn_e = mul(offs_bn_e, broadcast(stride_bn, offs_bn_e.type()));
240         TensorType b_ptrs_t = joinShape(offs_k_e_1.type(), offs_bn_e.type());
241         var b_ptrs = add(broadcast(b_ptr, b_ptrs_t),
242                 add(broadcast(offs_k_e_1, b_ptrs_t), broadcast(offs_bn_e, b_ptrs_t)));
243 
244         // -----------------------------------------------------------
245         // Iterate to compute a block of the C matrix.
246         // We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
247         // of fp32 values for higher accuracy.
248         // `accumulator` will be converted back to fp16 after the loop.
249         var accumulator = zeros(float.class, BLOCK_SIZE_M, BLOCK_SIZE_N);
250         for (int k = 0; k < cdiv(K, BLOCK_SIZE_K); k++) {
251             // Load the next block of A and B, generate a mask by checking the K dimension.
252             // If it is out of bounds, set it to 0.
253             var offs_k_m_0 = expand(offs_k, 0);
254             offs_k_m_0 = compare(offs_k_m_0,
255                     broadcast(K - k * BLOCK_SIZE_K, offs_k_m_0.type()),
256                     LessThan);
257             var a = load(a_ptrs, broadcast(offs_k_m_0, a_ptrs.type()));
258             var offs_k_m_1 = expand(offs_k, 1);
259             offs_k_m_1 = compare(offs_k_m_1,
260                     broadcast(K - k * BLOCK_SIZE_K, offs_k_m_1.type()),
261                     LessThan);
262             var b = load(b_ptrs, broadcast(offs_k_m_1, b_ptrs.type()));
263             // We accumulate along the K dimension.
264             accumulator = add(accumulator, dot(a, b));
265             // Advance the ptrs to the next K block.
266             a_ptrs = add(a_ptrs, broadcast(BLOCK_SIZE_K * stride_ak, a_ptrs.type()));
267             b_ptrs = add(b_ptrs, broadcast(BLOCK_SIZE_K * stride_bk, b_ptrs.type()));
268         }
269 
270         // You can fuse arbitrary activation functions here
271         // while the accumulator is still in FP32!
272 //        if (ACTIVATION) {
273 //            // ...
274 //        }
275         // c = Triton.to(activation, tl.float16)
276         var c = accumulator;
277 
278         // -----------------------------------------------------------
279         // Write back the block of the output matrix C with masks.
280         var offs_cm = add(broadcast(pid_m * BLOCK_SIZE_M, offs_m.type()), offs_m);
281         var offs_cn = add(broadcast(pid_n * BLOCK_SIZE_N, offs_n.type()), offs_n);
282 
283         var offs_cm_e = expand(offs_cm, 1);
284         offs_cm_e = mul(offs_cm_e, broadcast(stride_cm, offs_cm_e.type()));
285         var offs_cn_e = expand(offs_cn, 0);
286         offs_cn_e = mul(offs_cn_e, broadcast(stride_cn, offs_cn_e.type()));
287         TensorType c_ptrs_t = joinShape(offs_cm_e.type(), offs_cn_e.type());
288         var c_ptrs = add(broadcast(c_ptr, c_ptrs_t),
289                 add(broadcast(offs_cm_e, c_ptrs_t), broadcast(offs_cn_e, c_ptrs_t)));
290 
291         offs_cm_e = expand(offs_cm, 1);
292         var c_mask_l = compare(offs_cm_e, broadcast(M, offs_cm_e.type()), LessThan);
293         offs_cn_e = expand(offs_cn, 0);
294         var c_mask_r = compare(offs_cn_e, broadcast(N, offs_cn_e.type()), LessThan);
295         var c_mask = and(broadcast(c_mask_l, c_ptrs_t), broadcast(c_mask_r, c_ptrs_t));
296 
297         store(c_ptrs, c, c_mask);
298     }
299 
300     @TritonTestExtension.Kernel("matmul_kernel_broadcast")
301     @Test
302     public void testWithBroadcast(TritonTestExtension.TritonTestData t) {
303         List<TypeElement> argTypes = List.of(
304                 new PtrType(JavaType.FLOAT),
305                 new PtrType(JavaType.FLOAT),
306                 new PtrType(JavaType.FLOAT),
307                 JavaType.INT, JavaType.INT, JavaType.INT,
308                 JavaType.INT, JavaType.INT,
309                 JavaType.INT, JavaType.INT,
310                 JavaType.INT, JavaType.INT,
311                 new ConstantType(JavaType.INT, 32), new ConstantType(JavaType.INT, 64), new ConstantType(JavaType.INT, 32),
312                 new ConstantType(JavaType.INT, 8),
313                 new ConstantType(JavaType.INT, false));
314 
315         t.test(argTypes);
316     }
317 
318 
319     @TritonCodeModel("""
320             module ()void -> {
321                 tt.func @"cdiv_int_32_int" (%0 : int)int -> {
322                     %1 : int = arith.constant @"32";
323                     %2 : int = arith.addi %0 %1;
324                     %3 : int = arith.constant @"1";
325                     %4 : int = arith.subi %2 %3;
326                     %5 : int = arith.divsi %4 %1;
327                     tt.return %5;
328                 };
329                 tt.func @"cdiv_int_64_int" (%6 : int)int -> {
330                     %7 : int = arith.constant @"64";
331                     %8 : int = arith.addi %6 %7;
332                     %9 : int = arith.constant @"1";
333                     %10 : int = arith.subi %8 %9;
334                     %11 : int = arith.divsi %10 %7;
335                     tt.return %11;
336                 };
337                 tt.func @"matmul_kernel_ptr<oracle.code.triton.Float16>_ptr<oracle.code.triton.Float16>_ptr<oracle.code.triton.Float16>_int_int_int_int_int_int_int_int_int_32_64_32_8_false_void" (%12 : ptr<oracle.code.triton.Float16>, %13 : ptr<oracle.code.triton.Float16>, %14 : ptr<oracle.code.triton.Float16>, %15 : int, %16 : int, %17 : int, %18 : int, %19 : int, %20 : int, %21 : int, %22 : int, %23 : int)void -> {
338                     %24 : int = arith.constant @"32";
339                     %25 : int = arith.constant @"64";
340                     %26 : int = arith.constant @"32";
341                     %27 : int = arith.constant @"8";
342                     %28 : int = tt.get_program_id @"0";
343                     %29 : int = tt.call %15 @"cdiv_int_32_int";
344                     %30 : int = tt.call %16 @"cdiv_int_64_int";
345                     %31 : int = arith.muli %27 %30;
346                     %32 : int = arith.divsi %28 %31;
347                     %33 : int = arith.muli %32 %27;
348                     %34 : int = arith.subi %29 %33;
349                     %35 : int = arith.minsi %34 %27;
350                     %36 : int = arith.remsi %28 %35;
351                     %37 : int = arith.addi %33 %36;
352                     %38 : int = arith.remsi %28 %31;
353                     %39 : int = arith.divsi %38 %35;
354                     %40 : tensor<x32, int> = tt.make_range @start="0" @end="32";
355                     %41 : int = arith.muli %37 %24;
356                     %42 : tensor<x32, int> = tt.splat %41;
357                     %43 : tensor<x32, int> = arith.addi %42 %40;
358                     %44 : tensor<x32, int> = tt.splat %15;
359                     %45 : tensor<x32, int> = arith.remsi %43 %44;
360                     %46 : tensor<x64, int> = tt.make_range @start="0" @end="64";
361                     %47 : int = arith.muli %39 %25;
362                     %48 : tensor<x64, int> = tt.splat %47;
363                     %49 : tensor<x64, int> = arith.addi %48 %46;
364                     %50 : tensor<x64, int> = tt.splat %16;
365                     %51 : tensor<x64, int> = arith.remsi %49 %50;
366                     %52 : tensor<x32, int> = tt.make_range @start="0" @end="32";
367                     %53 : tensor<x32, x1, int> = tt.expand_dims %45 @"1";
368                     %54 : tensor<x32, x1, int> = tt.splat %18;
369                     %55 : tensor<x32, x1, int> = arith.muli %53 %54;
370                     %56 : tensor<x1, x32, int> = tt.expand_dims %52 @"0";
371                     %57 : tensor<x1, x32, int> = tt.splat %19;
372                     %58 : tensor<x1, x32, int> = arith.muli %56 %57;
373                     %59 : tensor<x32, x32, int> = tt.broadcast %55;
374                     %60 : tensor<x32, x32, int> = tt.broadcast %58;
375                     %61 : tensor<x32, x32, int> = arith.addi %59 %60;
376                     %62 : tensor<x32, x32, ptr<oracle.code.triton.Float16>> = tt.splat %12;
377                     %63 : tensor<x32, x32, ptr<oracle.code.triton.Float16>> = tt.addptr %62 %61;
378                     %64 : tensor<x32, x1, int> = tt.expand_dims %52 @"1";
379                     %65 : tensor<x32, x1, int> = tt.splat %20;
380                     %66 : tensor<x32, x1, int> = arith.muli %64 %65;
381                     %67 : tensor<x1, x64, int> = tt.expand_dims %51 @"0";
382                     %68 : tensor<x1, x64, int> = tt.splat %21;
383                     %69 : tensor<x1, x64, int> = arith.muli %67 %68;
384                     %70 : tensor<x32, x64, int> = tt.broadcast %66;
385                     %71 : tensor<x32, x64, int> = tt.broadcast %69;
386                     %72 : tensor<x32, x64, int> = arith.addi %70 %71;
387                     %73 : tensor<x32, x64, ptr<oracle.code.triton.Float16>> = tt.splat %13;
388                     %74 : tensor<x32, x64, ptr<oracle.code.triton.Float16>> = tt.addptr %73 %72;
389                     %75 : tensor<x32, x64, float> = arith.constant @"0.0";
390                     %76 : int = arith.constant @"0";
391                     %77 : int = tt.call %17 @"cdiv_int_32_int";
392                     %78 : int = arith.constant @"1";
393                     %79 : Tuple<tensor<x32, x64, float>, tensor<x32, x32, ptr<oracle.code.triton.Float16>>, tensor<x32, x64, ptr<oracle.code.triton.Float16>>> = scf.for %76 %77 %78 %75 %63 %74 (%80 : int, %81 : tensor<x32, x64, float>, %82 : tensor<x32, x32, ptr<oracle.code.triton.Float16>>, %83 : tensor<x32, x64, ptr<oracle.code.triton.Float16>>)Tuple<tensor<x32, x64, float>, tensor<x32, x32, ptr<oracle.code.triton.Float16>>, tensor<x32, x64, ptr<oracle.code.triton.Float16>>> -> {
394                         %84 : tensor<x1, x32, int> = tt.expand_dims %52 @"0";
395                         %85 : int = arith.muli %80 %26;
396                         %86 : int = arith.subi %17 %85;
397                         %87 : tensor<x1, x32, int> = tt.splat %86;
398                         %88 : tensor<x1, x32, int> = arith.cmpi %84 %87 @"slt";
399                         %89 : tensor<x32, x32, int> = tt.broadcast %88;
400                         %90 : tensor<x32, x32, oracle.code.triton.Float16> = tt.load %82 %89;
401                         %91 : tensor<x32, x1, int> = tt.expand_dims %52 @"1";
402                         %92 : int = arith.muli %80 %26;
403                         %93 : int = arith.subi %17 %92;
404                         %94 : tensor<x32, x1, int> = tt.splat %93;
405                         %95 : tensor<x32, x1, int> = arith.cmpi %91 %94 @"slt";
406                         %96 : tensor<x32, x64, int> = tt.broadcast %95;
407                         %97 : tensor<x32, x64, oracle.code.triton.Float16> = tt.load %83 %96;
408                         %98 : tensor<x32, x64, float> = tt.dot %90 %97;
409                         %99 : tensor<x32, x64, float> = arith.addf %81 %98;
410                         %100 : int = arith.muli %26 %19;
411                         %101 : tensor<x32, x32, int> = tt.splat %100;
412                         %102 : tensor<x32, x32, ptr<oracle.code.triton.Float16>> = tt.addptr %82 %101;
413                         %103 : int = arith.muli %26 %20;
414                         %104 : tensor<x32, x64, int> = tt.splat %103;
415                         %105 : tensor<x32, x64, ptr<oracle.code.triton.Float16>> = tt.addptr %83 %104;
416                         scf.yield %99 %102 %105;
417                     };
418                     %106 : tensor<x32, x64, float> = tuple.load %79 @"0";
419                     %107 : tensor<x32, x32, ptr<oracle.code.triton.Float16>> = tuple.load %79 @"1";
420                     %108 : tensor<x32, x64, ptr<oracle.code.triton.Float16>> = tuple.load %79 @"2";
421                     %109 : tensor<x32, x64, oracle.code.triton.Float16> = arith.truncf %106;
422                     %110 : int = arith.muli %37 %24;
423                     %111 : tensor<x32, int> = tt.splat %110;
424                     %112 : tensor<x32, int> = arith.addi %111 %40;
425                     %113 : int = arith.muli %39 %25;
426                     %114 : tensor<x64, int> = tt.splat %113;
427                     %115 : tensor<x64, int> = arith.addi %114 %46;
428                     %116 : tensor<x32, x1, int> = tt.expand_dims %112 @"1";
429                     %117 : tensor<x32, x1, int> = tt.splat %22;
430                     %118 : tensor<x32, x1, int> = arith.muli %117 %116;
431                     %119 : tensor<x1, x64, int> = tt.expand_dims %115 @"0";
432                     %120 : tensor<x1, x64, int> = tt.splat %23;
433                     %121 : tensor<x1, x64, int> = arith.muli %120 %119;
434                     %122 : tensor<x32, x64, int> = tt.broadcast %118;
435                     %123 : tensor<x32, x64, int> = tt.broadcast %121;
436                     %124 : tensor<x32, x64, int> = arith.addi %122 %123;
437                     %125 : tensor<x32, x64, ptr<oracle.code.triton.Float16>> = tt.splat %14;
438                     %126 : tensor<x32, x64, ptr<oracle.code.triton.Float16>> = tt.addptr %125 %124;
439                     %127 : tensor<x32, x1, int> = tt.expand_dims %112 @"1";
440                     %128 : tensor<x32, x1, int> = tt.splat %15;
441                     %129 : tensor<x32, x1, int> = arith.cmpi %127 %128 @"slt";
442                     %130 : tensor<x1, x64, int> = tt.expand_dims %115 @"0";
443                     %131 : tensor<x1, x64, int> = tt.splat %16;
444                     %132 : tensor<x1, x64, int> = arith.cmpi %130 %131 @"slt";
445                     %133 : tensor<x32, x64, int> = tt.broadcast %129;
446                     %134 : tensor<x32, x64, int> = tt.broadcast %132;
447                     %135 : tensor<x32, x64, int> = arith.andi %133 %134;
448                     tt.store %126 %109 %135;
449                     tt.return;
450                 };
451                 unreachable;
452             };
453             """)
454     @CodeReflection
455     static void matmul_kernel(
456             // Pointers to matrices
457             Ptr a_ptr, Ptr b_ptr, Ptr c_ptr,
458             // Matrix dimensions
459             int M, int N, int K,
460             // The stride variables represent how much to increase the ptr by when moving by 1
461             // element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr`
462             // by to get the element one row down (A has M rows).
463             int stride_am, int stride_ak,
464             int stride_bk, int stride_bn,
465             int stride_cm, int stride_cn,
466             // Meta-parameters
467             @Constant int BLOCK_SIZE_M, @Constant int BLOCK_SIZE_N, @Constant int BLOCK_SIZE_K,
468             @Constant int GROUP_SIZE_M,
469             @Constant boolean ACTIVATION) {
470 
471         // """Kernel for computing the matmul C = A x B.
472         // A has shape (M, K), B has shape (K, N) and C has shape (M, N)
473         // """
474         // -----------------------------------------------------------
475         // Map program ids `pid` to the block of C it should compute.
476         // This is done in a grouped ordering to promote L2 data reuse.
477         // See above `L2 Cache Optimizations` section for details.
478         var pid = programId(0);
479         var num_pid_m = cdiv(M, BLOCK_SIZE_M);
480         var num_pid_n = cdiv(N, BLOCK_SIZE_N);
481         var num_pid_in_group = GROUP_SIZE_M * num_pid_n;
482         var group_id = pid / num_pid_in_group;
483         var first_pid_m = group_id * GROUP_SIZE_M;
484         var group_size_m = Math.min(num_pid_m - first_pid_m, GROUP_SIZE_M);
485         var pid_m = first_pid_m + (pid % group_size_m);
486         var pid_n = (pid % num_pid_in_group) / group_size_m;
487 
488         // ----------------------------------------------------------
489         // Create pointers for the first blocks of A and B.
490         // We will advance this pointer as we move in the K direction
491         // and accumulate
492         // `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
493         // `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
494         // See above `Pointer Arithmetics` section for details
495         var offs_m = arange(0, BLOCK_SIZE_M);
496         var offs_am = mod(add(pid_m * BLOCK_SIZE_M, offs_m), M);
497         var offs_n = arange(0, BLOCK_SIZE_N);
498         var offs_bn = mod(add(pid_n * BLOCK_SIZE_N, offs_n), N);
499         var offs_k = arange(0, BLOCK_SIZE_K);
500         var a_ptrs = add(a_ptr, add(
501                 mul(expand(offs_am, 1), stride_am),
502                 mul(expand(offs_k, 0), stride_ak)));
503         var b_ptrs = add(b_ptr, add(
504                         mul(expand(offs_k, 1), stride_bk),
505                         mul(expand(offs_bn, 0), stride_bn)));
506 
507         // -----------------------------------------------------------
508         // Iterate to compute a block of the C matrix.
509         // We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
510         // of fp32 values for higher accuracy.
511         // `accumulator` will be converted back to fp16 after the loop.
512         var accumulator = zeros(float.class, BLOCK_SIZE_M, BLOCK_SIZE_N);
513         for (int k = 0; k < cdiv(K, BLOCK_SIZE_K); k++) {
514             // Load the next block of A and B, generate a mask by checking the K dimension.
515             // If it is out of bounds, set it to 0.
516             var a = load(a_ptrs,
517                     compare(expand(offs_k, 0), K - k * BLOCK_SIZE_K, LessThan));
518             var b = load(b_ptrs,
519                     compare(expand(offs_k, 1), K - k * BLOCK_SIZE_K, LessThan));
520             // We accumulate along the K dimension.
521             accumulator = add(accumulator, dot(a, b));
522             // Advance the ptrs to the next K block.
523             a_ptrs = add(a_ptrs, BLOCK_SIZE_K * stride_ak);
524             b_ptrs = add(b_ptrs, BLOCK_SIZE_K * stride_bk);
525         }
526 
527         // You can fuse arbitrary activation functions here
528         // while the accumulator is still in FP32!
529 //        if (ACTIVATION) {
530 //            // ...
531 //        }
532         var c = Triton.conv(Float16.class, accumulator);
533 
534         // -----------------------------------------------------------
535         // Write back the block of the output matrix C with masks.
536         var offs_cm = add(pid_m * BLOCK_SIZE_M, offs_m);
537         var offs_cn = add(pid_n * BLOCK_SIZE_N, offs_n);
538         var c_ptrs = add(c_ptr, add(
539                         mul(stride_cm, expand(offs_cm, 1)),
540                         mul(stride_cn, expand(offs_cn, 0))));
541         var c_mask = and(
542                 compare(expand(offs_cm, 1), M, LessThan),
543                 compare(expand(offs_cn, 0), N, LessThan));
544         store(c_ptrs, c, c_mask);
545     }
546 
547     @TritonTestExtension.Kernel("matmul_kernel")
548     @Test
549     public void test(TritonTestExtension.TritonTestData t) {
550         List<TypeElement> argTypes = List.of(
551                 new PtrType(Float16.FLOAT_16_TYPE),
552                 new PtrType(Float16.FLOAT_16_TYPE),
553                 new PtrType(Float16.FLOAT_16_TYPE),
554                 JavaType.INT, JavaType.INT, JavaType.INT,
555                 JavaType.INT, JavaType.INT,
556                 JavaType.INT, JavaType.INT,
557                 JavaType.INT, JavaType.INT,
558                 new ConstantType(JavaType.INT, 32), new ConstantType(JavaType.INT, 64), new ConstantType(JavaType.INT, 32),
559                 new ConstantType(JavaType.INT, 8),
560                 new ConstantType(JavaType.INT, false));
561 
562         t.test(argTypes);
563     }
564 
565 }
566 
567 /*
568 @triton.jit
569 def matmul_kernel(
570         # Pointers to matrices
571         a_ptr, b_ptr, c_ptr,
572         # Matrix dimensions
573         M, N, K,
574         # The stride variables represent how much to increase the ptr by when moving by 1
575         # element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr`
576         # by to get the element one row down (A has M rows).
577         stride_am, stride_ak,  #
578         stride_bk, stride_bn,  #
579         stride_cm, stride_cn,
580         # Meta-parameters
581         BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,  #
582         GROUP_SIZE_M: tl.constexpr,  #
583         ACTIVATION: tl.constexpr  #
584 ):
585     """Kernel for computing the matmul C = A x B.
586     A has shape (M, K), B has shape (K, N) and C has shape (M, N)
587     """
588     # -----------------------------------------------------------
589     # Map program ids `pid` to the block of C it should compute.
590     # This is done in a grouped ordering to promote L2 data reuse.
591     # See above `L2 Cache Optimizations` section for details.
592     pid = tl.program_id(axis=0)
593     num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
594     num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
595     num_pid_in_group = GROUP_SIZE_M * num_pid_n
596     group_id = pid // num_pid_in_group
597     first_pid_m = group_id * GROUP_SIZE_M
598     group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
599     pid_m = first_pid_m + (pid % group_size_m)
600     pid_n = (pid % num_pid_in_group) // group_size_m
601 
602     # ----------------------------------------------------------
603     # Create pointers for the first blocks of A and B.
604     # We will advance this pointer as we move in the K direction
605     # and accumulate
606     # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
607     # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
608     # See above `Pointer Arithmetics` section for details
609     offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
610     offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
611     offs_k = tl.arange(0, BLOCK_SIZE_K)
612     a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
613     b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
614 
615     # -----------------------------------------------------------
616     # Iterate to compute a block of the C matrix.
617     # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
618     # of fp32 values for higher accuracy.
619     # `accumulator` will be converted back to fp16 after the loop.
620     accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
621     for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
622         # Load the next block of A and B, generate a mask by checking the K dimension.
623         # If it is out of bounds, set it to 0.
624         a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
625         b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
626         # We accumulate along the K dimension.
627         accumulator += tl.dot(a, b)
628         # Advance the ptrs to the next K block.
629         a_ptrs += BLOCK_SIZE_K * stride_ak
630         b_ptrs += BLOCK_SIZE_K * stride_bk
631     # You can fuse arbitrary activation functions here
632     # while the accumulator is still in FP32!
633     if ACTIVATION == "leaky_relu":
634         accumulator = leaky_relu(accumulator)
635     c = accumulator.to(tl.float16)
636 
637     # -----------------------------------------------------------
638     # Write back the block of the output matrix C with masks.
639     offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
640     offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
641     c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
642     c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
643     tl.store(c_ptrs, c, mask=c_mask)
644 
645 
646 # We can fuse `leaky_relu` by providing it as an `ACTIVATION` meta-parameter in `_matmul`.
647 @triton.jit
648 def leaky_relu(x):
649     x = x + 1
650     return tl.where(x >= 0, x, 0.01 * x)
651 */
652 
653 /*
654 
655  triton/python/triton/tools/compile.py \
656     --kernel-name matmul_kernel \
657     --signature "*fp16,*fp16,*fp16,i32,i32,i32,i32,i32,i32,i32,i32,i32,32,64,32,8,0" \
658     --grid=1024,1024,1024 \
659     03-matrix-multiplication.py
660 
661 BLOCK_SIZE_M = 32
662 BLOCK_SIZE_N = 64
663 BLOCK_SIZE_K = 32
664 GROUP_SIZE_M = 8
665 ACTIVATION = 0
666 
667 module {
668   tt.func public @matmul_kernel_01234567891011(
669             %arg0: !tt.ptr<f16, 1>, %arg1: !tt.ptr<f16, 1>, %arg2: !tt.ptr<f16, 1> ,
670             %arg3: i32, %arg4: i32, %arg5: i32 ,
671             %arg6: i32, %arg7: i32, %arg8: i32 ,
672             %%arg9: i32, %arg10: i32, %arg11: i32 ) attributes {noinline = false} {
673     %0 = tt.get_program_id x : i32
674     %1 = tt.call @cdiv__i32__1cconstexpr_32_(%arg3) : (i32) -> i32
675     %2 = tt.call @cdiv__i32__1cconstexpr_64_(%arg4) : (i32) -> i32
676     %c8_i32 = arith.constant 8 : i32
677     %3 = arith.muli %2, %c8_i32 : i32
678     %4 = arith.divsi %0, %3 : i32
679     %c8_i32_0 = arith.constant 8 : i32
680     %5 = arith.muli %4, %c8_i32_0 : i32
681     %6 = arith.subi %1, %5 : i32
682     %7 = tt.call @minimum__i32__1cconstexpr_8_(%6) : (i32) -> i32
683     %8 = arith.remsi %0, %7 : i32
684     %9 = arith.addi %5, %8 : i32
685     %10 = arith.remsi %0, %3 : i32
686     %11 = arith.divsi %10, %7 : i32
687     %c32_i32 = arith.constant 32 : i32
688     %12 = arith.muli %9, %c32_i32 : i32
689     %13 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32>
690     %14 = tt.splat %12 : (i32) -> tensor<32xi32>
691     %15 = arith.addi %14, %13 : tensor<32xi32>
692     %16 = tt.splat %arg3 : (i32) -> tensor<32xi32>
693     %17 = arith.remsi %15, %16 : tensor<32xi32>
694     %c64_i32 = arith.constant 64 : i32
695     %18 = arith.muli %11, %c64_i32 : i32
696     %19 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
697     %20 = tt.splat %18 : (i32) -> tensor<64xi32>
698     %21 = arith.addi %20, %19 : tensor<64xi32>
699     %22 = tt.splat %arg4 : (i32) -> tensor<64xi32>
700     %23 = arith.remsi %21, %22 : tensor<64xi32>
701     %24 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32>
702     %25 = tt.expand_dims %17 {axis = 1 : i32} : (tensor<32xi32>) -> tensor<32x1xi32>
703     %26 = tt.splat %arg6 : (i32) -> tensor<32x1xi32>
704     %27 = arith.muli %25, %26 : tensor<32x1xi32>
705     %28 = tt.expand_dims %24 {axis = 0 : i32} : (tensor<32xi32>) -> tensor<1x32xi32>
706     %29 = tt.splat %arg7 : (i32) -> tensor<1x32xi32>
707     %30 = arith.muli %28, %29 : tensor<1x32xi32>
708     %31 = tt.broadcast %27 : (tensor<32x1xi32>) -> tensor<32x32xi32>
709     %32 = tt.broadcast %30 : (tensor<1x32xi32>) -> tensor<32x32xi32>
710     %33 = arith.addi %31, %32 : tensor<32x32xi32>
711     %34 = tt.splat %arg0 : (!tt.ptr<f16, 1>) -> tensor<32x32x!tt.ptr<f16, 1>>
712     %35 = tt.addptr %34, %33 : tensor<32x32x!tt.ptr<f16, 1>>, tensor<32x32xi32>
713     %36 = tt.expand_dims %24 {axis = 1 : i32} : (tensor<32xi32>) -> tensor<32x1xi32>
714     %37 = tt.splat %arg8 : (i32) -> tensor<32x1xi32>
715     %38 = arith.muli %36, %37 : tensor<32x1xi32>
716     %39 = tt.expand_dims %23 {axis = 0 : i32} : (tensor<64xi32>) -> tensor<1x64xi32>
717     %40 = tt.splat %arg9 : (i32) -> tensor<1x64xi32>
718     %41 = arith.muli %39, %40 : tensor<1x64xi32>
719     %42 = tt.broadcast %38 : (tensor<32x1xi32>) -> tensor<32x64xi32>
720     %43 = tt.broadcast %41 : (tensor<1x64xi32>) -> tensor<32x64xi32>
721     %44 = arith.addi %42, %43 : tensor<32x64xi32>
722     %45 = tt.splat %arg1 : (!tt.ptr<f16, 1>) -> tensor<32x64x!tt.ptr<f16, 1>>
723     %46 = tt.addptr %45, %44 : tensor<32x64x!tt.ptr<f16, 1>>, tensor<32x64xi32>
724     %47 = tt.call @"zeros____0cconstexpr_(constexpr_32_, constexpr_64_)__1cconstexpr_fp32_"() : () -> tensor<32x64xf32>
725     %48 = tt.call @cdiv__i32__1cconstexpr_32_(%arg5) : (i32) -> i32
726     %c0_i32 = arith.constant 0 : i32
727     %c1_i32 = arith.constant 1 : i32
728     %49 = arith.bitcast %c0_i32 : i32 to i32
729     %50 = arith.bitcast %48 : i32 to i32
730     %51 = arith.bitcast %c1_i32 : i32 to i32
731     %52 = llvm.mlir.undef : i32
732     %53:3 = scf.for %arg12 = %49 to %50 step %51 iter_args(%arg13 = %47, %arg14 = %35, %arg15 = %46) -> (tensor<32x64xf32>, tensor<32x32x!tt.ptr<f16, 1>>, tensor<32x64x!tt.ptr<f16, 1>>)  : i32 {
733       %83 = tt.expand_dims %24 {axis = 0 : i32} : (tensor<32xi32>) -> tensor<1x32xi32>
734       %c32_i32_3 = arith.constant 32 : i32
735       %84 = arith.muli %arg12, %c32_i32_3 : i32
736       %85 = arith.subi %arg5, %84 : i32
737       %86 = tt.splat %85 : (i32) -> tensor<1x32xi32>
738       %87 = arith.cmpi slt, %83, %86 : tensor<1x32xi32>
739       %cst = arith.constant 0.000000e+00 : f32
740       %88 = tt.broadcast %87 : (tensor<1x32xi1>) -> tensor<32x32xi1>
741       %cst_4 = arith.constant dense<0.000000e+00> : tensor<32x32xf32>
742       %89 = arith.truncf %cst_4 : tensor<32x32xf32> to tensor<32x32xf16>
743       %90 = tt.load %arg14, %88, %89 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x32xf16>
744       %91 = tt.expand_dims %24 {axis = 1 : i32} : (tensor<32xi32>) -> tensor<32x1xi32>
745       %c32_i32_5 = arith.constant 32 : i32
746       %92 = arith.muli %arg12, %c32_i32_5 : i32
747       %93 = arith.subi %arg5, %92 : i32
748       %94 = tt.splat %93 : (i32) -> tensor<32x1xi32>
749       %95 = arith.cmpi slt, %91, %94 : tensor<32x1xi32>
750       %cst_6 = arith.constant 0.000000e+00 : f32
751       %96 = tt.broadcast %95 : (tensor<32x1xi1>) -> tensor<32x64xi1>
752       %cst_7 = arith.constant dense<0.000000e+00> : tensor<32x64xf32>
753       %97 = arith.truncf %cst_7 : tensor<32x64xf32> to tensor<32x64xf16>
754       %98 = tt.load %arg15, %96, %97 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x64xf16>
755       %cst_8 = arith.constant 0.000000e+00 : f32
756       %cst_9 = arith.constant dense<0.000000e+00> : tensor<32x64xf32>
757       %99 = tt.dot %90, %98, %cst_9 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<32x32xf16> * tensor<32x64xf16> -> tensor<32x64xf32>
758       %100 = arith.addf %arg13, %99 : tensor<32x64xf32>
759       %c32_i32_10 = arith.constant 32 : i32
760       %101 = arith.muli %arg7, %c32_i32_10 : i32
761       %102 = tt.splat %101 : (i32) -> tensor<32x32xi32>
762       %103 = tt.addptr %arg14, %102 : tensor<32x32x!tt.ptr<f16, 1>>, tensor<32x32xi32>
763       %c32_i32_11 = arith.constant 32 : i32
764       %104 = arith.muli %arg8, %c32_i32_11 : i32
765       %105 = tt.splat %104 : (i32) -> tensor<32x64xi32>
766       %106 = tt.addptr %arg15, %105 : tensor<32x64x!tt.ptr<f16, 1>>, tensor<32x64xi32>
767       scf.yield %100, %103, %106 : tensor<32x64xf32>, tensor<32x32x!tt.ptr<f16, 1>>, tensor<32x64x!tt.ptr<f16, 1>>
768     }
769     %54 = arith.truncf %53#0 : tensor<32x64xf32> to tensor<32x64xf16>
770     %c32_i32_1 = arith.constant 32 : i32
771     %55 = arith.muli %9, %c32_i32_1 : i32
772     %56 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32>
773     %57 = tt.splat %55 : (i32) -> tensor<32xi32>
774     %58 = arith.addi %57, %56 : tensor<32xi32>
775     %c64_i32_2 = arith.constant 64 : i32
776     %59 = arith.muli %11, %c64_i32_2 : i32
777     %60 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
778     %61 = tt.splat %59 : (i32) -> tensor<64xi32>
779     %62 = arith.addi %61, %60 : tensor<64xi32>
780     %63 = tt.expand_dims %58 {axis = 1 : i32} : (tensor<32xi32>) -> tensor<32x1xi32>
781     %64 = tt.splat %arg10 : (i32) -> tensor<32x1xi32>
782     %65 = arith.muli %64, %63 : tensor<32x1xi32>
783     %66 = tt.splat %arg2 : (!tt.ptr<f16, 1>) -> tensor<32x1x!tt.ptr<f16, 1>>
784     %67 = tt.addptr %66, %65 : tensor<32x1x!tt.ptr<f16, 1>>, tensor<32x1xi32>
785     %68 = tt.expand_dims %62 {axis = 0 : i32} : (tensor<64xi32>) -> tensor<1x64xi32>
786     %69 = tt.splat %arg11 : (i32) -> tensor<1x64xi32>
787     %70 = arith.muli %69, %68 : tensor<1x64xi32>
788     %71 = tt.broadcast %67 : (tensor<32x1x!tt.ptr<f16, 1>>) -> tensor<32x64x!tt.ptr<f16, 1>>
789     %72 = tt.broadcast %70 : (tensor<1x64xi32>) -> tensor<32x64xi32>
790     %73 = tt.addptr %71, %72 : tensor<32x64x!tt.ptr<f16, 1>>, tensor<32x64xi32>
791     %74 = tt.expand_dims %58 {axis = 1 : i32} : (tensor<32xi32>) -> tensor<32x1xi32>
792     %75 = tt.splat %arg3 : (i32) -> tensor<32x1xi32>
793     %76 = arith.cmpi slt, %74, %75 : tensor<32x1xi32>
794     %77 = tt.expand_dims %62 {axis = 0 : i32} : (tensor<64xi32>) -> tensor<1x64xi32>
795     %78 = tt.splat %arg4 : (i32) -> tensor<1x64xi32>
796     %79 = arith.cmpi slt, %77, %78 : tensor<1x64xi32>
797     %80 = tt.broadcast %76 : (tensor<32x1xi1>) -> tensor<32x64xi1>
798     %81 = tt.broadcast %79 : (tensor<1x64xi1>) -> tensor<32x64xi1>
799     %82 = arith.andi %80, %81 : tensor<32x64xi1>
800     tt.store %73, %54, %82 {cache = 1 : i32, evict = 1 : i32} : tensor<32x64xf16>
801     tt.return
802   }
803   tt.func private @cdiv__i32__1cconstexpr_32_(%arg0: i32 ) -> i32 attributes {noinline = false} {
804     %c32_i32 = arith.constant 32 : i32
805     %0 = arith.addi %arg0, %c32_i32 : i32
806     %c1_i32 = arith.constant 1 : i32
807     %1 = arith.subi %0, %c1_i32 : i32
808     %c32_i32_0 = arith.constant 32 : i32
809     %2 = arith.divsi %1, %c32_i32_0 : i32
810     tt.return %2 : i32
811   }
812   tt.func private @cdiv__i32__1cconstexpr_64_(%arg0: i32 ) -> i32 attributes {noinline = false} {
813     %c64_i32 = arith.constant 64 : i32
814     %0 = arith.addi %arg0, %c64_i32 : i32
815     %c1_i32 = arith.constant 1 : i32
816     %1 = arith.subi %0, %c1_i32 : i32
817     %c64_i32_0 = arith.constant 64 : i32
818     %2 = arith.divsi %1, %c64_i32_0 : i32
819     tt.return %2 : i32
820   }
821   tt.func private @minimum__i32__1cconstexpr_8_(%arg0: i32 ) -> i32 attributes {noinline = false} {
822     %c8_i32 = arith.constant 8 : i32
823     %0 = arith.minsi %arg0, %c8_i32 : i32
824     tt.return %0 : i32
825   }
826   tt.func private @"zeros____0cconstexpr_(constexpr_32_, constexpr_64_)__1cconstexpr_fp32_"() -> tensor<32x64xf32> attributes {noinline = false} {
827     %cst = arith.constant 0.000000e+00 : f32
828     %cst_0 = arith.constant dense<0.000000e+00> : tensor<32x64xf32>
829     tt.return %cst_0 : tensor<32x64xf32>
830   }
831 }
832  */