1 /*
  2  * Copyright (c) 2026, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 package hat.test;
 26 
 27 import hat.Accelerator;
 28 import hat.ComputeContext;
 29 import hat.KernelContext;
 30 import hat.NDRange.Tile2D;
 31 import hat.backend.Backend;
 32 import hat.buffer.F16Array;
 33 import hat.buffer.F32Array;
 34 import hat.buffer.F32ArrayPadded;
 35 import hat.test.annotation.HatTest;
 36 import hat.test.exceptions.HATAssertionError;
 37 import hat.test.exceptions.HATAsserts;
 38 import hat.test.exceptions.HATExpectedPrecisionError;
 39 import hat.types.F16;
 40 import hat.types.Tensor;
 41 import jdk.incubator.code.Reflect;
 42 
 43 import java.lang.invoke.MethodHandles;
 44 import java.util.Random;
 45 
 46 import static hat.NDRange.Global2D;
 47 import static hat.NDRange.Local2D;
 48 import static hat.NDRange.NDRange2D;
 49 import static hat.NDRange.Warp2D;
 50 import static optkl.ifacemapper.MappableIface.RO;
 51 import static optkl.ifacemapper.MappableIface.WO;
 52 
 53 /**
 54  * Check tensor operations in HAT. How to run?
 55  *
 56  * <p>
 57  * <code>
 58  * HAT=SHOW_CODE java -cp hat/job.jar hat.java test ffi-cuda hat.test.TestTensors
 59  * HAT=SHOW_CODE java -cp hat/job.jar hat.java test ffi-opencl hat.test.TestTensors
 60  * </code>
 61  * </p>
 62  *
 63  */
 64 public class TestTensors {
 65 
 66     @Reflect
 67     public static void mxmTensorsColumnMajor(@RO KernelContext kc, @RO F16Array matrixA, @RO F16Array matrixB, @WO F32Array matrixC, int size) {
 68         final int WMMA_M = 16;
 69         final int WMMA_N = 16;
 70         final int WMMA_K = 16;
 71         int warpM = kc.gix / kc.wrs;
 72         int warpN = kc.giy;
 73 
 74         final int lda = 1024;
 75         final int ldb = 1024;
 76         final int ldc = 1024;
 77 
 78         Tensor tensorA = Tensor.create(Tensor.FIRST, Tensor.shape(16, 16, 16), F16.class, Tensor.ofColumnMajor());
 79         Tensor tensorB = Tensor.create(Tensor.SECOND, Tensor.shape(16, 16, 16), F16.class, Tensor.ofColumnMajor());
 80         Tensor acc = Tensor.create(Tensor.ACC, Tensor.shape(16, 16, 16), float.class);
 81 
 82         Tensor.fill(acc, 0.0f);
 83 
 84         for (int i = 0; i < size; i += WMMA_K) {
 85             int aRow = warpM * WMMA_M;
 86             int aCol = i;
 87 
 88             int bRow = i;
 89             int bCol = warpN * WMMA_N;
 90 
 91             if (aRow < lda && aCol < lda && bRow < ldb && bCol < ldb) {
 92 
 93                 tensorA = Tensor.load(matrixA, aRow, aCol, lda);
 94                 tensorB = Tensor.load(matrixB, bRow, bCol, ldb);
 95 
 96                 // acc = tensorA * tensorB + acc
 97                 Tensor.mma(acc, tensorA, tensorB, acc);
 98             }
 99         }
100         int cRow = warpM * WMMA_M;
101         int cCol = warpN * WMMA_N;
102         Tensor.store(matrixC, cRow, cCol, acc, ldc, Tensor.ofColumnMajor());
103     }
104 
105     @Reflect
106     public static void mxmTensorsColumnMajor(@RO ComputeContext cc, @RO F16Array matrixA, @RO F16Array matrixB, @WO F32Array matrixC, int globalSize) {
107         // The total number of threads is calculated as follows:
108         // [ (size / tile), (size / tile) ]
109         // If warpSize > 1, then each dimension using warp operations is multiplied by the value of the warp-size. This is architecture dependent, but the
110         // HAT runtime and HAT JIT compiler handle this automatically.
111 
112         var ndRange = NDRange2D.of(Global2D.of(globalSize, globalSize),
113                 Local2D.of(128, 4),
114                 Tile2D.of(16, 16),
115                 Warp2D.of(true, false));
116 
117         cc.dispatchKernel(ndRange, kc -> mxmTensorsColumnMajor(kc, matrixA, matrixB, matrixC, globalSize));
118     }
119 
120     @Reflect
121     public static void mxmTensorsRowColumnMajor(@RO KernelContext kc, @RO F16Array matrixA, @RO F16Array matrixB, @WO F32Array matrixC, int size) {
122         final int WMMA_M = 16;
123         final int WMMA_N = 16;
124         final int WMMA_K = 16;
125         int warpM = kc.gix / kc.wrs;
126         int warpN = kc.giy;
127 
128         final int lda = 1024;
129         final int ldb = 1024;
130         final int ldc = 1024;
131 
132         Tensor tensorA = Tensor.create(Tensor.FIRST, Tensor.shape(16, 16, 16), F16.class, Tensor.ofRowMajor());
133         Tensor tensorB = Tensor.create(Tensor.SECOND, Tensor.shape(16, 16, 16), F16.class, Tensor.ofColumnMajor());
134         Tensor acc = Tensor.create(Tensor.ACC, Tensor.shape(16, 16, 16), float.class);
135 
136         Tensor.fill(acc, 0.0f);
137 
138         for (int i = 0; i < size; i += WMMA_K) {
139             int aRow = warpM * WMMA_M;
140             int aCol = i;
141 
142             int bRow = i;
143             int bCol = warpN * WMMA_N;
144 
145             if (aRow < lda && aCol < lda && bRow < ldb && bCol < ldb) {
146 
147                 tensorA = Tensor.load(matrixA, aRow, aCol, lda);
148                 tensorB = Tensor.load(matrixB, bRow, bCol, ldb);
149 
150                 // acc = tensorA * tensorB + acc
151                 Tensor.mma(acc, tensorA, tensorB, acc);
152             }
153         }
154         int cRow = warpM * WMMA_M;
155         int cCol = warpN * WMMA_N;
156         Tensor.store(matrixC, cRow, cCol, acc, ldc, Tensor.ofColumnMajor());
157     }
158 
159     @Reflect
160     public static void mxmTensorsRowColumnMajor(@RO ComputeContext cc, @RO F16Array matrixA, @RO F16Array matrixB, @WO F32Array matrixC, int globalSize) {
161         // The total number of threads is calculated as follows:
162         // [ (size / tile), (size / tile) ]
163         // If warpSize > 1, then each dimension using warp operations is multiplied by the value of the warp-size. This is architecture dependent, but the
164         // HAT runtime and HAT JIT compiler handle this automatically.
165 
166         var ndRange = NDRange2D.of(Global2D.of(globalSize, globalSize),
167                 Local2D.of(128, 4),
168                 Tile2D.of(16, 16),
169                 Warp2D.of(true, false));
170 
171         cc.dispatchKernel(ndRange, kc -> mxmTensorsRowColumnMajor(kc, matrixA, matrixB, matrixC, globalSize));
172     }
173 
174     @Reflect
175     public static void mxmTensorsRowMajor(@RO KernelContext kc, @RO F16Array matrixA, @RO F16Array matrixB, @WO F32ArrayPadded matrixC, int size) {
176         final int WMMA_M = 16;
177         final int WMMA_N = 16;
178         final int WMMA_K = 16;
179         int warpM = kc.gix / kc.wrs;
180         int warpN = kc.giy;
181 
182         final int lda = 1024;
183         final int ldb = 1024;
184         final int ldc = 1024;
185 
186         Tensor tensorA = Tensor.create(Tensor.FIRST, Tensor.shape(16, 16, 16), F16.class, Tensor.ofRowMajor());
187         Tensor tensorB = Tensor.create(Tensor.SECOND, Tensor.shape(16, 16, 16), F16.class, Tensor.ofRowMajor());
188         Tensor acc = Tensor.create(Tensor.ACC, Tensor.shape(16, 16, 16), float.class);
189 
190         Tensor.fill(acc, 0.0f);
191 
192         for (int i = 0; i < size; i += WMMA_K) {
193             int aRow = warpM * WMMA_M;
194             int aCol = i;
195 
196             int bRow = i;
197             int bCol = warpN * WMMA_N;
198 
199             if (aRow < lda && aCol < lda && bRow < ldb && bCol < ldb) {
200 
201                 tensorA = Tensor.load(matrixA, aRow, aCol, lda);
202                 tensorB = Tensor.load(matrixB, bRow, bCol, ldb);
203 
204                 // acc = tensorA * tensorB + acc
205                 Tensor.mma(acc, tensorA, tensorB, acc);
206             }
207         }
208         int cRow = warpM * WMMA_M;
209         int cCol = warpN * WMMA_N;
210         Tensor.store(matrixC, cRow, cCol, acc, ldc, Tensor.ofRowMajor());
211     }
212 
213     @Reflect
214     public static void mxmTensorsRowMajor(@RO ComputeContext cc, @RO F16Array matrixA, @RO F16Array matrixB, @WO F32ArrayPadded matrixC, int globalSize) {
215         var ndRange = NDRange2D.of(
216                 Global2D.of(globalSize, globalSize),
217                 Local2D.of(128, 4),
218                 Tile2D.of(16, 16),
219                 Warp2D.of(true, false));
220         cc.dispatchKernel(ndRange, kc -> mxmTensorsRowMajor(kc, matrixA, matrixB, matrixC, globalSize));
221     }
222 
223     private static void runSequentialColMajor(F16Array matrixA, F16Array matrixB, F32Array matrixC, final int size) {
224         for (int i = 0; i < size; i++) {
225             for (int j = 0; j < size; j++) {
226                 float sum = 0.0f;
227                 for (int k = 0; k < size; k++) {
228                     F16 a = matrixA.array((long) k * size + i);
229                     F16 b = matrixB.array((long) j * size + k);
230                     F16 mul = F16.mul(a, b);
231                     sum += F16.f16ToFloat(mul);
232                 }
233                 matrixC.array((long) j * size + i, sum);
234             }
235         }
236     }
237 
238     private static void runSequentialRowAndColMajor(F16Array matrixA, F16Array matrixB, F32Array matrixC, final int size) {
239         for (int i = 0; i < size; i++) {
240             for (int j = 0; j < size; j++) {
241                 float sum = 0.0f;
242                 for (int k = 0; k < size; k++) {
243                     F16 a = matrixA.array((long) i * size + k);
244                     F16 b = matrixB.array((long) j * size + k);
245                     F16 mul = F16.mul(a, b);
246                     sum += F16.f16ToFloat(mul);
247                 }
248                 matrixC.array((long) j * size + i, sum);
249             }
250         }
251     }
252 
253     private static void runSequentialRowMajor(F16Array matrixA, F16Array matrixB, F32Array matrixC, final int size) {
254         for (int i = 0; i < size; i++) {
255             for (int j = 0; j < size; j++) {
256                 float sum = 0.0f;
257                 for (int k = 0; k < size; k++) {
258                     F16 a = matrixA.array((long) i * size + k);
259                     F16 b = matrixB.array((long) k * size + j);
260                     F16 mul = F16.mul(a, b);
261                     sum += F16.f16ToFloat(mul);
262                 }
263                 matrixC.array((long) i * size + j, sum);
264             }
265         }
266     }
267 
268     @HatTest
269     @Reflect
270     public void testTensor01() {
271         var accelerator = new Accelerator(MethodHandles.lookup(), Backend.FIRST);
272         final int size = 1024;
273 
274         F16Array matrixAHalf = F16Array.create(accelerator, size * size);
275         F16Array matrixBHalf = F16Array.create(accelerator, size * size);
276         F32Array matrixC = F32Array.create(accelerator, size * size);
277         F32Array resultSequential = F32Array.create(accelerator, size * size);
278 
279         Random r = new Random(19);
280         for (int j = 0; j < matrixAHalf.length(); j++) {
281             matrixAHalf.array(j).value(F16.floatToF16(r.nextFloat()).value());
282             matrixBHalf.array(j).value(F16.floatToF16(r.nextFloat()).value());
283         }
284 
285         for (int i = 0; i < 10; i++) {
286             accelerator.compute(cc -> mxmTensorsColumnMajor(cc, matrixAHalf, matrixBHalf, matrixC, size));
287         }
288 
289         runSequentialColMajor(matrixAHalf, matrixBHalf, resultSequential, size);
290 
291         for (int i = 0; i < size; i++) {
292             for (int j = 0; j < size; j++) {
293                 final int index = j * size + i;
294                 float expectedValue = resultSequential.array(index);
295                 float gotValue = matrixC.array(index);
296                 try {
297                     HATAsserts.assertEquals(expectedValue, gotValue, 0.1f);
298                 } catch (HATAssertionError e) {
299                     throw new HATExpectedPrecisionError("Expected: " + expectedValue + " but got " + gotValue);
300                 }
301             }
302         }
303     }
304 
305     @HatTest
306     @Reflect
307     public void testTensor02() {
308         var accelerator = new Accelerator(MethodHandles.lookup(), Backend.FIRST);
309         final int size = 1024;
310 
311         F16Array matrixAHalf = F16Array.create(accelerator, size * size);
312         F16Array matrixBHalf = F16Array.create(accelerator, size * size);
313         F32Array matrixC = F32Array.create(accelerator, size * size);
314         F32Array resultSequential = F32Array.create(accelerator, size * size);
315 
316         Random r = new Random(19);
317         for (int j = 0; j < matrixAHalf.length(); j++) {
318             matrixAHalf.array(j).value(F16.floatToF16(r.nextFloat()).value());
319             matrixBHalf.array(j).value(F16.floatToF16(r.nextFloat()).value());
320         }
321 
322         for (int i = 0; i < 10; i++) {
323             accelerator.compute(cc -> mxmTensorsRowColumnMajor(cc, matrixAHalf, matrixBHalf, matrixC, size));
324         }
325 
326         runSequentialRowAndColMajor(matrixAHalf, matrixBHalf, resultSequential, size);
327 
328         for (int i = 0; i < size; i++) {
329             for (int j = 0; j < size; j++) {
330                 final int index = j * size + i;
331                 float expectedValue = resultSequential.array(index);
332                 float gotValue = matrixC.array(index);
333                 try {
334                     HATAsserts.assertEquals(expectedValue, gotValue, 0.1f);
335                 } catch (HATAssertionError e) {
336                     throw new HATExpectedPrecisionError("Expected: " + expectedValue + " but got " + gotValue);
337                 }
338             }
339         }
340     }
341 
342     @HatTest
343     @Reflect
344     public void testTensor03() {
345 
346         // To be able to run tensor-matmul in a row-major layout, we need to add padding.
347         // Thus, the result matrix must be of type F32ArrayPadded.
348 
349         var accelerator = new Accelerator(MethodHandles.lookup(), Backend.FIRST);
350         final int size = 1024;
351 
352         F16Array matrixAHalf = F16Array.create(accelerator, size * size);
353         F16Array matrixBHalf = F16Array.create(accelerator, size * size);
354         F32ArrayPadded matrixC = F32ArrayPadded.create(accelerator, size * size);
355         F32Array resultSequential = F32Array.create(accelerator, size * size);
356 
357         Random r = new Random(19);
358         for (int j = 0; j < matrixAHalf.length(); j++) {
359             matrixAHalf.array(j).value(F16.floatToF16(r.nextFloat()).value());
360             matrixBHalf.array(j).value(F16.floatToF16(r.nextFloat()).value());
361         }
362 
363         for (int i = 0; i < 10; i++) {
364             accelerator.compute(cc -> mxmTensorsRowMajor(cc, matrixAHalf, matrixBHalf, matrixC, size));
365         }
366 
367         runSequentialRowMajor(matrixAHalf, matrixBHalf, resultSequential, size);
368 
369         for (int i = 0; i < size; i++) {
370             for (int j = 0; j < size; j++) {
371                 final int index = j * size + i;
372                 float expectedValue = resultSequential.array(index);
373                 float gotValue = matrixC.array(index);
374                 try {
375                     HATAsserts.assertEquals(expectedValue, gotValue, 0.1f);
376                 } catch (HATAssertionError e) {
377                     throw new HATExpectedPrecisionError("Expected: " + expectedValue + " but got " + gotValue);
378                 }
379             }
380         }
381     }
382 }