1 /*
2 * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation.
8 *
9 * This code is distributed in the hope that it will be useful, but WITHOUT
10 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
11 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
12 * version 2 for more details (a copy is included in the LICENSE file that
13 * accompanied this code).
14 *
15 * You should have received a copy of the GNU General Public License version
16 * 2 along with this work; if not, write to the Free Software Foundation,
17 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
18 *
19 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
20 * or visit www.oracle.com if you need additional information or have any
21 * questions.
22 */
23
24 package oracle.code.triton;
25
26 import oracle.code.triton.TritonTestExtension.TritonTestData;
27 import org.junit.jupiter.api.Test;
28 import org.junit.jupiter.api.extension.ExtendWith;
29
30 import jdk.incubator.code.TypeElement;
31 import jdk.incubator.code.dialect.java.JavaType;
32 import jdk.incubator.code.CodeReflection;
33 import java.util.List;
34
35 import static oracle.code.triton.Triton.*;
36 import static oracle.code.triton.TritonTest.consume;
37
38 @ExtendWith(TritonTestExtension.class)
39 public class TestCountedLoop {
40
41 @TritonCodeModel(value = """
42 module ()java.type:"void" -> {
43 tt.func @"test1_int_64_void" (%0 : java.type:"int")java.type:"void" -> {
44 %1 : java.type:"int" = arith.constant @64;
45 %2 : tensor<x64, java.type:"int"> = tt.make_range @start=0 @end=64;
46 %3 : tensor<x64, java.type:"int"> = tt.make_range @start=0 @end=64;
47 %4 : java.type:"int" = arith.constant @0;
48 %5 : java.type:"int" = arith.constant @1;
49 %6 : Tuple<tensor<x64, java.type:"int">, tensor<x64, java.type:"int">> = scf.for %4 %0 %5 %2 %3 (%7 : java.type:"int", %8 : tensor<x64, java.type:"int">, %9 : tensor<x64, java.type:"int">)Tuple<tensor<x64, java.type:"int">, tensor<x64, java.type:"int">> -> {
50 %10 : tensor<x64, java.type:"int"> = tt.splat %7;
51 %11 : tensor<x64, java.type:"int"> = arith.addi %8 %10;
52 %12 : tensor<x64, java.type:"int"> = tt.splat %1;
53 %13 : tensor<x64, java.type:"int"> = arith.addi %9 %12;
54 scf.yield %11 %13;
55 };
56 %14 : tensor<x64, java.type:"int"> = tuple.load %6 @0;
57 %15 : tensor<x64, java.type:"int"> = tuple.load %6 @1;
58 tt.consume %14;
59 tt.consume %15;
60 tt.return;
61 };
62 unreachable;
63 };
64 """)
65 @CodeReflection
66 static void test1(int n, @Constant int s) {
67 var a = arange(0, s);
68 var b = arange(0, s);
69 for (int i = 0; i < n; i++) {
70 a = Triton.add(a, i);
71 b = Triton.add(b, s);
72 }
73 consume(a);
74 consume(b);
75 }
76
77 @Test
78 public void test1(TritonTestData t) {
79 List<TypeElement> argTypes = List.of(
80 JavaType.INT,
81 new ConstantType(JavaType.INT, 64));
82
83 t.test(argTypes);
84 }
85 }