1 /*
2 * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation.
8 *
9 * This code is distributed in the hope that it will be useful, but WITHOUT
10 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
11 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
12 * version 2 for more details (a copy is included in the LICENSE file that
13 * accompanied this code).
14 *
15 * You should have received a copy of the GNU General Public License version
16 * 2 along with this work; if not, write to the Free Software Foundation,
17 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
18 *
19 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
20 * or visit www.oracle.com if you need additional information or have any
21 * questions.
22 */
23
24 package oracle.code.triton;
25
26 import org.junit.jupiter.api.Test;
27 import org.junit.jupiter.api.extension.ExtendWith;
28
29 import jdk.incubator.code.TypeElement;
30 import jdk.incubator.code.dialect.java.JavaType;
31 import jdk.incubator.code.CodeReflection;
32 import java.util.List;
33
34 import static oracle.code.triton.Triton.*;
35 import static oracle.code.triton.TritonTest.consume;
36
37 @ExtendWith(TritonTestExtension.class)
38 public class TestZeros {
39
40 @TritonCodeModel("""
41 module ()java.type:"void" -> {
42 tt.func @"test1_32_64_void" ()java.type:"void" -> {
43 %0 : tensor<x32, x64, java.type:"float"> = arith.constant @0.0f;
44 tt.consume %0;
45 tt.return;
46 };
47 unreachable;
48 };
49 """)
50 @CodeReflection
51 static void test1(@Constant int M, @Constant int N) {
52 var t = zeros(float.class, M, N);
53 consume(t);
54 }
55
56 @Test
57 public void test1(TritonTestExtension.TritonTestData t) {
58 List<TypeElement> argTypes = List.of(
59 new ConstantType(JavaType.INT, 32),
60 new ConstantType(JavaType.INT, 64));
61
62 t.test(argTypes);
63 }
64
65 }