1 /*
 2  * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
 3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
 4  *
 5  * This code is free software; you can redistribute it and/or modify it
 6  * under the terms of the GNU General Public License version 2 only, as
 7  * published by the Free Software Foundation.
 8  *
 9  * This code is distributed in the hope that it will be useful, but WITHOUT
10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
12  * version 2 for more details (a copy is included in the LICENSE file that
13  * accompanied this code).
14  *
15  * You should have received a copy of the GNU General Public License version
16  * 2 along with this work; if not, write to the Free Software Foundation,
17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
18  *
19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
20  * or visit www.oracle.com if you need additional information or have any
21  * questions.
22  */
23 
24 package oracle.code.triton;
25 
26 import oracle.code.triton.TritonTestExtension.TritonTestData;
27 import org.junit.jupiter.api.Test;
28 import org.junit.jupiter.api.extension.ExtendWith;
29 
30 import java.lang.reflect.code.TypeElement;
31 import java.lang.reflect.code.type.JavaType;
32 import java.lang.runtime.CodeReflection;
33 import java.util.List;
34 
35 import static oracle.code.triton.Triton.*;
36 import static oracle.code.triton.TritonTest.consume;
37 
38 @ExtendWith(TritonTestExtension.class)
39 public class TestCountedLoop {
40 
41     @TritonCodeModel(value = """
42             module ()void -> {
43                 tt.func @"test1_int_64_void" (%0 : int)void -> {
44                     %1 : int = arith.constant @"64";
45                     %2 : tensor<x64, int> = tt.make_range @start="0" @end="64";
46                     %3 : tensor<x64, int> = tt.make_range @start="0" @end="64";
47                     %4 : int = arith.constant @"0";
48                     %5 : int = arith.constant @"1";
49                     %6 : Tuple<tensor<x64, int>, tensor<x64,int>> = scf.for %4 %0 %5 %2 %3 (%7 : int, %8 : tensor<x64, int>, %9 : tensor<x64, int>)Tuple<tensor<x64, int>, tensor<x64, int>> -> {
50                         %10 : tensor<x64, int> = tt.splat %7;
51                         %11 : tensor<x64, int> = arith.addi %8 %10;
52                         %12 : tensor<x64, int> = tt.splat %1;
53                         %13 : tensor<x64, int> = arith.addi %9 %12;
54                         scf.yield %11 %13;
55                     };
56                     %14 : tensor<x64, int> = tuple.load %6 @"0";
57                     %15 : tensor<x64, int> = tuple.load %6 @"1";
58                     tt.consume %14;
59                     tt.consume %15;
60                     tt.return;
61                 };
62                 unreachable;
63             };
64             """)
65     @CodeReflection
66     static void test1(int n, @Constant int s) {
67         var a = arange(0, s);
68         var b = arange(0, s);
69         for (int i = 0; i < n; i++) {
70             a = Triton.add(a, i);
71             b = Triton.add(b, s);
72         }
73         consume(a);
74         consume(b);
75     }
76 
77     @Test
78     public void test1(TritonTestData t) {
79         List<TypeElement> argTypes = List.of(
80                 JavaType.INT,
81                 new ConstantType(JavaType.INT, 64));
82 
83         t.test(argTypes);
84     }
85 }