1 /*
  2  * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.
  8  *
  9  * This code is distributed in the hope that it will be useful, but WITHOUT
 10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 12  * version 2 for more details (a copy is included in the LICENSE file that
 13  * accompanied this code).
 14  *
 15  * You should have received a copy of the GNU General Public License version
 16  * 2 along with this work; if not, write to the Free Software Foundation,
 17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 18  *
 19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 20  * or visit www.oracle.com if you need additional information or have any
 21  * questions.
 22  */
 23 
 24 package oracle.code.triton;
 25 
 26 import oracle.code.triton.TritonTestExtension.Kernel;
 27 import oracle.code.triton.TritonTestExtension.TritonTestData;
 28 import org.junit.jupiter.api.Test;
 29 import org.junit.jupiter.api.extension.ExtendWith;
 30 
 31 import jdk.incubator.code.TypeElement;
 32 import jdk.incubator.code.dialect.java.JavaType;
 33 import jdk.incubator.code.CodeReflection;
 34 import java.util.List;
 35 
 36 @ExtendWith(TritonTestExtension.class)
 37 public class TestSoftMax {
 38 
 39     @TritonCodeModel("""
 40             module ()java.type:"void" -> {
 41                 tt.func @"max_float_float_float" (%0 : java.type:"float", %1 : java.type:"float")java.type:"float" -> {
 42                     %2 : java.type:"float" = arith.maximumf %0 %1;
 43                     tt.return %2;
 44                 };
 45                 tt.func @"reduce_max_float_float_float_0" (%3 : tensor<x64, java.type:"float">)java.type:"float" -> {
 46                     %4 : java.type:"float" = tt.reduce %3 @axis=0 (%5 : java.type:"float", %6 : java.type:"float")java.type:"float" -> {
 47                         %7 : java.type:"float" = tt.call %5 %6 @"max_float_float_float";
 48                         tt.reduce.return %7;
 49                     };
 50                     tt.return %4;
 51                 };
 52                 tt.func @"sum_float_float_float" (%8 : java.type:"float", %9 : java.type:"float")java.type:"float" -> {
 53                     %10 : java.type:"float" = arith.addf %8 %9;
 54                     tt.return %10;
 55                 };
 56                 tt.func @"reduce_sum_float_float_float_0" (%11 : tensor<x64, java.type:"float">)java.type:"float" -> {
 57                     %12 : java.type:"float" = tt.reduce %11 @axis=0 (%13 : java.type:"float", %14 : java.type:"float")java.type:"float" -> {
 58                         %15 : java.type:"float" = tt.call %13 %14 @"sum_float_float_float";
 59                         tt.reduce.return %15;
 60                     };
 61                     tt.return %12;
 62                 };
 63                 tt.func @sym_name="softmax_kernel_ptr<java.type.primitive<float>>_ptr<java.type.primitive<float>>_int_int_int_64_void" (%16 : ptr<java.type:"float">, %17 : ptr<java.type:"float">, %18 : java.type:"int", %19 : java.type:"int", %20 : java.type:"int")java.type:"void" -> {
 64                     %21 : java.type:"int" = tt.get_program_id @0;
 65                     %22 : java.type:"int" = arith.muli %21 %18;
 66                     %23 : ptr<java.type:"float"> = tt.addptr %17 %22;
 67                     %24 : tensor<x64, java.type:"int"> = tt.make_range @start=0 @end=64;
 68                     %25 : tensor<x64, ptr<java.type:"float">> = tt.splat %23;
 69                     %26 : tensor<x64, ptr<java.type:"float">> = tt.addptr %25 %24;
 70                     %27 : tensor<x64, java.type:"int"> = tt.splat %20;
 71                     %28 : tensor<x64, java.type:"boolean"> = arith.cmpi %24 %27 @"slt";
 72                     %29 : tensor<x64, java.type:"float"> = tt.load %26 %28;
 73                     %30 : java.type:"float" = tt.call %29 @"reduce_max_float_float_float_0";
 74                     %31 : tensor<x64, java.type:"float"> = tt.splat %30;
 75                     %32 : tensor<x64, java.type:"float"> = arith.subf %29 %31;
 76                     %33 : tensor<x64, java.type:"float"> = math.exp %32;
 77                     %34 : java.type:"float" = tt.call %33 @"reduce_sum_float_float_float_0";
 78                     %35 : tensor<x64, java.type:"float"> = tt.splat %34;
 79                     %36 : tensor<x64, java.type:"float"> = arith.divf %33 %35;
 80                     %37 : java.type:"int" = arith.muli %21 %19;
 81                     %38 : ptr<java.type:"float"> = tt.addptr %16 %37;
 82                     %39 : tensor<x64, ptr<java.type:"float">> = tt.splat %38;
 83                     %40 : tensor<x64, ptr<java.type:"float">> = tt.addptr %39 %24;
 84                     tt.store %40 %36 %28;
 85                     tt.return;
 86                 };
 87                 unreachable;
 88             };
 89             """)
 90     @CodeReflection
 91     static void softmax_kernel(Ptr output_ptr,
 92                                Ptr input_ptr,
 93                                int input_row_stride,
 94                                int output_row_stride,
 95                                int n_cols,
 96                                @Constant int BLOCK_SIZE) {
 97         // The rows of the softmax are independent, so we parallelize across those
 98         var row_idx = Triton.programId(0);
 99         var row_start_ptr = Triton.add(input_ptr, row_idx * input_row_stride);
100         // The block size is the next power of two greater than n_cols, so we can fit each
101         // row in a single block
102         var col_offsets = Triton.arange(0, BLOCK_SIZE);
103         var input_ptrs = Triton.add(Triton.broadcast(row_start_ptr, col_offsets.type()), col_offsets);
104         // Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols
105         var mask = Triton.compare(col_offsets,
106                 Triton.broadcast(n_cols, col_offsets.type()),
107                 Triton.CompareKind.LessThan);
108         var row = Triton.load(input_ptrs, mask);
109         // Subtract maximum for numerical stability
110         var row_minus_max = Triton.sub(row, Triton.broadcast(Triton.max(row, 0), row.type()));
111         // Note that exponentiation in Triton is fast but approximate (i.e., think __expf in CUDA)
112         var numerator = Triton.exp(row_minus_max);
113         var denominator = Triton.sum(numerator, 0);
114         var softmax_output = Triton.div(numerator, Triton.broadcast(denominator, numerator.type()));
115         // Write back output to DRAM
116         var output_row_start_ptr = Triton.add(output_ptr, row_idx * output_row_stride);
117         var output_ptrs = Triton.add(Triton.broadcast(output_row_start_ptr, col_offsets.type()), col_offsets);
118         Triton.store(output_ptrs, softmax_output, mask);
119     }
120 
121     @Kernel("softmax_kernel")
122     @Test
123     public void test(TritonTestData t) {
124         List<TypeElement> argTypes = List.of(
125                 new PtrType(JavaType.FLOAT),
126                 new PtrType(JavaType.FLOAT),
127                 JavaType.INT,
128                 JavaType.INT,
129                 JavaType.INT,
130                 new ConstantType(JavaType.INT, 64));
131 
132         t.test(argTypes);
133     }
134 
135     @TritonCodeModel("""
136             module ()java.type:"void" -> {
137                 tt.func @"max_float_float_float" (%0 : java.type:"float", %1 : java.type:"float")java.type:"float" -> {
138                     %2 : java.type:"float" = arith.maximumf %0 %1;
139                     tt.return %2;
140                 };
141                 tt.func @"reduce_max_float_float_float_0" (%3 : tensor<x64, java.type:"float">)java.type:"float" -> {
142                     %4 : java.type:"float" = tt.reduce %3 @axis=0 (%5 : java.type:"float", %6 : java.type:"float")java.type:"float" -> {
143                         %7 : java.type:"float" = tt.call %5 %6 @"max_float_float_float";
144                         tt.reduce.return %7;
145                     };
146                     tt.return %4;
147                 };
148                 tt.func @"sum_float_float_float" (%8 : java.type:"float", %9 : java.type:"float")java.type:"float" -> {
149                     %10 : java.type:"float" = arith.addf %8 %9;
150                     tt.return %10;
151                 };
152                 tt.func @"reduce_sum_float_float_float_0" (%11 : tensor<x64, java.type:"float">)java.type:"float" -> {
153                     %12 : java.type:"float" = tt.reduce %11 @axis=0 (%13 : java.type:"float", %14 : java.type:"float")java.type:"float" -> {
154                         %15 : java.type:"float" = tt.call %13 %14 @"sum_float_float_float";
155                         tt.reduce.return %15;
156                     };
157                     tt.return %12;
158                 };
159                 tt.func @sym_name="softmax_kernel2_ptr<java.type.primitive<float>>_ptr<java.type.primitive<float>>_1_1_10_64_void" (%16 : ptr<java.type:"float">, %17 : ptr<java.type:"float">)java.type:"void" -> {
160                     %18 : java.type:"int" = arith.constant @1;
161                     %19 : java.type:"int" = arith.constant @1;
162                     %20 : java.type:"int" = arith.constant @10;
163                     %21 : java.type:"int" = tt.get_program_id @0;
164                     %22 : java.type:"int" = arith.muli %21 %18;
165                     %23 : ptr<java.type:"float"> = tt.addptr %17 %22;
166                     %24 : tensor<x64, java.type:"int"> = tt.make_range @start=0 @end=64;
167                     %25 : tensor<x64, ptr<java.type:"float">> = tt.splat %23;
168                     %26 : tensor<x64, ptr<java.type:"float">> = tt.addptr %25 %24;
169                     %27 : tensor<x64, java.type:"int"> = tt.splat %20;
170                     %28 : tensor<x64, java.type:"boolean"> = arith.cmpi %24 %27 @"slt";
171                     %29 : tensor<x64, java.type:"float"> = tt.load %26 %28;
172                     %30 : java.type:"float" = tt.call %29 @"reduce_max_float_float_float_0";
173                     %31 : tensor<x64, java.type:"float"> = tt.splat %30;
174                     %32 : tensor<x64, java.type:"float"> = arith.subf %29 %31;
175                     %33 : tensor<x64, java.type:"float"> = math.exp %32;
176                     %34 : java.type:"float" = tt.call %33 @"reduce_sum_float_float_float_0";
177                     %35 : tensor<x64, java.type:"float"> = tt.splat %34;
178                     %36 : tensor<x64, java.type:"float"> = arith.divf %33 %35;
179                     %37 : java.type:"int" = arith.muli %21 %19;
180                     %38 : ptr<java.type:"float"> = tt.addptr %16 %37;
181                     %39 : tensor<x64, ptr<java.type:"float">> = tt.splat %38;
182                     %40 : tensor<x64, ptr<java.type:"float">> = tt.addptr %39 %24;
183                     tt.store %40 %36 %28;
184                     tt.return;
185                 };
186                 unreachable;
187             };
188             """)
189     @CodeReflection
190     static void softmax_kernel2(Ptr output_ptr,
191                                 Ptr input_ptr,
192                                 int input_row_stride,
193                                 int output_row_stride,
194                                 int n_cols,
195                                 @Constant int BLOCK_SIZE) {
196         // The rows of the softmax are independent, so we parallelize across those
197         var row_idx = Triton.programId(0);
198         var row_start_ptr = Triton.add(input_ptr, row_idx * input_row_stride);
199         // The block size is the next power of two greater than n_cols, so we can fit each
200         // row in a single block
201         var col_offsets = Triton.arange(0, BLOCK_SIZE);
202         var input_ptrs = Triton.add(row_start_ptr, col_offsets);
203         // Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols
204         var mask = Triton.compare(col_offsets, n_cols, Triton.CompareKind.LessThan);
205         var row = Triton.load(input_ptrs, mask);
206         // Subtract maximum for numerical stability
207         var row_minus_max = Triton.sub(row, Triton.max(row, 0));
208         // Note that exponentiation in Triton is fast but approximate (i.e., think __expf in CUDA)
209         var numerator = Triton.exp(row_minus_max);
210         var denominator = Triton.sum(numerator, 0);
211         var softmax_output = Triton.div(numerator, denominator);
212         // Write back output to DRAM
213         var output_row_start_ptr = Triton.add(output_ptr, row_idx * output_row_stride);
214         var output_ptrs = Triton.add(output_row_start_ptr, col_offsets);
215         Triton.store(output_ptrs, softmax_output, mask);
216     }
217 
218     @Kernel("softmax_kernel2")
219     @Test
220     public void test2(TritonTestData t) {
221         List<TypeElement> argTypes = List.of(
222                 new PtrType(JavaType.FLOAT),
223                 new PtrType(JavaType.FLOAT),
224                 new ConstantType(JavaType.INT, 1),
225                 new ConstantType(JavaType.INT, 1),
226                 new ConstantType(JavaType.INT, 10),
227                 new ConstantType(JavaType.INT, 64));
228 
229         t.test(argTypes);
230     }
231 }
232 
233 /*
234 def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_cols, BLOCK_SIZE: tl.constexpr):
235     # The rows of the softmax are independent, so we parallelize across those
236     row_idx = tl.program_id(0)
237     # The stride represents how much we need to increase the pointer to advance 1 row
238     row_start_ptr = input_ptr + row_idx * input_row_stride
239     # The block size is the next power of two greater than n_cols, so we can fit each
240     # row in a single block
241     col_offsets = tl.arange(0, BLOCK_SIZE)
242     input_ptrs = row_start_ptr + col_offsets
243     # Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols
244     row = tl.load(input_ptrs, mask=col_offsets < n_cols, other=-float('inf'))
245     # Subtract maximum for numerical stability
246     row_minus_max = row - tl.max(row, axis=0)
247     # Note that exponentiation in Triton is fast but approximate (i.e., think __expf in CUDA)
248     numerator = tl.exp(row_minus_max)
249     denominator = tl.sum(numerator, axis=0)
250     softmax_output = numerator / denominator
251     # Write back output to DRAM
252     output_row_start_ptr = output_ptr + row_idx * output_row_stride
253     output_ptrs = output_row_start_ptr + col_offsets
254     tl.store(output_ptrs, softmax_output, mask=col_offsets < n_cols)
255 */
256 
257 /*
258 input_row_stride = 1
259 output_row_stride = 1
260 n_cols=10
261 BLOCK_SIZE=64
262 
263 module {
264   tt.func public @softmax_kernel_01(%arg0: !tt.ptr<f32, 1> , %arg1: !tt.ptr<f32, 1> ) attributes {noinline = false} {
265     %0 = tt.get_program_id x : i32
266     %c1_i32 = arith.constant 1 : i32
267     %1 = arith.muli %0, %c1_i32 : i32
268     %2 = tt.addptr %arg1, %1 : !tt.ptr<f32, 1>, i32
269     %3 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
270     %4 = tt.splat %2 : (!tt.ptr<f32, 1>) -> tensor<64x!tt.ptr<f32, 1>>
271     %5 = tt.addptr %4, %3 : tensor<64x!tt.ptr<f32, 1>>, tensor<64xi32>
272     %c10_i32 = arith.constant 10 : i32
273     %cst = arith.constant dense<10> : tensor<64xi32>
274     %6 = arith.cmpi slt, %3, %cst : tensor<64xi32>
275     %cst_0 = arith.constant 0xFF800000 : f32
276     %cst_1 = arith.constant dense<0xFF800000> : tensor<64xf32>
277     %7 = tt.load %5, %6, %cst_1 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64xf32>
278     %8 = tt.call @max__fp32S64S__1cconstexpr_0__2cconstexpr_False__3cconstexpr_True_(%7) : (tensor<64xf32>) -> f32
279     %9 = tt.splat %8 : (f32) -> tensor<64xf32>
280     %10 = arith.subf %7, %9 : tensor<64xf32>
281     %11 = math.exp %10 : tensor<64xf32>
282     %12 = tt.call @sum__fp32S64S__1cconstexpr_0_(%11) : (tensor<64xf32>) -> f32
283     %13 = tt.splat %12 : (f32) -> tensor<64xf32>
284     %14 = arith.divf %11, %13 : tensor<64xf32>
285     %c1_i32_2 = arith.constant 1 : i32
286     %15 = arith.muli %0, %c1_i32_2 : i32
287     %16 = tt.addptr %arg0, %15 : !tt.ptr<f32, 1>, i32
288     %17 = tt.splat %16 : (!tt.ptr<f32, 1>) -> tensor<64x!tt.ptr<f32, 1>>
289     %18 = tt.addptr %17, %3 : tensor<64x!tt.ptr<f32, 1>>, tensor<64xi32>
290     %c10_i32_3 = arith.constant 10 : i32
291     %cst_4 = arith.constant dense<10> : tensor<64xi32>
292     %19 = arith.cmpi slt, %3, %cst_4 : tensor<64xi32>
293     tt.store %18, %14, %19 {cache = 1 : i32, evict = 1 : i32} : tensor<64xf32>
294     tt.return
295   }
296   tt.func private @max__fp32S64S__1cconstexpr_0__2cconstexpr_False__3cconstexpr_True_(%arg0: tensor<64xf32> ) -> f32 attributes {noinline = false} {
297     %0 = "tt.reduce"(%arg0) <{axis = 0 : i32}> ({
298     ^bb0(%arg1: f32 , %arg2: f32 ):
299       %1 = tt.call @maximum__fp32_fp32__(%arg1, %arg2) : (f32, f32) -> f32
300       tt.reduce.return %1 : f32
301     }) : (tensor<64xf32>) -> f32
302     tt.return %0 : f32
303   }
304   tt.func private @maximum__fp32_fp32__(%arg0: f32 , %arg1: f32 ) -> f32 attributes {noinline = false} {
305     %0 = arith.maximumf %arg0, %arg1 : f32
306     tt.return %0 : f32
307   }
308   tt.func private @sum__fp32S64S__1cconstexpr_0_(%arg0: tensor<64xf32> ) -> f32 attributes {noinline = false} {
309     %0 = "tt.reduce"(%arg0) <{axis = 0 : i32}> ({
310     ^bb0(%arg1: f32 , %arg2: f32 ):
311       %1 = tt.call @_sum_combine__fp32_fp32__(%arg1, %arg2) : (f32, f32) -> f32
312       tt.reduce.return %1 : f32
313     }) : (tensor<64xf32>) -> f32
314     tt.return %0 : f32
315   }
316   tt.func private @_sum_combine__fp32_fp32__(%arg0: f32 , %arg1: f32 ) -> f32 attributes {noinline = false} {
317     %0 = arith.addf %arg0, %arg1 : f32
318     tt.return %0 : f32
319   }
320 }
321 */