1 /*
  2  * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.
  8  *
  9  * This code is distributed in the hope that it will be useful, but WITHOUT
 10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 12  * version 2 for more details (a copy is included in the LICENSE file that
 13  * accompanied this code).
 14  *
 15  * You should have received a copy of the GNU General Public License version
 16  * 2 along with this work; if not, write to the Free Software Foundation,
 17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 18  *
 19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 20  * or visit www.oracle.com if you need additional information or have any
 21  * questions.
 22  */
 23 
 24 package oracle.code.triton;
 25 
 26 import oracle.code.triton.TritonTestExtension.TritonTestData;
 27 import org.junit.jupiter.api.Test;
 28 import org.junit.jupiter.api.extension.ExtendWith;
 29 
 30 import java.lang.reflect.code.TypeElement;
 31 import java.lang.reflect.code.type.JavaType;
 32 import java.lang.runtime.CodeReflection;
 33 import java.util.List;
 34 
 35 @ExtendWith(TritonTestExtension.class)
 36 public class TestAddKernel {
 37 
 38     @TritonCodeModel("""
 39             module ()void -> {
 40                 tt.func @"add_kernel_ptr<float>_ptr<float>_ptr<float>_int_64_void" (%0 : ptr<float>, %1 : ptr<float>, %2 : ptr<float>, %3 : int)void -> {
 41                     %4 : int = arith.constant @"64";
 42                     %5 : int = tt.get_program_id @"0";
 43                     %6 : int = arith.muli %5 %4;
 44                     %7 : tensor<x64, int> = tt.make_range @start="0" @end="64";
 45                     %8 : tensor<x64, int> = tt.splat %6;
 46                     %9 : tensor<x64, int> = arith.addi %8 %7;
 47                     %10 : tensor<x64, int> = tt.splat %3;
 48                     %11 : tensor<x64, int> = arith.cmpi %9 %10 @"slt";
 49                     %12 : tensor<x64, ptr<float>> = tt.splat %0;
 50                     %13 : tensor<x64, ptr<float>> = tt.addptr %12 %9;
 51                     %14 : tensor<x64, float> = tt.load %13 %11;
 52                     %15 : tensor<x64, ptr<float>> = tt.splat %1;
 53                     %16 : tensor<x64, ptr<float>> = tt.addptr %15 %9;
 54                     %17 : tensor<x64, float> = tt.load %16 %11;
 55                     %18 : tensor<x64, float> = arith.addf %14 %17;
 56                     %19 : tensor<x64, ptr<float>> = tt.splat %2;
 57                     %20 : tensor<x64, ptr<float>> = tt.addptr %19 %9;
 58                     tt.store %20 %18 %11;
 59                     tt.return;
 60                 };
 61                 unreachable;
 62             };
 63             """)
 64     @CodeReflection
 65     static void add_kernel(Ptr x_ptr,  // *Pointer* to first input vector.
 66                            Ptr y_ptr,  // *Pointer* to second input vector.
 67                            Ptr output_ptr,  // *Pointer* to output vector.
 68                            int n_elements,  // Size of the vector.
 69                            @Constant int BLOCK_SIZE)  // Number of elements each program should process.
 70     // NOTE: @Constant so it can be used as a shape value
 71     {
 72         // There are multiple 'programs' processing different data. We identify which program
 73         // we are here:
 74         var pid = Triton.programId(0); // We use a 1D launch grid so axis is 0.
 75         // This program will process inputs that are offset from the initial data.
 76         // For instance, if you had a vector of length 256 and block_size of 64, the programs
 77         // would each access the elements [0:64, 64:128, 128:192, 192:256].
 78         // Note that offsets is a list of pointers:
 79         var block_start = pid * BLOCK_SIZE;
 80         var range = Triton.arange(0, BLOCK_SIZE);
 81         var offsets = Triton.add(Triton.broadcast(block_start, range.type()), range);
 82         // Create a mask to guard memory operations against out-of-bounds accesses.
 83         var mask = Triton.compare(offsets, Triton.broadcast(n_elements, offsets.type()), Triton.CompareKind.LessThan);
 84         // Load x and y from DRAM, masking out any extra elements in case the input is not a
 85         // multiple of the block size.
 86         var x = Triton.load(Triton.add(Triton.broadcast(x_ptr, offsets.type()), offsets), mask);
 87         var y = Triton.load(Triton.add(Triton.broadcast(y_ptr, offsets.type()), offsets), mask);
 88         var output = Triton.add(x, y);
 89         // Write x + y back to DRAM.
 90         Triton.store(Triton.add(Triton.broadcast(output_ptr, offsets.type()), offsets), output, mask);
 91     }
 92 
 93     @TritonTestExtension.Kernel("add_kernel")
 94     @Test
 95     public void test(TritonTestData t) {
 96         List<TypeElement> argTypes = List.of(
 97                 new PtrType(JavaType.FLOAT),
 98                 new PtrType(JavaType.FLOAT),
 99                 new PtrType(JavaType.FLOAT),
100                 JavaType.INT,
101                 new ConstantType(JavaType.INT, 64));
102 
103         t.test(argTypes);
104     }
105 
106 
107     @TritonCodeModel("""
108             module ()void -> {
109                 tt.func @"add_kernel2_ptr<float>_ptr<float>_ptr<float>_int_64_void" (%0 : ptr<float>, %1 : ptr<float>, %2 : ptr<float>, %3 : int)void -> {
110                     %4 : int = arith.constant @"64";
111                     %5 : int = tt.get_program_id @"0";
112                     %6 : int = arith.muli %5 %4;
113                     %7 : tensor<x64, int> = tt.make_range @start="0" @end="64";
114                     %8 : tensor<x64, int> = tt.splat %6;
115                     %9 : tensor<x64, int> = arith.addi %8 %7;
116                     %10 : tensor<x64, int> = tt.splat %3;
117                     %11 : tensor<x64, int> = arith.cmpi %9 %10 @"slt";
118                     %12 : tensor<x64, ptr<float>> = tt.splat %0;
119                     %13 : tensor<x64, ptr<float>> = tt.addptr %12 %9;
120                     %14 : tensor<x64, float> = tt.load %13 %11;
121                     %15 : tensor<x64, ptr<float>> = tt.splat %1;
122                     %16 : tensor<x64, ptr<float>> = tt.addptr %15 %9;
123                     %17 : tensor<x64, float> = tt.load %16 %11;
124                     %18 : tensor<x64, float> = arith.addf %14 %17;
125                     %19 : tensor<x64, ptr<float>> = tt.splat %2;
126                     %20 : tensor<x64, ptr<float>> = tt.addptr %19 %9;
127                     tt.store %20 %18 %11;
128                     tt.return;
129                 };
130                 unreachable;
131             };
132             """)
133     @CodeReflection
134     static void add_kernel2(Ptr x_ptr,  // *Pointer* to first input vector.
135                             Ptr y_ptr,  // *Pointer* to second input vector.
136                             Ptr output_ptr,  // *Pointer* to output vector.
137                             int n_elements,  // Size of the vector.
138                             @Constant int BLOCK_SIZE)  // Number of elements each program should process.
139     // NOTE: @Constant so it can be used as a shape value
140     {
141         // There are multiple 'programs' processing different data. We identify which program
142         // we are here:
143         var pid = Triton.programId(0); // We use a 1D launch grid so axis is 0.
144         // This program will process inputs that are offset from the initial data.
145         // For instance, if you had a vector of length 256 and block_size of 64, the programs
146         // would each access the elements [0:64, 64:128, 128:192, 192:256].
147         // Note that offsets is a list of pointers:
148         var block_start = pid * BLOCK_SIZE;
149         var range = Triton.arange(0, BLOCK_SIZE);
150         var offsets = Triton.add(block_start, range);
151         // Create a mask to guard memory operations against out-of-bounds accesses.
152         var mask = Triton.compare(offsets, n_elements, Triton.CompareKind.LessThan);
153         // Load x and y from DRAM, masking out any extra elements in case the input is not a
154         // multiple of the block size.
155         var x = Triton.load(Triton.add(x_ptr, offsets), mask);
156         var y = Triton.load(Triton.add(y_ptr, offsets), mask);
157         var output = Triton.add(x, y);
158         // Write x + y back to DRAM.
159         Triton.store(Triton.add(output_ptr, offsets), output, mask);
160     }
161 
162     @TritonTestExtension.Kernel("add_kernel2")
163     @Test
164     public void test2(TritonTestData t) {
165         List<TypeElement> argTypes = List.of(
166                 new PtrType(JavaType.FLOAT),
167                 new PtrType(JavaType.FLOAT),
168                 new PtrType(JavaType.FLOAT),
169                 JavaType.INT,
170                 new ConstantType(JavaType.INT, 64));
171 
172         t.test(argTypes);
173     }
174 }
175 
176 /*
177 @triton.jit
178 def add_kernel(x_ptr,  # *Pointer* to first input vector.
179                y_ptr,  # *Pointer* to second input vector.
180                output_ptr,  # *Pointer* to output vector.
181                n_elements,  # Size of the vector.
182                BLOCK_SIZE: tl.constexpr,  # Number of elements each program should process.
183                # NOTE: `constexpr` so it can be used as a shape value.
184                ):
185     # There are multiple 'programs' processing different data. We identify which program
186     # we are here:
187     pid = tl.program_id(axis=0)  # We use a 1D launch grid so axis is 0.
188     # This program will process inputs that are offset from the initial data.
189     # For instance, if you had a vector of length 256 and block_size of 64, the programs
190     # would each access the elements [0:64, 64:128, 128:192, 192:256].
191     # Note that offsets is a list of pointers:
192     block_start = pid * BLOCK_SIZE
193     offsets = block_start + tl.arange(0, BLOCK_SIZE)
194     # Create a mask to guard memory operations against out-of-bounds accesses.
195     mask = offsets < n_elements
196     # Load x and y from DRAM, masking out any extra elements in case the input is not a
197     # multiple of the block size.
198     x = tl.load(x_ptr + offsets, mask=mask)
199     y = tl.load(y_ptr + offsets, mask=mask)
200     output = x + y
201     # Write x + y back to DRAM.
202     tl.store(output_ptr + offsets, output, mask=mask)
203 */
204 
205 /*
206 module {
207   tt.func public @add_kernel_0123(%arg0: !tt.ptr<f32, 1> , %arg1: !tt.ptr<f32, 1> , %arg2: !tt.ptr<f32, 1> , %arg3: i32 ) attributes {noinline = false} {
208     %0 = tt.get_program_id x : i32
209     %c64_i32 = arith.constant 64 : i32
210     %1 = arith.muli %0, %c64_i32 : i32
211     %2 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
212     %3 = tt.splat %1 : (i32) -> tensor<64xi32>
213     %4 = arith.addi %3, %2 : tensor<64xi32>
214     %5 = tt.splat %arg3 : (i32) -> tensor<64xi32>
215     %6 = arith.cmpi slt, %4, %5 : tensor<64xi32>
216     %7 = tt.splat %arg0 : (!tt.ptr<f32, 1>) -> tensor<64x!tt.ptr<f32, 1>>
217     %8 = tt.addptr %7, %4 : tensor<64x!tt.ptr<f32, 1>>, tensor<64xi32>
218     %9 = tt.load %8, %6 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64xf32>
219     %10 = tt.splat %arg1 : (!tt.ptr<f32, 1>) -> tensor<64x!tt.ptr<f32, 1>>
220     %11 = tt.addptr %10, %4 : tensor<64x!tt.ptr<f32, 1>>, tensor<64xi32>
221     %12 = tt.load %11, %6 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64xf32>
222     %13 = arith.addf %9, %12 : tensor<64xf32>
223     %14 = tt.splat %arg2 : (!tt.ptr<f32, 1>) -> tensor<64x!tt.ptr<f32, 1>>
224     %15 = tt.addptr %14, %4 : tensor<64x!tt.ptr<f32, 1>>, tensor<64xi32>
225     tt.store %15, %13, %6 {cache = 1 : i32, evict = 1 : i32} : tensor<64xf32>
226     tt.return
227   }
228 }
229 */