1 /*
2 * Copyright (c) 2026, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation. Oracle designates this
8 * particular file as subject to the "Classpath" exception as provided
9 * by Oracle in the LICENSE file that accompanied this code.
10 *
11 * This code is distributed in the hope that it will be useful, but WITHOUT
12 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
13 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
14 * version 2 for more details (a copy is included in the LICENSE file that
15 * accompanied this code).
16 *
17 * You should have received a copy of the GNU General Public License version
18 * 2 along with this work; if not, write to the Free Software Foundation,
19 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
20 *
21 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
22 * or visit www.oracle.com if you need additional information or have any
23 * questions.
24 */
25 package hat.test;
26
27 import hat.Accelerator;
28 import hat.ComputeContext;
29 import hat.KernelContext;
30 import hat.NDRange.Tile2D;
31 import hat.backend.Backend;
32 import hat.buffer.F16Array;
33 import hat.buffer.F32Array;
34 import hat.buffer.F32ArrayPadded;
35 import hat.test.annotation.HatTest;
36 import hat.test.exceptions.HATAssertionError;
37 import hat.test.exceptions.HATAsserts;
38 import hat.test.exceptions.HATExpectedPrecisionError;
39 import hat.types.F16;
40 import hat.types.Tensor;
41 import jdk.incubator.code.Reflect;
42
43 import java.lang.invoke.MethodHandles;
44 import java.util.Random;
45
46 import static hat.NDRange.Global2D;
47 import static hat.NDRange.Local2D;
48 import static hat.NDRange.NDRange2D;
49 import static hat.NDRange.Warp2D;
50 import static optkl.ifacemapper.MappableIface.RO;
51 import static optkl.ifacemapper.MappableIface.WO;
52
53 /**
54 * Check tensor operations in HAT. How to run?
55 *
56 * <p>
57 * <code>
58 * HAT=SHOW_CODE java -cp hat/job.jar hat.java test ffi-cuda hat.test.TestTensors
59 * HAT=SHOW_CODE java -cp hat/job.jar hat.java test ffi-opencl hat.test.TestTensors
60 * </code>
61 * </p>
62 *
63 */
64 public class TestTensors {
65
66 @Reflect
67 public static void mxmTensorsColumnMajor(@RO KernelContext kc, @RO F16Array matrixA, @RO F16Array matrixB, @WO F32Array matrixC, int size) {
68 final int WMMA_M = 16;
69 final int WMMA_N = 16;
70 final int WMMA_K = 16;
71 int warpM = kc.gix / kc.wrs;
72 int warpN = kc.giy;
73
74 final int lda = 1024;
75 final int ldb = 1024;
76 final int ldc = 1024;
77
78 Tensor tensorA = Tensor.create(Tensor.FIRST, Tensor.shape(16, 16, 16), F16.class, Tensor.ofColumnMajor());
79 Tensor tensorB = Tensor.create(Tensor.SECOND, Tensor.shape(16, 16, 16), F16.class, Tensor.ofColumnMajor());
80 Tensor acc = Tensor.create(Tensor.ACC, Tensor.shape(16, 16, 16), float.class);
81
82 Tensor.fill(acc, 0.0f);
83
84 for (int i = 0; i < size; i += WMMA_K) {
85 int aRow = warpM * WMMA_M;
86 int aCol = i;
87
88 int bRow = i;
89 int bCol = warpN * WMMA_N;
90
91 if (aRow < lda && aCol < lda && bRow < ldb && bCol < ldb) {
92
93 tensorA = Tensor.load(matrixA, aRow, aCol, lda);
94 tensorB = Tensor.load(matrixB, bRow, bCol, ldb);
95
96 // acc = tensorA * tensorB + acc
97 Tensor.mma(acc, tensorA, tensorB, acc);
98 }
99 }
100 int cRow = warpM * WMMA_M;
101 int cCol = warpN * WMMA_N;
102 Tensor.store(matrixC, cRow, cCol, acc, ldc, Tensor.ofColumnMajor());
103 }
104
105 @Reflect
106 public static void mxmTensorsColumnMajor(@RO ComputeContext cc, @RO F16Array matrixA, @RO F16Array matrixB, @WO F32Array matrixC, int globalSize) {
107 // The total number of threads is calculated as follows:
108 // [ (size / tile), (size / tile) ]
109 // If warpSize > 1, then each dimension using warp operations is multiplied by the value of the warp-size. This is architecture dependent, but the
110 // HAT runtime and HAT JIT compiler handle this automatically.
111
112 var ndRange = NDRange2D.of(Global2D.of(globalSize, globalSize),
113 Local2D.of(128, 4),
114 Tile2D.of(16, 16),
115 Warp2D.of(true, false));
116
117 cc.dispatchKernel(ndRange, kc -> mxmTensorsColumnMajor(kc, matrixA, matrixB, matrixC, globalSize));
118 }
119
120 @Reflect
121 public static void mxmTensorsRowColumnMajor(@RO KernelContext kc, @RO F16Array matrixA, @RO F16Array matrixB, @WO F32Array matrixC, int size) {
122 final int WMMA_M = 16;
123 final int WMMA_N = 16;
124 final int WMMA_K = 16;
125 int warpM = kc.gix / kc.wrs;
126 int warpN = kc.giy;
127
128 final int lda = 1024;
129 final int ldb = 1024;
130 final int ldc = 1024;
131
132 Tensor tensorA = Tensor.create(Tensor.FIRST, Tensor.shape(16, 16, 16), F16.class, Tensor.ofRowMajor());
133 Tensor tensorB = Tensor.create(Tensor.SECOND, Tensor.shape(16, 16, 16), F16.class, Tensor.ofColumnMajor());
134 Tensor acc = Tensor.create(Tensor.ACC, Tensor.shape(16, 16, 16), float.class);
135
136 Tensor.fill(acc, 0.0f);
137
138 for (int i = 0; i < size; i += WMMA_K) {
139 int aRow = warpM * WMMA_M;
140 int aCol = i;
141
142 int bRow = i;
143 int bCol = warpN * WMMA_N;
144
145 if (aRow < lda && aCol < lda && bRow < ldb && bCol < ldb) {
146
147 tensorA = Tensor.load(matrixA, aRow, aCol, lda);
148 tensorB = Tensor.load(matrixB, bRow, bCol, ldb);
149
150 // acc = tensorA * tensorB + acc
151 Tensor.mma(acc, tensorA, tensorB, acc);
152 }
153 }
154 int cRow = warpM * WMMA_M;
155 int cCol = warpN * WMMA_N;
156 Tensor.store(matrixC, cRow, cCol, acc, ldc, Tensor.ofColumnMajor());
157 }
158
159 @Reflect
160 public static void mxmTensorsRowColumnMajor(@RO ComputeContext cc, @RO F16Array matrixA, @RO F16Array matrixB, @WO F32Array matrixC, int globalSize) {
161 // The total number of threads is calculated as follows:
162 // [ (size / tile), (size / tile) ]
163 // If warpSize > 1, then each dimension using warp operations is multiplied by the value of the warp-size. This is architecture dependent, but the
164 // HAT runtime and HAT JIT compiler handle this automatically.
165
166 var ndRange = NDRange2D.of(Global2D.of(globalSize, globalSize),
167 Local2D.of(128, 4),
168 Tile2D.of(16, 16),
169 Warp2D.of(true, false));
170
171 cc.dispatchKernel(ndRange, kc -> mxmTensorsRowColumnMajor(kc, matrixA, matrixB, matrixC, globalSize));
172 }
173
174 @Reflect
175 public static void mxmTensorsRowMajor(@RO KernelContext kc, @RO F16Array matrixA, @RO F16Array matrixB, @WO F32ArrayPadded matrixC, int size) {
176 final int WMMA_M = 16;
177 final int WMMA_N = 16;
178 final int WMMA_K = 16;
179 int warpM = kc.gix / kc.wrs;
180 int warpN = kc.giy;
181
182 final int lda = 1024;
183 final int ldb = 1024;
184 final int ldc = 1024;
185
186 Tensor tensorA = Tensor.create(Tensor.FIRST, Tensor.shape(16, 16, 16), F16.class, Tensor.ofRowMajor());
187 Tensor tensorB = Tensor.create(Tensor.SECOND, Tensor.shape(16, 16, 16), F16.class, Tensor.ofRowMajor());
188 Tensor acc = Tensor.create(Tensor.ACC, Tensor.shape(16, 16, 16), float.class);
189
190 Tensor.fill(acc, 0.0f);
191
192 for (int i = 0; i < size; i += WMMA_K) {
193 int aRow = warpM * WMMA_M;
194 int aCol = i;
195
196 int bRow = i;
197 int bCol = warpN * WMMA_N;
198
199 if (aRow < lda && aCol < lda && bRow < ldb && bCol < ldb) {
200
201 tensorA = Tensor.load(matrixA, aRow, aCol, lda);
202 tensorB = Tensor.load(matrixB, bRow, bCol, ldb);
203
204 // acc = tensorA * tensorB + acc
205 Tensor.mma(acc, tensorA, tensorB, acc);
206 }
207 }
208 int cRow = warpM * WMMA_M;
209 int cCol = warpN * WMMA_N;
210 Tensor.store(matrixC, cRow, cCol, acc, ldc, Tensor.ofRowMajor());
211 }
212
213 @Reflect
214 public static void mxmTensorsRowMajor(@RO ComputeContext cc, @RO F16Array matrixA, @RO F16Array matrixB, @WO F32ArrayPadded matrixC, int globalSize) {
215 var ndRange = NDRange2D.of(
216 Global2D.of(globalSize, globalSize),
217 Local2D.of(128, 4),
218 Tile2D.of(16, 16),
219 Warp2D.of(true, false));
220 cc.dispatchKernel(ndRange, kc -> mxmTensorsRowMajor(kc, matrixA, matrixB, matrixC, globalSize));
221 }
222
223 private static void runSequentialColMajor(F16Array matrixA, F16Array matrixB, F32Array matrixC, final int size) {
224 for (int i = 0; i < size; i++) {
225 for (int j = 0; j < size; j++) {
226 float sum = 0.0f;
227 for (int k = 0; k < size; k++) {
228 F16 a = matrixA.array((long) k * size + i);
229 F16 b = matrixB.array((long) j * size + k);
230 F16 mul = F16.mul(a, b);
231 sum += F16.f16ToFloat(mul);
232 }
233 matrixC.array((long) j * size + i, sum);
234 }
235 }
236 }
237
238 private static void runSequentialRowAndColMajor(F16Array matrixA, F16Array matrixB, F32Array matrixC, final int size) {
239 for (int i = 0; i < size; i++) {
240 for (int j = 0; j < size; j++) {
241 float sum = 0.0f;
242 for (int k = 0; k < size; k++) {
243 F16 a = matrixA.array((long) i * size + k);
244 F16 b = matrixB.array((long) j * size + k);
245 F16 mul = F16.mul(a, b);
246 sum += F16.f16ToFloat(mul);
247 }
248 matrixC.array((long) j * size + i, sum);
249 }
250 }
251 }
252
253 private static void runSequentialRowMajor(F16Array matrixA, F16Array matrixB, F32Array matrixC, final int size) {
254 for (int i = 0; i < size; i++) {
255 for (int j = 0; j < size; j++) {
256 float sum = 0.0f;
257 for (int k = 0; k < size; k++) {
258 F16 a = matrixA.array((long) i * size + k);
259 F16 b = matrixB.array((long) k * size + j);
260 F16 mul = F16.mul(a, b);
261 sum += F16.f16ToFloat(mul);
262 }
263 matrixC.array((long) i * size + j, sum);
264 }
265 }
266 }
267
268 @HatTest
269 @Reflect
270 public void testTensor01() {
271 var accelerator = new Accelerator(MethodHandles.lookup(), Backend.FIRST);
272 final int size = 1024;
273
274 F16Array matrixAHalf = F16Array.create(accelerator, size * size);
275 F16Array matrixBHalf = F16Array.create(accelerator, size * size);
276 F32Array matrixC = F32Array.create(accelerator, size * size);
277 F32Array resultSequential = F32Array.create(accelerator, size * size);
278
279 Random r = new Random(19);
280 for (int j = 0; j < matrixAHalf.length(); j++) {
281 matrixAHalf.array(j).value(F16.floatToF16(r.nextFloat()).value());
282 matrixBHalf.array(j).value(F16.floatToF16(r.nextFloat()).value());
283 }
284
285 for (int i = 0; i < 10; i++) {
286 accelerator.compute(cc -> mxmTensorsColumnMajor(cc, matrixAHalf, matrixBHalf, matrixC, size));
287 }
288
289 runSequentialColMajor(matrixAHalf, matrixBHalf, resultSequential, size);
290
291 for (int i = 0; i < size; i++) {
292 for (int j = 0; j < size; j++) {
293 final int index = j * size + i;
294 float expectedValue = resultSequential.array(index);
295 float gotValue = matrixC.array(index);
296 try {
297 HATAsserts.assertEquals(expectedValue, gotValue, 0.1f);
298 } catch (HATAssertionError e) {
299 throw new HATExpectedPrecisionError("Expected: " + expectedValue + " but got " + gotValue);
300 }
301 }
302 }
303 }
304
305 @HatTest
306 @Reflect
307 public void testTensor02() {
308 var accelerator = new Accelerator(MethodHandles.lookup(), Backend.FIRST);
309 final int size = 1024;
310
311 F16Array matrixAHalf = F16Array.create(accelerator, size * size);
312 F16Array matrixBHalf = F16Array.create(accelerator, size * size);
313 F32Array matrixC = F32Array.create(accelerator, size * size);
314 F32Array resultSequential = F32Array.create(accelerator, size * size);
315
316 Random r = new Random(19);
317 for (int j = 0; j < matrixAHalf.length(); j++) {
318 matrixAHalf.array(j).value(F16.floatToF16(r.nextFloat()).value());
319 matrixBHalf.array(j).value(F16.floatToF16(r.nextFloat()).value());
320 }
321
322 for (int i = 0; i < 10; i++) {
323 accelerator.compute(cc -> mxmTensorsRowColumnMajor(cc, matrixAHalf, matrixBHalf, matrixC, size));
324 }
325
326 runSequentialRowAndColMajor(matrixAHalf, matrixBHalf, resultSequential, size);
327
328 for (int i = 0; i < size; i++) {
329 for (int j = 0; j < size; j++) {
330 final int index = j * size + i;
331 float expectedValue = resultSequential.array(index);
332 float gotValue = matrixC.array(index);
333 try {
334 HATAsserts.assertEquals(expectedValue, gotValue, 0.1f);
335 } catch (HATAssertionError e) {
336 throw new HATExpectedPrecisionError("Expected: " + expectedValue + " but got " + gotValue);
337 }
338 }
339 }
340 }
341
342 @HatTest
343 @Reflect
344 public void testTensor03() {
345
346 // To be able to run tensor-matmul in a row-major layout, we need to add padding.
347 // Thus, the result matrix must be of type F32ArrayPadded.
348
349 var accelerator = new Accelerator(MethodHandles.lookup(), Backend.FIRST);
350 final int size = 1024;
351
352 F16Array matrixAHalf = F16Array.create(accelerator, size * size);
353 F16Array matrixBHalf = F16Array.create(accelerator, size * size);
354 F32ArrayPadded matrixC = F32ArrayPadded.create(accelerator, size * size);
355 F32Array resultSequential = F32Array.create(accelerator, size * size);
356
357 Random r = new Random(19);
358 for (int j = 0; j < matrixAHalf.length(); j++) {
359 matrixAHalf.array(j).value(F16.floatToF16(r.nextFloat()).value());
360 matrixBHalf.array(j).value(F16.floatToF16(r.nextFloat()).value());
361 }
362
363 for (int i = 0; i < 10; i++) {
364 accelerator.compute(cc -> mxmTensorsRowMajor(cc, matrixAHalf, matrixBHalf, matrixC, size));
365 }
366
367 runSequentialRowMajor(matrixAHalf, matrixBHalf, resultSequential, size);
368
369 for (int i = 0; i < size; i++) {
370 for (int j = 0; j < size; j++) {
371 final int index = j * size + i;
372 float expectedValue = resultSequential.array(index);
373 float gotValue = matrixC.array(index);
374 try {
375 HATAsserts.assertEquals(expectedValue, gotValue, 0.1f);
376 } catch (HATAssertionError e) {
377 throw new HATExpectedPrecisionError("Expected: " + expectedValue + " but got " + gotValue);
378 }
379 }
380 }
381 }
382 }