1 /*
2 * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation. Oracle designates this
8 * particular file as subject to the "Classpath" exception as provided
9 * by Oracle in the LICENSE file that accompanied this code.
10 *
11 * This code is distributed in the hope that it will be useful, but WITHOUT
12 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
13 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
14 * version 2 for more details (a copy is included in the LICENSE file that
15 * accompanied this code).
16 *
17 * You should have received a copy of the GNU General Public License version
18 * 2 along with this work; if not, write to the Free Software Foundation,
19 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
20 *
21 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
22 * or visit www.oracle.com if you need additional information or have any
23 * questions.
24 */
25 package hat.test;
26
27 import hat.Accelerator;
28 import hat.ComputeContext;
29 import hat.NDRange;
30 import hat.KernelContext;
31 import hat.backend.Backend;
32 import hat.types.BF16;
33 import hat.buffer.BF16Array;
34 import hat.types.F16;
35 import hat.buffer.F16Array;
36 import hat.buffer.F32Array;
37 import hat.buffer.F32ArrayPadded;
38 import hat.types.Float4;
39 import hat.device.DeviceSchema;
40 import hat.device.NonMappableIface;
41 import hat.test.annotation.HatTest;
42 import hat.test.exceptions.HATAssertionError;
43 import hat.test.exceptions.HATAsserts;
44 import hat.test.exceptions.HATExpectedPrecisionError;
45 import jdk.incubator.code.Reflect;
46
47 import java.lang.invoke.MethodHandles;
48 import java.util.Random;
49
50 import static optkl.ifacemapper.MappableIface.RO;
51 import static optkl.ifacemapper.MappableIface.RW;
52
53 public class TestMatMul {
54
55 private static final int SIZE = 256;
56
57 @Reflect
58 public static void matrixMultiplyKernel2D(@RO KernelContext kc, @RO F32Array matrixA, @RO F32Array matrixB, @RW F32Array matrixC, int size) {
59 if (kc.gix < kc.gsx) {
60 if (kc.giy < kc.gsy) {
61 float acc = 0.0f;
62 for (int k = 0; k < size; k++) {
63 acc += (matrixA.array(kc.gix * size + k) * matrixB.array(k * size + kc.giy));
64 }
65 matrixC.array(kc.gix * size + kc.giy, acc);
66 }
67 }
68 }
69
70 @Reflect
71 public static void matrixMultiplyKernel2DLI(@RO KernelContext kc, @RO F32Array matrixA, @RO F32Array matrixB, @RW F32Array matrixC, int size) {
72 if (kc.gix < kc.gsx) {
73 if (kc.giy < kc.gsy) {
74 float acc = 0.0f;
75 for (int k = 0; k < size; k++) {
76 acc += (matrixA.array(kc.giy * size + k) * matrixB.array(k * size + kc.gix));
77 }
78 matrixC.array(kc.giy * size + kc.gix, acc);
79 }
80 }
81 }
82
83 @Reflect
84 public static void matrixMultiplyKernel2DLIF16(@RO KernelContext kc, @RO F16Array matrixA, @RO F16Array matrixB, @RW F16Array matrixC, int size) {
85 if (kc.gix < kc.gsx) {
86 if (kc.giy < kc.gsy) {
87 F16 acc = F16.of(0.0f);
88 for (int k = 0; k < size; k++) {
89 F16 valA = matrixA.array(kc.giy * size + k);
90 F16 valB = matrixB.array(k * size + kc.gix);
91 F16 valc = F16.mul(valA, valB);
92 acc = F16.add(acc, valc);
93 }
94 F16 resultC = matrixC.array(kc.giy * size + kc.gix);
95 resultC.value(acc.value());
96 }
97 }
98 }
99
100 private interface MyLocalArrayFixedSize extends NonMappableIface {
101 void array(long index, float value);
102 float array(long index);
103
104 DeviceSchema<MyLocalArrayFixedSize> schema = DeviceSchema.of(MyLocalArrayFixedSize.class,
105 myPrivateArray -> myPrivateArray
106 .withArray("array", 256));
107
108 static MyLocalArrayFixedSize create(Accelerator accelerator) {
109 return null;
110 }
111
112 static MyLocalArrayFixedSize createLocal() {
113 return null;
114 }
115 }
116
117 @Reflect
118 public static void matrixMultiplyKernel2DTiling(@RO KernelContext kc, @RO F32Array matrixA, @RO F32Array matrixB, @RW F32Array matrixC, int size) {
119
120 final int tileSize = 16;
121 MyLocalArrayFixedSize tileA = MyLocalArrayFixedSize.createLocal();
122 MyLocalArrayFixedSize tileB = MyLocalArrayFixedSize.createLocal();
123
124 int groupIndexX = kc.bix;
125 int groupIndexY = kc.biy;
126 int localIdx = kc.lix;
127 int localIdy = kc.liy;
128
129 // we identify the row and column
130 int row = groupIndexY * tileSize + localIdy;
131 int col = groupIndexX * tileSize + localIdx;
132
133 // Compute matrix-vector and accumulate the result over the tiles
134 float sum = 0.0f;
135 for (int tile = 0; tile < (size/tileSize); tile++) {
136 // Copy from global to shared memory
137 tileA.array((long) localIdy * tileSize + localIdx, matrixA.array((long) row * size + tile * tileSize + localIdx));
138 tileB.array((long) localIdy * tileSize + localIdx, matrixB.array((tile * tileSize + localIdy) * size + col));
139
140 // Apply a barrier for the local group: we need to guarantee that all threads that belong
141 // to the same group reach this point before doing the partial reduction
142 kc.barrier();
143
144 // compute partial reductions over the tile
145 for (int k = 0; k < tileSize; k++) {
146 sum += (tileA.array((long) localIdy * tileSize + k) * tileB.array(k * tileSize + localIdx));
147 }
148
149 // A new local barrier for all threads that belong to the same group before loading a new tile into
150 // share memory. With the following barrier, we can ensure that all threads within the same workgroup
151 // finished the compute for the partial reduction
152 kc.barrier();
153 }
154
155 // copy result from shared memory to global memory
156 matrixC.array((long) row * size + col, sum);
157 }
158
159 @Reflect
160 public static float compute(@RO KernelContext kc, @RO F32Array matrixA, @RO F32Array matrixB, int size, int j) {
161 float acc = 0.0f;
162 for (int k = 0; k < size; k++) {
163 acc += (matrixA.array(kc.gix * size + k) * matrixB.array(k * size + j));
164 }
165 return acc;
166 }
167
168 @Reflect
169 public static void matrixMultiplyKernel1D(@RO KernelContext kc, @RO F32Array matrixA, @RO F32Array matrixB, @RW F32Array matrixC, int size) {
170 if (kc.gix < kc.gsx) {
171 for (int j = 0; j < size; j++) {
172 float acc = 0.0f;
173 for (int k = 0; k < size; k++) {
174 acc += (matrixA.array(kc.gix * size + k) * matrixB.array(k * size + j));
175 }
176 matrixC.array(kc.gix * size + j, acc);
177 }
178 }
179 }
180
181 @Reflect
182 public static void matrixMultiplyKernel1DWithFunctionCalls(@RO KernelContext kc, @RO F32Array matrixA, @RO F32Array matrixB, @RW F32Array matrixC, int size) {
183 if (kc.gix < kc.gsx) {
184 for (int j = 0; j < size; j++) {
185 float acc = compute(kc, matrixA, matrixB, size, j);
186 matrixC.array(kc.gix * size + j, acc);
187 }
188 }
189 }
190
191 @Reflect
192 public static void matrixMultiply1D(@RO ComputeContext cc, @RO F32Array matrixA, @RO F32Array matrixB, @RW F32Array matrixC, int globalSize) {
193 cc.dispatchKernel(NDRange.of1D(globalSize,16),
194 kc -> matrixMultiplyKernel1D(kc, matrixA, matrixB, matrixC, globalSize)
195 );
196 }
197
198 final static int BLOCK_SIZE = 16;
199
200 @Reflect
201 public static void matrixMultiply1DWithFunctionCalls(@RO ComputeContext cc, @RO F32Array matrixA, @RO F32Array matrixB, @RW F32Array matrixC, int size) {
202 cc.dispatchKernel(NDRange.of1D(size),
203 kc -> matrixMultiplyKernel1DWithFunctionCalls(kc, matrixA, matrixB, matrixC, size)
204 );
205 }
206
207 @Reflect
208 public static void matrixMultiply2D(@RO ComputeContext cc, @RO F32Array matrixA, @RO F32Array matrixB, @RW F32Array matrixC, int globalSize) {
209 cc.dispatchKernel(NDRange.of2D(globalSize, globalSize,BLOCK_SIZE, BLOCK_SIZE),
210 kc -> matrixMultiplyKernel2D(kc, matrixA, matrixB, matrixC, globalSize)
211 );
212 }
213
214 @Reflect
215 public static void matrixMultiply2DLI(@RO ComputeContext cc, @RO F32Array matrixA, @RO F32Array matrixB, @RW F32Array matrixC, int globalSize) {
216 cc.dispatchKernel(NDRange.of2D(globalSize, globalSize,BLOCK_SIZE, BLOCK_SIZE),
217 kc -> matrixMultiplyKernel2DLI(kc, matrixA, matrixB, matrixC, globalSize)
218 );
219 }
220
221 @Reflect
222 public static void matrixMultiply2DLIF16(@RO ComputeContext cc, @RO F16Array matrixA, @RO F16Array matrixB, @RW F16Array matrixC, int globalSize) {
223 cc.dispatchKernel(NDRange.of2D(globalSize, globalSize,BLOCK_SIZE, BLOCK_SIZE),
224 kc -> matrixMultiplyKernel2DLIF16(kc, matrixA, matrixB, matrixC, globalSize)
225 );
226 }
227
228 @Reflect
229 public static void matrixMultiply2DTiling(@RO ComputeContext cc, @RO F32Array matrixA, @RO F32Array matrixB, @RW F32Array matrixC, int globalSize) {
230 cc.dispatchKernel(NDRange.of2D(globalSize, globalSize,BLOCK_SIZE, BLOCK_SIZE),
231 kc -> matrixMultiplyKernel2DTiling(kc, matrixA, matrixB, matrixC, globalSize)
232 );
233 }
234
235 private static void runSequential(F16Array matrixA, F16Array matrixB, F16Array matrixC, final int size) {
236 for (int i = 0; i < size; i++) {
237 for (int j = 0; j < size; j++) {
238 F16 sum = F16.of(0.0f);
239 for (int k = 0; k < size; k++) {
240 F16 a = matrixA.array((long) i * size + k);
241 F16 b = matrixB.array((long) k * size + j);
242 sum = F16.add(sum, F16.mul(a, b));
243 }
244 matrixC.array((long) i * size + j).value(sum.value());
245 }
246 }
247 }
248
249 private static void runSequential(F32Array matrixA, F32Array matrixB, F32Array matrixC, final int size) {
250 for (int i = 0; i < size; i++) {
251 for (int j = 0; j < size; j++) {
252 float sum = 0;
253 for (int k = 0; k < size; k++) {
254 float a = matrixA.array((long) i * size + k);
255 float b = matrixB.array((long) k * size + j);
256 sum += a * b;
257 }
258 matrixC.array((long) i * size + j, sum);
259 }
260 }
261 }
262
263 private static void runSequential(F32ArrayPadded matrixA, F32ArrayPadded matrixB, F32ArrayPadded matrixC, final int size) {
264 for (int i = 0; i < size; i++) {
265 for (int j = 0; j < size; j++) {
266 float sum = 0;
267 for (int k = 0; k < size; k++) {
268 float a = matrixA.array((long) i * size + k);
269 float b = matrixB.array((long) k * size + j);
270 sum += a * b;
271 }
272 matrixC.array((long) i * size + j, sum);
273 }
274 }
275 }
276
277 private static void runSequential(BF16Array matrixA, BF16Array matrixB, BF16Array matrixC, final int size) {
278 for (int i = 0; i < size; i++) {
279 for (int j = 0; j < size; j++) {
280 BF16 sum = BF16.of(0.0f);
281 for (int k = 0; k < size; k++) {
282 BF16 a = matrixA.array((long) i * size + k);
283 BF16 b = matrixB.array((long) k * size + j);
284 sum = BF16.add(sum, BF16.mul(a, b));
285 }
286 matrixC.array((long) i * size + j).value(sum.value());
287 }
288 }
289 }
290
291 @HatTest
292 @Reflect
293 public void testMatrixMultiply1D() {
294 var lookup = MethodHandles.lookup();
295 var accelerator = new Accelerator(lookup, Backend.FIRST);
296
297 final int size = SIZE;
298 var matrixA = F32Array.create(accelerator, size * size);
299 var matrixB = F32Array.create(accelerator, size * size);
300
301 // Matrix for the results
302 var matrixC = F32Array.create(accelerator, size * size);
303 var resultSeq = F32Array.create(accelerator, size * size);
304
305 // Initialize matrices (A and B have the same size)
306 Random r = new Random(19);
307
308 for (int j = 0; j < matrixA.length(); j++) {
309 matrixA.array(j, r.nextFloat());
310 matrixB.array(j, r.nextFloat());
311 }
312
313 accelerator.compute(cc ->
314 TestMatMul.matrixMultiply1D(cc, matrixA, matrixB, matrixC, size));
315
316 // Run Seq for reference
317 runSequential(matrixA, matrixB, resultSeq, size);
318
319 for (int j = 0; j < size; j++) {
320 for (int i = 0; i < size; i++) {
321 HATAsserts.assertEquals(resultSeq.array(i * size + j), matrixC.array(i * size + j), 0.01f);
322 }
323 }
324 }
325
326 @HatTest
327 @Reflect
328 public void testMatrixMultiply1DWithFunctionCalls() {
329 var lookup = MethodHandles.lookup();
330 var accelerator = new Accelerator(lookup, Backend.FIRST);
331
332 final int size = SIZE;
333 var matrixA = F32Array.create(accelerator, size * size);
334 var matrixB = F32Array.create(accelerator, size * size);
335
336 // Matrix for the results
337 var matrixC = F32Array.create(accelerator, size * size);
338 var resultSeq = F32Array.create(accelerator, size * size);
339
340 // Initialize matrices (A and B have the same size)
341 Random r = new Random(19);
342
343 for (int j = 0; j < matrixA.length(); j++) {
344 matrixA.array(j, r.nextFloat());
345 matrixB.array(j, r.nextFloat());
346 }
347
348 accelerator.compute(cc ->
349 TestMatMul.matrixMultiply1DWithFunctionCalls(cc, matrixA, matrixB, matrixC, size));
350
351 // Run Seq for reference
352 runSequential(matrixA, matrixB, resultSeq, size);
353
354 for (int j = 0; j < size; j++) {
355 for (int i = 0; i < size; i++) {
356 HATAsserts.assertEquals(resultSeq.array(i * size + j), matrixC.array(i * size + j), 0.01f);
357 }
358 }
359 }
360
361
362 @HatTest
363 @Reflect
364 public void testMatrixMultiply2D() {
365 var lookup = MethodHandles.lookup();
366 var accelerator = new Accelerator(lookup, Backend.FIRST);
367
368 final int size = SIZE;
369 var matrixA = F32Array.create(accelerator, size * size);
370 var matrixB = F32Array.create(accelerator, size * size);
371
372 // Matrix for the results
373 var matrixC = F32Array.create(accelerator, size * size);
374 var resultSeq = F32Array.create(accelerator, size * size);
375
376 // Initialize matrices (A and B have the same size)
377 Random r = new Random(19);
378
379 for (int j = 0; j < matrixA.length(); j++) {
380 matrixA.array(j, r.nextFloat());
381 matrixB.array(j, r.nextFloat());
382 }
383
384 accelerator.compute(cc ->
385 TestMatMul.matrixMultiply2D(cc, matrixA, matrixB, matrixC, size));
386
387 // Run Seq for reference
388 runSequential(matrixA, matrixB, resultSeq, size);
389
390 for (int j = 0; j < size; j++) {
391 for (int i = 0; i < size; i++) {
392 HATAsserts.assertEquals(resultSeq.array(i * size + j), matrixC.array(i * size + j), 0.01f);
393 }
394 }
395 }
396
397 @HatTest
398 @Reflect
399 public void testMatrixMultiply2DLI() {
400 var lookup = MethodHandles.lookup();
401 var accelerator = new Accelerator(lookup, Backend.FIRST);
402
403 final int size = SIZE;
404 var matrixA = F32Array.create(accelerator, size * size);
405 var matrixB = F32Array.create(accelerator, size * size);
406
407 // Matrix for the results
408 var matrixC = F32Array.create(accelerator, size * size);
409 var resultSeq = F32Array.create(accelerator, size * size);
410
411 // Initialize matrices (A and B have the same size)
412 Random r = new Random(19);
413
414 for (int j = 0; j < matrixA.length(); j++) {
415 matrixA.array(j, r.nextFloat());
416 matrixB.array(j, r.nextFloat());
417 }
418
419 accelerator.compute(cc ->
420 TestMatMul.matrixMultiply2DLI(cc, matrixA, matrixB, matrixC, size));
421
422 // Run Seq for reference
423 runSequential(matrixA, matrixB, resultSeq, size);
424
425 for (int j = 0; j < size; j++) {
426 for (int i = 0; i < size; i++) {
427 HATAsserts.assertEquals(resultSeq.array(i * size + j), matrixC.array(i * size + j), 0.01f);
428 }
429 }
430 }
431
432 @HatTest
433 @Reflect
434 public void testMatrixMultiply2DLIF16() {
435 var lookup = MethodHandles.lookup();
436 var accelerator = new Accelerator(lookup, Backend.FIRST);
437
438 final int size = SIZE;
439 var matrixA = F16Array.create(accelerator, size * size);
440 var matrixB = F16Array.create(accelerator, size * size);
441
442 // Matrix for the results
443 var matrixC = F16Array.create(accelerator, size * size);
444 var resultSeq = F16Array.create(accelerator, size * size);
445
446 // Initialize matrices (A and B have the same size)
447 Random r = new Random(19);
448
449 for (int j = 0; j < matrixA.length(); j++) {
450 matrixA.array(j).value(F16.floatToF16(r.nextFloat()).value());
451 matrixB.array(j).value(F16.floatToF16(r.nextFloat()).value());
452 }
453
454 accelerator.compute(cc ->
455 TestMatMul.matrixMultiply2DLIF16(cc, matrixA, matrixB, matrixC, size));
456
457 // Run Seq for reference
458 runSequential(matrixA, matrixB, resultSeq, size);
459
460 for (int j = 0; j < size; j++) {
461 for (int i = 0; i < size; i++) {
462 try {
463 HATAsserts.assertEquals(
464 Float.float16ToFloat(resultSeq.array(i * size + j).value()),
465 Float.float16ToFloat(matrixC.array(i * size + j).value()),
466 0.01f);
467 } catch (HATAssertionError hatAssertionError) {
468 throw new HATExpectedPrecisionError(hatAssertionError.getMessage());
469 }
470 }
471 }
472 }
473
474 @HatTest
475 @Reflect
476 public void testMatrixMultiply2DTiling() {
477 var lookup = MethodHandles.lookup();
478 var accelerator = new Accelerator(lookup, Backend.FIRST);
479
480 final int size = SIZE;
481 var matrixA = F32Array.create(accelerator, size * size);
482 var matrixB = F32Array.create(accelerator, size * size);
483
484 // Matrix for the results
485 var matrixC = F32Array.create(accelerator, size * size);
486 var resultSeq = F32Array.create(accelerator, size * size);
487
488 // Initialize matrices (A and B have the same size)
489 Random r = new Random(19);
490
491 for (int j = 0; j < matrixA.length(); j++) {
492 matrixA.array(j, r.nextFloat());
493 matrixB.array(j, r.nextFloat());
494 }
495
496 accelerator.compute(cc ->
497 TestMatMul.matrixMultiply2DTiling(cc, matrixA, matrixB, matrixC, size));
498
499 // Run Seq for reference
500 runSequential(matrixA, matrixB, resultSeq, size);
501
502 for (int j = 0; j < size; j++) {
503 for (int i = 0; i < size; i++) {
504 HATAsserts.assertEquals(resultSeq.array(i * size + j), matrixC.array(i * size + j), 0.01f);
505 }
506 }
507 }
508
509 private interface SharedMemory extends NonMappableIface {
510 void array(long index, float value);
511 float array(long index);
512 DeviceSchema<SharedMemory> schema = DeviceSchema.of(SharedMemory.class,
513 arr -> arr.withArray("array", 1024));
514 static SharedMemory create(Accelerator accelerator) {
515 return null;
516 }
517 static SharedMemory createLocal() {
518 return null;
519 }
520 default void storeFloat4View(Float4 float4, int index) {
521 }
522 }
523
524 private interface PrivateArray extends NonMappableIface {
525 void array(long index, float value);
526 float array(long index);
527 DeviceSchema<PrivateArray> schema = DeviceSchema.of(PrivateArray.class,
528 arr -> arr.withArray("array", 16));
529 static PrivateArray create(Accelerator accelerator) {
530 return null;
531 }
532 static PrivateArray createPrivate() {
533 return null;
534 }
535 }
536
537 private interface FlatPrivate extends NonMappableIface {
538 void array(long index, float value);
539 float array(long index);
540 DeviceSchema<FlatPrivate> schema = DeviceSchema.of(FlatPrivate.class,
541 arr -> arr.withArray("array", 4));
542 static FlatPrivate create(Accelerator accelerator) {
543 return null;
544 }
545 static FlatPrivate createPrivate() {
546 return null;
547 }
548 }
549
550 // Code ported from the HAT example module.
551 @Reflect
552 public static void matrixMultiplyKernel2DRegisterTiling(@RO KernelContext kc, @RO F32Array matrixA, @RO F32Array matrixB, @RW F32Array matrixC, int size) {
553
554 // Configuration for the kernel: Keep in mind that if you change the following parameters,
555 // also change the scheduling (global and local work sizes).
556 final int BM = 64;
557 final int BN = 64;
558 final int BK = 16;
559 final int TM = 4;
560 final int TN = 4;
561
562 int bx = kc.bix;
563 int by = kc.biy;
564
565 int totalResultsBlockTile = BM * BN;
566 final int numThreadsBlockTile = totalResultsBlockTile / (TM * TN);
567
568 final int linearLocalId = kc.liy * kc.lsx + kc.lix;
569 final int threadCol = kc.lix;
570 final int threadRow = kc.liy;
571
572 SharedMemory tileA = SharedMemory.createLocal();
573 SharedMemory tileB = SharedMemory.createLocal();
574
575 int aFrom = by * BM * size;
576 int bFrom = bx * BN;
577 int v = bx * BN;
578 int cFrom = (by * BM * size) + (v);
579
580 final int innerRowA = linearLocalId / BK;
581 final int innerColA = linearLocalId % BK;
582
583 final int strideA = numThreadsBlockTile / BK;
584 final int innerRowB = linearLocalId / BN;
585 final int innerColB = linearLocalId % BN;
586
587 int strideB = numThreadsBlockTile / BN;
588
589 // Declarations of the arrays in private memory to perform register tiling
590 PrivateArray threadResults = PrivateArray.createPrivate();
591 FlatPrivate regM = FlatPrivate.createPrivate();
592 FlatPrivate regN = FlatPrivate.createPrivate();
593
594 // initialize values
595 for (int i = 0; i < (TN * TN); i++) {
596 threadResults.array(i, 0.0f);
597 }
598
599 // Each thread loops over the tiles
600 for (int bkIdx = 0; bkIdx < size; bkIdx += BK) {
601
602 // A) Load data into shared memory for array A
603 for (int loadOffset = 0; loadOffset < BM; loadOffset += strideA) {
604 tileA.array((innerRowA + loadOffset) * BK + innerColA,
605 matrixA.array(((innerRowA + loadOffset) * size + innerColA) + aFrom));
606 }
607
608 // B) Load data matrixB into shared memory for array B
609 for (int loadOffset = 0; loadOffset < BK; loadOffset += strideB) {
610 tileB.array((innerRowB + loadOffset) * BN + innerColB,
611 matrixB.array(((innerRowB + loadOffset) * size + innerColB) + bFrom));
612 }
613 kc.barrier();
614
615 aFrom += (BK);
616 int f = BK * size;
617 bFrom += f;
618
619 // Per-thread, we load the data from the shared memory into register for both
620 // array A and array B (matrix A and B), and then perform the reduction within
621 // the small region in private memory.
622 for (int dotIdx = 0; dotIdx < BK; dotIdx++) {
623 // block into registers
624 for (int i = 0; i < TM; i++) {
625 regM.array(i, tileA.array((threadRow * TM + i) * BK + dotIdx));
626 }
627 for (int i = 0; i < TN; i++) {
628 regN.array(i, tileB.array(dotIdx * BN + threadCol * TN + i));
629 }
630 for (int resIdxM = 0; resIdxM < TM; resIdxM++) {
631 for (int resIdxN = 0; resIdxN < TN; resIdxN++) {
632 float val = regM.array(resIdxM) * regN.array(resIdxN);
633 float acc = threadResults.array(resIdxM * TN + resIdxN);
634 acc += val;
635 threadResults.array((resIdxM * TN + resIdxN), (acc));
636 }
637 }
638 }
639 kc.barrier();
640 }
641
642 // Finally, we store the results of the reductions for the whole 2D register block into global memory.
643 // Essentially, each thread compute a small block of TM * TN sub-block size.
644 for (int resIdxM = 0; resIdxM < TM; resIdxM++) {
645 for (int resIdxN = 0; resIdxN < TN; resIdxN++) {
646 float value = threadResults.array(resIdxM * TN + resIdxN);
647 matrixC.array((((threadRow * TM + resIdxM) * size + threadCol * TN + resIdxN) + (cFrom)), value);
648 }
649 }
650 }
651
652 // Code ported from the HAT example module.
653 @Reflect
654 public static void matrixMultiplyKernel2DRegisterTilingVectorized(@RO KernelContext kc, @RO F32ArrayPadded matrixA, @RO F32ArrayPadded matrixB, @RW F32ArrayPadded matrixC, int size) {
655
656 // Configuration for the kernel: Keep in mind that if you change the following parameters,
657 // also change the scheduling (global and local work sizes).
658 // final int M = size;
659 final int N = size;
660 final int K = size;
661 final int BM = 64;
662 final int BN = 64;
663 final int BK = 16;
664 final int TM = 4;
665 final int TN = 4;
666
667 int bx = kc.bix;
668 int by = kc.biy;
669
670 final int linearLocalId = kc.liy * kc.lsx + kc.lix;
671 final int threadCol = kc.lix;
672 final int threadRow = kc.liy;
673
674 SharedMemory tileA = SharedMemory.createLocal();
675 SharedMemory tileB = SharedMemory.createLocal();
676
677 int aFrom = by * BM * size;
678 int bFrom = bx * BN;
679 int v = bx * BN;
680 int cFrom = (by * BM * size) + (v);
681
682 final int innerRowA = linearLocalId / (BK / 4);
683 final int innerColA = linearLocalId % (BK / 4);
684 final int innerRowB = linearLocalId / (BN / 4);
685 final int innerColB = linearLocalId % (BN / 4);
686
687 // Declarations of the arrays in private memory to perform register tiling
688 PrivateArray threadResults = PrivateArray.createPrivate();
689 FlatPrivate regM = FlatPrivate.createPrivate();
690 FlatPrivate regN = FlatPrivate.createPrivate();
691
692 // initialize values
693 for (int i = 0; i < (TN * TN); i++) {
694 threadResults.array(i, 0.0f);
695 }
696
697 final int extraCols = 0;
698
699 // Each thread loops over the tiles
700 for (int bkIdx = 0; bkIdx < size; bkIdx += BK) {
701
702 Float4 loadA = matrixA.float4View((innerRowA * K + innerColA * 4) + aFrom);
703 tileA.array((innerColA * 4 + 0) * BM + innerRowA, loadA.x());
704 tileA.array((innerColA * 4 + 1) * BM + innerRowA, loadA.y());
705 tileA.array((innerColA * 4 + 2) * BM + innerRowA, loadA.z());
706 tileA.array((innerColA * 4 + 3) * BM + innerRowA, loadA.w());
707
708 Float4 loadB = matrixB.float4View((innerRowB * N + innerColB * 4) + bFrom);
709 tileB.array(innerRowB * (BN + extraCols) + innerColB * 4 + 0, loadB.x());
710 tileB.array(innerRowB * (BN + extraCols) + innerColB * 4 + 1, loadB.y());
711 tileB.array(innerRowB * (BN + extraCols) + innerColB * 4 + 2, loadB.z());
712 tileB.array(innerRowB * (BN + extraCols) + innerColB * 4 + 3, loadB.w());
713
714 kc.barrier();
715
716 aFrom += (BK);
717 int f = BK * size;
718 bFrom += f;
719
720 // Per-thread, we load the data from the shared memory into register for both
721 // array A and array B (matrix A and B), and then perform the reduction within
722 // the small region in private memory.
723 for (int dotIdx = 0; dotIdx < BK; dotIdx++) {
724 // block into registers
725 for (int i = 0; i < TM; i++) {
726 regM.array(i, tileA.array(dotIdx * BM + threadRow * TM + i));
727 }
728 for (int i = 0; i < TN; i++) {
729 regN.array(i, tileB.array(dotIdx * (BN + extraCols) + threadCol * TN + i));
730 }
731 for (int resIdxM = 0; resIdxM < TM; resIdxM++) {
732 for (int resIdxN = 0; resIdxN < TN; resIdxN++) {
733 float val = regM.array(resIdxM) * regN.array(resIdxN);
734 float acc = threadResults.array(resIdxM * TN + resIdxN);
735 acc += val;
736 threadResults.array((resIdxM * TN + resIdxN), (acc));
737 }
738 }
739 }
740 kc.barrier();
741 }
742
743 // Finally, we store the results of the reductions for the whole 2D register block into global memory.
744 // Essentially, each thread compute a small block of TM * TN sub-block size.
745 for (int resIdxM = 0; resIdxM < TM; resIdxM++) {
746 for (int resIdxN = 0; resIdxN < TN; resIdxN++) {
747 float value = threadResults.array(resIdxM * TN + resIdxN);
748 matrixC.array((((threadRow * TM + resIdxM) * N + threadCol * TN + resIdxN) + (cFrom)), value);
749 }
750 }
751 }
752
753 @Reflect
754 public static void matrixMultiply2DRegisterTiling(@RO ComputeContext cc, @RO F32Array matrixA, @RO F32Array matrixB, @RW F32Array matrixC, final int size) {
755 cc.dispatchKernel(NDRange.of2D(256, 256,16, 16),
756 kc -> matrixMultiplyKernel2DRegisterTiling(kc, matrixA, matrixB, matrixC, size)
757 );
758 }
759
760 @Reflect
761 public static void matrixMultiply2DRegisterTilingVectorized(@RO ComputeContext cc, @RO F32ArrayPadded matrixA, @RO F32ArrayPadded matrixB, @RW F32ArrayPadded matrixC, final int size) {
762 cc.dispatchKernel(NDRange.of2D(256, 256,16, 16),
763 kc -> matrixMultiplyKernel2DRegisterTilingVectorized(kc, matrixA, matrixB, matrixC, size)
764 );
765 }
766
767 @HatTest
768 @Reflect
769 public void testMatMul2DRegisterTiling() {
770 var lookup = MethodHandles.lookup();
771 var accelerator = new Accelerator(lookup, Backend.FIRST);
772
773 final int size = 1024;
774 var matrixA = F32Array.create(accelerator, size * size);
775 var matrixB = F32Array.create(accelerator, size * size);
776
777 // Matrix for the results
778 var matrixC = F32Array.create(accelerator, size * size);
779 var resultSeq = F32Array.create(accelerator, size * size);
780
781 // Initialize matrices (A and B have the same size)
782 Random r = new Random(19);
783
784 for (int j = 0; j < matrixA.length(); j++) {
785 matrixA.array(j, r.nextFloat());
786 matrixB.array(j, r.nextFloat());
787 }
788
789 accelerator.compute(cc ->
790 TestMatMul.matrixMultiply2DRegisterTiling(cc, matrixA, matrixB, matrixC, size));
791
792 // Run Seq for reference
793 runSequential(matrixA, matrixB, resultSeq, size);
794
795 for (int j = 0; j < size; j++) {
796 for (int i = 0; i < size; i++) {
797 HATAsserts.assertEquals(resultSeq.array(i * size + j), matrixC.array(i * size + j), 0.01f);
798 }
799 }
800 }
801
802 @HatTest
803 @Reflect
804 public void testMatMul2DRegisterTilingVectorized() {
805 var lookup = MethodHandles.lookup();
806 var accelerator = new Accelerator(lookup, Backend.FIRST);
807
808 final int size = 1024;
809 var matrixA = F32ArrayPadded.create(accelerator, size * size);
810 var matrixB = F32ArrayPadded.create(accelerator, size * size);
811
812 // Matrix for the results
813 var matrixC = F32ArrayPadded.create(accelerator, size * size);
814 var resultSeq = F32ArrayPadded.create(accelerator, size * size);
815
816 // Initialize matrices (A and B have the same size)
817 Random r = new Random(19);
818
819 for (int j = 0; j < matrixA.length(); j++) {
820 matrixA.array(j, r.nextFloat());
821 matrixB.array(j, r.nextFloat());
822 }
823
824 accelerator.compute(cc ->
825 TestMatMul.matrixMultiply2DRegisterTilingVectorized(cc, matrixA, matrixB, matrixC, size));
826
827 // Run Seq for reference
828 runSequential(matrixA, matrixB, resultSeq, size);
829
830 for (int j = 0; j < size; j++) {
831 for (int i = 0; i < size; i++) {
832 HATAsserts.assertEquals(resultSeq.array(i * size + j), matrixC.array(i * size + j), 0.01f);
833 }
834 }
835 }
836
837 private interface SharedMemoryHalf extends NonMappableIface {
838 F16 array(int index);
839
840 DeviceSchema<SharedMemoryHalf> schema = DeviceSchema.of(SharedMemoryHalf.class,
841 arr -> arr.withArray("array", 1024)
842 .withDeps(F16.class, half -> half.withField("value")));
843
844 static SharedMemoryHalf create(Accelerator accelerator) {
845 return null;
846 }
847
848 static SharedMemoryHalf createLocal() {
849 return null;
850 }
851 }
852
853 private interface PrivateArrayHalf extends NonMappableIface {
854 F16 array(int index);
855
856 DeviceSchema<PrivateArrayHalf> schema = DeviceSchema.of(PrivateArrayHalf.class,
857 arr -> arr.withArray("array", 16)
858 .withDeps(F16.class, half -> half.withField("value")));
859
860 static PrivateArrayHalf create(Accelerator accelerator) {
861 return null;
862 }
863
864 static PrivateArrayHalf createPrivate() {
865 return null;
866 }
867 }
868
869 private interface FlatPrivateHalf extends NonMappableIface {
870 F16 array(int index);
871
872 DeviceSchema<FlatPrivateHalf> schema = DeviceSchema.of(FlatPrivateHalf.class,
873 arr -> arr.withArray("array", 4)
874 .withDeps(F16.class, half -> half.withField("value")));
875
876 static FlatPrivateHalf create(Accelerator accelerator) {
877 return null;
878 }
879
880 static FlatPrivateHalf createPrivate() {
881 return null;
882 }
883 }
884
885 // Taking from the HAT Examples module
886 @Reflect
887 public static void matrixMultiplyKernel2DRegisterTilingHalf(@RO KernelContext kc, @RO F16Array matrixA, @RO F16Array matrixB, @RW F16Array matrixC, int size) {
888 final int BM = 64;
889 final int BN = 64;
890 final int BK = 16;
891 final int TM = 4;
892 final int TN = 4;
893
894 int bx = kc.bix;
895 int by = kc.biy;
896
897 int totalResultsBlockTile = BM * BN;
898 final int numThreadsBlockTile = totalResultsBlockTile / (TM * TN);
899
900 final int linearLocalId = kc.liy * kc.lsx + kc.lix;
901 final int threadCol = kc.lix;
902 final int threadRow = kc.liy;
903
904 SharedMemoryHalf tileA = SharedMemoryHalf.createLocal();
905 SharedMemoryHalf tileB = SharedMemoryHalf.createLocal();
906
907 int aFrom = by * BM * size;
908 int bFrom = bx * BN;
909 int v = bx * BN;
910 int cFrom = (by * BM * size) + (v);
911
912 final int innerRowA = linearLocalId / BK;
913 final int innerColA = linearLocalId % BK;
914
915 final int strideA = numThreadsBlockTile / BK;
916 final int innerRowB = linearLocalId / BN;
917 final int innerColB = linearLocalId % BN;
918
919 int strideB = numThreadsBlockTile / BN;
920
921 PrivateArrayHalf threadResults = PrivateArrayHalf.createPrivate();
922 FlatPrivateHalf regM = FlatPrivateHalf.createPrivate();
923 FlatPrivateHalf regN = FlatPrivateHalf.createPrivate();
924
925 for (int i = 0; i < (TN * TN); i++) {
926 F16 init = F16.of(0.0f);
927 threadResults.array(i).value(init.value());
928 }
929
930 for (int bkIdx = 0; bkIdx < size; bkIdx += BK) {
931 for (int loadOffset = 0; loadOffset < BM; loadOffset += strideA) {
932 F16 ha = matrixA.array(((innerRowA + loadOffset) * size + innerColA) + aFrom);
933 tileA.array((innerRowA + loadOffset) * BK + innerColA).value(ha.value());
934 }
935 for (int loadOffset = 0; loadOffset < BK; loadOffset += strideB) {
936 F16 hb = matrixB.array(((innerRowB + loadOffset) * size + innerColB) + bFrom);
937 tileB.array((innerRowB + loadOffset) * BN + innerColB).value(hb.value());
938 }
939 kc.barrier();
940
941 aFrom += (BK);
942 int f = BK * size;
943 bFrom += f;
944
945 for (int dotIdx = 0; dotIdx < BK; dotIdx++) {
946 for (int i = 0; i < TM; i++) {
947 F16 ha = tileA.array((threadRow * TM + i) * BK + dotIdx);
948 regM.array(i).value(ha.value());
949 }
950 for (int i = 0; i < TN; i++) {
951 F16 hb = tileB.array(dotIdx * BN + threadCol * TN + i);
952 regN.array(i).value(hb.value());
953 }
954 for (int resIdxM = 0; resIdxM < TM; resIdxM++) {
955 for (int resIdxN = 0; resIdxN < TN; resIdxN++) {
956 F16 privA = regM.array(resIdxM);
957 F16 privB = regN.array(resIdxN);
958 F16 mul = F16.mul(privA, privB);
959 F16 acc = threadResults.array(resIdxM * TN + resIdxN);
960 acc = F16.add(acc, mul);
961 threadResults.array((resIdxM * TN + resIdxN)).value(acc.value());
962 }
963 }
964 }
965 kc.barrier();
966 }
967 for (int resIdxM = 0; resIdxM < TM; resIdxM++) {
968 for (int resIdxN = 0; resIdxN < TN; resIdxN++) {
969 F16 result = threadResults.array(resIdxM * TN + resIdxN);
970 matrixC.array((((threadRow * TM + resIdxM) * size + threadCol * TN + resIdxN) + (cFrom))).value(result.value());
971 }
972 }
973 }
974
975 private interface SharedMemoryBfloat16 extends NonMappableIface {
976 BF16 array(int index);
977
978 DeviceSchema<SharedMemoryBfloat16> schema = DeviceSchema.of(SharedMemoryBfloat16.class,
979 arr -> arr.withArray("array", 1024)
980 .withDeps(BF16.class, half -> half.withField("value")));
981
982 static SharedMemoryBfloat16 create(Accelerator accelerator) {
983 return null;
984 }
985
986 static SharedMemoryBfloat16 createLocal() {
987 return null;
988 }
989 }
990
991 private interface PrivateArrayBfloat16 extends NonMappableIface {
992 BF16 array(int index);
993
994 DeviceSchema<PrivateArrayBfloat16> schema = DeviceSchema.of(PrivateArrayBfloat16.class,
995 arr -> arr.withArray("array", 16)
996 .withDeps(BF16.class, half -> half.withField("value")));
997
998 static PrivateArrayBfloat16 create(Accelerator accelerator) {
999 return null;
1000 }
1001
1002 static PrivateArrayBfloat16 createPrivate() {
1003 return null;
1004 }
1005 }
1006
1007 private interface FlatPrivateBfloat16 extends NonMappableIface {
1008 BF16 array(int index);
1009
1010 DeviceSchema<FlatPrivateBfloat16> schema = DeviceSchema.of(FlatPrivateBfloat16.class,
1011 arr -> arr.withArray("array", 4)
1012 .withDeps(BF16.class, half -> half.withField("value")));
1013
1014 static FlatPrivateBfloat16 create(Accelerator accelerator) {
1015 return null;
1016 }
1017
1018 static FlatPrivateBfloat16 createPrivate() {
1019 return null;
1020 }
1021 }
1022
1023 @Reflect
1024 public static void matrixMultiplyKernel2DRegisterTilingBFloat16(@RO KernelContext kc, @RO BF16Array matrixA, @RO BF16Array matrixB, @RW BF16Array matrixC, int size) {
1025 final int BM = 64;
1026 final int BN = 64;
1027 final int BK = 16;
1028 final int TM = 4;
1029 final int TN = 4;
1030
1031 int bx = kc.bix;
1032 int by = kc.biy;
1033
1034 int totalResultsBlockTile = BM * BN;
1035 final int numThreadsBlockTile = totalResultsBlockTile / (TM * TN);
1036
1037 final int linearLocalId = kc.liy * kc.lsx + kc.lix;
1038 final int threadCol = kc.lix;
1039 final int threadRow = kc.liy;
1040
1041 SharedMemoryBfloat16 tileA = SharedMemoryBfloat16.createLocal();
1042 SharedMemoryBfloat16 tileB = SharedMemoryBfloat16.createLocal();
1043
1044 int aFrom = by * BM * size;
1045 int bFrom = bx * BN;
1046 int v = bx * BN;
1047 int cFrom = (by * BM * size) + (v);
1048
1049 final int innerRowA = linearLocalId / BK;
1050 final int innerColA = linearLocalId % BK;
1051
1052 final int strideA = numThreadsBlockTile / BK;
1053 final int innerRowB = linearLocalId / BN;
1054 final int innerColB = linearLocalId % BN;
1055
1056 int strideB = numThreadsBlockTile / BN;
1057
1058 PrivateArrayBfloat16 threadResults = PrivateArrayBfloat16.createPrivate();
1059 FlatPrivateBfloat16 regM = FlatPrivateBfloat16.createPrivate();
1060 FlatPrivateBfloat16 regN = FlatPrivateBfloat16.createPrivate();
1061
1062 for (int i = 0; i < (TN * TN); i++) {
1063 BF16 init = BF16.of(0.0f);
1064 threadResults.array(i).value(init.value());
1065 }
1066
1067 for (int bkIdx = 0; bkIdx < size; bkIdx += BK) {
1068 for (int loadOffset = 0; loadOffset < BM; loadOffset += strideA) {
1069 BF16 ha = matrixA.array(((innerRowA + loadOffset) * size + innerColA) + aFrom);
1070 tileA.array((innerRowA + loadOffset) * BK + innerColA).value(ha.value());
1071 }
1072 for (int loadOffset = 0; loadOffset < BK; loadOffset += strideB) {
1073 BF16 hb = matrixB.array(((innerRowB + loadOffset) * size + innerColB) + bFrom);
1074 tileB.array((innerRowB + loadOffset) * BN + innerColB).value(hb.value());
1075 }
1076 kc.barrier();
1077
1078 aFrom += (BK);
1079 int f = BK * size;
1080 bFrom += f;
1081
1082 for (int dotIdx = 0; dotIdx < BK; dotIdx++) {
1083 for (int i = 0; i < TM; i++) {
1084 BF16 ha = tileA.array((threadRow * TM + i) * BK + dotIdx);
1085 regM.array(i).value(ha.value());
1086 }
1087 for (int i = 0; i < TN; i++) {
1088 BF16 hb = tileB.array(dotIdx * BN + threadCol * TN + i);
1089 regN.array(i).value(hb.value());
1090 }
1091 for (int resIdxM = 0; resIdxM < TM; resIdxM++) {
1092 for (int resIdxN = 0; resIdxN < TN; resIdxN++) {
1093 BF16 privA = regM.array(resIdxM);
1094 BF16 privB = regN.array(resIdxN);
1095 BF16 mul = BF16.mul(privA, privB);
1096 BF16 acc = threadResults.array(resIdxM * TN + resIdxN);
1097 acc = BF16.add(acc, mul);
1098 threadResults.array((resIdxM * TN + resIdxN)).value(acc.value());
1099 }
1100 }
1101 }
1102 kc.barrier();
1103 }
1104 for (int resIdxM = 0; resIdxM < TM; resIdxM++) {
1105 for (int resIdxN = 0; resIdxN < TN; resIdxN++) {
1106 BF16 result = threadResults.array(resIdxM * TN + resIdxN);
1107 matrixC.array((((threadRow * TM + resIdxM) * size + threadCol * TN + resIdxN) + (cFrom))).value(result.value());
1108 }
1109 }
1110 }
1111
1112 @Reflect
1113 public static void matrixMultiply2DRegisterTilingHalf(@RO ComputeContext cc, @RO F16Array matrixA, @RO F16Array matrixB, @RW F16Array matrixC, int globalSize) {
1114 cc.dispatchKernel(NDRange.of2D(256, 256,16, 16),
1115 kc -> matrixMultiplyKernel2DRegisterTilingHalf(kc, matrixA, matrixB, matrixC, globalSize)
1116 );
1117 }
1118
1119 @Reflect
1120 public static void matrixMultiply2DRegisterTilingBFloat16(@RO ComputeContext cc, @RO BF16Array matrixA, @RO BF16Array matrixB, @RW BF16Array matrixC, int globalSize) {
1121 cc.dispatchKernel(NDRange.of2D(256, 256,16, 16),
1122 kc -> matrixMultiplyKernel2DRegisterTilingBFloat16(kc, matrixA, matrixB, matrixC, globalSize)
1123 );
1124 }
1125
1126 @HatTest
1127 @Reflect
1128 public void matrixMultiply2DRegisterTilingHalf() {
1129 var lookup = MethodHandles.lookup();
1130 var accelerator = new Accelerator(lookup, Backend.FIRST);
1131
1132 final int size = 1024;
1133 var matrixA = F16Array.create(accelerator, size * size);
1134 var matrixB = F16Array.create(accelerator, size * size);
1135
1136 // Matrix for the results
1137 var matrixC = F16Array.create(accelerator, size * size);
1138 var resultSeq = F16Array.create(accelerator, size * size);
1139
1140 // Initialize matrices (A and B have the same size)
1141 Random r = new Random(19);
1142 for (int j = 0; j < matrixA.length(); j++) {
1143 matrixA.array(j).value(F16.floatToF16(r.nextFloat()).value());
1144 matrixB.array(j).value(F16.floatToF16(r.nextFloat()).value());
1145 }
1146
1147 accelerator.compute(cc ->
1148 TestMatMul.matrixMultiply2DRegisterTilingHalf(cc, matrixA, matrixB, matrixC, size));
1149
1150 // Run Seq for reference
1151 runSequential(matrixA, matrixB, resultSeq, size);
1152
1153 for (int i = 0; i < size; i++) {
1154 for (int j = 0; j < size; j++) {
1155 try {
1156 HATAsserts.assertEquals(F16.f16ToFloat(resultSeq.array(i * size + j)),
1157 F16.f16ToFloat(matrixC.array(i * size + j)),
1158 0.01f);
1159 } catch (HATAssertionError hatAssertionError) {
1160 throw new HATExpectedPrecisionError(hatAssertionError.getMessage());
1161 }
1162 }
1163 }
1164 }
1165
1166 @HatTest
1167 @Reflect
1168 public void matrixMultiply2DRegisterTilingBFloat16() {
1169 var lookup = MethodHandles.lookup();
1170 var accelerator = new Accelerator(lookup, Backend.FIRST);
1171
1172 final int size = 1024;
1173 var matrixA = BF16Array.create(accelerator, size * size);
1174 var matrixB = BF16Array.create(accelerator, size * size);
1175
1176 // Matrix for the results
1177 var matrixC = BF16Array.create(accelerator, size * size);
1178 var resultSeq = BF16Array.create(accelerator, size * size);
1179
1180 // Initialize matrices (A and B have the same size)
1181 Random r = new Random(19);
1182 for (int j = 0; j < matrixA.length(); j++) {
1183 matrixA.array(j).value(BF16.float2bfloat16(r.nextFloat()).value());
1184 matrixB.array(j).value(BF16.float2bfloat16(r.nextFloat()).value());
1185 }
1186
1187 accelerator.compute(cc ->
1188 TestMatMul.matrixMultiply2DRegisterTilingBFloat16(cc, matrixA, matrixB, matrixC, size));
1189
1190 // Run Seq for reference
1191 runSequential(matrixA, matrixB, resultSeq, size);
1192
1193 for (int i = 0; i < size; i++) {
1194 for (int j = 0; j < size; j++) {
1195 try {
1196 HATAsserts.assertEquals(BF16.bfloat162float(resultSeq.array(i * size + j)),
1197 BF16.bfloat162float(matrixC.array(i * size + j)),
1198 0.01f);
1199 } catch (HATAssertionError hatAssertionError) {
1200 throw new HATExpectedPrecisionError(hatAssertionError.getMessage());
1201 }
1202 }
1203 }
1204 }
1205 }