1 /*
  2  * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.
  8  *
  9  * This code is distributed in the hope that it will be useful, but WITHOUT
 10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 12  * version 2 for more details (a copy is included in the LICENSE file that
 13  * accompanied this code).
 14  *
 15  * You should have received a copy of the GNU General Public License version
 16  * 2 along with this work; if not, write to the Free Software Foundation,
 17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 18  *
 19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 20  * or visit www.oracle.com if you need additional information or have any
 21  * questions.
 22  */
 23 
 24 package oracle.code.triton;
 25 
 26 import oracle.code.triton.TritonTestExtension.Kernel;
 27 import oracle.code.triton.TritonTestExtension.TritonTestData;
 28 import org.junit.jupiter.api.Test;
 29 import org.junit.jupiter.api.extension.ExtendWith;
 30 
 31 import java.lang.reflect.code.TypeElement;
 32 import java.lang.reflect.code.type.JavaType;
 33 import java.lang.runtime.CodeReflection;
 34 import java.util.List;
 35 
 36 @ExtendWith(TritonTestExtension.class)
 37 public class TestSoftMax {
 38 
 39     @TritonCodeModel("""
 40             module ()void -> {
 41                 tt.func @"max_float_float_float" (%0 : float, %1 : float)float -> {
 42                     %2 : float = arith.maximumf %0 %1;
 43                     tt.return %2;
 44                 };
 45                 tt.func @"reduce_max_float_float_float_0" (%3 : tensor<x64, float>)float -> {
 46                     %4 : float = tt.reduce %3 @axis="0" (%5 : float, %6 : float)float -> {
 47                         %7 : float = tt.call %5 %6 @"max_float_float_float";
 48                         tt.reduce.return %7;
 49                     };
 50                     tt.return %4;
 51                 };
 52                 tt.func @"sum_float_float_float" (%8 : float, %9 : float)float -> {
 53                     %10 : float = arith.addf %8 %9;
 54                     tt.return %10;
 55                 };
 56                 tt.func @"reduce_sum_float_float_float_0" (%11 : tensor<x64, float>)float -> {
 57                     %12 : float = tt.reduce %11 @axis="0" (%13 : float, %14 : float)float -> {
 58                         %15 : float = tt.call %13 %14 @"sum_float_float_float";
 59                         tt.reduce.return %15;
 60                     };
 61                     tt.return %12;
 62                 };
 63                 tt.func @"softmax_kernel_ptr<float>_ptr<float>_1_1_10_64_void" (%16 : ptr<float>, %17 : ptr<float>)void -> {
 64                     %18 : int = arith.constant @"1";
 65                     %19 : int = arith.constant @"1";
 66                     %20 : int = arith.constant @"10";
 67                     %21 : int = tt.get_program_id @"0";
 68                     %22 : int = arith.muli %21 %18;
 69                     %23 : ptr<float> = tt.addptr %17 %22;
 70                     %24 : tensor<x64, int> = tt.make_range @start="0" @end="64";
 71                     %25 : tensor<x64, ptr<float>> = tt.splat %23;
 72                     %26 : tensor<x64, ptr<float>> = tt.addptr %25 %24;
 73                     %27 : tensor<x64, int> = tt.splat %20;
 74                     %28 : tensor<x64, int> = arith.cmpi %24 %27 @"slt";
 75                     %29 : tensor<x64, float> = tt.load %26 %28;
 76                     %30 : float = tt.call %29 @"reduce_max_float_float_float_0";
 77                     %31 : tensor<x64, float> = tt.splat %30;
 78                     %32 : tensor<x64, float> = arith.subf %29 %31;
 79                     %33 : tensor<x64, float> = math.exp %32;
 80                     %34 : float = tt.call %33 @"reduce_sum_float_float_float_0";
 81                     %35 : tensor<x64, float> = tt.splat %34;
 82                     %36 : tensor<x64, float> = arith.divf %33 %35;
 83                     %37 : int = arith.muli %21 %19;
 84                     %38 : ptr<float> = tt.addptr %16 %37;
 85                     %39 : tensor<x64, ptr<float>> = tt.splat %38;
 86                     %40 : tensor<x64, ptr<float>> = tt.addptr %39 %24;
 87                     tt.store %40 %36 %28;
 88                     tt.return;
 89                 };
 90                 unreachable;
 91             };
 92             """)
 93     @CodeReflection
 94     static void softmax_kernel(Ptr output_ptr,
 95                                Ptr input_ptr,
 96                                int input_row_stride,
 97                                int output_row_stride,
 98                                int n_cols,
 99                                @Constant int BLOCK_SIZE) {
100         // The rows of the softmax are independent, so we parallelize across those
101         var row_idx = Triton.programId(0);
102         var row_start_ptr = Triton.add(input_ptr, row_idx * input_row_stride);
103         // The block size is the next power of two greater than n_cols, so we can fit each
104         // row in a single block
105         var col_offsets = Triton.arange(0, BLOCK_SIZE);
106         var input_ptrs = Triton.add(Triton.broadcast(row_start_ptr, col_offsets.type()), col_offsets);
107         // Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols
108         var mask = Triton.compare(col_offsets,
109                 Triton.broadcast(n_cols, col_offsets.type()),
110                 Triton.CompareKind.LessThan);
111         var row = Triton.load(input_ptrs, mask);
112         // Subtract maximum for numerical stability
113         var row_minus_max = Triton.sub(row, Triton.broadcast(Triton.max(row, 0), row.type()));
114         // Note that exponentiation in Triton is fast but approximate (i.e., think __expf in CUDA)
115         var numerator = Triton.exp(row_minus_max);
116         var denominator = Triton.sum(numerator, 0);
117         var softmax_output = Triton.div(numerator, Triton.broadcast(denominator, numerator.type()));
118         // Write back output to DRAM
119         var output_row_start_ptr = Triton.add(output_ptr, row_idx * output_row_stride);
120         var output_ptrs = Triton.add(Triton.broadcast(output_row_start_ptr, col_offsets.type()), col_offsets);
121         Triton.store(output_ptrs, softmax_output, mask);
122     }
123 
124     @Kernel("softmax_kernel")
125     @Test
126     public void test(TritonTestData t) {
127         List<TypeElement> argTypes = List.of(
128                 new PtrType(JavaType.FLOAT),
129                 new PtrType(JavaType.FLOAT),
130                 new ConstantType(JavaType.INT, 1),
131                 new ConstantType(JavaType.INT, 1),
132                 new ConstantType(JavaType.INT, 10),
133                 new ConstantType(JavaType.INT, 64));
134 
135         t.test(argTypes);
136     }
137 
138     @TritonCodeModel("""
139             module ()void -> {
140                 tt.func @"max_float_float_float" (%0 : float, %1 : float)float -> {
141                     %2 : float = arith.maximumf %0 %1;
142                     tt.return %2;
143                 };
144                 tt.func @"reduce_max_float_float_float_0" (%3 : tensor<x64, float>)float -> {
145                     %4 : float = tt.reduce %3 @axis="0" (%5 : float, %6 : float)float -> {
146                         %7 : float = tt.call %5 %6 @"max_float_float_float";
147                         tt.reduce.return %7;
148                     };
149                     tt.return %4;
150                 };
151                 tt.func @"sum_float_float_float" (%8 : float, %9 : float)float -> {
152                     %10 : float = arith.addf %8 %9;
153                     tt.return %10;
154                 };
155                 tt.func @"reduce_sum_float_float_float_0" (%11 : tensor<x64, float>)float -> {
156                     %12 : float = tt.reduce %11 @axis="0" (%13 : float, %14 : float)float -> {
157                         %15 : float = tt.call %13 %14 @"sum_float_float_float";
158                         tt.reduce.return %15;
159                     };
160                     tt.return %12;
161                 };
162                 tt.func @"softmax_kernel2_ptr<float>_ptr<float>_1_1_10_64_void" (%16 : ptr<float>, %17 : ptr<float>)void -> {
163                     %18 : int = arith.constant @"1";
164                     %19 : int = arith.constant @"1";
165                     %20 : int = arith.constant @"10";
166                     %21 : int = tt.get_program_id @"0";
167                     %22 : int = arith.muli %21 %18;
168                     %23 : ptr<float> = tt.addptr %17 %22;
169                     %24 : tensor<x64, int> = tt.make_range @start="0" @end="64";
170                     %25 : tensor<x64, ptr<float>> = tt.splat %23;
171                     %26 : tensor<x64, ptr<float>> = tt.addptr %25 %24;
172                     %27 : tensor<x64, int> = tt.splat %20;
173                     %28 : tensor<x64, int> = arith.cmpi %24 %27 @"slt";
174                     %29 : tensor<x64, float> = tt.load %26 %28;
175                     %30 : float = tt.call %29 @"reduce_max_float_float_float_0";
176                     %31 : tensor<x64, float> = tt.splat %30;
177                     %32 : tensor<x64, float> = arith.subf %29 %31;
178                     %33 : tensor<x64, float> = math.exp %32;
179                     %34 : float = tt.call %33 @"reduce_sum_float_float_float_0";
180                     %35 : tensor<x64, float> = tt.splat %34;
181                     %36 : tensor<x64, float> = arith.divf %33 %35;
182                     %37 : int = arith.muli %21 %19;
183                     %38 : ptr<float> = tt.addptr %16 %37;
184                     %39 : tensor<x64, ptr<float>> = tt.splat %38;
185                     %40 : tensor<x64, ptr<float>> = tt.addptr %39 %24;
186                     tt.store %40 %36 %28;
187                     tt.return;
188                 };
189                 unreachable;
190             };
191             """)
192     @CodeReflection
193     static void softmax_kernel2(Ptr output_ptr,
194                                 Ptr input_ptr,
195                                 int input_row_stride,
196                                 int output_row_stride,
197                                 int n_cols,
198                                 @Constant int BLOCK_SIZE) {
199         // The rows of the softmax are independent, so we parallelize across those
200         var row_idx = Triton.programId(0);
201         var row_start_ptr = Triton.add(input_ptr, row_idx * input_row_stride);
202         // The block size is the next power of two greater than n_cols, so we can fit each
203         // row in a single block
204         var col_offsets = Triton.arange(0, BLOCK_SIZE);
205         var input_ptrs = Triton.add(row_start_ptr, col_offsets);
206         // Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols
207         var mask = Triton.compare(col_offsets, n_cols, Triton.CompareKind.LessThan);
208         var row = Triton.load(input_ptrs, mask);
209         // Subtract maximum for numerical stability
210         var row_minus_max = Triton.sub(row, Triton.max(row, 0));
211         // Note that exponentiation in Triton is fast but approximate (i.e., think __expf in CUDA)
212         var numerator = Triton.exp(row_minus_max);
213         var denominator = Triton.sum(numerator, 0);
214         var softmax_output = Triton.div(numerator, denominator);
215         // Write back output to DRAM
216         var output_row_start_ptr = Triton.add(output_ptr, row_idx * output_row_stride);
217         var output_ptrs = Triton.add(output_row_start_ptr, col_offsets);
218         Triton.store(output_ptrs, softmax_output, mask);
219     }
220 
221     @Kernel("softmax_kernel2")
222     @Test
223     public void test2(TritonTestData t) {
224         List<TypeElement> argTypes = List.of(
225                 new PtrType(JavaType.FLOAT),
226                 new PtrType(JavaType.FLOAT),
227                 new ConstantType(JavaType.INT, 1),
228                 new ConstantType(JavaType.INT, 1),
229                 new ConstantType(JavaType.INT, 10),
230                 new ConstantType(JavaType.INT, 64));
231 
232         t.test(argTypes);
233     }
234 }
235 
236 /*
237 def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_cols, BLOCK_SIZE: tl.constexpr):
238     # The rows of the softmax are independent, so we parallelize across those
239     row_idx = tl.program_id(0)
240     # The stride represents how much we need to increase the pointer to advance 1 row
241     row_start_ptr = input_ptr + row_idx * input_row_stride
242     # The block size is the next power of two greater than n_cols, so we can fit each
243     # row in a single block
244     col_offsets = tl.arange(0, BLOCK_SIZE)
245     input_ptrs = row_start_ptr + col_offsets
246     # Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols
247     row = tl.load(input_ptrs, mask=col_offsets < n_cols, other=-float('inf'))
248     # Subtract maximum for numerical stability
249     row_minus_max = row - tl.max(row, axis=0)
250     # Note that exponentiation in Triton is fast but approximate (i.e., think __expf in CUDA)
251     numerator = tl.exp(row_minus_max)
252     denominator = tl.sum(numerator, axis=0)
253     softmax_output = numerator / denominator
254     # Write back output to DRAM
255     output_row_start_ptr = output_ptr + row_idx * output_row_stride
256     output_ptrs = output_row_start_ptr + col_offsets
257     tl.store(output_ptrs, softmax_output, mask=col_offsets < n_cols)
258 */
259 
260 /*
261 input_row_stride = 1
262 output_row_stride = 1
263 n_cols=10
264 BLOCK_SIZE=64
265 
266 module {
267   tt.func public @softmax_kernel_01(%arg0: !tt.ptr<f32, 1> , %arg1: !tt.ptr<f32, 1> ) attributes {noinline = false} {
268     %0 = tt.get_program_id x : i32
269     %c1_i32 = arith.constant 1 : i32
270     %1 = arith.muli %0, %c1_i32 : i32
271     %2 = tt.addptr %arg1, %1 : !tt.ptr<f32, 1>, i32
272     %3 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
273     %4 = tt.splat %2 : (!tt.ptr<f32, 1>) -> tensor<64x!tt.ptr<f32, 1>>
274     %5 = tt.addptr %4, %3 : tensor<64x!tt.ptr<f32, 1>>, tensor<64xi32>
275     %c10_i32 = arith.constant 10 : i32
276     %cst = arith.constant dense<10> : tensor<64xi32>
277     %6 = arith.cmpi slt, %3, %cst : tensor<64xi32>
278     %cst_0 = arith.constant 0xFF800000 : f32
279     %cst_1 = arith.constant dense<0xFF800000> : tensor<64xf32>
280     %7 = tt.load %5, %6, %cst_1 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64xf32>
281     %8 = tt.call @max__fp32S64S__1cconstexpr_0__2cconstexpr_False__3cconstexpr_True_(%7) : (tensor<64xf32>) -> f32
282     %9 = tt.splat %8 : (f32) -> tensor<64xf32>
283     %10 = arith.subf %7, %9 : tensor<64xf32>
284     %11 = math.exp %10 : tensor<64xf32>
285     %12 = tt.call @sum__fp32S64S__1cconstexpr_0_(%11) : (tensor<64xf32>) -> f32
286     %13 = tt.splat %12 : (f32) -> tensor<64xf32>
287     %14 = arith.divf %11, %13 : tensor<64xf32>
288     %c1_i32_2 = arith.constant 1 : i32
289     %15 = arith.muli %0, %c1_i32_2 : i32
290     %16 = tt.addptr %arg0, %15 : !tt.ptr<f32, 1>, i32
291     %17 = tt.splat %16 : (!tt.ptr<f32, 1>) -> tensor<64x!tt.ptr<f32, 1>>
292     %18 = tt.addptr %17, %3 : tensor<64x!tt.ptr<f32, 1>>, tensor<64xi32>
293     %c10_i32_3 = arith.constant 10 : i32
294     %cst_4 = arith.constant dense<10> : tensor<64xi32>
295     %19 = arith.cmpi slt, %3, %cst_4 : tensor<64xi32>
296     tt.store %18, %14, %19 {cache = 1 : i32, evict = 1 : i32} : tensor<64xf32>
297     tt.return
298   }
299   tt.func private @max__fp32S64S__1cconstexpr_0__2cconstexpr_False__3cconstexpr_True_(%arg0: tensor<64xf32> ) -> f32 attributes {noinline = false} {
300     %0 = "tt.reduce"(%arg0) <{axis = 0 : i32}> ({
301     ^bb0(%arg1: f32 , %arg2: f32 ):
302       %1 = tt.call @maximum__fp32_fp32__(%arg1, %arg2) : (f32, f32) -> f32
303       tt.reduce.return %1 : f32
304     }) : (tensor<64xf32>) -> f32
305     tt.return %0 : f32
306   }
307   tt.func private @maximum__fp32_fp32__(%arg0: f32 , %arg1: f32 ) -> f32 attributes {noinline = false} {
308     %0 = arith.maximumf %arg0, %arg1 : f32
309     tt.return %0 : f32
310   }
311   tt.func private @sum__fp32S64S__1cconstexpr_0_(%arg0: tensor<64xf32> ) -> f32 attributes {noinline = false} {
312     %0 = "tt.reduce"(%arg0) <{axis = 0 : i32}> ({
313     ^bb0(%arg1: f32 , %arg2: f32 ):
314       %1 = tt.call @_sum_combine__fp32_fp32__(%arg1, %arg2) : (f32, f32) -> f32
315       tt.reduce.return %1 : f32
316     }) : (tensor<64xf32>) -> f32
317     tt.return %0 : f32
318   }
319   tt.func private @_sum_combine__fp32_fp32__(%arg0: f32 , %arg1: f32 ) -> f32 attributes {noinline = false} {
320     %0 = arith.addf %arg0, %arg1 : f32
321     tt.return %0 : f32
322   }
323 }
324 */