1 /* 2 * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved. 3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. 4 * 5 * This code is free software; you can redistribute it and/or modify it 6 * under the terms of the GNU General Public License version 2 only, as 7 * published by the Free Software Foundation. 8 * 9 * This code is distributed in the hope that it will be useful, but WITHOUT 10 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or 11 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License 12 * version 2 for more details (a copy is included in the LICENSE file that 13 * accompanied this code). 14 * 15 * You should have received a copy of the GNU General Public License version 16 * 2 along with this work; if not, write to the Free Software Foundation, 17 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. 18 * 19 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA 20 * or visit www.oracle.com if you need additional information or have any 21 * questions. 22 */ 23 24 package oracle.code.triton; 25 26 import oracle.code.triton.TritonTestExtension.TritonTestData; 27 import org.junit.jupiter.api.Test; 28 import org.junit.jupiter.api.extension.ExtendWith; 29 30 import java.lang.reflect.code.TypeElement; 31 import java.lang.reflect.code.type.JavaType; 32 import java.lang.runtime.CodeReflection; 33 import java.util.List; 34 35 @ExtendWith(TritonTestExtension.class) 36 public class TestAddKernel { 37 38 @TritonCodeModel(""" 39 module ()void -> { 40 tt.func @"add_kernel_ptr<float>_ptr<float>_ptr<float>_int_64_void" (%0 : ptr<float>, %1 : ptr<float>, %2 : ptr<float>, %3 : int)void -> { 41 %4 : int = arith.constant @"64"; 42 %5 : int = tt.get_program_id @"0"; 43 %6 : int = arith.muli %5 %4; 44 %7 : tensor<x64, int> = tt.make_range @start="0" @end="64"; 45 %8 : tensor<x64, int> = tt.splat %6; 46 %9 : tensor<x64, int> = arith.addi %8 %7; 47 %10 : tensor<x64, int> = tt.splat %3; 48 %11 : tensor<x64, int> = arith.cmpi %9 %10 @"slt"; 49 %12 : tensor<x64, ptr<float>> = tt.splat %0; 50 %13 : tensor<x64, ptr<float>> = tt.addptr %12 %9; 51 %14 : tensor<x64, float> = tt.load %13 %11; 52 %15 : tensor<x64, ptr<float>> = tt.splat %1; 53 %16 : tensor<x64, ptr<float>> = tt.addptr %15 %9; 54 %17 : tensor<x64, float> = tt.load %16 %11; 55 %18 : tensor<x64, float> = arith.addf %14 %17; 56 %19 : tensor<x64, ptr<float>> = tt.splat %2; 57 %20 : tensor<x64, ptr<float>> = tt.addptr %19 %9; 58 tt.store %20 %18 %11; 59 tt.return; 60 }; 61 unreachable; 62 }; 63 """) 64 @CodeReflection 65 static void add_kernel(Ptr x_ptr, // *Pointer* to first input vector. 66 Ptr y_ptr, // *Pointer* to second input vector. 67 Ptr output_ptr, // *Pointer* to output vector. 68 int n_elements, // Size of the vector. 69 @Constant int BLOCK_SIZE) // Number of elements each program should process. 70 // NOTE: @Constant so it can be used as a shape value 71 { 72 // There are multiple 'programs' processing different data. We identify which program 73 // we are here: 74 var pid = Triton.programId(0); // We use a 1D launch grid so axis is 0. 75 // This program will process inputs that are offset from the initial data. 76 // For instance, if you had a vector of length 256 and block_size of 64, the programs 77 // would each access the elements [0:64, 64:128, 128:192, 192:256]. 78 // Note that offsets is a list of pointers: 79 var block_start = pid * BLOCK_SIZE; 80 var range = Triton.arange(0, BLOCK_SIZE); 81 var offsets = Triton.add(Triton.broadcast(block_start, range.type()), range); 82 // Create a mask to guard memory operations against out-of-bounds accesses. 83 var mask = Triton.compare(offsets, Triton.broadcast(n_elements, offsets.type()), Triton.CompareKind.LessThan); 84 // Load x and y from DRAM, masking out any extra elements in case the input is not a 85 // multiple of the block size. 86 var x = Triton.load(Triton.add(Triton.broadcast(x_ptr, offsets.type()), offsets), mask); 87 var y = Triton.load(Triton.add(Triton.broadcast(y_ptr, offsets.type()), offsets), mask); 88 var output = Triton.add(x, y); 89 // Write x + y back to DRAM. 90 Triton.store(Triton.add(Triton.broadcast(output_ptr, offsets.type()), offsets), output, mask); 91 } 92 93 @TritonTestExtension.Kernel("add_kernel") 94 @Test 95 public void test(TritonTestData t) { 96 List<TypeElement> argTypes = List.of( 97 new PtrType(JavaType.FLOAT), 98 new PtrType(JavaType.FLOAT), 99 new PtrType(JavaType.FLOAT), 100 JavaType.INT, 101 new ConstantType(JavaType.INT, 64)); 102 103 t.test(argTypes); 104 } 105 106 107 @TritonCodeModel(""" 108 module ()void -> { 109 tt.func @"add_kernel2_ptr<float>_ptr<float>_ptr<float>_int_64_void" (%0 : ptr<float>, %1 : ptr<float>, %2 : ptr<float>, %3 : int)void -> { 110 %4 : int = arith.constant @"64"; 111 %5 : int = tt.get_program_id @"0"; 112 %6 : int = arith.muli %5 %4; 113 %7 : tensor<x64, int> = tt.make_range @start="0" @end="64"; 114 %8 : tensor<x64, int> = tt.splat %6; 115 %9 : tensor<x64, int> = arith.addi %8 %7; 116 %10 : tensor<x64, int> = tt.splat %3; 117 %11 : tensor<x64, int> = arith.cmpi %9 %10 @"slt"; 118 %12 : tensor<x64, ptr<float>> = tt.splat %0; 119 %13 : tensor<x64, ptr<float>> = tt.addptr %12 %9; 120 %14 : tensor<x64, float> = tt.load %13 %11; 121 %15 : tensor<x64, ptr<float>> = tt.splat %1; 122 %16 : tensor<x64, ptr<float>> = tt.addptr %15 %9; 123 %17 : tensor<x64, float> = tt.load %16 %11; 124 %18 : tensor<x64, float> = arith.addf %14 %17; 125 %19 : tensor<x64, ptr<float>> = tt.splat %2; 126 %20 : tensor<x64, ptr<float>> = tt.addptr %19 %9; 127 tt.store %20 %18 %11; 128 tt.return; 129 }; 130 unreachable; 131 }; 132 """) 133 @CodeReflection 134 static void add_kernel2(Ptr x_ptr, // *Pointer* to first input vector. 135 Ptr y_ptr, // *Pointer* to second input vector. 136 Ptr output_ptr, // *Pointer* to output vector. 137 int n_elements, // Size of the vector. 138 @Constant int BLOCK_SIZE) // Number of elements each program should process. 139 // NOTE: @Constant so it can be used as a shape value 140 { 141 // There are multiple 'programs' processing different data. We identify which program 142 // we are here: 143 var pid = Triton.programId(0); // We use a 1D launch grid so axis is 0. 144 // This program will process inputs that are offset from the initial data. 145 // For instance, if you had a vector of length 256 and block_size of 64, the programs 146 // would each access the elements [0:64, 64:128, 128:192, 192:256]. 147 // Note that offsets is a list of pointers: 148 var block_start = pid * BLOCK_SIZE; 149 var range = Triton.arange(0, BLOCK_SIZE); 150 var offsets = Triton.add(block_start, range); 151 // Create a mask to guard memory operations against out-of-bounds accesses. 152 var mask = Triton.compare(offsets, n_elements, Triton.CompareKind.LessThan); 153 // Load x and y from DRAM, masking out any extra elements in case the input is not a 154 // multiple of the block size. 155 var x = Triton.load(Triton.add(x_ptr, offsets), mask); 156 var y = Triton.load(Triton.add(y_ptr, offsets), mask); 157 var output = Triton.add(x, y); 158 // Write x + y back to DRAM. 159 Triton.store(Triton.add(output_ptr, offsets), output, mask); 160 } 161 162 @TritonTestExtension.Kernel("add_kernel2") 163 @Test 164 public void test2(TritonTestData t) { 165 List<TypeElement> argTypes = List.of( 166 new PtrType(JavaType.FLOAT), 167 new PtrType(JavaType.FLOAT), 168 new PtrType(JavaType.FLOAT), 169 JavaType.INT, 170 new ConstantType(JavaType.INT, 64)); 171 172 t.test(argTypes); 173 } 174 } 175 176 /* 177 @triton.jit 178 def add_kernel(x_ptr, # *Pointer* to first input vector. 179 y_ptr, # *Pointer* to second input vector. 180 output_ptr, # *Pointer* to output vector. 181 n_elements, # Size of the vector. 182 BLOCK_SIZE: tl.constexpr, # Number of elements each program should process. 183 # NOTE: `constexpr` so it can be used as a shape value. 184 ): 185 # There are multiple 'programs' processing different data. We identify which program 186 # we are here: 187 pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0. 188 # This program will process inputs that are offset from the initial data. 189 # For instance, if you had a vector of length 256 and block_size of 64, the programs 190 # would each access the elements [0:64, 64:128, 128:192, 192:256]. 191 # Note that offsets is a list of pointers: 192 block_start = pid * BLOCK_SIZE 193 offsets = block_start + tl.arange(0, BLOCK_SIZE) 194 # Create a mask to guard memory operations against out-of-bounds accesses. 195 mask = offsets < n_elements 196 # Load x and y from DRAM, masking out any extra elements in case the input is not a 197 # multiple of the block size. 198 x = tl.load(x_ptr + offsets, mask=mask) 199 y = tl.load(y_ptr + offsets, mask=mask) 200 output = x + y 201 # Write x + y back to DRAM. 202 tl.store(output_ptr + offsets, output, mask=mask) 203 */ 204 205 /* 206 module { 207 tt.func public @add_kernel_0123(%arg0: !tt.ptr<f32, 1> , %arg1: !tt.ptr<f32, 1> , %arg2: !tt.ptr<f32, 1> , %arg3: i32 ) attributes {noinline = false} { 208 %0 = tt.get_program_id x : i32 209 %c64_i32 = arith.constant 64 : i32 210 %1 = arith.muli %0, %c64_i32 : i32 211 %2 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> 212 %3 = tt.splat %1 : (i32) -> tensor<64xi32> 213 %4 = arith.addi %3, %2 : tensor<64xi32> 214 %5 = tt.splat %arg3 : (i32) -> tensor<64xi32> 215 %6 = arith.cmpi slt, %4, %5 : tensor<64xi32> 216 %7 = tt.splat %arg0 : (!tt.ptr<f32, 1>) -> tensor<64x!tt.ptr<f32, 1>> 217 %8 = tt.addptr %7, %4 : tensor<64x!tt.ptr<f32, 1>>, tensor<64xi32> 218 %9 = tt.load %8, %6 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64xf32> 219 %10 = tt.splat %arg1 : (!tt.ptr<f32, 1>) -> tensor<64x!tt.ptr<f32, 1>> 220 %11 = tt.addptr %10, %4 : tensor<64x!tt.ptr<f32, 1>>, tensor<64xi32> 221 %12 = tt.load %11, %6 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64xf32> 222 %13 = arith.addf %9, %12 : tensor<64xf32> 223 %14 = tt.splat %arg2 : (!tt.ptr<f32, 1>) -> tensor<64x!tt.ptr<f32, 1>> 224 %15 = tt.addptr %14, %4 : tensor<64x!tt.ptr<f32, 1>>, tensor<64xi32> 225 tt.store %15, %13, %6 {cache = 1 : i32, evict = 1 : i32} : tensor<64xf32> 226 tt.return 227 } 228 } 229 */