1 /* 2 * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved. 3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. 4 * 5 * This code is free software; you can redistribute it and/or modify it 6 * under the terms of the GNU General Public License version 2 only, as 7 * published by the Free Software Foundation. 8 * 9 * This code is distributed in the hope that it will be useful, but WITHOUT 10 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or 11 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License 12 * version 2 for more details (a copy is included in the LICENSE file that 13 * accompanied this code). 14 * 15 * You should have received a copy of the GNU General Public License version 16 * 2 along with this work; if not, write to the Free Software Foundation, 17 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. 18 * 19 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA 20 * or visit www.oracle.com if you need additional information or have any 21 * questions. 22 */ 23 24 package oracle.code.triton; 25 26 import oracle.code.triton.TritonTestExtension.Kernel; 27 import oracle.code.triton.TritonTestExtension.TritonTestData; 28 import org.junit.jupiter.api.Test; 29 import org.junit.jupiter.api.extension.ExtendWith; 30 31 import java.lang.reflect.code.TypeElement; 32 import java.lang.reflect.code.type.JavaType; 33 import java.lang.runtime.CodeReflection; 34 import java.util.List; 35 36 @ExtendWith(TritonTestExtension.class) 37 public class TestSoftMax { 38 39 @TritonCodeModel(""" 40 module ()void -> { 41 tt.func @"max_float_float_float" (%0 : float, %1 : float)float -> { 42 %2 : float = arith.maximumf %0 %1; 43 tt.return %2; 44 }; 45 tt.func @"reduce_max_float_float_float_0" (%3 : tensor<x64, float>)float -> { 46 %4 : float = tt.reduce %3 @axis="0" (%5 : float, %6 : float)float -> { 47 %7 : float = tt.call %5 %6 @"max_float_float_float"; 48 tt.reduce.return %7; 49 }; 50 tt.return %4; 51 }; 52 tt.func @"sum_float_float_float" (%8 : float, %9 : float)float -> { 53 %10 : float = arith.addf %8 %9; 54 tt.return %10; 55 }; 56 tt.func @"reduce_sum_float_float_float_0" (%11 : tensor<x64, float>)float -> { 57 %12 : float = tt.reduce %11 @axis="0" (%13 : float, %14 : float)float -> { 58 %15 : float = tt.call %13 %14 @"sum_float_float_float"; 59 tt.reduce.return %15; 60 }; 61 tt.return %12; 62 }; 63 tt.func @"softmax_kernel_ptr<float>_ptr<float>_1_1_10_64_void" (%16 : ptr<float>, %17 : ptr<float>)void -> { 64 %18 : int = arith.constant @"1"; 65 %19 : int = arith.constant @"1"; 66 %20 : int = arith.constant @"10"; 67 %21 : int = tt.get_program_id @"0"; 68 %22 : int = arith.muli %21 %18; 69 %23 : ptr<float> = tt.addptr %17 %22; 70 %24 : tensor<x64, int> = tt.make_range @start="0" @end="64"; 71 %25 : tensor<x64, ptr<float>> = tt.splat %23; 72 %26 : tensor<x64, ptr<float>> = tt.addptr %25 %24; 73 %27 : tensor<x64, int> = tt.splat %20; 74 %28 : tensor<x64, int> = arith.cmpi %24 %27 @"slt"; 75 %29 : tensor<x64, float> = tt.load %26 %28; 76 %30 : float = tt.call %29 @"reduce_max_float_float_float_0"; 77 %31 : tensor<x64, float> = tt.splat %30; 78 %32 : tensor<x64, float> = arith.subf %29 %31; 79 %33 : tensor<x64, float> = math.exp %32; 80 %34 : float = tt.call %33 @"reduce_sum_float_float_float_0"; 81 %35 : tensor<x64, float> = tt.splat %34; 82 %36 : tensor<x64, float> = arith.divf %33 %35; 83 %37 : int = arith.muli %21 %19; 84 %38 : ptr<float> = tt.addptr %16 %37; 85 %39 : tensor<x64, ptr<float>> = tt.splat %38; 86 %40 : tensor<x64, ptr<float>> = tt.addptr %39 %24; 87 tt.store %40 %36 %28; 88 tt.return; 89 }; 90 unreachable; 91 }; 92 """) 93 @CodeReflection 94 static void softmax_kernel(Ptr output_ptr, 95 Ptr input_ptr, 96 int input_row_stride, 97 int output_row_stride, 98 int n_cols, 99 @Constant int BLOCK_SIZE) { 100 // The rows of the softmax are independent, so we parallelize across those 101 var row_idx = Triton.programId(0); 102 var row_start_ptr = Triton.add(input_ptr, row_idx * input_row_stride); 103 // The block size is the next power of two greater than n_cols, so we can fit each 104 // row in a single block 105 var col_offsets = Triton.arange(0, BLOCK_SIZE); 106 var input_ptrs = Triton.add(Triton.broadcast(row_start_ptr, col_offsets.type()), col_offsets); 107 // Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols 108 var mask = Triton.compare(col_offsets, 109 Triton.broadcast(n_cols, col_offsets.type()), 110 Triton.CompareKind.LessThan); 111 var row = Triton.load(input_ptrs, mask); 112 // Subtract maximum for numerical stability 113 var row_minus_max = Triton.sub(row, Triton.broadcast(Triton.max(row, 0), row.type())); 114 // Note that exponentiation in Triton is fast but approximate (i.e., think __expf in CUDA) 115 var numerator = Triton.exp(row_minus_max); 116 var denominator = Triton.sum(numerator, 0); 117 var softmax_output = Triton.div(numerator, Triton.broadcast(denominator, numerator.type())); 118 // Write back output to DRAM 119 var output_row_start_ptr = Triton.add(output_ptr, row_idx * output_row_stride); 120 var output_ptrs = Triton.add(Triton.broadcast(output_row_start_ptr, col_offsets.type()), col_offsets); 121 Triton.store(output_ptrs, softmax_output, mask); 122 } 123 124 @Kernel("softmax_kernel") 125 @Test 126 public void test(TritonTestData t) { 127 List<TypeElement> argTypes = List.of( 128 new PtrType(JavaType.FLOAT), 129 new PtrType(JavaType.FLOAT), 130 new ConstantType(JavaType.INT, 1), 131 new ConstantType(JavaType.INT, 1), 132 new ConstantType(JavaType.INT, 10), 133 new ConstantType(JavaType.INT, 64)); 134 135 t.test(argTypes); 136 } 137 138 @TritonCodeModel(""" 139 module ()void -> { 140 tt.func @"max_float_float_float" (%0 : float, %1 : float)float -> { 141 %2 : float = arith.maximumf %0 %1; 142 tt.return %2; 143 }; 144 tt.func @"reduce_max_float_float_float_0" (%3 : tensor<x64, float>)float -> { 145 %4 : float = tt.reduce %3 @axis="0" (%5 : float, %6 : float)float -> { 146 %7 : float = tt.call %5 %6 @"max_float_float_float"; 147 tt.reduce.return %7; 148 }; 149 tt.return %4; 150 }; 151 tt.func @"sum_float_float_float" (%8 : float, %9 : float)float -> { 152 %10 : float = arith.addf %8 %9; 153 tt.return %10; 154 }; 155 tt.func @"reduce_sum_float_float_float_0" (%11 : tensor<x64, float>)float -> { 156 %12 : float = tt.reduce %11 @axis="0" (%13 : float, %14 : float)float -> { 157 %15 : float = tt.call %13 %14 @"sum_float_float_float"; 158 tt.reduce.return %15; 159 }; 160 tt.return %12; 161 }; 162 tt.func @"softmax_kernel2_ptr<float>_ptr<float>_1_1_10_64_void" (%16 : ptr<float>, %17 : ptr<float>)void -> { 163 %18 : int = arith.constant @"1"; 164 %19 : int = arith.constant @"1"; 165 %20 : int = arith.constant @"10"; 166 %21 : int = tt.get_program_id @"0"; 167 %22 : int = arith.muli %21 %18; 168 %23 : ptr<float> = tt.addptr %17 %22; 169 %24 : tensor<x64, int> = tt.make_range @start="0" @end="64"; 170 %25 : tensor<x64, ptr<float>> = tt.splat %23; 171 %26 : tensor<x64, ptr<float>> = tt.addptr %25 %24; 172 %27 : tensor<x64, int> = tt.splat %20; 173 %28 : tensor<x64, int> = arith.cmpi %24 %27 @"slt"; 174 %29 : tensor<x64, float> = tt.load %26 %28; 175 %30 : float = tt.call %29 @"reduce_max_float_float_float_0"; 176 %31 : tensor<x64, float> = tt.splat %30; 177 %32 : tensor<x64, float> = arith.subf %29 %31; 178 %33 : tensor<x64, float> = math.exp %32; 179 %34 : float = tt.call %33 @"reduce_sum_float_float_float_0"; 180 %35 : tensor<x64, float> = tt.splat %34; 181 %36 : tensor<x64, float> = arith.divf %33 %35; 182 %37 : int = arith.muli %21 %19; 183 %38 : ptr<float> = tt.addptr %16 %37; 184 %39 : tensor<x64, ptr<float>> = tt.splat %38; 185 %40 : tensor<x64, ptr<float>> = tt.addptr %39 %24; 186 tt.store %40 %36 %28; 187 tt.return; 188 }; 189 unreachable; 190 }; 191 """) 192 @CodeReflection 193 static void softmax_kernel2(Ptr output_ptr, 194 Ptr input_ptr, 195 int input_row_stride, 196 int output_row_stride, 197 int n_cols, 198 @Constant int BLOCK_SIZE) { 199 // The rows of the softmax are independent, so we parallelize across those 200 var row_idx = Triton.programId(0); 201 var row_start_ptr = Triton.add(input_ptr, row_idx * input_row_stride); 202 // The block size is the next power of two greater than n_cols, so we can fit each 203 // row in a single block 204 var col_offsets = Triton.arange(0, BLOCK_SIZE); 205 var input_ptrs = Triton.add(row_start_ptr, col_offsets); 206 // Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols 207 var mask = Triton.compare(col_offsets, n_cols, Triton.CompareKind.LessThan); 208 var row = Triton.load(input_ptrs, mask); 209 // Subtract maximum for numerical stability 210 var row_minus_max = Triton.sub(row, Triton.max(row, 0)); 211 // Note that exponentiation in Triton is fast but approximate (i.e., think __expf in CUDA) 212 var numerator = Triton.exp(row_minus_max); 213 var denominator = Triton.sum(numerator, 0); 214 var softmax_output = Triton.div(numerator, denominator); 215 // Write back output to DRAM 216 var output_row_start_ptr = Triton.add(output_ptr, row_idx * output_row_stride); 217 var output_ptrs = Triton.add(output_row_start_ptr, col_offsets); 218 Triton.store(output_ptrs, softmax_output, mask); 219 } 220 221 @Kernel("softmax_kernel2") 222 @Test 223 public void test2(TritonTestData t) { 224 List<TypeElement> argTypes = List.of( 225 new PtrType(JavaType.FLOAT), 226 new PtrType(JavaType.FLOAT), 227 new ConstantType(JavaType.INT, 1), 228 new ConstantType(JavaType.INT, 1), 229 new ConstantType(JavaType.INT, 10), 230 new ConstantType(JavaType.INT, 64)); 231 232 t.test(argTypes); 233 } 234 } 235 236 /* 237 def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_cols, BLOCK_SIZE: tl.constexpr): 238 # The rows of the softmax are independent, so we parallelize across those 239 row_idx = tl.program_id(0) 240 # The stride represents how much we need to increase the pointer to advance 1 row 241 row_start_ptr = input_ptr + row_idx * input_row_stride 242 # The block size is the next power of two greater than n_cols, so we can fit each 243 # row in a single block 244 col_offsets = tl.arange(0, BLOCK_SIZE) 245 input_ptrs = row_start_ptr + col_offsets 246 # Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols 247 row = tl.load(input_ptrs, mask=col_offsets < n_cols, other=-float('inf')) 248 # Subtract maximum for numerical stability 249 row_minus_max = row - tl.max(row, axis=0) 250 # Note that exponentiation in Triton is fast but approximate (i.e., think __expf in CUDA) 251 numerator = tl.exp(row_minus_max) 252 denominator = tl.sum(numerator, axis=0) 253 softmax_output = numerator / denominator 254 # Write back output to DRAM 255 output_row_start_ptr = output_ptr + row_idx * output_row_stride 256 output_ptrs = output_row_start_ptr + col_offsets 257 tl.store(output_ptrs, softmax_output, mask=col_offsets < n_cols) 258 */ 259 260 /* 261 input_row_stride = 1 262 output_row_stride = 1 263 n_cols=10 264 BLOCK_SIZE=64 265 266 module { 267 tt.func public @softmax_kernel_01(%arg0: !tt.ptr<f32, 1> , %arg1: !tt.ptr<f32, 1> ) attributes {noinline = false} { 268 %0 = tt.get_program_id x : i32 269 %c1_i32 = arith.constant 1 : i32 270 %1 = arith.muli %0, %c1_i32 : i32 271 %2 = tt.addptr %arg1, %1 : !tt.ptr<f32, 1>, i32 272 %3 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> 273 %4 = tt.splat %2 : (!tt.ptr<f32, 1>) -> tensor<64x!tt.ptr<f32, 1>> 274 %5 = tt.addptr %4, %3 : tensor<64x!tt.ptr<f32, 1>>, tensor<64xi32> 275 %c10_i32 = arith.constant 10 : i32 276 %cst = arith.constant dense<10> : tensor<64xi32> 277 %6 = arith.cmpi slt, %3, %cst : tensor<64xi32> 278 %cst_0 = arith.constant 0xFF800000 : f32 279 %cst_1 = arith.constant dense<0xFF800000> : tensor<64xf32> 280 %7 = tt.load %5, %6, %cst_1 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64xf32> 281 %8 = tt.call @max__fp32S64S__1cconstexpr_0__2cconstexpr_False__3cconstexpr_True_(%7) : (tensor<64xf32>) -> f32 282 %9 = tt.splat %8 : (f32) -> tensor<64xf32> 283 %10 = arith.subf %7, %9 : tensor<64xf32> 284 %11 = math.exp %10 : tensor<64xf32> 285 %12 = tt.call @sum__fp32S64S__1cconstexpr_0_(%11) : (tensor<64xf32>) -> f32 286 %13 = tt.splat %12 : (f32) -> tensor<64xf32> 287 %14 = arith.divf %11, %13 : tensor<64xf32> 288 %c1_i32_2 = arith.constant 1 : i32 289 %15 = arith.muli %0, %c1_i32_2 : i32 290 %16 = tt.addptr %arg0, %15 : !tt.ptr<f32, 1>, i32 291 %17 = tt.splat %16 : (!tt.ptr<f32, 1>) -> tensor<64x!tt.ptr<f32, 1>> 292 %18 = tt.addptr %17, %3 : tensor<64x!tt.ptr<f32, 1>>, tensor<64xi32> 293 %c10_i32_3 = arith.constant 10 : i32 294 %cst_4 = arith.constant dense<10> : tensor<64xi32> 295 %19 = arith.cmpi slt, %3, %cst_4 : tensor<64xi32> 296 tt.store %18, %14, %19 {cache = 1 : i32, evict = 1 : i32} : tensor<64xf32> 297 tt.return 298 } 299 tt.func private @max__fp32S64S__1cconstexpr_0__2cconstexpr_False__3cconstexpr_True_(%arg0: tensor<64xf32> ) -> f32 attributes {noinline = false} { 300 %0 = "tt.reduce"(%arg0) <{axis = 0 : i32}> ({ 301 ^bb0(%arg1: f32 , %arg2: f32 ): 302 %1 = tt.call @maximum__fp32_fp32__(%arg1, %arg2) : (f32, f32) -> f32 303 tt.reduce.return %1 : f32 304 }) : (tensor<64xf32>) -> f32 305 tt.return %0 : f32 306 } 307 tt.func private @maximum__fp32_fp32__(%arg0: f32 , %arg1: f32 ) -> f32 attributes {noinline = false} { 308 %0 = arith.maximumf %arg0, %arg1 : f32 309 tt.return %0 : f32 310 } 311 tt.func private @sum__fp32S64S__1cconstexpr_0_(%arg0: tensor<64xf32> ) -> f32 attributes {noinline = false} { 312 %0 = "tt.reduce"(%arg0) <{axis = 0 : i32}> ({ 313 ^bb0(%arg1: f32 , %arg2: f32 ): 314 %1 = tt.call @_sum_combine__fp32_fp32__(%arg1, %arg2) : (f32, f32) -> f32 315 tt.reduce.return %1 : f32 316 }) : (tensor<64xf32>) -> f32 317 tt.return %0 : f32 318 } 319 tt.func private @_sum_combine__fp32_fp32__(%arg0: f32 , %arg1: f32 ) -> f32 attributes {noinline = false} { 320 %0 = arith.addf %arg0, %arg1 : f32 321 tt.return %0 : f32 322 } 323 } 324 */