1 /*
2 * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation. Oracle designates this
8 * particular file as subject to the "Classpath" exception as provided
9 * by Oracle in the LICENSE file that accompanied this code.
10 *
11 * This code is distributed in the hope that it will be useful, but WITHOUT
12 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
13 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
14 * version 2 for more details (a copy is included in the LICENSE file that
15 * accompanied this code).
16 *
17 * You should have received a copy of the GNU General Public License version
18 * 2 along with this work; if not, write to the Free Software Foundation,
19 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
20 *
21 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
22 * or visit www.oracle.com if you need additional information or have any
23 * questions.
24 */
25 package hat.test;
26
27 import hat.Accelerator;
28 import hat.ComputeContext;
29 import hat.NDRange;
30 import hat.KernelContext;
31 import hat.backend.Backend;
32 import hat.buffer.F32Array;
33 import hat.device.DeviceSchema;
34 import hat.device.NonMappableIface;
35 import optkl.ifacemapper.MappableIface;
36 import jdk.incubator.code.Reflect;
37 import hat.test.annotation.HatTest;
38 import hat.test.exceptions.HATAsserts;
39
40 import java.lang.invoke.MethodHandles;
41
42 public class TestLocal {
43
44 private interface MySharedArray extends NonMappableIface {
45 void array(long index, float value);
46 float array(long index);
47
48 DeviceSchema<MySharedArray> schema = DeviceSchema.of(MySharedArray.class,
49 builder -> builder.withArray("array", 16));
50
51 static MySharedArray create(Accelerator accelerator) {
52 return null;
53 }
54
55 static MySharedArray createLocal() {
56 return create(new Accelerator(MethodHandles.lookup(), Backend.FIRST));
57 }
58 }
59
60 @Reflect
61 private static void compute(@MappableIface.RO KernelContext kernelContext, @MappableIface.RW F32Array data) {
62 MySharedArray mySharedArray = MySharedArray.createLocal();
63 int lix = kernelContext.lix;
64 int blockId = kernelContext.bix;
65 int blockSize = kernelContext.lsx;
66 mySharedArray.array(lix, lix);
67 kernelContext.barrier();
68 data.array(lix + (long) blockId * blockSize, mySharedArray.array(lix));
69 }
70
71 @Reflect
72 private static void myCompute(@MappableIface.RO ComputeContext computeContext, @MappableIface.RW F32Array data) {
73 computeContext.dispatchKernel(NDRange.of1D(32,16),
74 kernelContext -> compute(kernelContext, data)
75 );
76 }
77
78 @HatTest
79 @Reflect
80 public void testLocal() {
81 Accelerator accelerator = new Accelerator(MethodHandles.lookup(), Backend.FIRST);
82 F32Array data = F32Array.create(accelerator, 32);
83 accelerator.compute(computeContext -> {
84 TestLocal.myCompute(computeContext, data);
85 });
86
87 // Check result
88 boolean isCorrect = true;
89 int jIndex = 0;
90 for (int i = 0; i < data.length(); i++) {
91 if (data.array(i) != jIndex) {
92 isCorrect = false;
93 break;
94 }
95 jIndex++;
96 if (jIndex == 16) {
97 jIndex = 0;
98 }
99 }
100 HATAsserts.assertTrue(isCorrect);
101 }
102
103 }