1 /*
2 * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation. Oracle designates this
8 * particular file as subject to the "Classpath" exception as provided
9 * by Oracle in the LICENSE file that accompanied this code.
10 *
11 * This code is distributed in the hope that it will be useful, but WITHOUT
12 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
13 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
14 * version 2 for more details (a copy is included in the LICENSE file that
15 * accompanied this code).
16 *
17 * You should have received a copy of the GNU General Public License version
18 * 2 along with this work; if not, write to the Free Software Foundation,
19 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
20 *
21 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
22 * or visit www.oracle.com if you need additional information or have any
23 * questions.
24 */
25 package experiments;
26
27 import hat.Accelerator;
28 import hat.Accelerator.Compute;
29 import hat.ComputeContext;
30 import hat.NDRange;
31 import hat.KernelContext;
32 import hat.backend.Backend;
33 import hat.buffer.S32Array;
34 import optkl.ifacemapper.MappableIface.RO;
35 import optkl.ifacemapper.MappableIface.RW;
36 import jdk.incubator.code.Reflect;
37
38 import java.lang.invoke.MethodHandles;
39 import java.util.stream.IntStream;
40
41 /**
42 * How to test?
43 * <code>
44 * HAT=SHOW_CODE java -cp job.jar hat.java exp ffi-opencl LocalIds
45 * </code>
46 */
47 public class LocalIds {
48
49 private static boolean PRINT_RESULTS = false;
50
51 @Reflect
52 private static void assign(@RO KernelContext context, @RW S32Array arrayA, @RW S32Array arrayB, @RW S32Array arrayC) {
53 int gx = context.gix;
54 int lx = context.lix;
55 int lsx = context.lsx;
56 int bix = context.bix;
57 arrayA.array(gx, lx);
58 arrayB.array(gx, lsx);
59 arrayC.array(gx, bix);
60 }
61
62 private static final int BLOCK_SIZE = 16;
63
64 @Reflect
65 private static void mySimpleCompute(@RO ComputeContext cc, @RW S32Array arrayA, @RW S32Array arrayB, @RW S32Array arrayC) {
66 cc.dispatchKernel(NDRange.of1D(32,BLOCK_SIZE), kc -> assign(kc, arrayA, arrayB, arrayC));
67 }
68
69 public static void main(String[] args) {
70 System.out.println("Experiment: local IDs and local groups");
71
72 Accelerator accelerator = new Accelerator(MethodHandles.lookup(), Backend.FIRST);
73 final int size = 32;
74 S32Array arrayA = S32Array.create(accelerator, size);
75 S32Array arrayB = S32Array.create(accelerator, size);
76 S32Array arrayC = S32Array.create(accelerator, size);
77
78 // Set initial value to 0
79 arrayA.fill(i -> 0);
80 arrayB.fill(i -> 0);
81 arrayC.fill(i -> 0);
82
83 // Compute on the accelerator
84 accelerator.compute((@Reflect Compute)
85 cc -> LocalIds.mySimpleCompute(cc, arrayA, arrayB, arrayC));
86
87 int[] expectedIds = new int[size];
88 int j = 0;
89 for (int i = 0; i < size; i++) {
90 expectedIds[i] = j++;
91 if (j == BLOCK_SIZE) {
92 j = 0;
93 }
94 }
95
96 System.out.println("Execution finished");
97
98 if (PRINT_RESULTS) {
99 System.out.println("Result Locals: ");
100 for (int i = 0; i < arrayA.length(); i++) {
101 System.out.println(arrayA.array(i));
102 }
103 System.out.println("Result Blocks: ");
104 for (int i = 0; i < arrayB.length(); i++) {
105 System.out.println(arrayB.array(i));
106 }
107 System.out.println("Result Block ID: ");
108 for (int i = 0; i < arrayC.length(); i++) {
109 System.out.println(arrayC.array(i));
110 }
111 }
112
113 boolean correct = true;
114 for (int i = 0; i < arrayA.length(); i++) {
115 if (expectedIds[i] != arrayA.array(i)) {
116 System.out.println("Mismatch local ids");
117 correct = false;
118 }
119 }
120 if (correct) {
121 System.out.println("Local IDs are correct");
122 }
123
124
125 correct = true;
126 for (int i = 0; i < arrayB.length(); i++) {
127 if (BLOCK_SIZE != arrayB.array(i)) {
128 System.out.println("Mismatch group Sizes");
129 correct = false;
130 }
131 }
132 if (correct) {
133 System.out.println("Group Size are correct");
134 }
135
136 IntStream.range(0, size).forEach(i -> {
137 int v = i < BLOCK_SIZE ? 0 : 1;
138 expectedIds[i] = v;
139 });
140 for (int i = 0; i < arrayC.length(); i++) {
141 if (expectedIds[i] != arrayC.array(i)) {
142 System.out.println("Mismatch group IDs");
143 correct = false;
144 }
145 }
146 if (correct) {
147 System.out.println("Group IDs are correct");
148 }
149 }
150
151 }