1 /* 2 * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved. 3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. 4 * 5 * This code is free software; you can redistribute it and/or modify it 6 * under the terms of the GNU General Public License version 2 only, as 7 * published by the Free Software Foundation. 8 * 9 * This code is distributed in the hope that it will be useful, but WITHOUT 10 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or 11 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License 12 * version 2 for more details (a copy is included in the LICENSE file that 13 * accompanied this code). 14 * 15 * You should have received a copy of the GNU General Public License version 16 * 2 along with this work; if not, write to the Free Software Foundation, 17 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. 18 * 19 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA 20 * or visit www.oracle.com if you need additional information or have any 21 * questions. 22 */ 23 24 package oracle.code.triton; 25 26 import org.junit.jupiter.api.Test; 27 import org.junit.jupiter.api.extension.ExtendWith; 28 29 import java.lang.reflect.code.TypeElement; 30 import java.lang.reflect.code.type.JavaType; 31 import java.lang.runtime.CodeReflection; 32 import java.util.List; 33 34 import static oracle.code.triton.Triton.*; 35 import static oracle.code.triton.Triton.CompareKind.*; 36 import static oracle.code.triton.Triton.compare; 37 import static oracle.code.triton.Triton.load; 38 39 @ExtendWith(TritonTestExtension.class) 40 public class TestMatrix { 41 42 @TritonCodeModel(""" 43 module ()void -> { 44 tt.func @"cdiv_int_32_int" (%0 : int)int -> { 45 %1 : int = arith.constant @"32"; 46 %2 : int = arith.addi %0 %1; 47 %3 : int = arith.constant @"1"; 48 %4 : int = arith.subi %2 %3; 49 %5 : int = arith.divsi %4 %1; 50 tt.return %5; 51 }; 52 tt.func @"cdiv_int_64_int" (%6 : int)int -> { 53 %7 : int = arith.constant @"64"; 54 %8 : int = arith.addi %6 %7; 55 %9 : int = arith.constant @"1"; 56 %10 : int = arith.subi %8 %9; 57 %11 : int = arith.divsi %10 %7; 58 tt.return %11; 59 }; 60 tt.func @"matmul_kernel_broadcast_ptr<float>_ptr<float>_ptr<float>_int_int_int_int_int_int_int_int_int_32_64_32_8_false_void" (%12 : ptr<float>, %13 : ptr<float>, %14 : ptr<float>, %15 : int, %16 : int, %17 : int, %18 : int, %19 : int, %20 : int, %21 : int, %22 : int, %23 : int)void -> { 61 %24 : int = arith.constant @"32"; 62 %25 : int = arith.constant @"64"; 63 %26 : int = arith.constant @"32"; 64 %27 : int = arith.constant @"8"; 65 %28 : int = tt.get_program_id @"0"; 66 %29 : int = tt.call %15 @"cdiv_int_32_int"; 67 %30 : int = tt.call %16 @"cdiv_int_64_int"; 68 %31 : int = arith.muli %27 %30; 69 %32 : int = arith.divsi %28 %31; 70 %33 : int = arith.muli %32 %27; 71 %34 : int = arith.subi %29 %33; 72 %35 : int = arith.minsi %34 %27; 73 %36 : int = arith.remsi %28 %35; 74 %37 : int = arith.addi %33 %36; 75 %38 : int = arith.remsi %28 %31; 76 %39 : int = arith.divsi %38 %35; 77 %40 : tensor<x32, int> = tt.make_range @start="0" @end="32"; 78 %41 : int = arith.muli %37 %24; 79 %42 : tensor<x32, int> = tt.splat %41; 80 %43 : tensor<x32, int> = arith.addi %42 %40; 81 %44 : tensor<x32, int> = tt.splat %15; 82 %45 : tensor<x32, int> = arith.remsi %43 %44; 83 %46 : tensor<x64, int> = tt.make_range @start="0" @end="64"; 84 %47 : int = arith.muli %39 %25; 85 %48 : tensor<x64, int> = tt.splat %47; 86 %49 : tensor<x64, int> = arith.addi %48 %46; 87 %50 : tensor<x64, int> = tt.splat %16; 88 %51 : tensor<x64, int> = arith.remsi %49 %50; 89 %52 : tensor<x32, int> = tt.make_range @start="0" @end="32"; 90 %53 : tensor<x32, x1, int> = tt.expand_dims %45 @"1"; 91 %54 : tensor<x32, x1, int> = tt.splat %18; 92 %55 : tensor<x32, x1, int> = arith.muli %53 %54; 93 %56 : tensor<x1, x32, int> = tt.expand_dims %52 @"0"; 94 %57 : tensor<x1, x32, int> = tt.splat %19; 95 %58 : tensor<x1, x32, int> = arith.muli %56 %57; 96 %59 : tensor<x32, x32, ptr<float>> = tt.splat %12; 97 %60 : tensor<x32, x32, int> = tt.broadcast %55; 98 %61 : tensor<x32, x32, int> = tt.broadcast %58; 99 %62 : tensor<x32, x32, int> = arith.addi %60 %61; 100 %63 : tensor<x32, x32, ptr<float>> = tt.addptr %59 %62; 101 %64 : tensor<x32, x1, int> = tt.expand_dims %52 @"1"; 102 %65 : tensor<x32, x1, int> = tt.splat %20; 103 %66 : tensor<x32, x1, int> = arith.muli %64 %65; 104 %67 : tensor<x1, x64, int> = tt.expand_dims %51 @"0"; 105 %68 : tensor<x1, x64, int> = tt.splat %21; 106 %69 : tensor<x1, x64, int> = arith.muli %67 %68; 107 %70 : tensor<x32, x64, ptr<float>> = tt.splat %13; 108 %71 : tensor<x32, x64, int> = tt.broadcast %66; 109 %72 : tensor<x32, x64, int> = tt.broadcast %69; 110 %73 : tensor<x32, x64, int> = arith.addi %71 %72; 111 %74 : tensor<x32, x64, ptr<float>> = tt.addptr %70 %73; 112 %75 : tensor<x32, x64, float> = arith.constant @"0.0"; 113 %76 : int = arith.constant @"0"; 114 %77 : int = tt.call %17 @"cdiv_int_32_int"; 115 %78 : int = arith.constant @"1"; 116 %79 : Tuple<tensor<x32, x64, float>, tensor<x32, x32, ptr<float>>, tensor<x32, x64, ptr<float>>> = scf.for %76 %77 %78 %75 %63 %74 (%80 : int, %81 : tensor<x32, x64, float>, %82 : tensor<x32, x32, ptr<float>>, %83 : tensor<x32, x64, ptr<float>>)Tuple<tensor<x32, x64, float>, tensor<x32, x32, ptr<float>>, tensor<x32, x64, ptr<float>>> -> { 117 %84 : tensor<x1, x32, int> = tt.expand_dims %52 @"0"; 118 %85 : int = arith.muli %80 %26; 119 %86 : int = arith.subi %17 %85; 120 %87 : tensor<x1, x32, int> = tt.splat %86; 121 %88 : tensor<x1, x32, int> = arith.cmpi %84 %87 @"slt"; 122 %89 : tensor<x32, x32, int> = tt.broadcast %88; 123 %90 : tensor<x32, x32, float> = tt.load %82 %89; 124 %91 : tensor<x32, x1, int> = tt.expand_dims %52 @"1"; 125 %92 : int = arith.muli %80 %26; 126 %93 : int = arith.subi %17 %92; 127 %94 : tensor<x32, x1, int> = tt.splat %93; 128 %95 : tensor<x32, x1, int> = arith.cmpi %91 %94 @"slt"; 129 %96 : tensor<x32, x64, int> = tt.broadcast %95; 130 %97 : tensor<x32, x64, float> = tt.load %83 %96; 131 %98 : tensor<x32, x64, float> = tt.dot %90 %97; 132 %99 : tensor<x32, x64, float> = arith.addf %81 %98; 133 %100 : int = arith.muli %26 %19; 134 %101 : tensor<x32, x32, int> = tt.splat %100; 135 %102 : tensor<x32, x32, ptr<float>> = tt.addptr %82 %101; 136 %103 : int = arith.muli %26 %20; 137 %104 : tensor<x32, x64, int> = tt.splat %103; 138 %105 : tensor<x32, x64, ptr<float>> = tt.addptr %83 %104; 139 scf.yield %99 %102 %105; 140 }; 141 %106 : tensor<x32, x64, float> = tuple.load %79 @"0"; 142 %107 : tensor<x32, x32, ptr<float>> = tuple.load %79 @"1"; 143 %108 : tensor<x32, x64, ptr<float>> = tuple.load %79 @"2"; 144 %109 : int = arith.muli %37 %24; 145 %110 : tensor<x32, int> = tt.splat %109; 146 %111 : tensor<x32, int> = arith.addi %110 %40; 147 %112 : int = arith.muli %39 %25; 148 %113 : tensor<x64, int> = tt.splat %112; 149 %114 : tensor<x64, int> = arith.addi %113 %46; 150 %115 : tensor<x32, x1, int> = tt.expand_dims %111 @"1"; 151 %116 : tensor<x32, x1, int> = tt.splat %22; 152 %117 : tensor<x32, x1, int> = arith.muli %115 %116; 153 %118 : tensor<x1, x64, int> = tt.expand_dims %114 @"0"; 154 %119 : tensor<x1, x64, int> = tt.splat %23; 155 %120 : tensor<x1, x64, int> = arith.muli %118 %119; 156 %121 : tensor<x32, x64, ptr<float>> = tt.splat %14; 157 %122 : tensor<x32, x64, int> = tt.broadcast %117; 158 %123 : tensor<x32, x64, int> = tt.broadcast %120; 159 %124 : tensor<x32, x64, int> = arith.addi %122 %123; 160 %125 : tensor<x32, x64, ptr<float>> = tt.addptr %121 %124; 161 %126 : tensor<x32, x1, int> = tt.expand_dims %111 @"1"; 162 %127 : tensor<x32, x1, int> = tt.splat %15; 163 %128 : tensor<x32, x1, int> = arith.cmpi %126 %127 @"slt"; 164 %129 : tensor<x1, x64, int> = tt.expand_dims %114 @"0"; 165 %130 : tensor<x1, x64, int> = tt.splat %16; 166 %131 : tensor<x1, x64, int> = arith.cmpi %129 %130 @"slt"; 167 %132 : tensor<x32, x64, int> = tt.broadcast %128; 168 %133 : tensor<x32, x64, int> = tt.broadcast %131; 169 %134 : tensor<x32, x64, int> = arith.andi %132 %133; 170 tt.store %125 %106 %134; 171 tt.return; 172 }; 173 unreachable; 174 }; 175 """) 176 @CodeReflection 177 static void matmul_kernel_broadcast( 178 // Pointers to matrices 179 Ptr a_ptr, Ptr b_ptr, Ptr c_ptr, 180 // Matrix dimensions 181 int M, int N, int K, 182 // The stride variables represent how much to increase the ptr by when moving by 1 183 // element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr` 184 // by to get the element one row down (A has M rows). 185 int stride_am, int stride_ak, 186 int stride_bk, int stride_bn, 187 int stride_cm, int stride_cn, 188 // Meta-parameters 189 @Constant int BLOCK_SIZE_M, @Constant int BLOCK_SIZE_N, @Constant int BLOCK_SIZE_K, 190 @Constant int GROUP_SIZE_M, 191 @Constant boolean ACTIVATION) { 192 193 // """Kernel for computing the matmul C = A x B. 194 // A has shape (M, K), B has shape (K, N) and C has shape (M, N) 195 // """ 196 // ----------------------------------------------------------- 197 // Map program ids `pid` to the block of C it should compute. 198 // This is done in a grouped ordering to promote L2 data reuse. 199 // See above `L2 Cache Optimizations` section for details. 200 var pid = programId(0); 201 202 var num_pid_m = cdiv(M, BLOCK_SIZE_M); 203 var num_pid_n = cdiv(N, BLOCK_SIZE_N); 204 var num_pid_in_group = GROUP_SIZE_M * num_pid_n; 205 var group_id = pid / num_pid_in_group; 206 var first_pid_m = group_id * GROUP_SIZE_M; 207 var group_size_m = Math.min(num_pid_m - first_pid_m, GROUP_SIZE_M); 208 var pid_m = first_pid_m + (pid % group_size_m); 209 var pid_n = (pid % num_pid_in_group) / group_size_m; 210 211 // ---------------------------------------------------------- 212 // Create pointers for the first blocks of A and B. 213 // We will advance this pointer as we move in the K direction 214 // and accumulate 215 // `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers 216 // `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers 217 // See above `Pointer Arithmetics` section for details 218 var offs_m = arange(0, BLOCK_SIZE_M); 219 var offs_am = mod( 220 add(broadcast(pid_m * BLOCK_SIZE_M, offs_m.type()), offs_m), 221 broadcast(M, offs_m.type())); 222 var offs_n = arange(0, BLOCK_SIZE_N); 223 var offs_bn = mod( 224 add(broadcast(pid_n * BLOCK_SIZE_N, offs_n.type()), offs_n), 225 broadcast(N, offs_n.type())); 226 var offs_k = arange(0, BLOCK_SIZE_K); 227 228 var offs_am_e = expand(offs_am, 1); 229 offs_am_e = mul(offs_am_e, broadcast(stride_am, offs_am_e.type())); 230 var offs_k_e_0 = expand(offs_k, 0); 231 offs_k_e_0 = mul(offs_k_e_0, broadcast(stride_ak, offs_k_e_0.type())); 232 TensorType a_ptrs_t = joinShape(offs_am_e.type(), offs_k_e_0.type()); 233 var a_ptrs = add(broadcast(a_ptr, a_ptrs_t), 234 add(broadcast(offs_am_e, a_ptrs_t), broadcast(offs_k_e_0, a_ptrs_t))); 235 236 var offs_k_e_1 = expand(offs_k, 1); 237 offs_k_e_1 = mul(offs_k_e_1, broadcast(stride_bk, offs_k_e_1.type())); 238 var offs_bn_e = expand(offs_bn, 0); 239 offs_bn_e = mul(offs_bn_e, broadcast(stride_bn, offs_bn_e.type())); 240 TensorType b_ptrs_t = joinShape(offs_k_e_1.type(), offs_bn_e.type()); 241 var b_ptrs = add(broadcast(b_ptr, b_ptrs_t), 242 add(broadcast(offs_k_e_1, b_ptrs_t), broadcast(offs_bn_e, b_ptrs_t))); 243 244 // ----------------------------------------------------------- 245 // Iterate to compute a block of the C matrix. 246 // We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block 247 // of fp32 values for higher accuracy. 248 // `accumulator` will be converted back to fp16 after the loop. 249 var accumulator = zeros(float.class, BLOCK_SIZE_M, BLOCK_SIZE_N); 250 for (int k = 0; k < cdiv(K, BLOCK_SIZE_K); k++) { 251 // Load the next block of A and B, generate a mask by checking the K dimension. 252 // If it is out of bounds, set it to 0. 253 var offs_k_m_0 = expand(offs_k, 0); 254 offs_k_m_0 = compare(offs_k_m_0, 255 broadcast(K - k * BLOCK_SIZE_K, offs_k_m_0.type()), 256 LessThan); 257 var a = load(a_ptrs, broadcast(offs_k_m_0, a_ptrs.type())); 258 var offs_k_m_1 = expand(offs_k, 1); 259 offs_k_m_1 = compare(offs_k_m_1, 260 broadcast(K - k * BLOCK_SIZE_K, offs_k_m_1.type()), 261 LessThan); 262 var b = load(b_ptrs, broadcast(offs_k_m_1, b_ptrs.type())); 263 // We accumulate along the K dimension. 264 accumulator = add(accumulator, dot(a, b)); 265 // Advance the ptrs to the next K block. 266 a_ptrs = add(a_ptrs, broadcast(BLOCK_SIZE_K * stride_ak, a_ptrs.type())); 267 b_ptrs = add(b_ptrs, broadcast(BLOCK_SIZE_K * stride_bk, b_ptrs.type())); 268 } 269 270 // You can fuse arbitrary activation functions here 271 // while the accumulator is still in FP32! 272 // if (ACTIVATION) { 273 // // ... 274 // } 275 // c = Triton.to(activation, tl.float16) 276 var c = accumulator; 277 278 // ----------------------------------------------------------- 279 // Write back the block of the output matrix C with masks. 280 var offs_cm = add(broadcast(pid_m * BLOCK_SIZE_M, offs_m.type()), offs_m); 281 var offs_cn = add(broadcast(pid_n * BLOCK_SIZE_N, offs_n.type()), offs_n); 282 283 var offs_cm_e = expand(offs_cm, 1); 284 offs_cm_e = mul(offs_cm_e, broadcast(stride_cm, offs_cm_e.type())); 285 var offs_cn_e = expand(offs_cn, 0); 286 offs_cn_e = mul(offs_cn_e, broadcast(stride_cn, offs_cn_e.type())); 287 TensorType c_ptrs_t = joinShape(offs_cm_e.type(), offs_cn_e.type()); 288 var c_ptrs = add(broadcast(c_ptr, c_ptrs_t), 289 add(broadcast(offs_cm_e, c_ptrs_t), broadcast(offs_cn_e, c_ptrs_t))); 290 291 offs_cm_e = expand(offs_cm, 1); 292 var c_mask_l = compare(offs_cm_e, broadcast(M, offs_cm_e.type()), LessThan); 293 offs_cn_e = expand(offs_cn, 0); 294 var c_mask_r = compare(offs_cn_e, broadcast(N, offs_cn_e.type()), LessThan); 295 var c_mask = and(broadcast(c_mask_l, c_ptrs_t), broadcast(c_mask_r, c_ptrs_t)); 296 297 store(c_ptrs, c, c_mask); 298 } 299 300 @TritonTestExtension.Kernel("matmul_kernel_broadcast") 301 @Test 302 public void testWithBroadcast(TritonTestExtension.TritonTestData t) { 303 List<TypeElement> argTypes = List.of( 304 new PtrType(JavaType.FLOAT), 305 new PtrType(JavaType.FLOAT), 306 new PtrType(JavaType.FLOAT), 307 JavaType.INT, JavaType.INT, JavaType.INT, 308 JavaType.INT, JavaType.INT, 309 JavaType.INT, JavaType.INT, 310 JavaType.INT, JavaType.INT, 311 new ConstantType(JavaType.INT, 32), new ConstantType(JavaType.INT, 64), new ConstantType(JavaType.INT, 32), 312 new ConstantType(JavaType.INT, 8), 313 new ConstantType(JavaType.INT, false)); 314 315 t.test(argTypes); 316 } 317 318 319 @TritonCodeModel(""" 320 module ()void -> { 321 tt.func @"cdiv_int_32_int" (%0 : int)int -> { 322 %1 : int = arith.constant @"32"; 323 %2 : int = arith.addi %0 %1; 324 %3 : int = arith.constant @"1"; 325 %4 : int = arith.subi %2 %3; 326 %5 : int = arith.divsi %4 %1; 327 tt.return %5; 328 }; 329 tt.func @"cdiv_int_64_int" (%6 : int)int -> { 330 %7 : int = arith.constant @"64"; 331 %8 : int = arith.addi %6 %7; 332 %9 : int = arith.constant @"1"; 333 %10 : int = arith.subi %8 %9; 334 %11 : int = arith.divsi %10 %7; 335 tt.return %11; 336 }; 337 tt.func @"matmul_kernel_ptr<oracle.code.triton.Float16>_ptr<oracle.code.triton.Float16>_ptr<oracle.code.triton.Float16>_int_int_int_int_int_int_int_int_int_32_64_32_8_false_void" (%12 : ptr<oracle.code.triton.Float16>, %13 : ptr<oracle.code.triton.Float16>, %14 : ptr<oracle.code.triton.Float16>, %15 : int, %16 : int, %17 : int, %18 : int, %19 : int, %20 : int, %21 : int, %22 : int, %23 : int)void -> { 338 %24 : int = arith.constant @"32"; 339 %25 : int = arith.constant @"64"; 340 %26 : int = arith.constant @"32"; 341 %27 : int = arith.constant @"8"; 342 %28 : int = tt.get_program_id @"0"; 343 %29 : int = tt.call %15 @"cdiv_int_32_int"; 344 %30 : int = tt.call %16 @"cdiv_int_64_int"; 345 %31 : int = arith.muli %27 %30; 346 %32 : int = arith.divsi %28 %31; 347 %33 : int = arith.muli %32 %27; 348 %34 : int = arith.subi %29 %33; 349 %35 : int = arith.minsi %34 %27; 350 %36 : int = arith.remsi %28 %35; 351 %37 : int = arith.addi %33 %36; 352 %38 : int = arith.remsi %28 %31; 353 %39 : int = arith.divsi %38 %35; 354 %40 : tensor<x32, int> = tt.make_range @start="0" @end="32"; 355 %41 : int = arith.muli %37 %24; 356 %42 : tensor<x32, int> = tt.splat %41; 357 %43 : tensor<x32, int> = arith.addi %42 %40; 358 %44 : tensor<x32, int> = tt.splat %15; 359 %45 : tensor<x32, int> = arith.remsi %43 %44; 360 %46 : tensor<x64, int> = tt.make_range @start="0" @end="64"; 361 %47 : int = arith.muli %39 %25; 362 %48 : tensor<x64, int> = tt.splat %47; 363 %49 : tensor<x64, int> = arith.addi %48 %46; 364 %50 : tensor<x64, int> = tt.splat %16; 365 %51 : tensor<x64, int> = arith.remsi %49 %50; 366 %52 : tensor<x32, int> = tt.make_range @start="0" @end="32"; 367 %53 : tensor<x32, x1, int> = tt.expand_dims %45 @"1"; 368 %54 : tensor<x32, x1, int> = tt.splat %18; 369 %55 : tensor<x32, x1, int> = arith.muli %53 %54; 370 %56 : tensor<x1, x32, int> = tt.expand_dims %52 @"0"; 371 %57 : tensor<x1, x32, int> = tt.splat %19; 372 %58 : tensor<x1, x32, int> = arith.muli %56 %57; 373 %59 : tensor<x32, x32, int> = tt.broadcast %55; 374 %60 : tensor<x32, x32, int> = tt.broadcast %58; 375 %61 : tensor<x32, x32, int> = arith.addi %59 %60; 376 %62 : tensor<x32, x32, ptr<oracle.code.triton.Float16>> = tt.splat %12; 377 %63 : tensor<x32, x32, ptr<oracle.code.triton.Float16>> = tt.addptr %62 %61; 378 %64 : tensor<x32, x1, int> = tt.expand_dims %52 @"1"; 379 %65 : tensor<x32, x1, int> = tt.splat %20; 380 %66 : tensor<x32, x1, int> = arith.muli %64 %65; 381 %67 : tensor<x1, x64, int> = tt.expand_dims %51 @"0"; 382 %68 : tensor<x1, x64, int> = tt.splat %21; 383 %69 : tensor<x1, x64, int> = arith.muli %67 %68; 384 %70 : tensor<x32, x64, int> = tt.broadcast %66; 385 %71 : tensor<x32, x64, int> = tt.broadcast %69; 386 %72 : tensor<x32, x64, int> = arith.addi %70 %71; 387 %73 : tensor<x32, x64, ptr<oracle.code.triton.Float16>> = tt.splat %13; 388 %74 : tensor<x32, x64, ptr<oracle.code.triton.Float16>> = tt.addptr %73 %72; 389 %75 : tensor<x32, x64, float> = arith.constant @"0.0"; 390 %76 : int = arith.constant @"0"; 391 %77 : int = tt.call %17 @"cdiv_int_32_int"; 392 %78 : int = arith.constant @"1"; 393 %79 : Tuple<tensor<x32, x64, float>, tensor<x32, x32, ptr<oracle.code.triton.Float16>>, tensor<x32, x64, ptr<oracle.code.triton.Float16>>> = scf.for %76 %77 %78 %75 %63 %74 (%80 : int, %81 : tensor<x32, x64, float>, %82 : tensor<x32, x32, ptr<oracle.code.triton.Float16>>, %83 : tensor<x32, x64, ptr<oracle.code.triton.Float16>>)Tuple<tensor<x32, x64, float>, tensor<x32, x32, ptr<oracle.code.triton.Float16>>, tensor<x32, x64, ptr<oracle.code.triton.Float16>>> -> { 394 %84 : tensor<x1, x32, int> = tt.expand_dims %52 @"0"; 395 %85 : int = arith.muli %80 %26; 396 %86 : int = arith.subi %17 %85; 397 %87 : tensor<x1, x32, int> = tt.splat %86; 398 %88 : tensor<x1, x32, int> = arith.cmpi %84 %87 @"slt"; 399 %89 : tensor<x32, x32, int> = tt.broadcast %88; 400 %90 : tensor<x32, x32, oracle.code.triton.Float16> = tt.load %82 %89; 401 %91 : tensor<x32, x1, int> = tt.expand_dims %52 @"1"; 402 %92 : int = arith.muli %80 %26; 403 %93 : int = arith.subi %17 %92; 404 %94 : tensor<x32, x1, int> = tt.splat %93; 405 %95 : tensor<x32, x1, int> = arith.cmpi %91 %94 @"slt"; 406 %96 : tensor<x32, x64, int> = tt.broadcast %95; 407 %97 : tensor<x32, x64, oracle.code.triton.Float16> = tt.load %83 %96; 408 %98 : tensor<x32, x64, float> = tt.dot %90 %97; 409 %99 : tensor<x32, x64, float> = arith.addf %81 %98; 410 %100 : int = arith.muli %26 %19; 411 %101 : tensor<x32, x32, int> = tt.splat %100; 412 %102 : tensor<x32, x32, ptr<oracle.code.triton.Float16>> = tt.addptr %82 %101; 413 %103 : int = arith.muli %26 %20; 414 %104 : tensor<x32, x64, int> = tt.splat %103; 415 %105 : tensor<x32, x64, ptr<oracle.code.triton.Float16>> = tt.addptr %83 %104; 416 scf.yield %99 %102 %105; 417 }; 418 %106 : tensor<x32, x64, float> = tuple.load %79 @"0"; 419 %107 : tensor<x32, x32, ptr<oracle.code.triton.Float16>> = tuple.load %79 @"1"; 420 %108 : tensor<x32, x64, ptr<oracle.code.triton.Float16>> = tuple.load %79 @"2"; 421 %109 : tensor<x32, x64, oracle.code.triton.Float16> = arith.truncf %106; 422 %110 : int = arith.muli %37 %24; 423 %111 : tensor<x32, int> = tt.splat %110; 424 %112 : tensor<x32, int> = arith.addi %111 %40; 425 %113 : int = arith.muli %39 %25; 426 %114 : tensor<x64, int> = tt.splat %113; 427 %115 : tensor<x64, int> = arith.addi %114 %46; 428 %116 : tensor<x32, x1, int> = tt.expand_dims %112 @"1"; 429 %117 : tensor<x32, x1, int> = tt.splat %22; 430 %118 : tensor<x32, x1, int> = arith.muli %117 %116; 431 %119 : tensor<x1, x64, int> = tt.expand_dims %115 @"0"; 432 %120 : tensor<x1, x64, int> = tt.splat %23; 433 %121 : tensor<x1, x64, int> = arith.muli %120 %119; 434 %122 : tensor<x32, x64, int> = tt.broadcast %118; 435 %123 : tensor<x32, x64, int> = tt.broadcast %121; 436 %124 : tensor<x32, x64, int> = arith.addi %122 %123; 437 %125 : tensor<x32, x64, ptr<oracle.code.triton.Float16>> = tt.splat %14; 438 %126 : tensor<x32, x64, ptr<oracle.code.triton.Float16>> = tt.addptr %125 %124; 439 %127 : tensor<x32, x1, int> = tt.expand_dims %112 @"1"; 440 %128 : tensor<x32, x1, int> = tt.splat %15; 441 %129 : tensor<x32, x1, int> = arith.cmpi %127 %128 @"slt"; 442 %130 : tensor<x1, x64, int> = tt.expand_dims %115 @"0"; 443 %131 : tensor<x1, x64, int> = tt.splat %16; 444 %132 : tensor<x1, x64, int> = arith.cmpi %130 %131 @"slt"; 445 %133 : tensor<x32, x64, int> = tt.broadcast %129; 446 %134 : tensor<x32, x64, int> = tt.broadcast %132; 447 %135 : tensor<x32, x64, int> = arith.andi %133 %134; 448 tt.store %126 %109 %135; 449 tt.return; 450 }; 451 unreachable; 452 }; 453 """) 454 @CodeReflection 455 static void matmul_kernel( 456 // Pointers to matrices 457 Ptr a_ptr, Ptr b_ptr, Ptr c_ptr, 458 // Matrix dimensions 459 int M, int N, int K, 460 // The stride variables represent how much to increase the ptr by when moving by 1 461 // element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr` 462 // by to get the element one row down (A has M rows). 463 int stride_am, int stride_ak, 464 int stride_bk, int stride_bn, 465 int stride_cm, int stride_cn, 466 // Meta-parameters 467 @Constant int BLOCK_SIZE_M, @Constant int BLOCK_SIZE_N, @Constant int BLOCK_SIZE_K, 468 @Constant int GROUP_SIZE_M, 469 @Constant boolean ACTIVATION) { 470 471 // """Kernel for computing the matmul C = A x B. 472 // A has shape (M, K), B has shape (K, N) and C has shape (M, N) 473 // """ 474 // ----------------------------------------------------------- 475 // Map program ids `pid` to the block of C it should compute. 476 // This is done in a grouped ordering to promote L2 data reuse. 477 // See above `L2 Cache Optimizations` section for details. 478 var pid = programId(0); 479 var num_pid_m = cdiv(M, BLOCK_SIZE_M); 480 var num_pid_n = cdiv(N, BLOCK_SIZE_N); 481 var num_pid_in_group = GROUP_SIZE_M * num_pid_n; 482 var group_id = pid / num_pid_in_group; 483 var first_pid_m = group_id * GROUP_SIZE_M; 484 var group_size_m = Math.min(num_pid_m - first_pid_m, GROUP_SIZE_M); 485 var pid_m = first_pid_m + (pid % group_size_m); 486 var pid_n = (pid % num_pid_in_group) / group_size_m; 487 488 // ---------------------------------------------------------- 489 // Create pointers for the first blocks of A and B. 490 // We will advance this pointer as we move in the K direction 491 // and accumulate 492 // `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers 493 // `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers 494 // See above `Pointer Arithmetics` section for details 495 var offs_m = arange(0, BLOCK_SIZE_M); 496 var offs_am = mod(add(pid_m * BLOCK_SIZE_M, offs_m), M); 497 var offs_n = arange(0, BLOCK_SIZE_N); 498 var offs_bn = mod(add(pid_n * BLOCK_SIZE_N, offs_n), N); 499 var offs_k = arange(0, BLOCK_SIZE_K); 500 var a_ptrs = add(a_ptr, add( 501 mul(expand(offs_am, 1), stride_am), 502 mul(expand(offs_k, 0), stride_ak))); 503 var b_ptrs = add(b_ptr, add( 504 mul(expand(offs_k, 1), stride_bk), 505 mul(expand(offs_bn, 0), stride_bn))); 506 507 // ----------------------------------------------------------- 508 // Iterate to compute a block of the C matrix. 509 // We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block 510 // of fp32 values for higher accuracy. 511 // `accumulator` will be converted back to fp16 after the loop. 512 var accumulator = zeros(float.class, BLOCK_SIZE_M, BLOCK_SIZE_N); 513 for (int k = 0; k < cdiv(K, BLOCK_SIZE_K); k++) { 514 // Load the next block of A and B, generate a mask by checking the K dimension. 515 // If it is out of bounds, set it to 0. 516 var a = load(a_ptrs, 517 compare(expand(offs_k, 0), K - k * BLOCK_SIZE_K, LessThan)); 518 var b = load(b_ptrs, 519 compare(expand(offs_k, 1), K - k * BLOCK_SIZE_K, LessThan)); 520 // We accumulate along the K dimension. 521 accumulator = add(accumulator, dot(a, b)); 522 // Advance the ptrs to the next K block. 523 a_ptrs = add(a_ptrs, BLOCK_SIZE_K * stride_ak); 524 b_ptrs = add(b_ptrs, BLOCK_SIZE_K * stride_bk); 525 } 526 527 // You can fuse arbitrary activation functions here 528 // while the accumulator is still in FP32! 529 // if (ACTIVATION) { 530 // // ... 531 // } 532 var c = Triton.conv(Float16.class, accumulator); 533 534 // ----------------------------------------------------------- 535 // Write back the block of the output matrix C with masks. 536 var offs_cm = add(pid_m * BLOCK_SIZE_M, offs_m); 537 var offs_cn = add(pid_n * BLOCK_SIZE_N, offs_n); 538 var c_ptrs = add(c_ptr, add( 539 mul(stride_cm, expand(offs_cm, 1)), 540 mul(stride_cn, expand(offs_cn, 0)))); 541 var c_mask = and( 542 compare(expand(offs_cm, 1), M, LessThan), 543 compare(expand(offs_cn, 0), N, LessThan)); 544 store(c_ptrs, c, c_mask); 545 } 546 547 @TritonTestExtension.Kernel("matmul_kernel") 548 @Test 549 public void test(TritonTestExtension.TritonTestData t) { 550 List<TypeElement> argTypes = List.of( 551 new PtrType(Float16.FLOAT_16_TYPE), 552 new PtrType(Float16.FLOAT_16_TYPE), 553 new PtrType(Float16.FLOAT_16_TYPE), 554 JavaType.INT, JavaType.INT, JavaType.INT, 555 JavaType.INT, JavaType.INT, 556 JavaType.INT, JavaType.INT, 557 JavaType.INT, JavaType.INT, 558 new ConstantType(JavaType.INT, 32), new ConstantType(JavaType.INT, 64), new ConstantType(JavaType.INT, 32), 559 new ConstantType(JavaType.INT, 8), 560 new ConstantType(JavaType.INT, false)); 561 562 t.test(argTypes); 563 } 564 565 } 566 567 /* 568 @triton.jit 569 def matmul_kernel( 570 # Pointers to matrices 571 a_ptr, b_ptr, c_ptr, 572 # Matrix dimensions 573 M, N, K, 574 # The stride variables represent how much to increase the ptr by when moving by 1 575 # element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr` 576 # by to get the element one row down (A has M rows). 577 stride_am, stride_ak, # 578 stride_bk, stride_bn, # 579 stride_cm, stride_cn, 580 # Meta-parameters 581 BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, # 582 GROUP_SIZE_M: tl.constexpr, # 583 ACTIVATION: tl.constexpr # 584 ): 585 """Kernel for computing the matmul C = A x B. 586 A has shape (M, K), B has shape (K, N) and C has shape (M, N) 587 """ 588 # ----------------------------------------------------------- 589 # Map program ids `pid` to the block of C it should compute. 590 # This is done in a grouped ordering to promote L2 data reuse. 591 # See above `L2 Cache Optimizations` section for details. 592 pid = tl.program_id(axis=0) 593 num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) 594 num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) 595 num_pid_in_group = GROUP_SIZE_M * num_pid_n 596 group_id = pid // num_pid_in_group 597 first_pid_m = group_id * GROUP_SIZE_M 598 group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) 599 pid_m = first_pid_m + (pid % group_size_m) 600 pid_n = (pid % num_pid_in_group) // group_size_m 601 602 # ---------------------------------------------------------- 603 # Create pointers for the first blocks of A and B. 604 # We will advance this pointer as we move in the K direction 605 # and accumulate 606 # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers 607 # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers 608 # See above `Pointer Arithmetics` section for details 609 offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M 610 offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N 611 offs_k = tl.arange(0, BLOCK_SIZE_K) 612 a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) 613 b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) 614 615 # ----------------------------------------------------------- 616 # Iterate to compute a block of the C matrix. 617 # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block 618 # of fp32 values for higher accuracy. 619 # `accumulator` will be converted back to fp16 after the loop. 620 accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) 621 for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): 622 # Load the next block of A and B, generate a mask by checking the K dimension. 623 # If it is out of bounds, set it to 0. 624 a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) 625 b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) 626 # We accumulate along the K dimension. 627 accumulator += tl.dot(a, b) 628 # Advance the ptrs to the next K block. 629 a_ptrs += BLOCK_SIZE_K * stride_ak 630 b_ptrs += BLOCK_SIZE_K * stride_bk 631 # You can fuse arbitrary activation functions here 632 # while the accumulator is still in FP32! 633 if ACTIVATION == "leaky_relu": 634 accumulator = leaky_relu(accumulator) 635 c = accumulator.to(tl.float16) 636 637 # ----------------------------------------------------------- 638 # Write back the block of the output matrix C with masks. 639 offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) 640 offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) 641 c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] 642 c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) 643 tl.store(c_ptrs, c, mask=c_mask) 644 645 646 # We can fuse `leaky_relu` by providing it as an `ACTIVATION` meta-parameter in `_matmul`. 647 @triton.jit 648 def leaky_relu(x): 649 x = x + 1 650 return tl.where(x >= 0, x, 0.01 * x) 651 */ 652 653 /* 654 655 triton/python/triton/tools/compile.py \ 656 --kernel-name matmul_kernel \ 657 --signature "*fp16,*fp16,*fp16,i32,i32,i32,i32,i32,i32,i32,i32,i32,32,64,32,8,0" \ 658 --grid=1024,1024,1024 \ 659 03-matrix-multiplication.py 660 661 BLOCK_SIZE_M = 32 662 BLOCK_SIZE_N = 64 663 BLOCK_SIZE_K = 32 664 GROUP_SIZE_M = 8 665 ACTIVATION = 0 666 667 module { 668 tt.func public @matmul_kernel_01234567891011( 669 %arg0: !tt.ptr<f16, 1>, %arg1: !tt.ptr<f16, 1>, %arg2: !tt.ptr<f16, 1> , 670 %arg3: i32, %arg4: i32, %arg5: i32 , 671 %arg6: i32, %arg7: i32, %arg8: i32 , 672 %%arg9: i32, %arg10: i32, %arg11: i32 ) attributes {noinline = false} { 673 %0 = tt.get_program_id x : i32 674 %1 = tt.call @cdiv__i32__1cconstexpr_32_(%arg3) : (i32) -> i32 675 %2 = tt.call @cdiv__i32__1cconstexpr_64_(%arg4) : (i32) -> i32 676 %c8_i32 = arith.constant 8 : i32 677 %3 = arith.muli %2, %c8_i32 : i32 678 %4 = arith.divsi %0, %3 : i32 679 %c8_i32_0 = arith.constant 8 : i32 680 %5 = arith.muli %4, %c8_i32_0 : i32 681 %6 = arith.subi %1, %5 : i32 682 %7 = tt.call @minimum__i32__1cconstexpr_8_(%6) : (i32) -> i32 683 %8 = arith.remsi %0, %7 : i32 684 %9 = arith.addi %5, %8 : i32 685 %10 = arith.remsi %0, %3 : i32 686 %11 = arith.divsi %10, %7 : i32 687 %c32_i32 = arith.constant 32 : i32 688 %12 = arith.muli %9, %c32_i32 : i32 689 %13 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> 690 %14 = tt.splat %12 : (i32) -> tensor<32xi32> 691 %15 = arith.addi %14, %13 : tensor<32xi32> 692 %16 = tt.splat %arg3 : (i32) -> tensor<32xi32> 693 %17 = arith.remsi %15, %16 : tensor<32xi32> 694 %c64_i32 = arith.constant 64 : i32 695 %18 = arith.muli %11, %c64_i32 : i32 696 %19 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> 697 %20 = tt.splat %18 : (i32) -> tensor<64xi32> 698 %21 = arith.addi %20, %19 : tensor<64xi32> 699 %22 = tt.splat %arg4 : (i32) -> tensor<64xi32> 700 %23 = arith.remsi %21, %22 : tensor<64xi32> 701 %24 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> 702 %25 = tt.expand_dims %17 {axis = 1 : i32} : (tensor<32xi32>) -> tensor<32x1xi32> 703 %26 = tt.splat %arg6 : (i32) -> tensor<32x1xi32> 704 %27 = arith.muli %25, %26 : tensor<32x1xi32> 705 %28 = tt.expand_dims %24 {axis = 0 : i32} : (tensor<32xi32>) -> tensor<1x32xi32> 706 %29 = tt.splat %arg7 : (i32) -> tensor<1x32xi32> 707 %30 = arith.muli %28, %29 : tensor<1x32xi32> 708 %31 = tt.broadcast %27 : (tensor<32x1xi32>) -> tensor<32x32xi32> 709 %32 = tt.broadcast %30 : (tensor<1x32xi32>) -> tensor<32x32xi32> 710 %33 = arith.addi %31, %32 : tensor<32x32xi32> 711 %34 = tt.splat %arg0 : (!tt.ptr<f16, 1>) -> tensor<32x32x!tt.ptr<f16, 1>> 712 %35 = tt.addptr %34, %33 : tensor<32x32x!tt.ptr<f16, 1>>, tensor<32x32xi32> 713 %36 = tt.expand_dims %24 {axis = 1 : i32} : (tensor<32xi32>) -> tensor<32x1xi32> 714 %37 = tt.splat %arg8 : (i32) -> tensor<32x1xi32> 715 %38 = arith.muli %36, %37 : tensor<32x1xi32> 716 %39 = tt.expand_dims %23 {axis = 0 : i32} : (tensor<64xi32>) -> tensor<1x64xi32> 717 %40 = tt.splat %arg9 : (i32) -> tensor<1x64xi32> 718 %41 = arith.muli %39, %40 : tensor<1x64xi32> 719 %42 = tt.broadcast %38 : (tensor<32x1xi32>) -> tensor<32x64xi32> 720 %43 = tt.broadcast %41 : (tensor<1x64xi32>) -> tensor<32x64xi32> 721 %44 = arith.addi %42, %43 : tensor<32x64xi32> 722 %45 = tt.splat %arg1 : (!tt.ptr<f16, 1>) -> tensor<32x64x!tt.ptr<f16, 1>> 723 %46 = tt.addptr %45, %44 : tensor<32x64x!tt.ptr<f16, 1>>, tensor<32x64xi32> 724 %47 = tt.call @"zeros____0cconstexpr_(constexpr_32_, constexpr_64_)__1cconstexpr_fp32_"() : () -> tensor<32x64xf32> 725 %48 = tt.call @cdiv__i32__1cconstexpr_32_(%arg5) : (i32) -> i32 726 %c0_i32 = arith.constant 0 : i32 727 %c1_i32 = arith.constant 1 : i32 728 %49 = arith.bitcast %c0_i32 : i32 to i32 729 %50 = arith.bitcast %48 : i32 to i32 730 %51 = arith.bitcast %c1_i32 : i32 to i32 731 %52 = llvm.mlir.undef : i32 732 %53:3 = scf.for %arg12 = %49 to %50 step %51 iter_args(%arg13 = %47, %arg14 = %35, %arg15 = %46) -> (tensor<32x64xf32>, tensor<32x32x!tt.ptr<f16, 1>>, tensor<32x64x!tt.ptr<f16, 1>>) : i32 { 733 %83 = tt.expand_dims %24 {axis = 0 : i32} : (tensor<32xi32>) -> tensor<1x32xi32> 734 %c32_i32_3 = arith.constant 32 : i32 735 %84 = arith.muli %arg12, %c32_i32_3 : i32 736 %85 = arith.subi %arg5, %84 : i32 737 %86 = tt.splat %85 : (i32) -> tensor<1x32xi32> 738 %87 = arith.cmpi slt, %83, %86 : tensor<1x32xi32> 739 %cst = arith.constant 0.000000e+00 : f32 740 %88 = tt.broadcast %87 : (tensor<1x32xi1>) -> tensor<32x32xi1> 741 %cst_4 = arith.constant dense<0.000000e+00> : tensor<32x32xf32> 742 %89 = arith.truncf %cst_4 : tensor<32x32xf32> to tensor<32x32xf16> 743 %90 = tt.load %arg14, %88, %89 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x32xf16> 744 %91 = tt.expand_dims %24 {axis = 1 : i32} : (tensor<32xi32>) -> tensor<32x1xi32> 745 %c32_i32_5 = arith.constant 32 : i32 746 %92 = arith.muli %arg12, %c32_i32_5 : i32 747 %93 = arith.subi %arg5, %92 : i32 748 %94 = tt.splat %93 : (i32) -> tensor<32x1xi32> 749 %95 = arith.cmpi slt, %91, %94 : tensor<32x1xi32> 750 %cst_6 = arith.constant 0.000000e+00 : f32 751 %96 = tt.broadcast %95 : (tensor<32x1xi1>) -> tensor<32x64xi1> 752 %cst_7 = arith.constant dense<0.000000e+00> : tensor<32x64xf32> 753 %97 = arith.truncf %cst_7 : tensor<32x64xf32> to tensor<32x64xf16> 754 %98 = tt.load %arg15, %96, %97 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x64xf16> 755 %cst_8 = arith.constant 0.000000e+00 : f32 756 %cst_9 = arith.constant dense<0.000000e+00> : tensor<32x64xf32> 757 %99 = tt.dot %90, %98, %cst_9 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<32x32xf16> * tensor<32x64xf16> -> tensor<32x64xf32> 758 %100 = arith.addf %arg13, %99 : tensor<32x64xf32> 759 %c32_i32_10 = arith.constant 32 : i32 760 %101 = arith.muli %arg7, %c32_i32_10 : i32 761 %102 = tt.splat %101 : (i32) -> tensor<32x32xi32> 762 %103 = tt.addptr %arg14, %102 : tensor<32x32x!tt.ptr<f16, 1>>, tensor<32x32xi32> 763 %c32_i32_11 = arith.constant 32 : i32 764 %104 = arith.muli %arg8, %c32_i32_11 : i32 765 %105 = tt.splat %104 : (i32) -> tensor<32x64xi32> 766 %106 = tt.addptr %arg15, %105 : tensor<32x64x!tt.ptr<f16, 1>>, tensor<32x64xi32> 767 scf.yield %100, %103, %106 : tensor<32x64xf32>, tensor<32x32x!tt.ptr<f16, 1>>, tensor<32x64x!tt.ptr<f16, 1>> 768 } 769 %54 = arith.truncf %53#0 : tensor<32x64xf32> to tensor<32x64xf16> 770 %c32_i32_1 = arith.constant 32 : i32 771 %55 = arith.muli %9, %c32_i32_1 : i32 772 %56 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> 773 %57 = tt.splat %55 : (i32) -> tensor<32xi32> 774 %58 = arith.addi %57, %56 : tensor<32xi32> 775 %c64_i32_2 = arith.constant 64 : i32 776 %59 = arith.muli %11, %c64_i32_2 : i32 777 %60 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> 778 %61 = tt.splat %59 : (i32) -> tensor<64xi32> 779 %62 = arith.addi %61, %60 : tensor<64xi32> 780 %63 = tt.expand_dims %58 {axis = 1 : i32} : (tensor<32xi32>) -> tensor<32x1xi32> 781 %64 = tt.splat %arg10 : (i32) -> tensor<32x1xi32> 782 %65 = arith.muli %64, %63 : tensor<32x1xi32> 783 %66 = tt.splat %arg2 : (!tt.ptr<f16, 1>) -> tensor<32x1x!tt.ptr<f16, 1>> 784 %67 = tt.addptr %66, %65 : tensor<32x1x!tt.ptr<f16, 1>>, tensor<32x1xi32> 785 %68 = tt.expand_dims %62 {axis = 0 : i32} : (tensor<64xi32>) -> tensor<1x64xi32> 786 %69 = tt.splat %arg11 : (i32) -> tensor<1x64xi32> 787 %70 = arith.muli %69, %68 : tensor<1x64xi32> 788 %71 = tt.broadcast %67 : (tensor<32x1x!tt.ptr<f16, 1>>) -> tensor<32x64x!tt.ptr<f16, 1>> 789 %72 = tt.broadcast %70 : (tensor<1x64xi32>) -> tensor<32x64xi32> 790 %73 = tt.addptr %71, %72 : tensor<32x64x!tt.ptr<f16, 1>>, tensor<32x64xi32> 791 %74 = tt.expand_dims %58 {axis = 1 : i32} : (tensor<32xi32>) -> tensor<32x1xi32> 792 %75 = tt.splat %arg3 : (i32) -> tensor<32x1xi32> 793 %76 = arith.cmpi slt, %74, %75 : tensor<32x1xi32> 794 %77 = tt.expand_dims %62 {axis = 0 : i32} : (tensor<64xi32>) -> tensor<1x64xi32> 795 %78 = tt.splat %arg4 : (i32) -> tensor<1x64xi32> 796 %79 = arith.cmpi slt, %77, %78 : tensor<1x64xi32> 797 %80 = tt.broadcast %76 : (tensor<32x1xi1>) -> tensor<32x64xi1> 798 %81 = tt.broadcast %79 : (tensor<1x64xi1>) -> tensor<32x64xi1> 799 %82 = arith.andi %80, %81 : tensor<32x64xi1> 800 tt.store %73, %54, %82 {cache = 1 : i32, evict = 1 : i32} : tensor<32x64xf16> 801 tt.return 802 } 803 tt.func private @cdiv__i32__1cconstexpr_32_(%arg0: i32 ) -> i32 attributes {noinline = false} { 804 %c32_i32 = arith.constant 32 : i32 805 %0 = arith.addi %arg0, %c32_i32 : i32 806 %c1_i32 = arith.constant 1 : i32 807 %1 = arith.subi %0, %c1_i32 : i32 808 %c32_i32_0 = arith.constant 32 : i32 809 %2 = arith.divsi %1, %c32_i32_0 : i32 810 tt.return %2 : i32 811 } 812 tt.func private @cdiv__i32__1cconstexpr_64_(%arg0: i32 ) -> i32 attributes {noinline = false} { 813 %c64_i32 = arith.constant 64 : i32 814 %0 = arith.addi %arg0, %c64_i32 : i32 815 %c1_i32 = arith.constant 1 : i32 816 %1 = arith.subi %0, %c1_i32 : i32 817 %c64_i32_0 = arith.constant 64 : i32 818 %2 = arith.divsi %1, %c64_i32_0 : i32 819 tt.return %2 : i32 820 } 821 tt.func private @minimum__i32__1cconstexpr_8_(%arg0: i32 ) -> i32 attributes {noinline = false} { 822 %c8_i32 = arith.constant 8 : i32 823 %0 = arith.minsi %arg0, %c8_i32 : i32 824 tt.return %0 : i32 825 } 826 tt.func private @"zeros____0cconstexpr_(constexpr_32_, constexpr_64_)__1cconstexpr_fp32_"() -> tensor<32x64xf32> attributes {noinline = false} { 827 %cst = arith.constant 0.000000e+00 : f32 828 %cst_0 = arith.constant dense<0.000000e+00> : tensor<32x64xf32> 829 tt.return %cst_0 : tensor<32x64xf32> 830 } 831 } 832 */