1 /*
  2  * Copyright (c) 2026, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 package hat.test;
 26 
 27 import hat.Accelerator;
 28 import hat.ComputeContext;
 29 import hat.KernelContext;
 30 import hat.NDRange.Tile2D;
 31 import hat.backend.Backend;
 32 import hat.buffer.F16Array;
 33 import hat.buffer.F32Array;
 34 import hat.buffer.F32ArrayPadded;
 35 import hat.test.annotation.HatTest;
 36 import hat.test.exceptions.HATAssertionError;
 37 import hat.test.exceptions.HATAsserts;
 38 import hat.test.exceptions.HATExpectedPrecisionError;
 39 import hat.types.F16;
 40 import hat.types.Tensor;
 41 import jdk.incubator.code.Reflect;
 42 
 43 import java.lang.invoke.MethodHandles;
 44 import java.util.Random;
 45 
 46 import static hat.NDRange.Global2D;
 47 import static hat.NDRange.Local2D;
 48 import static hat.NDRange.NDRange2D;
 49 import static hat.NDRange.Warp2D;
 50 import static optkl.ifacemapper.MappableIface.RO;
 51 import static optkl.ifacemapper.MappableIface.WO;
 52 
 53 /**
 54  * Test tensor operations in HAT.
 55  *
 56  * <p>How to run?</p>
 57  * <p>For the CUDA backend:
 58  * <code>
 59  * HAT=SHOW_CODE java -cp hat/job.jar hat.java test ffi-cuda hat.test.TestTensors
 60  * </code>
 61  * </p>
 62  *
 63  * <p>For the OpenCL backend:
 64  * <code>
 65  * HAT=SHOW_CODE java -cp hat/job.jar hat.java test ffi-opencl hat.test.TestTensors
 66  * </code>
 67  * </p>
 68  *
 69  */
 70 public class TestTensors {
 71 
 72     @Reflect
 73     public static void mxmTensorsColumnMajor(@RO KernelContext kc, @RO F16Array matrixA, @RO F16Array matrixB, @WO F32Array matrixC, int size) {
 74         final int SHAPE = 16;
 75         final int WMMA_M = SHAPE;
 76         final int WMMA_N = SHAPE;
 77         final int WMMA_K = SHAPE;
 78         int warpM = kc.gix / kc.wrs;
 79         int warpN = kc.giy;
 80 
 81         final int lda = 1024;
 82         final int ldb = 1024;
 83         final int ldc = 1024;
 84 
 85         var shape = Tensor.shape(WMMA_M, WMMA_N, WMMA_K);
 86 
 87         // Initialize a tensor accumulator with zeros
 88         Tensor acc = Tensor.zeros(shape, float.class);
 89 
 90         for (int i = 0; i < size; i += WMMA_K) {
 91             int aRow = warpM * WMMA_M;
 92             int aCol = i;
 93             int bRow = i;
 94             int bCol = warpN * WMMA_N;
 95 
 96             if (aRow < lda && aCol < lda && bRow < ldb && bCol < ldb) {
 97                 // Load data from matrix A with the specified shape using column-major into a tensor of FP16
 98                 Tensor tensorA = Tensor.loadF16(matrixA, aRow, aCol, lda, shape, Tensor.ofColumnMajor());
 99 
100                 // Load data from matrix B with the specified shape using column-major into a tensor of FP16
101                 Tensor tensorB = Tensor.loadF16(matrixB, bRow, bCol, ldb, shape, Tensor.ofColumnMajor());
102 
103                 // Perform the MMA operation:
104                 // acc = tensorA * tensorB + acc
105                 acc = Tensor.mma(tensorA, tensorB, acc);
106             }
107         }
108         int cRow = warpM * WMMA_M;
109         int cCol = warpN * WMMA_N;
110 
111         // Store the resulting tensor into main memory using column-major layout.
112         if (cRow < size && cCol < size) {
113             // We operate with square matrices
114             Tensor.store(matrixC, cRow, cCol, acc, ldc, Tensor.ofColumnMajor());
115         }
116     }
117 
118     @Reflect
119     public static void mxmTensorsColumnMajor(@RO ComputeContext cc, @RO F16Array matrixA, @RO F16Array matrixB, @WO F32Array matrixC, int globalSize) {
120         // The total number of threads is calculated as follows:
121         // [ (size / tile), (size / tile) ]
122         // If warpSize > 1, then each dimension using warp operations is multiplied by the value of the warp-size. This is architecture dependent, but the
123         // HAT runtime and HAT JIT compiler handle this automatically.
124 
125         var ndRange = NDRange2D.of(Global2D.of(globalSize, globalSize),
126                 Local2D.of(128, 4),
127                 Tile2D.of(16, 16),
128                 Warp2D.of(true, false));
129 
130         cc.dispatchKernel(ndRange, kc -> mxmTensorsColumnMajor(kc, matrixA, matrixB, matrixC, globalSize));
131     }
132 
133     @Reflect
134     public static void mxmTensorsRowColumnMajor(@RO KernelContext kc, @RO F16Array matrixA, @RO F16Array matrixB, @WO F32Array matrixC, int size) {
135 
136         final int WMMA_M = 16;
137         final int WMMA_N = 16;
138         final int WMMA_K = 16;
139         int warpM = kc.gix / kc.wrs;
140         int warpN = kc.giy;
141 
142         final int lda = 1024;
143         final int ldb = 1024;
144         final int ldc = 1024;
145 
146         // We keep explicit constant in this version to check shape with ConstantOp
147         Tensor acc = Tensor.create(Tensor.shape(16, 16, 16), float.class);
148 
149         Tensor.fill(acc, 0.0f);
150 
151         for (int i = 0; i < size; i += WMMA_K) {
152             int aRow = warpM * WMMA_M;
153             int aCol = i;
154             int bRow = i;
155             int bCol = warpN * WMMA_N;
156             if (aRow < lda && aCol < lda && bRow < ldb && bCol < ldb) {
157                 Tensor tensorA = Tensor.loadF16(matrixA, aRow, aCol, lda, Tensor.shape(16, 16, 16), Tensor.ofRowMajor());
158                 Tensor tensorB = Tensor.loadF16(matrixB, bRow, bCol, ldb, Tensor.shape(16, 16, 16),Tensor.ofColumnMajor());
159                 acc = Tensor.mma(tensorA, tensorB, acc);
160             }
161         }
162         int cRow = warpM * WMMA_M;
163         int cCol = warpN * WMMA_N;
164         if (cRow < size && cCol < size) {
165             // We operate with square matrices
166             Tensor.store(matrixC, cRow, cCol, acc, ldc, Tensor.ofColumnMajor());
167         }
168     }
169 
170     @Reflect
171     public static void mxmTensorsRowColumnMajor(@RO ComputeContext cc, @RO F16Array matrixA, @RO F16Array matrixB, @WO F32Array matrixC, int globalSize) {
172         // The total number of threads is calculated as follows:
173         // [ (size / tile), (size / tile) ]
174         // If warpSize > 1, then each dimension using warp operations is multiplied by the value of the warp-size. This is architecture dependent, but the
175         // HAT runtime and HAT JIT compiler handle this automatically.
176 
177         var ndRange = NDRange2D.of(Global2D.of(globalSize, globalSize),
178                 Local2D.of(128, 4),
179                 Tile2D.of(16, 16),
180                 Warp2D.of(true, false));
181 
182         cc.dispatchKernel(ndRange, kc -> mxmTensorsRowColumnMajor(kc, matrixA, matrixB, matrixC, globalSize));
183     }
184 
185     @Reflect
186     public static void mxmTensorsRowMajor(@RO KernelContext kc, @RO F16Array matrixA, @RO F16Array matrixB, @WO F32ArrayPadded matrixC, int size) {
187         final int WMMA_M = 16;
188         final int WMMA_N = 16;
189         final int WMMA_K = 16;
190         int warpM = kc.gix / kc.wrs;
191         int warpN = kc.giy;
192 
193         final int lda = 1024;
194         final int ldb = 1024;
195         final int ldc = 1024;
196 
197         Tensor acc = Tensor.create(Tensor.shape(16, 16, 16), float.class);
198 
199         Tensor.fill(acc, 0.0f);
200 
201         for (int i = 0; i < size; i += WMMA_K) {
202             int aRow = warpM * WMMA_M;
203             int aCol = i;
204 
205             int bRow = i;
206             int bCol = warpN * WMMA_N;
207 
208             if (aRow < lda && aCol < lda && bRow < ldb && bCol < ldb) {
209 
210                 Tensor tensorA = Tensor.loadF16(matrixA, aRow, aCol, lda, Tensor.shape(16, 16, 16), Tensor.ofRowMajor());
211                 Tensor tensorB = Tensor.loadF16(matrixB, bRow, bCol, ldb, Tensor.shape(16, 16, 16), Tensor.ofRowMajor());
212 
213                 // acc = tensorA * tensorB + acc
214                 acc = Tensor.mma(tensorA, tensorB, acc);
215             }
216         }
217         int cRow = warpM * WMMA_M;
218         int cCol = warpN * WMMA_N;
219         if (cRow < size && cCol < size) {
220             Tensor.store(matrixC, cRow, cCol, acc, ldc, Tensor.ofRowMajor());
221         }
222     }
223 
224     @Reflect
225     public static void mxmTensorsRowMajor(@RO ComputeContext cc, @RO F16Array matrixA, @RO F16Array matrixB, @WO F32ArrayPadded matrixC, int globalSize) {
226         var ndRange = NDRange2D.of(
227                 Global2D.of(globalSize, globalSize),
228                 Local2D.of(128, 4),
229                 Tile2D.of(16, 16),
230                 Warp2D.of(true, false));
231         cc.dispatchKernel(ndRange, kc -> mxmTensorsRowMajor(kc, matrixA, matrixB, matrixC, globalSize));
232     }
233 
234     @Reflect
235     public static void mxmTensorsDefaultAccess(@RO KernelContext kc, @RO F16Array matrixA, @RO F16Array matrixB, @WO F32ArrayPadded matrixC, int size) {
236         final int sizeShape = 16;
237         final int WMMA_M = sizeShape;
238         final int WMMA_N = sizeShape;
239         final int WMMA_K = sizeShape;
240         int warpM = kc.gix / kc.wrs;
241         int warpN = kc.giy;
242 
243         final int lda = 1024;
244         final int ldb = 1024;
245         final int ldc = 1024;
246 
247         var shape = Tensor.shape(sizeShape, sizeShape, sizeShape);
248         Tensor acc = Tensor.zeros(shape, float.class);
249         for (int i = 0; i < size; i += WMMA_K) {
250             int aRow = warpM * WMMA_M;
251             int bCol = warpN * WMMA_N;
252             if (aRow < lda && i < lda && i < ldb && bCol < ldb) {
253                 Tensor tensorA = Tensor.loadF16(matrixA, aRow, i, lda, shape);
254                 Tensor tensorB = Tensor.loadF16(matrixB, i, bCol, ldb, shape);
255                 acc = Tensor.mma(tensorA, tensorB, acc);
256             }
257         }
258         int cRow = warpM * WMMA_M;
259         int cCol = warpN * WMMA_N;
260         if (cRow < size && cCol < size) {
261             Tensor.store(matrixC, cRow, cCol, acc, ldc);
262         }
263     }
264 
265     @Reflect
266     public static void mxmTensorsDefaultAccess(@RO ComputeContext cc, @RO F16Array matrixA, @RO F16Array matrixB, @WO F32ArrayPadded matrixC, int globalSize) {
267         var ndRange = NDRange2D.of(Global2D.of(globalSize, globalSize),
268                 Local2D.of(128, 4),
269                 Tile2D.of(16, 16),
270                 Warp2D.of(true, false));
271         cc.dispatchKernel(ndRange, kc -> mxmTensorsDefaultAccess(kc, matrixA, matrixB, matrixC, globalSize));
272     }
273 
274     private static void runSequentialColMajor(F16Array matrixA, F16Array matrixB, F32Array matrixC, final int size) {
275         for (int i = 0; i < size; i++) {
276             for (int j = 0; j < size; j++) {
277                 float sum = 0.0f;
278                 for (int k = 0; k < size; k++) {
279                     F16 a = matrixA.array((long) k * size + i);
280                     F16 b = matrixB.array((long) j * size + k);
281                     F16 mul = F16.mul(a, b);
282                     sum += F16.f16ToFloat(mul);
283                 }
284                 matrixC.array((long) j * size + i, sum);
285             }
286         }
287     }
288 
289     private static void runSequentialRowAndColMajor(F16Array matrixA, F16Array matrixB, F32Array matrixC, final int size) {
290         for (int i = 0; i < size; i++) {
291             for (int j = 0; j < size; j++) {
292                 float sum = 0.0f;
293                 for (int k = 0; k < size; k++) {
294                     F16 a = matrixA.array((long) i * size + k);
295                     F16 b = matrixB.array((long) j * size + k);
296                     F16 mul = F16.mul(a, b);
297                     sum += F16.f16ToFloat(mul);
298                 }
299                 matrixC.array((long) j * size + i, sum);
300             }
301         }
302     }
303 
304     private static void runSequentialRowMajor(F16Array matrixA, F16Array matrixB, F32Array matrixC, final int size) {
305         for (int i = 0; i < size; i++) {
306             for (int j = 0; j < size; j++) {
307                 float sum = 0.0f;
308                 for (int k = 0; k < size; k++) {
309                     F16 a = matrixA.array((long) i * size + k);
310                     F16 b = matrixB.array((long) k * size + j);
311                     F16 mul = F16.mul(a, b);
312                     sum += F16.f16ToFloat(mul);
313                 }
314                 matrixC.array((long) i * size + j, sum);
315             }
316         }
317     }
318 
319     @HatTest
320     @Reflect
321     public void testTensor01() {
322         var accelerator = new Accelerator(MethodHandles.lookup(), Backend.FIRST);
323         final int size = 1024;
324 
325         F16Array matrixAHalf = F16Array.create(accelerator, size * size);
326         F16Array matrixBHalf = F16Array.create(accelerator, size * size);
327         F32Array matrixC = F32Array.create(accelerator, size * size);
328         F32Array resultSequential = F32Array.create(accelerator, size * size);
329 
330         Random r = new Random(19);
331         for (int j = 0; j < matrixAHalf.length(); j++) {
332             matrixAHalf.array(j).value(F16.floatToF16(r.nextFloat()).value());
333             matrixBHalf.array(j).value(F16.floatToF16(r.nextFloat()).value());
334         }
335 
336         // Run multiple time
337         for (int i = 0; i < 10; i++) {
338             accelerator.compute(cc -> mxmTensorsColumnMajor(cc, matrixAHalf, matrixBHalf, matrixC, size));
339         }
340 
341         runSequentialColMajor(matrixAHalf, matrixBHalf, resultSequential, size);
342 
343         for (int i = 0; i < size; i++) {
344             for (int j = 0; j < size; j++) {
345                 final int index = j * size + i;
346                 float expectedValue = resultSequential.array(index);
347                 float gotValue = matrixC.array(index);
348                 try {
349                     HATAsserts.assertEquals(expectedValue, gotValue, 0.1f);
350                 } catch (HATAssertionError e) {
351                     throw new HATExpectedPrecisionError("Expected: " + expectedValue + " but got " + gotValue);
352                 }
353             }
354         }
355     }
356 
357     @HatTest
358     @Reflect
359     public void testTensor02() {
360         var accelerator = new Accelerator(MethodHandles.lookup(), Backend.FIRST);
361         final int size = 1024;
362 
363         F16Array matrixAHalf = F16Array.create(accelerator, size * size);
364         F16Array matrixBHalf = F16Array.create(accelerator, size * size);
365         F32Array matrixC = F32Array.create(accelerator, size * size);
366         F32Array resultSequential = F32Array.create(accelerator, size * size);
367 
368         Random r = new Random(19);
369         for (int j = 0; j < matrixAHalf.length(); j++) {
370             matrixAHalf.array(j).value(F16.floatToF16(r.nextFloat()).value());
371             matrixBHalf.array(j).value(F16.floatToF16(r.nextFloat()).value());
372         }
373 
374         accelerator.compute(cc -> mxmTensorsRowColumnMajor(cc, matrixAHalf, matrixBHalf, matrixC, size));
375 
376         runSequentialRowAndColMajor(matrixAHalf, matrixBHalf, resultSequential, size);
377 
378         for (int i = 0; i < size; i++) {
379             for (int j = 0; j < size; j++) {
380                 final int index = j * size + i;
381                 float expectedValue = resultSequential.array(index);
382                 float gotValue = matrixC.array(index);
383                 try {
384                     HATAsserts.assertEquals(expectedValue, gotValue, 0.1f);
385                 } catch (HATAssertionError e) {
386                     throw new HATExpectedPrecisionError("Expected: " + expectedValue + " but got " + gotValue);
387                 }
388             }
389         }
390     }
391 
392     @HatTest
393     @Reflect
394     public void testTensor03() {
395 
396         // To be able to run tensor-matmul in a row-major layout, we need to add padding.
397         // Thus, the result matrix must be of type F32ArrayPadded.
398 
399         var accelerator = new Accelerator(MethodHandles.lookup(), Backend.FIRST);
400         final int size = 1024;
401 
402         F16Array matrixAHalf = F16Array.create(accelerator, size * size);
403         F16Array matrixBHalf = F16Array.create(accelerator, size * size);
404         F32ArrayPadded matrixC = F32ArrayPadded.create(accelerator, size * size);
405         F32Array resultSequential = F32Array.create(accelerator, size * size);
406 
407         Random r = new Random(19);
408         for (int j = 0; j < matrixAHalf.length(); j++) {
409             matrixAHalf.array(j).value(F16.floatToF16(r.nextFloat()).value());
410             matrixBHalf.array(j).value(F16.floatToF16(r.nextFloat()).value());
411         }
412         accelerator.compute(cc -> mxmTensorsRowMajor(cc, matrixAHalf, matrixBHalf, matrixC, size));
413         runSequentialRowMajor(matrixAHalf, matrixBHalf, resultSequential, size);
414 
415         for (int i = 0; i < size; i++) {
416             for (int j = 0; j < size; j++) {
417                 final int index = j * size + i;
418                 float expectedValue = resultSequential.array(index);
419                 float gotValue = matrixC.array(index);
420                 try {
421                     HATAsserts.assertEquals(expectedValue, gotValue, 0.1f);
422                 } catch (HATAssertionError e) {
423                     throw new HATExpectedPrecisionError("Expected: " + expectedValue + " but got " + gotValue);
424                 }
425             }
426         }
427     }
428 
429     @HatTest
430     @Reflect
431     public void testTensor04() {
432 
433         // To be able to run tensor-matmul in a row-major layout, we need to add padding.
434         // Thus, the result matrix must be of type F32ArrayPadded.
435 
436         var accelerator = new Accelerator(MethodHandles.lookup(), Backend.FIRST);
437         final int size = 1024;
438 
439         F16Array matrixAHalf = F16Array.create(accelerator, size * size);
440         F16Array matrixBHalf = F16Array.create(accelerator, size * size);
441         F32ArrayPadded matrixC = F32ArrayPadded.create(accelerator, size * size);
442         F32Array resultSequential = F32Array.create(accelerator, size * size);
443 
444         Random r = new Random(19);
445         for (int j = 0; j < matrixAHalf.length(); j++) {
446             matrixAHalf.array(j).value(F16.floatToF16(r.nextFloat()).value());
447             matrixBHalf.array(j).value(F16.floatToF16(r.nextFloat()).value());
448         }
449 
450         accelerator.compute(cc -> mxmTensorsDefaultAccess(cc, matrixAHalf, matrixBHalf, matrixC, size));
451 
452         runSequentialRowMajor(matrixAHalf, matrixBHalf, resultSequential, size);
453 
454         for (int i = 0; i < size; i++) {
455             for (int j = 0; j < size; j++) {
456                 final int index = j * size + i;
457                 float expectedValue = resultSequential.array(index);
458                 float gotValue = matrixC.array(index);
459                 try {
460                     HATAsserts.assertEquals(expectedValue, gotValue, 0.1f);
461                 } catch (HATAssertionError e) {
462                     throw new HATExpectedPrecisionError("Expected: " + expectedValue + " but got " + gotValue);
463                 }
464             }
465         }
466     }
467 }