1 /* 2 * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved. 3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. 4 * 5 * This code is free software; you can redistribute it and/or modify it 6 * under the terms of the GNU General Public License version 2 only, as 7 * published by the Free Software Foundation. 8 * 9 * This code is distributed in the hope that it will be useful, but WITHOUT 10 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or 11 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License 12 * version 2 for more details (a copy is included in the LICENSE file that 13 * accompanied this code). 14 * 15 * You should have received a copy of the GNU General Public License version 16 * 2 along with this work; if not, write to the Free Software Foundation, 17 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. 18 * 19 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA 20 * or visit www.oracle.com if you need additional information or have any 21 * questions. 22 */ 23 24 package oracle.code.triton; 25 26 import oracle.code.triton.TritonTestExtension.Kernel; 27 import oracle.code.triton.TritonTestExtension.TritonTestData; 28 import org.junit.jupiter.api.Test; 29 import org.junit.jupiter.api.extension.ExtendWith; 30 31 import jdk.incubator.code.TypeElement; 32 import jdk.incubator.code.dialect.java.JavaType; 33 import jdk.incubator.code.CodeReflection; 34 import java.util.List; 35 36 @ExtendWith(TritonTestExtension.class) 37 public class TestSoftMax { 38 39 @TritonCodeModel(""" 40 module ()java.type:"void" -> { 41 tt.func @"max_float_float_float" (%0 : java.type:"float", %1 : java.type:"float")java.type:"float" -> { 42 %2 : java.type:"float" = arith.maximumf %0 %1; 43 tt.return %2; 44 }; 45 tt.func @"reduce_max_float_float_float_0" (%3 : tensor<x64, java.type:"float">)java.type:"float" -> { 46 %4 : java.type:"float" = tt.reduce %3 @axis=0 (%5 : java.type:"float", %6 : java.type:"float")java.type:"float" -> { 47 %7 : java.type:"float" = tt.call %5 %6 @"max_float_float_float"; 48 tt.reduce.return %7; 49 }; 50 tt.return %4; 51 }; 52 tt.func @"sum_float_float_float" (%8 : java.type:"float", %9 : java.type:"float")java.type:"float" -> { 53 %10 : java.type:"float" = arith.addf %8 %9; 54 tt.return %10; 55 }; 56 tt.func @"reduce_sum_float_float_float_0" (%11 : tensor<x64, java.type:"float">)java.type:"float" -> { 57 %12 : java.type:"float" = tt.reduce %11 @axis=0 (%13 : java.type:"float", %14 : java.type:"float")java.type:"float" -> { 58 %15 : java.type:"float" = tt.call %13 %14 @"sum_float_float_float"; 59 tt.reduce.return %15; 60 }; 61 tt.return %12; 62 }; 63 tt.func @sym_name="softmax_kernel_ptr<java.type.primitive<float>>_ptr<java.type.primitive<float>>_int_int_int_64_void" (%16 : ptr<java.type:"float">, %17 : ptr<java.type:"float">, %18 : java.type:"int", %19 : java.type:"int", %20 : java.type:"int")java.type:"void" -> { 64 %21 : java.type:"int" = tt.get_program_id @0; 65 %22 : java.type:"int" = arith.muli %21 %18; 66 %23 : ptr<java.type:"float"> = tt.addptr %17 %22; 67 %24 : tensor<x64, java.type:"int"> = tt.make_range @start=0 @end=64; 68 %25 : tensor<x64, ptr<java.type:"float">> = tt.splat %23; 69 %26 : tensor<x64, ptr<java.type:"float">> = tt.addptr %25 %24; 70 %27 : tensor<x64, java.type:"int"> = tt.splat %20; 71 %28 : tensor<x64, java.type:"boolean"> = arith.cmpi %24 %27 @"slt"; 72 %29 : tensor<x64, java.type:"float"> = tt.load %26 %28; 73 %30 : java.type:"float" = tt.call %29 @"reduce_max_float_float_float_0"; 74 %31 : tensor<x64, java.type:"float"> = tt.splat %30; 75 %32 : tensor<x64, java.type:"float"> = arith.subf %29 %31; 76 %33 : tensor<x64, java.type:"float"> = math.exp %32; 77 %34 : java.type:"float" = tt.call %33 @"reduce_sum_float_float_float_0"; 78 %35 : tensor<x64, java.type:"float"> = tt.splat %34; 79 %36 : tensor<x64, java.type:"float"> = arith.divf %33 %35; 80 %37 : java.type:"int" = arith.muli %21 %19; 81 %38 : ptr<java.type:"float"> = tt.addptr %16 %37; 82 %39 : tensor<x64, ptr<java.type:"float">> = tt.splat %38; 83 %40 : tensor<x64, ptr<java.type:"float">> = tt.addptr %39 %24; 84 tt.store %40 %36 %28; 85 tt.return; 86 }; 87 unreachable; 88 }; 89 """) 90 @CodeReflection 91 static void softmax_kernel(Ptr output_ptr, 92 Ptr input_ptr, 93 int input_row_stride, 94 int output_row_stride, 95 int n_cols, 96 @Constant int BLOCK_SIZE) { 97 // The rows of the softmax are independent, so we parallelize across those 98 var row_idx = Triton.programId(0); 99 var row_start_ptr = Triton.add(input_ptr, row_idx * input_row_stride); 100 // The block size is the next power of two greater than n_cols, so we can fit each 101 // row in a single block 102 var col_offsets = Triton.arange(0, BLOCK_SIZE); 103 var input_ptrs = Triton.add(Triton.broadcast(row_start_ptr, col_offsets.type()), col_offsets); 104 // Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols 105 var mask = Triton.compare(col_offsets, 106 Triton.broadcast(n_cols, col_offsets.type()), 107 Triton.CompareKind.LessThan); 108 var row = Triton.load(input_ptrs, mask); 109 // Subtract maximum for numerical stability 110 var row_minus_max = Triton.sub(row, Triton.broadcast(Triton.max(row, 0), row.type())); 111 // Note that exponentiation in Triton is fast but approximate (i.e., think __expf in CUDA) 112 var numerator = Triton.exp(row_minus_max); 113 var denominator = Triton.sum(numerator, 0); 114 var softmax_output = Triton.div(numerator, Triton.broadcast(denominator, numerator.type())); 115 // Write back output to DRAM 116 var output_row_start_ptr = Triton.add(output_ptr, row_idx * output_row_stride); 117 var output_ptrs = Triton.add(Triton.broadcast(output_row_start_ptr, col_offsets.type()), col_offsets); 118 Triton.store(output_ptrs, softmax_output, mask); 119 } 120 121 @Kernel("softmax_kernel") 122 @Test 123 public void test(TritonTestData t) { 124 List<TypeElement> argTypes = List.of( 125 new PtrType(JavaType.FLOAT), 126 new PtrType(JavaType.FLOAT), 127 JavaType.INT, 128 JavaType.INT, 129 JavaType.INT, 130 new ConstantType(JavaType.INT, 64)); 131 132 t.test(argTypes); 133 } 134 135 @TritonCodeModel(""" 136 module ()java.type:"void" -> { 137 tt.func @"max_float_float_float" (%0 : java.type:"float", %1 : java.type:"float")java.type:"float" -> { 138 %2 : java.type:"float" = arith.maximumf %0 %1; 139 tt.return %2; 140 }; 141 tt.func @"reduce_max_float_float_float_0" (%3 : tensor<x64, java.type:"float">)java.type:"float" -> { 142 %4 : java.type:"float" = tt.reduce %3 @axis=0 (%5 : java.type:"float", %6 : java.type:"float")java.type:"float" -> { 143 %7 : java.type:"float" = tt.call %5 %6 @"max_float_float_float"; 144 tt.reduce.return %7; 145 }; 146 tt.return %4; 147 }; 148 tt.func @"sum_float_float_float" (%8 : java.type:"float", %9 : java.type:"float")java.type:"float" -> { 149 %10 : java.type:"float" = arith.addf %8 %9; 150 tt.return %10; 151 }; 152 tt.func @"reduce_sum_float_float_float_0" (%11 : tensor<x64, java.type:"float">)java.type:"float" -> { 153 %12 : java.type:"float" = tt.reduce %11 @axis=0 (%13 : java.type:"float", %14 : java.type:"float")java.type:"float" -> { 154 %15 : java.type:"float" = tt.call %13 %14 @"sum_float_float_float"; 155 tt.reduce.return %15; 156 }; 157 tt.return %12; 158 }; 159 tt.func @sym_name="softmax_kernel2_ptr<java.type.primitive<float>>_ptr<java.type.primitive<float>>_1_1_10_64_void" (%16 : ptr<java.type:"float">, %17 : ptr<java.type:"float">)java.type:"void" -> { 160 %18 : java.type:"int" = arith.constant @1; 161 %19 : java.type:"int" = arith.constant @1; 162 %20 : java.type:"int" = arith.constant @10; 163 %21 : java.type:"int" = tt.get_program_id @0; 164 %22 : java.type:"int" = arith.muli %21 %18; 165 %23 : ptr<java.type:"float"> = tt.addptr %17 %22; 166 %24 : tensor<x64, java.type:"int"> = tt.make_range @start=0 @end=64; 167 %25 : tensor<x64, ptr<java.type:"float">> = tt.splat %23; 168 %26 : tensor<x64, ptr<java.type:"float">> = tt.addptr %25 %24; 169 %27 : tensor<x64, java.type:"int"> = tt.splat %20; 170 %28 : tensor<x64, java.type:"boolean"> = arith.cmpi %24 %27 @"slt"; 171 %29 : tensor<x64, java.type:"float"> = tt.load %26 %28; 172 %30 : java.type:"float" = tt.call %29 @"reduce_max_float_float_float_0"; 173 %31 : tensor<x64, java.type:"float"> = tt.splat %30; 174 %32 : tensor<x64, java.type:"float"> = arith.subf %29 %31; 175 %33 : tensor<x64, java.type:"float"> = math.exp %32; 176 %34 : java.type:"float" = tt.call %33 @"reduce_sum_float_float_float_0"; 177 %35 : tensor<x64, java.type:"float"> = tt.splat %34; 178 %36 : tensor<x64, java.type:"float"> = arith.divf %33 %35; 179 %37 : java.type:"int" = arith.muli %21 %19; 180 %38 : ptr<java.type:"float"> = tt.addptr %16 %37; 181 %39 : tensor<x64, ptr<java.type:"float">> = tt.splat %38; 182 %40 : tensor<x64, ptr<java.type:"float">> = tt.addptr %39 %24; 183 tt.store %40 %36 %28; 184 tt.return; 185 }; 186 unreachable; 187 }; 188 """) 189 @CodeReflection 190 static void softmax_kernel2(Ptr output_ptr, 191 Ptr input_ptr, 192 int input_row_stride, 193 int output_row_stride, 194 int n_cols, 195 @Constant int BLOCK_SIZE) { 196 // The rows of the softmax are independent, so we parallelize across those 197 var row_idx = Triton.programId(0); 198 var row_start_ptr = Triton.add(input_ptr, row_idx * input_row_stride); 199 // The block size is the next power of two greater than n_cols, so we can fit each 200 // row in a single block 201 var col_offsets = Triton.arange(0, BLOCK_SIZE); 202 var input_ptrs = Triton.add(row_start_ptr, col_offsets); 203 // Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols 204 var mask = Triton.compare(col_offsets, n_cols, Triton.CompareKind.LessThan); 205 var row = Triton.load(input_ptrs, mask); 206 // Subtract maximum for numerical stability 207 var row_minus_max = Triton.sub(row, Triton.max(row, 0)); 208 // Note that exponentiation in Triton is fast but approximate (i.e., think __expf in CUDA) 209 var numerator = Triton.exp(row_minus_max); 210 var denominator = Triton.sum(numerator, 0); 211 var softmax_output = Triton.div(numerator, denominator); 212 // Write back output to DRAM 213 var output_row_start_ptr = Triton.add(output_ptr, row_idx * output_row_stride); 214 var output_ptrs = Triton.add(output_row_start_ptr, col_offsets); 215 Triton.store(output_ptrs, softmax_output, mask); 216 } 217 218 @Kernel("softmax_kernel2") 219 @Test 220 public void test2(TritonTestData t) { 221 List<TypeElement> argTypes = List.of( 222 new PtrType(JavaType.FLOAT), 223 new PtrType(JavaType.FLOAT), 224 new ConstantType(JavaType.INT, 1), 225 new ConstantType(JavaType.INT, 1), 226 new ConstantType(JavaType.INT, 10), 227 new ConstantType(JavaType.INT, 64)); 228 229 t.test(argTypes); 230 } 231 } 232 233 /* 234 def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_cols, BLOCK_SIZE: tl.constexpr): 235 # The rows of the softmax are independent, so we parallelize across those 236 row_idx = tl.program_id(0) 237 # The stride represents how much we need to increase the pointer to advance 1 row 238 row_start_ptr = input_ptr + row_idx * input_row_stride 239 # The block size is the next power of two greater than n_cols, so we can fit each 240 # row in a single block 241 col_offsets = tl.arange(0, BLOCK_SIZE) 242 input_ptrs = row_start_ptr + col_offsets 243 # Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols 244 row = tl.load(input_ptrs, mask=col_offsets < n_cols, other=-float('inf')) 245 # Subtract maximum for numerical stability 246 row_minus_max = row - tl.max(row, axis=0) 247 # Note that exponentiation in Triton is fast but approximate (i.e., think __expf in CUDA) 248 numerator = tl.exp(row_minus_max) 249 denominator = tl.sum(numerator, axis=0) 250 softmax_output = numerator / denominator 251 # Write back output to DRAM 252 output_row_start_ptr = output_ptr + row_idx * output_row_stride 253 output_ptrs = output_row_start_ptr + col_offsets 254 tl.store(output_ptrs, softmax_output, mask=col_offsets < n_cols) 255 */ 256 257 /* 258 input_row_stride = 1 259 output_row_stride = 1 260 n_cols=10 261 BLOCK_SIZE=64 262 263 module { 264 tt.func public @softmax_kernel_01(%arg0: !tt.ptr<f32, 1> , %arg1: !tt.ptr<f32, 1> ) attributes {noinline = false} { 265 %0 = tt.get_program_id x : i32 266 %c1_i32 = arith.constant 1 : i32 267 %1 = arith.muli %0, %c1_i32 : i32 268 %2 = tt.addptr %arg1, %1 : !tt.ptr<f32, 1>, i32 269 %3 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> 270 %4 = tt.splat %2 : (!tt.ptr<f32, 1>) -> tensor<64x!tt.ptr<f32, 1>> 271 %5 = tt.addptr %4, %3 : tensor<64x!tt.ptr<f32, 1>>, tensor<64xi32> 272 %c10_i32 = arith.constant 10 : i32 273 %cst = arith.constant dense<10> : tensor<64xi32> 274 %6 = arith.cmpi slt, %3, %cst : tensor<64xi32> 275 %cst_0 = arith.constant 0xFF800000 : f32 276 %cst_1 = arith.constant dense<0xFF800000> : tensor<64xf32> 277 %7 = tt.load %5, %6, %cst_1 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64xf32> 278 %8 = tt.call @max__fp32S64S__1cconstexpr_0__2cconstexpr_False__3cconstexpr_True_(%7) : (tensor<64xf32>) -> f32 279 %9 = tt.splat %8 : (f32) -> tensor<64xf32> 280 %10 = arith.subf %7, %9 : tensor<64xf32> 281 %11 = math.exp %10 : tensor<64xf32> 282 %12 = tt.call @sum__fp32S64S__1cconstexpr_0_(%11) : (tensor<64xf32>) -> f32 283 %13 = tt.splat %12 : (f32) -> tensor<64xf32> 284 %14 = arith.divf %11, %13 : tensor<64xf32> 285 %c1_i32_2 = arith.constant 1 : i32 286 %15 = arith.muli %0, %c1_i32_2 : i32 287 %16 = tt.addptr %arg0, %15 : !tt.ptr<f32, 1>, i32 288 %17 = tt.splat %16 : (!tt.ptr<f32, 1>) -> tensor<64x!tt.ptr<f32, 1>> 289 %18 = tt.addptr %17, %3 : tensor<64x!tt.ptr<f32, 1>>, tensor<64xi32> 290 %c10_i32_3 = arith.constant 10 : i32 291 %cst_4 = arith.constant dense<10> : tensor<64xi32> 292 %19 = arith.cmpi slt, %3, %cst_4 : tensor<64xi32> 293 tt.store %18, %14, %19 {cache = 1 : i32, evict = 1 : i32} : tensor<64xf32> 294 tt.return 295 } 296 tt.func private @max__fp32S64S__1cconstexpr_0__2cconstexpr_False__3cconstexpr_True_(%arg0: tensor<64xf32> ) -> f32 attributes {noinline = false} { 297 %0 = "tt.reduce"(%arg0) <{axis = 0 : i32}> ({ 298 ^bb0(%arg1: f32 , %arg2: f32 ): 299 %1 = tt.call @maximum__fp32_fp32__(%arg1, %arg2) : (f32, f32) -> f32 300 tt.reduce.return %1 : f32 301 }) : (tensor<64xf32>) -> f32 302 tt.return %0 : f32 303 } 304 tt.func private @maximum__fp32_fp32__(%arg0: f32 , %arg1: f32 ) -> f32 attributes {noinline = false} { 305 %0 = arith.maximumf %arg0, %arg1 : f32 306 tt.return %0 : f32 307 } 308 tt.func private @sum__fp32S64S__1cconstexpr_0_(%arg0: tensor<64xf32> ) -> f32 attributes {noinline = false} { 309 %0 = "tt.reduce"(%arg0) <{axis = 0 : i32}> ({ 310 ^bb0(%arg1: f32 , %arg2: f32 ): 311 %1 = tt.call @_sum_combine__fp32_fp32__(%arg1, %arg2) : (f32, f32) -> f32 312 tt.reduce.return %1 : f32 313 }) : (tensor<64xf32>) -> f32 314 tt.return %0 : f32 315 } 316 tt.func private @_sum_combine__fp32_fp32__(%arg0: f32 , %arg1: f32 ) -> f32 attributes {noinline = false} { 317 %0 = arith.addf %arg0, %arg1 : f32 318 tt.return %0 : f32 319 } 320 } 321 */