1 /*
2 * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation.
8 *
9 * This code is distributed in the hope that it will be useful, but WITHOUT
10 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
11 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
12 * version 2 for more details (a copy is included in the LICENSE file that
13 * accompanied this code).
14 *
15 * You should have received a copy of the GNU General Public License version
16 * 2 along with this work; if not, write to the Free Software Foundation,
17 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
18 *
19 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
20 * or visit www.oracle.com if you need additional information or have any
21 * questions.
22 */
23
24 package oracle.code.triton;
25
26 import oracle.code.triton.TritonTestExtension.Kernel;
27 import oracle.code.triton.TritonTestExtension.TritonTestData;
28 import org.junit.jupiter.api.Test;
29 import org.junit.jupiter.api.extension.ExtendWith;
30
31 import jdk.incubator.code.TypeElement;
32 import jdk.incubator.code.dialect.java.JavaType;
33 import jdk.incubator.code.CodeReflection;
34 import java.util.List;
35
36 @ExtendWith(TritonTestExtension.class)
37 public class TestSoftMax {
38
39 @TritonCodeModel("""
40 module ()java.type:"void" -> {
41 tt.func @"max_float_float_float" (%0 : java.type:"float", %1 : java.type:"float")java.type:"float" -> {
42 %2 : java.type:"float" = arith.maximumf %0 %1;
43 tt.return %2;
44 };
45 tt.func @"reduce_max_float_float_float_0" (%3 : tensor<x64, java.type:"float">)java.type:"float" -> {
46 %4 : java.type:"float" = tt.reduce %3 @axis=0 (%5 : java.type:"float", %6 : java.type:"float")java.type:"float" -> {
47 %7 : java.type:"float" = tt.call %5 %6 @"max_float_float_float";
48 tt.reduce.return %7;
49 };
50 tt.return %4;
51 };
52 tt.func @"sum_float_float_float" (%8 : java.type:"float", %9 : java.type:"float")java.type:"float" -> {
53 %10 : java.type:"float" = arith.addf %8 %9;
54 tt.return %10;
55 };
56 tt.func @"reduce_sum_float_float_float_0" (%11 : tensor<x64, java.type:"float">)java.type:"float" -> {
57 %12 : java.type:"float" = tt.reduce %11 @axis=0 (%13 : java.type:"float", %14 : java.type:"float")java.type:"float" -> {
58 %15 : java.type:"float" = tt.call %13 %14 @"sum_float_float_float";
59 tt.reduce.return %15;
60 };
61 tt.return %12;
62 };
63 tt.func @sym_name="softmax_kernel_ptr<java.type.primitive<float>>_ptr<java.type.primitive<float>>_int_int_int_64_void" (%16 : ptr<java.type:"float">, %17 : ptr<java.type:"float">, %18 : java.type:"int", %19 : java.type:"int", %20 : java.type:"int")java.type:"void" -> {
64 %21 : java.type:"int" = tt.get_program_id @0;
65 %22 : java.type:"int" = arith.muli %21 %18;
66 %23 : ptr<java.type:"float"> = tt.addptr %17 %22;
67 %24 : tensor<x64, java.type:"int"> = tt.make_range @start=0 @end=64;
68 %25 : tensor<x64, ptr<java.type:"float">> = tt.splat %23;
69 %26 : tensor<x64, ptr<java.type:"float">> = tt.addptr %25 %24;
70 %27 : tensor<x64, java.type:"int"> = tt.splat %20;
71 %28 : tensor<x64, java.type:"boolean"> = arith.cmpi %24 %27 @"slt";
72 %29 : tensor<x64, java.type:"float"> = tt.load %26 %28;
73 %30 : java.type:"float" = tt.call %29 @"reduce_max_float_float_float_0";
74 %31 : tensor<x64, java.type:"float"> = tt.splat %30;
75 %32 : tensor<x64, java.type:"float"> = arith.subf %29 %31;
76 %33 : tensor<x64, java.type:"float"> = math.exp %32;
77 %34 : java.type:"float" = tt.call %33 @"reduce_sum_float_float_float_0";
78 %35 : tensor<x64, java.type:"float"> = tt.splat %34;
79 %36 : tensor<x64, java.type:"float"> = arith.divf %33 %35;
80 %37 : java.type:"int" = arith.muli %21 %19;
81 %38 : ptr<java.type:"float"> = tt.addptr %16 %37;
82 %39 : tensor<x64, ptr<java.type:"float">> = tt.splat %38;
83 %40 : tensor<x64, ptr<java.type:"float">> = tt.addptr %39 %24;
84 tt.store %40 %36 %28;
85 tt.return;
86 };
87 unreachable;
88 };
89 """)
90 @CodeReflection
91 static void softmax_kernel(Ptr output_ptr,
92 Ptr input_ptr,
93 int input_row_stride,
94 int output_row_stride,
95 int n_cols,
96 @Constant int BLOCK_SIZE) {
97 // The rows of the softmax are independent, so we parallelize across those
98 var row_idx = Triton.programId(0);
99 var row_start_ptr = Triton.add(input_ptr, row_idx * input_row_stride);
100 // The block size is the next power of two greater than n_cols, so we can fit each
101 // row in a single block
102 var col_offsets = Triton.arange(0, BLOCK_SIZE);
103 var input_ptrs = Triton.add(Triton.broadcast(row_start_ptr, col_offsets.type()), col_offsets);
104 // Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols
105 var mask = Triton.compare(col_offsets,
106 Triton.broadcast(n_cols, col_offsets.type()),
107 Triton.CompareKind.LessThan);
108 var row = Triton.load(input_ptrs, mask);
109 // Subtract maximum for numerical stability
110 var row_minus_max = Triton.sub(row, Triton.broadcast(Triton.max(row, 0), row.type()));
111 // Note that exponentiation in Triton is fast but approximate (i.e., think __expf in CUDA)
112 var numerator = Triton.exp(row_minus_max);
113 var denominator = Triton.sum(numerator, 0);
114 var softmax_output = Triton.div(numerator, Triton.broadcast(denominator, numerator.type()));
115 // Write back output to DRAM
116 var output_row_start_ptr = Triton.add(output_ptr, row_idx * output_row_stride);
117 var output_ptrs = Triton.add(Triton.broadcast(output_row_start_ptr, col_offsets.type()), col_offsets);
118 Triton.store(output_ptrs, softmax_output, mask);
119 }
120
121 @Kernel("softmax_kernel")
122 @Test
123 public void test(TritonTestData t) {
124 List<TypeElement> argTypes = List.of(
125 new PtrType(JavaType.FLOAT),
126 new PtrType(JavaType.FLOAT),
127 JavaType.INT,
128 JavaType.INT,
129 JavaType.INT,
130 new ConstantType(JavaType.INT, 64));
131
132 t.test(argTypes);
133 }
134
135 @TritonCodeModel("""
136 module ()java.type:"void" -> {
137 tt.func @"max_float_float_float" (%0 : java.type:"float", %1 : java.type:"float")java.type:"float" -> {
138 %2 : java.type:"float" = arith.maximumf %0 %1;
139 tt.return %2;
140 };
141 tt.func @"reduce_max_float_float_float_0" (%3 : tensor<x64, java.type:"float">)java.type:"float" -> {
142 %4 : java.type:"float" = tt.reduce %3 @axis=0 (%5 : java.type:"float", %6 : java.type:"float")java.type:"float" -> {
143 %7 : java.type:"float" = tt.call %5 %6 @"max_float_float_float";
144 tt.reduce.return %7;
145 };
146 tt.return %4;
147 };
148 tt.func @"sum_float_float_float" (%8 : java.type:"float", %9 : java.type:"float")java.type:"float" -> {
149 %10 : java.type:"float" = arith.addf %8 %9;
150 tt.return %10;
151 };
152 tt.func @"reduce_sum_float_float_float_0" (%11 : tensor<x64, java.type:"float">)java.type:"float" -> {
153 %12 : java.type:"float" = tt.reduce %11 @axis=0 (%13 : java.type:"float", %14 : java.type:"float")java.type:"float" -> {
154 %15 : java.type:"float" = tt.call %13 %14 @"sum_float_float_float";
155 tt.reduce.return %15;
156 };
157 tt.return %12;
158 };
159 tt.func @sym_name="softmax_kernel2_ptr<java.type.primitive<float>>_ptr<java.type.primitive<float>>_1_1_10_64_void" (%16 : ptr<java.type:"float">, %17 : ptr<java.type:"float">)java.type:"void" -> {
160 %18 : java.type:"int" = arith.constant @1;
161 %19 : java.type:"int" = arith.constant @1;
162 %20 : java.type:"int" = arith.constant @10;
163 %21 : java.type:"int" = tt.get_program_id @0;
164 %22 : java.type:"int" = arith.muli %21 %18;
165 %23 : ptr<java.type:"float"> = tt.addptr %17 %22;
166 %24 : tensor<x64, java.type:"int"> = tt.make_range @start=0 @end=64;
167 %25 : tensor<x64, ptr<java.type:"float">> = tt.splat %23;
168 %26 : tensor<x64, ptr<java.type:"float">> = tt.addptr %25 %24;
169 %27 : tensor<x64, java.type:"int"> = tt.splat %20;
170 %28 : tensor<x64, java.type:"boolean"> = arith.cmpi %24 %27 @"slt";
171 %29 : tensor<x64, java.type:"float"> = tt.load %26 %28;
172 %30 : java.type:"float" = tt.call %29 @"reduce_max_float_float_float_0";
173 %31 : tensor<x64, java.type:"float"> = tt.splat %30;
174 %32 : tensor<x64, java.type:"float"> = arith.subf %29 %31;
175 %33 : tensor<x64, java.type:"float"> = math.exp %32;
176 %34 : java.type:"float" = tt.call %33 @"reduce_sum_float_float_float_0";
177 %35 : tensor<x64, java.type:"float"> = tt.splat %34;
178 %36 : tensor<x64, java.type:"float"> = arith.divf %33 %35;
179 %37 : java.type:"int" = arith.muli %21 %19;
180 %38 : ptr<java.type:"float"> = tt.addptr %16 %37;
181 %39 : tensor<x64, ptr<java.type:"float">> = tt.splat %38;
182 %40 : tensor<x64, ptr<java.type:"float">> = tt.addptr %39 %24;
183 tt.store %40 %36 %28;
184 tt.return;
185 };
186 unreachable;
187 };
188 """)
189 @CodeReflection
190 static void softmax_kernel2(Ptr output_ptr,
191 Ptr input_ptr,
192 int input_row_stride,
193 int output_row_stride,
194 int n_cols,
195 @Constant int BLOCK_SIZE) {
196 // The rows of the softmax are independent, so we parallelize across those
197 var row_idx = Triton.programId(0);
198 var row_start_ptr = Triton.add(input_ptr, row_idx * input_row_stride);
199 // The block size is the next power of two greater than n_cols, so we can fit each
200 // row in a single block
201 var col_offsets = Triton.arange(0, BLOCK_SIZE);
202 var input_ptrs = Triton.add(row_start_ptr, col_offsets);
203 // Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols
204 var mask = Triton.compare(col_offsets, n_cols, Triton.CompareKind.LessThan);
205 var row = Triton.load(input_ptrs, mask);
206 // Subtract maximum for numerical stability
207 var row_minus_max = Triton.sub(row, Triton.max(row, 0));
208 // Note that exponentiation in Triton is fast but approximate (i.e., think __expf in CUDA)
209 var numerator = Triton.exp(row_minus_max);
210 var denominator = Triton.sum(numerator, 0);
211 var softmax_output = Triton.div(numerator, denominator);
212 // Write back output to DRAM
213 var output_row_start_ptr = Triton.add(output_ptr, row_idx * output_row_stride);
214 var output_ptrs = Triton.add(output_row_start_ptr, col_offsets);
215 Triton.store(output_ptrs, softmax_output, mask);
216 }
217
218 @Kernel("softmax_kernel2")
219 @Test
220 public void test2(TritonTestData t) {
221 List<TypeElement> argTypes = List.of(
222 new PtrType(JavaType.FLOAT),
223 new PtrType(JavaType.FLOAT),
224 new ConstantType(JavaType.INT, 1),
225 new ConstantType(JavaType.INT, 1),
226 new ConstantType(JavaType.INT, 10),
227 new ConstantType(JavaType.INT, 64));
228
229 t.test(argTypes);
230 }
231 }
232
233 /*
234 def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_cols, BLOCK_SIZE: tl.constexpr):
235 # The rows of the softmax are independent, so we parallelize across those
236 row_idx = tl.program_id(0)
237 # The stride represents how much we need to increase the pointer to advance 1 row
238 row_start_ptr = input_ptr + row_idx * input_row_stride
239 # The block size is the next power of two greater than n_cols, so we can fit each
240 # row in a single block
241 col_offsets = tl.arange(0, BLOCK_SIZE)
242 input_ptrs = row_start_ptr + col_offsets
243 # Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols
244 row = tl.load(input_ptrs, mask=col_offsets < n_cols, other=-float('inf'))
245 # Subtract maximum for numerical stability
246 row_minus_max = row - tl.max(row, axis=0)
247 # Note that exponentiation in Triton is fast but approximate (i.e., think __expf in CUDA)
248 numerator = tl.exp(row_minus_max)
249 denominator = tl.sum(numerator, axis=0)
250 softmax_output = numerator / denominator
251 # Write back output to DRAM
252 output_row_start_ptr = output_ptr + row_idx * output_row_stride
253 output_ptrs = output_row_start_ptr + col_offsets
254 tl.store(output_ptrs, softmax_output, mask=col_offsets < n_cols)
255 */
256
257 /*
258 input_row_stride = 1
259 output_row_stride = 1
260 n_cols=10
261 BLOCK_SIZE=64
262
263 module {
264 tt.func public @softmax_kernel_01(%arg0: !tt.ptr<f32, 1> , %arg1: !tt.ptr<f32, 1> ) attributes {noinline = false} {
265 %0 = tt.get_program_id x : i32
266 %c1_i32 = arith.constant 1 : i32
267 %1 = arith.muli %0, %c1_i32 : i32
268 %2 = tt.addptr %arg1, %1 : !tt.ptr<f32, 1>, i32
269 %3 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
270 %4 = tt.splat %2 : (!tt.ptr<f32, 1>) -> tensor<64x!tt.ptr<f32, 1>>
271 %5 = tt.addptr %4, %3 : tensor<64x!tt.ptr<f32, 1>>, tensor<64xi32>
272 %c10_i32 = arith.constant 10 : i32
273 %cst = arith.constant dense<10> : tensor<64xi32>
274 %6 = arith.cmpi slt, %3, %cst : tensor<64xi32>
275 %cst_0 = arith.constant 0xFF800000 : f32
276 %cst_1 = arith.constant dense<0xFF800000> : tensor<64xf32>
277 %7 = tt.load %5, %6, %cst_1 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64xf32>
278 %8 = tt.call @max__fp32S64S__1cconstexpr_0__2cconstexpr_False__3cconstexpr_True_(%7) : (tensor<64xf32>) -> f32
279 %9 = tt.splat %8 : (f32) -> tensor<64xf32>
280 %10 = arith.subf %7, %9 : tensor<64xf32>
281 %11 = math.exp %10 : tensor<64xf32>
282 %12 = tt.call @sum__fp32S64S__1cconstexpr_0_(%11) : (tensor<64xf32>) -> f32
283 %13 = tt.splat %12 : (f32) -> tensor<64xf32>
284 %14 = arith.divf %11, %13 : tensor<64xf32>
285 %c1_i32_2 = arith.constant 1 : i32
286 %15 = arith.muli %0, %c1_i32_2 : i32
287 %16 = tt.addptr %arg0, %15 : !tt.ptr<f32, 1>, i32
288 %17 = tt.splat %16 : (!tt.ptr<f32, 1>) -> tensor<64x!tt.ptr<f32, 1>>
289 %18 = tt.addptr %17, %3 : tensor<64x!tt.ptr<f32, 1>>, tensor<64xi32>
290 %c10_i32_3 = arith.constant 10 : i32
291 %cst_4 = arith.constant dense<10> : tensor<64xi32>
292 %19 = arith.cmpi slt, %3, %cst_4 : tensor<64xi32>
293 tt.store %18, %14, %19 {cache = 1 : i32, evict = 1 : i32} : tensor<64xf32>
294 tt.return
295 }
296 tt.func private @max__fp32S64S__1cconstexpr_0__2cconstexpr_False__3cconstexpr_True_(%arg0: tensor<64xf32> ) -> f32 attributes {noinline = false} {
297 %0 = "tt.reduce"(%arg0) <{axis = 0 : i32}> ({
298 ^bb0(%arg1: f32 , %arg2: f32 ):
299 %1 = tt.call @maximum__fp32_fp32__(%arg1, %arg2) : (f32, f32) -> f32
300 tt.reduce.return %1 : f32
301 }) : (tensor<64xf32>) -> f32
302 tt.return %0 : f32
303 }
304 tt.func private @maximum__fp32_fp32__(%arg0: f32 , %arg1: f32 ) -> f32 attributes {noinline = false} {
305 %0 = arith.maximumf %arg0, %arg1 : f32
306 tt.return %0 : f32
307 }
308 tt.func private @sum__fp32S64S__1cconstexpr_0_(%arg0: tensor<64xf32> ) -> f32 attributes {noinline = false} {
309 %0 = "tt.reduce"(%arg0) <{axis = 0 : i32}> ({
310 ^bb0(%arg1: f32 , %arg2: f32 ):
311 %1 = tt.call @_sum_combine__fp32_fp32__(%arg1, %arg2) : (f32, f32) -> f32
312 tt.reduce.return %1 : f32
313 }) : (tensor<64xf32>) -> f32
314 tt.return %0 : f32
315 }
316 tt.func private @_sum_combine__fp32_fp32__(%arg0: f32 , %arg1: f32 ) -> f32 attributes {noinline = false} {
317 %0 = arith.addf %arg0, %arg1 : f32
318 tt.return %0 : f32
319 }
320 }
321 */