1 /*
  2  * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.
  8  *
  9  * This code is distributed in the hope that it will be useful, but WITHOUT
 10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 12  * version 2 for more details (a copy is included in the LICENSE file that
 13  * accompanied this code).
 14  *
 15  * You should have received a copy of the GNU General Public License version
 16  * 2 along with this work; if not, write to the Free Software Foundation,
 17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 18  *
 19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 20  * or visit www.oracle.com if you need additional information or have any
 21  * questions.
 22  */
 23 
 24 package oracle.code.triton;
 25 
 26 import oracle.code.triton.TritonTestExtension.TritonTestData;
 27 import org.junit.jupiter.api.Test;
 28 import org.junit.jupiter.api.extension.ExtendWith;
 29 
 30 import java.lang.reflect.code.TypeElement;
 31 import java.lang.reflect.code.type.JavaType;
 32 import java.lang.runtime.CodeReflection;
 33 import java.util.List;
 34 
 35 import static oracle.code.triton.Triton.*;
 36 import static oracle.code.triton.TritonTest.consume;
 37 
 38 @ExtendWith(TritonTestExtension.class)
 39 public class TestBroadcast {
 40 
 41     @TritonCodeModel("""
 42             module ()void -> {
 43                 tt.func @"test1_ptr<int>_int_64_void" (%0 : ptr<int>, %1 : int)void -> {
 44                     %2 : tensor<x64, int> = tt.make_range @start="0" @end="64";
 45                     %3 : tensor<x64, ptr<int>> = tt.splat %0;
 46                     %4 : tensor<x64, ptr<int>> = tt.addptr %3 %2;
 47                     tt.consume %4;
 48                     %5 : tensor<x64, int> = tt.splat %1;
 49                     %6 : tensor<x64, int> = arith.addi %5 %2;
 50                     tt.consume %6;
 51                     %7 : tensor<x64, int> = tt.splat %1;
 52                     %8 : tensor<x64, int> = arith.addi %2 %7;
 53                     tt.consume %8;
 54                     %9 : tensor<x64, int> = tt.splat %1;
 55                     %10 : tensor<x64, int> = arith.addi %9 %2;
 56                     tt.consume %10;
 57                     %11 : tensor<x64, int> = tt.splat %1;
 58                     %12 : tensor<x64, int> = arith.addi %2 %11;
 59                     tt.consume %12;
 60                     tt.return;
 61                 };
 62                 unreachable;
 63             };
 64             """)
 65     @CodeReflection
 66     static void test1(Ptr ptr, int a, @Constant int s) {
 67         var t = arange(0, s);
 68         consume(add(ptr, t));
 69         consume(add(a, t));
 70         consume(add(t, a));
 71         consume(add(broadcast(a, t.type()), t));
 72         consume(add(t, broadcast(a, t.type())));
 73     }
 74 
 75     @Test
 76     public void test1(TritonTestData t) {
 77         List<TypeElement> argTypes = List.of(
 78                 new PtrType(JavaType.INT),
 79                 JavaType.INT,
 80                 new ConstantType(JavaType.INT, 64));
 81 
 82         t.test(argTypes);
 83     }
 84 
 85     @TritonCodeModel("""
 86             module ()void -> {
 87                 tt.func @"test2_int_64_32_void" (%1 : int)void -> {
 88                     %2 : tensor<x64, int> = tt.make_range @start="0" @end="64";
 89                     %3 : tensor<x1, x64, int> = tt.expand_dims %2 @"0";
 90                     %4 : tensor<x32, int> = tt.make_range @start="0" @end="32";
 91                     %5 : tensor<x32, x1, int> = tt.expand_dims %4 @"1";
 92                     %6 : tensor<x1, x64, int> = tt.splat %1;
 93                     %7 : tensor<x1, x64, int> = arith.addi %3 %6;
 94                     tt.consume %7;
 95                     %8 : tensor<x32, x64, int> = tt.broadcast %3;
 96                     %9 : tensor<x32, x64, int> = tt.broadcast %5;
 97                     %10 : tensor<x32, x64, int> = arith.addi %8 %9;
 98                     tt.consume %10;
 99                     tt.return;
100                 };
101                 unreachable;
102             };
103             """)
104     @CodeReflection
105     static void test2(int a, @Constant int M, @Constant int N) {
106         var m = arange(0, M);
107         var me = expand(m, 0);
108 
109         var n = arange(0, N);
110         var ne = expand(n, 1);
111 
112         var t4 = add(me, a);
113         consume(t4);
114 
115         var t3 = add(me, ne);
116         consume(t3);
117     }
118 
119     @Test
120     public void test2(TritonTestData t) {
121         List<TypeElement> argTypes = List.of(
122                 JavaType.INT,
123                 new ConstantType(JavaType.INT, 64),
124                 new ConstantType(JavaType.INT, 32)
125         );
126 
127         t.test(argTypes);
128     }
129 }