1 /*
2 * Copyright (c) 2026, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation. Oracle designates this
8 * particular file as subject to the "Classpath" exception as provided
9 * by Oracle in the LICENSE file that accompanied this code.
10 *
11 * This code is distributed in the hope that it will be useful, but WITHOUT
12 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
13 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
14 * version 2 for more details (a copy is included in the LICENSE file that
15 * accompanied this code).
16 *
17 * You should have received a copy of the GNU General Public License version
18 * 2 along with this work; if not, write to the Free Software Foundation,
19 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
20 *
21 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
22 * or visit www.oracle.com if you need additional information or have any
23 * questions.
24 */
25 package hat.test;
26
27 import hat.Accelerator;
28 import hat.ComputeContext;
29 import hat.KernelContext;
30 import hat.NDRange.Tile2D;
31 import hat.backend.Backend;
32 import hat.buffer.F16Array;
33 import hat.buffer.F32Array;
34 import hat.buffer.F32ArrayPadded;
35 import hat.test.annotation.HatTest;
36 import hat.test.exceptions.HATAssertionError;
37 import hat.test.exceptions.HATAsserts;
38 import hat.test.exceptions.HATExpectedPrecisionError;
39 import hat.types.F16;
40 import hat.types.Tensor;
41 import jdk.incubator.code.Reflect;
42
43 import java.lang.invoke.MethodHandles;
44 import java.util.Random;
45
46 import static hat.NDRange.Global2D;
47 import static hat.NDRange.Local2D;
48 import static hat.NDRange.NDRange2D;
49 import static hat.NDRange.Warp2D;
50 import static optkl.ifacemapper.MappableIface.RO;
51 import static optkl.ifacemapper.MappableIface.WO;
52
53 /**
54 * Test tensor operations in HAT.
55 *
56 * <p>How to run?</p>
57 * <p>For the CUDA backend:
58 * <code>
59 * HAT=SHOW_CODE java -cp hat/job.jar hat.java test ffi-cuda hat.test.TestTensors
60 * </code>
61 * </p>
62 *
63 * <p>For the OpenCL backend:
64 * <code>
65 * HAT=SHOW_CODE java -cp hat/job.jar hat.java test ffi-opencl hat.test.TestTensors
66 * </code>
67 * </p>
68 *
69 */
70 public class TestTensors {
71
72 @Reflect
73 public static void mxmTensorsColumnMajor(@RO KernelContext kc, @RO F16Array matrixA, @RO F16Array matrixB, @WO F32Array matrixC, int size) {
74 final int SHAPE = 16;
75 final int WMMA_M = SHAPE;
76 final int WMMA_N = SHAPE;
77 final int WMMA_K = SHAPE;
78 int warpM = kc.gix / kc.wrs;
79 int warpN = kc.giy;
80
81 final int lda = 1024;
82 final int ldb = 1024;
83 final int ldc = 1024;
84
85 var shape = Tensor.shape(WMMA_M, WMMA_N, WMMA_K);
86
87 // Initialize a tensor accumulator with zeros
88 Tensor acc = Tensor.zeros(shape, float.class);
89
90 for (int i = 0; i < size; i += WMMA_K) {
91 int aRow = warpM * WMMA_M;
92 int aCol = i;
93 int bRow = i;
94 int bCol = warpN * WMMA_N;
95
96 if (aRow < lda && aCol < lda && bRow < ldb && bCol < ldb) {
97 // Load data from matrix A with the specified shape using column-major into a tensor of FP16
98 Tensor tensorA = Tensor.loadF16(matrixA, aRow, aCol, lda, shape, Tensor.ofColumnMajor());
99
100 // Load data from matrix B with the specified shape using column-major into a tensor of FP16
101 Tensor tensorB = Tensor.loadF16(matrixB, bRow, bCol, ldb, shape, Tensor.ofColumnMajor());
102
103 // Perform the MMA operation:
104 // acc = tensorA * tensorB + acc
105 acc = Tensor.mma(tensorA, tensorB, acc);
106 }
107 }
108 int cRow = warpM * WMMA_M;
109 int cCol = warpN * WMMA_N;
110
111 // Store the resulting tensor into main memory using column-major layout.
112 if (cRow < size && cCol < size) {
113 // We operate with square matrices
114 Tensor.store(matrixC, cRow, cCol, acc, ldc, Tensor.ofColumnMajor());
115 }
116 }
117
118 @Reflect
119 public static void mxmTensorsColumnMajor(@RO ComputeContext cc, @RO F16Array matrixA, @RO F16Array matrixB, @WO F32Array matrixC, int globalSize) {
120 // The total number of threads is calculated as follows:
121 // [ (size / tile), (size / tile) ]
122 // If warpSize > 1, then each dimension using warp operations is multiplied by the value of the warp-size. This is architecture dependent, but the
123 // HAT runtime and HAT JIT compiler handle this automatically.
124
125 var ndRange = NDRange2D.of(Global2D.of(globalSize, globalSize),
126 Local2D.of(128, 4),
127 Tile2D.of(16, 16),
128 Warp2D.of(true, false));
129
130 cc.dispatchKernel(ndRange, kc -> mxmTensorsColumnMajor(kc, matrixA, matrixB, matrixC, globalSize));
131 }
132
133 @Reflect
134 public static void mxmTensorsRowColumnMajor(@RO KernelContext kc, @RO F16Array matrixA, @RO F16Array matrixB, @WO F32Array matrixC, int size) {
135
136 final int WMMA_M = 16;
137 final int WMMA_N = 16;
138 final int WMMA_K = 16;
139 int warpM = kc.gix / kc.wrs;
140 int warpN = kc.giy;
141
142 final int lda = 1024;
143 final int ldb = 1024;
144 final int ldc = 1024;
145
146 // We keep explicit constant in this version to check shape with ConstantOp
147 Tensor acc = Tensor.create(Tensor.shape(16, 16, 16), float.class);
148
149 Tensor.fill(acc, 0.0f);
150
151 for (int i = 0; i < size; i += WMMA_K) {
152 int aRow = warpM * WMMA_M;
153 int aCol = i;
154 int bRow = i;
155 int bCol = warpN * WMMA_N;
156 if (aRow < lda && aCol < lda && bRow < ldb && bCol < ldb) {
157 Tensor tensorA = Tensor.loadF16(matrixA, aRow, aCol, lda, Tensor.shape(16, 16, 16), Tensor.ofRowMajor());
158 Tensor tensorB = Tensor.loadF16(matrixB, bRow, bCol, ldb, Tensor.shape(16, 16, 16),Tensor.ofColumnMajor());
159 acc = Tensor.mma(tensorA, tensorB, acc);
160 }
161 }
162 int cRow = warpM * WMMA_M;
163 int cCol = warpN * WMMA_N;
164 if (cRow < size && cCol < size) {
165 // We operate with square matrices
166 Tensor.store(matrixC, cRow, cCol, acc, ldc, Tensor.ofColumnMajor());
167 }
168 }
169
170 @Reflect
171 public static void mxmTensorsRowColumnMajor(@RO ComputeContext cc, @RO F16Array matrixA, @RO F16Array matrixB, @WO F32Array matrixC, int globalSize) {
172 // The total number of threads is calculated as follows:
173 // [ (size / tile), (size / tile) ]
174 // If warpSize > 1, then each dimension using warp operations is multiplied by the value of the warp-size. This is architecture dependent, but the
175 // HAT runtime and HAT JIT compiler handle this automatically.
176
177 var ndRange = NDRange2D.of(Global2D.of(globalSize, globalSize),
178 Local2D.of(128, 4),
179 Tile2D.of(16, 16),
180 Warp2D.of(true, false));
181
182 cc.dispatchKernel(ndRange, kc -> mxmTensorsRowColumnMajor(kc, matrixA, matrixB, matrixC, globalSize));
183 }
184
185 @Reflect
186 public static void mxmTensorsRowMajor(@RO KernelContext kc, @RO F16Array matrixA, @RO F16Array matrixB, @WO F32ArrayPadded matrixC, int size) {
187 final int WMMA_M = 16;
188 final int WMMA_N = 16;
189 final int WMMA_K = 16;
190 int warpM = kc.gix / kc.wrs;
191 int warpN = kc.giy;
192
193 final int lda = 1024;
194 final int ldb = 1024;
195 final int ldc = 1024;
196
197 Tensor acc = Tensor.create(Tensor.shape(16, 16, 16), float.class);
198
199 Tensor.fill(acc, 0.0f);
200
201 for (int i = 0; i < size; i += WMMA_K) {
202 int aRow = warpM * WMMA_M;
203 int aCol = i;
204
205 int bRow = i;
206 int bCol = warpN * WMMA_N;
207
208 if (aRow < lda && aCol < lda && bRow < ldb && bCol < ldb) {
209
210 Tensor tensorA = Tensor.loadF16(matrixA, aRow, aCol, lda, Tensor.shape(16, 16, 16), Tensor.ofRowMajor());
211 Tensor tensorB = Tensor.loadF16(matrixB, bRow, bCol, ldb, Tensor.shape(16, 16, 16), Tensor.ofRowMajor());
212
213 // acc = tensorA * tensorB + acc
214 acc = Tensor.mma(tensorA, tensorB, acc);
215 }
216 }
217 int cRow = warpM * WMMA_M;
218 int cCol = warpN * WMMA_N;
219 if (cRow < size && cCol < size) {
220 Tensor.store(matrixC, cRow, cCol, acc, ldc, Tensor.ofRowMajor());
221 }
222 }
223
224 @Reflect
225 public static void mxmTensorsRowMajor(@RO ComputeContext cc, @RO F16Array matrixA, @RO F16Array matrixB, @WO F32ArrayPadded matrixC, int globalSize) {
226 var ndRange = NDRange2D.of(
227 Global2D.of(globalSize, globalSize),
228 Local2D.of(128, 4),
229 Tile2D.of(16, 16),
230 Warp2D.of(true, false));
231 cc.dispatchKernel(ndRange, kc -> mxmTensorsRowMajor(kc, matrixA, matrixB, matrixC, globalSize));
232 }
233
234 @Reflect
235 public static void mxmTensorsDefaultAccess(@RO KernelContext kc, @RO F16Array matrixA, @RO F16Array matrixB, @WO F32ArrayPadded matrixC, int size) {
236 final int sizeShape = 16;
237 final int WMMA_M = sizeShape;
238 final int WMMA_N = sizeShape;
239 final int WMMA_K = sizeShape;
240 int warpM = kc.gix / kc.wrs;
241 int warpN = kc.giy;
242
243 final int lda = 1024;
244 final int ldb = 1024;
245 final int ldc = 1024;
246
247 var shape = Tensor.shape(sizeShape, sizeShape, sizeShape);
248 Tensor acc = Tensor.zeros(shape, float.class);
249 for (int i = 0; i < size; i += WMMA_K) {
250 int aRow = warpM * WMMA_M;
251 int bCol = warpN * WMMA_N;
252 if (aRow < lda && i < lda && i < ldb && bCol < ldb) {
253 Tensor tensorA = Tensor.loadF16(matrixA, aRow, i, lda, shape);
254 Tensor tensorB = Tensor.loadF16(matrixB, i, bCol, ldb, shape);
255 acc = Tensor.mma(tensorA, tensorB, acc);
256 }
257 }
258 int cRow = warpM * WMMA_M;
259 int cCol = warpN * WMMA_N;
260 if (cRow < size && cCol < size) {
261 Tensor.store(matrixC, cRow, cCol, acc, ldc);
262 }
263 }
264
265 @Reflect
266 public static void mxmTensorsDefaultAccess(@RO ComputeContext cc, @RO F16Array matrixA, @RO F16Array matrixB, @WO F32ArrayPadded matrixC, int globalSize) {
267 var ndRange = NDRange2D.of(Global2D.of(globalSize, globalSize),
268 Local2D.of(128, 4),
269 Tile2D.of(16, 16),
270 Warp2D.of(true, false));
271 cc.dispatchKernel(ndRange, kc -> mxmTensorsDefaultAccess(kc, matrixA, matrixB, matrixC, globalSize));
272 }
273
274 private static void runSequentialColMajor(F16Array matrixA, F16Array matrixB, F32Array matrixC, final int size) {
275 for (int i = 0; i < size; i++) {
276 for (int j = 0; j < size; j++) {
277 float sum = 0.0f;
278 for (int k = 0; k < size; k++) {
279 F16 a = matrixA.array((long) k * size + i);
280 F16 b = matrixB.array((long) j * size + k);
281 F16 mul = F16.mul(a, b);
282 sum += F16.f16ToFloat(mul);
283 }
284 matrixC.array((long) j * size + i, sum);
285 }
286 }
287 }
288
289 private static void runSequentialRowAndColMajor(F16Array matrixA, F16Array matrixB, F32Array matrixC, final int size) {
290 for (int i = 0; i < size; i++) {
291 for (int j = 0; j < size; j++) {
292 float sum = 0.0f;
293 for (int k = 0; k < size; k++) {
294 F16 a = matrixA.array((long) i * size + k);
295 F16 b = matrixB.array((long) j * size + k);
296 F16 mul = F16.mul(a, b);
297 sum += F16.f16ToFloat(mul);
298 }
299 matrixC.array((long) j * size + i, sum);
300 }
301 }
302 }
303
304 private static void runSequentialRowMajor(F16Array matrixA, F16Array matrixB, F32Array matrixC, final int size) {
305 for (int i = 0; i < size; i++) {
306 for (int j = 0; j < size; j++) {
307 float sum = 0.0f;
308 for (int k = 0; k < size; k++) {
309 F16 a = matrixA.array((long) i * size + k);
310 F16 b = matrixB.array((long) k * size + j);
311 F16 mul = F16.mul(a, b);
312 sum += F16.f16ToFloat(mul);
313 }
314 matrixC.array((long) i * size + j, sum);
315 }
316 }
317 }
318
319 @HatTest
320 @Reflect
321 public void testTensor01() {
322 var accelerator = new Accelerator(MethodHandles.lookup(), Backend.FIRST);
323 final int size = 1024;
324
325 F16Array matrixAHalf = F16Array.create(accelerator, size * size);
326 F16Array matrixBHalf = F16Array.create(accelerator, size * size);
327 F32Array matrixC = F32Array.create(accelerator, size * size);
328 F32Array resultSequential = F32Array.create(accelerator, size * size);
329
330 Random r = new Random(19);
331 for (int j = 0; j < matrixAHalf.length(); j++) {
332 matrixAHalf.array(j).value(F16.floatToF16(r.nextFloat()).value());
333 matrixBHalf.array(j).value(F16.floatToF16(r.nextFloat()).value());
334 }
335
336 // Run multiple time
337 for (int i = 0; i < 10; i++) {
338 accelerator.compute(cc -> mxmTensorsColumnMajor(cc, matrixAHalf, matrixBHalf, matrixC, size));
339 }
340
341 runSequentialColMajor(matrixAHalf, matrixBHalf, resultSequential, size);
342
343 for (int i = 0; i < size; i++) {
344 for (int j = 0; j < size; j++) {
345 final int index = j * size + i;
346 float expectedValue = resultSequential.array(index);
347 float gotValue = matrixC.array(index);
348 try {
349 HATAsserts.assertEquals(expectedValue, gotValue, 0.1f);
350 } catch (HATAssertionError e) {
351 throw new HATExpectedPrecisionError("Expected: " + expectedValue + " but got " + gotValue);
352 }
353 }
354 }
355 }
356
357 @HatTest
358 @Reflect
359 public void testTensor02() {
360 var accelerator = new Accelerator(MethodHandles.lookup(), Backend.FIRST);
361 final int size = 1024;
362
363 F16Array matrixAHalf = F16Array.create(accelerator, size * size);
364 F16Array matrixBHalf = F16Array.create(accelerator, size * size);
365 F32Array matrixC = F32Array.create(accelerator, size * size);
366 F32Array resultSequential = F32Array.create(accelerator, size * size);
367
368 Random r = new Random(19);
369 for (int j = 0; j < matrixAHalf.length(); j++) {
370 matrixAHalf.array(j).value(F16.floatToF16(r.nextFloat()).value());
371 matrixBHalf.array(j).value(F16.floatToF16(r.nextFloat()).value());
372 }
373
374 accelerator.compute(cc -> mxmTensorsRowColumnMajor(cc, matrixAHalf, matrixBHalf, matrixC, size));
375
376 runSequentialRowAndColMajor(matrixAHalf, matrixBHalf, resultSequential, size);
377
378 for (int i = 0; i < size; i++) {
379 for (int j = 0; j < size; j++) {
380 final int index = j * size + i;
381 float expectedValue = resultSequential.array(index);
382 float gotValue = matrixC.array(index);
383 try {
384 HATAsserts.assertEquals(expectedValue, gotValue, 0.1f);
385 } catch (HATAssertionError e) {
386 throw new HATExpectedPrecisionError("Expected: " + expectedValue + " but got " + gotValue);
387 }
388 }
389 }
390 }
391
392 @HatTest
393 @Reflect
394 public void testTensor03() {
395
396 // To be able to run tensor-matmul in a row-major layout, we need to add padding.
397 // Thus, the result matrix must be of type F32ArrayPadded.
398
399 var accelerator = new Accelerator(MethodHandles.lookup(), Backend.FIRST);
400 final int size = 1024;
401
402 F16Array matrixAHalf = F16Array.create(accelerator, size * size);
403 F16Array matrixBHalf = F16Array.create(accelerator, size * size);
404 F32ArrayPadded matrixC = F32ArrayPadded.create(accelerator, size * size);
405 F32Array resultSequential = F32Array.create(accelerator, size * size);
406
407 Random r = new Random(19);
408 for (int j = 0; j < matrixAHalf.length(); j++) {
409 matrixAHalf.array(j).value(F16.floatToF16(r.nextFloat()).value());
410 matrixBHalf.array(j).value(F16.floatToF16(r.nextFloat()).value());
411 }
412 accelerator.compute(cc -> mxmTensorsRowMajor(cc, matrixAHalf, matrixBHalf, matrixC, size));
413 runSequentialRowMajor(matrixAHalf, matrixBHalf, resultSequential, size);
414
415 for (int i = 0; i < size; i++) {
416 for (int j = 0; j < size; j++) {
417 final int index = j * size + i;
418 float expectedValue = resultSequential.array(index);
419 float gotValue = matrixC.array(index);
420 try {
421 HATAsserts.assertEquals(expectedValue, gotValue, 0.1f);
422 } catch (HATAssertionError e) {
423 throw new HATExpectedPrecisionError("Expected: " + expectedValue + " but got " + gotValue);
424 }
425 }
426 }
427 }
428
429 @HatTest
430 @Reflect
431 public void testTensor04() {
432
433 // To be able to run tensor-matmul in a row-major layout, we need to add padding.
434 // Thus, the result matrix must be of type F32ArrayPadded.
435
436 var accelerator = new Accelerator(MethodHandles.lookup(), Backend.FIRST);
437 final int size = 1024;
438
439 F16Array matrixAHalf = F16Array.create(accelerator, size * size);
440 F16Array matrixBHalf = F16Array.create(accelerator, size * size);
441 F32ArrayPadded matrixC = F32ArrayPadded.create(accelerator, size * size);
442 F32Array resultSequential = F32Array.create(accelerator, size * size);
443
444 Random r = new Random(19);
445 for (int j = 0; j < matrixAHalf.length(); j++) {
446 matrixAHalf.array(j).value(F16.floatToF16(r.nextFloat()).value());
447 matrixBHalf.array(j).value(F16.floatToF16(r.nextFloat()).value());
448 }
449
450 accelerator.compute(cc -> mxmTensorsDefaultAccess(cc, matrixAHalf, matrixBHalf, matrixC, size));
451
452 runSequentialRowMajor(matrixAHalf, matrixBHalf, resultSequential, size);
453
454 for (int i = 0; i < size; i++) {
455 for (int j = 0; j < size; j++) {
456 final int index = j * size + i;
457 float expectedValue = resultSequential.array(index);
458 float gotValue = matrixC.array(index);
459 try {
460 HATAsserts.assertEquals(expectedValue, gotValue, 0.1f);
461 } catch (HATAssertionError e) {
462 throw new HATExpectedPrecisionError("Expected: " + expectedValue + " but got " + gotValue);
463 }
464 }
465 }
466 }
467 }