1 /* 2 * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved. 3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. 4 * 5 * This code is free software; you can redistribute it and/or modify it 6 * under the terms of the GNU General Public License version 2 only, as 7 * published by the Free Software Foundation. 8 * 9 * This code is distributed in the hope that it will be useful, but WITHOUT 10 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or 11 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License 12 * version 2 for more details (a copy is included in the LICENSE file that 13 * accompanied this code). 14 * 15 * You should have received a copy of the GNU General Public License version 16 * 2 along with this work; if not, write to the Free Software Foundation, 17 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. 18 * 19 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA 20 * or visit www.oracle.com if you need additional information or have any 21 * questions. 22 */ 23 24 package oracle.code.triton; 25 26 import oracle.code.triton.TritonTestExtension.TritonTestData; 27 import org.junit.jupiter.api.Test; 28 import org.junit.jupiter.api.extension.ExtendWith; 29 30 import java.lang.reflect.code.TypeElement; 31 import java.lang.reflect.code.type.JavaType; 32 import java.lang.runtime.CodeReflection; 33 import java.util.List; 34 35 import static oracle.code.triton.Triton.*; 36 import static oracle.code.triton.TritonTest.consume; 37 38 @ExtendWith(TritonTestExtension.class) 39 public class TestBroadcast { 40 41 @TritonCodeModel(""" 42 module ()void -> { 43 tt.func @"test1_ptr<int>_int_64_void" (%0 : ptr<int>, %1 : int)void -> { 44 %2 : tensor<x64, int> = tt.make_range @start="0" @end="64"; 45 %3 : tensor<x64, ptr<int>> = tt.splat %0; 46 %4 : tensor<x64, ptr<int>> = tt.addptr %3 %2; 47 tt.consume %4; 48 %5 : tensor<x64, int> = tt.splat %1; 49 %6 : tensor<x64, int> = arith.addi %5 %2; 50 tt.consume %6; 51 %7 : tensor<x64, int> = tt.splat %1; 52 %8 : tensor<x64, int> = arith.addi %2 %7; 53 tt.consume %8; 54 %9 : tensor<x64, int> = tt.splat %1; 55 %10 : tensor<x64, int> = arith.addi %9 %2; 56 tt.consume %10; 57 %11 : tensor<x64, int> = tt.splat %1; 58 %12 : tensor<x64, int> = arith.addi %2 %11; 59 tt.consume %12; 60 tt.return; 61 }; 62 unreachable; 63 }; 64 """) 65 @CodeReflection 66 static void test1(Ptr ptr, int a, @Constant int s) { 67 var t = arange(0, s); 68 consume(add(ptr, t)); 69 consume(add(a, t)); 70 consume(add(t, a)); 71 consume(add(broadcast(a, t.type()), t)); 72 consume(add(t, broadcast(a, t.type()))); 73 } 74 75 @Test 76 public void test1(TritonTestData t) { 77 List<TypeElement> argTypes = List.of( 78 new PtrType(JavaType.INT), 79 JavaType.INT, 80 new ConstantType(JavaType.INT, 64)); 81 82 t.test(argTypes); 83 } 84 85 @TritonCodeModel(""" 86 module ()void -> { 87 tt.func @"test2_int_64_32_void" (%1 : int)void -> { 88 %2 : tensor<x64, int> = tt.make_range @start="0" @end="64"; 89 %3 : tensor<x1, x64, int> = tt.expand_dims %2 @"0"; 90 %4 : tensor<x32, int> = tt.make_range @start="0" @end="32"; 91 %5 : tensor<x32, x1, int> = tt.expand_dims %4 @"1"; 92 %6 : tensor<x1, x64, int> = tt.splat %1; 93 %7 : tensor<x1, x64, int> = arith.addi %3 %6; 94 tt.consume %7; 95 %8 : tensor<x32, x64, int> = tt.broadcast %3; 96 %9 : tensor<x32, x64, int> = tt.broadcast %5; 97 %10 : tensor<x32, x64, int> = arith.addi %8 %9; 98 tt.consume %10; 99 tt.return; 100 }; 101 unreachable; 102 }; 103 """) 104 @CodeReflection 105 static void test2(int a, @Constant int M, @Constant int N) { 106 var m = arange(0, M); 107 var me = expand(m, 0); 108 109 var n = arange(0, N); 110 var ne = expand(n, 1); 111 112 var t4 = add(me, a); 113 consume(t4); 114 115 var t3 = add(me, ne); 116 consume(t3); 117 } 118 119 @Test 120 public void test2(TritonTestData t) { 121 List<TypeElement> argTypes = List.of( 122 JavaType.INT, 123 new ConstantType(JavaType.INT, 64), 124 new ConstantType(JavaType.INT, 32) 125 ); 126 127 t.test(argTypes); 128 } 129 }