1 /*
   2  * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved.
   3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
   4  *
   5  * This code is free software; you can redistribute it and/or modify it
   6  * under the terms of the GNU General Public License version 2 only, as
   7  * published by the Free Software Foundation.  Oracle designates this
   8  * particular file as subject to the "Classpath" exception as provided
   9  * by Oracle in the LICENSE file that accompanied this code.
  10  *
  11  * This code is distributed in the hope that it will be useful, but WITHOUT
  12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
  13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
  14  * version 2 for more details (a copy is included in the LICENSE file that
  15  * accompanied this code).
  16  *
  17  * You should have received a copy of the GNU General Public License version
  18  * 2 along with this work; if not, write to the Free Software Foundation,
  19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
  20  *
  21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
  22  * or visit www.oracle.com if you need additional information or have any
  23  * questions.
  24  */
  25 package hat.test;
  26 
  27 import hat.Accelerator;
  28 import hat.ComputeContext;
  29 import hat.NDRange;
  30 import hat.KernelContext;
  31 import hat.backend.Backend;
  32 import hat.types.BF16;
  33 import hat.buffer.BF16Array;
  34 import hat.types.F16;
  35 import hat.buffer.F16Array;
  36 import hat.buffer.F32Array;
  37 import hat.buffer.F32ArrayPadded;
  38 import hat.types.Float4;
  39 import hat.device.DeviceSchema;
  40 import hat.device.NonMappableIface;
  41 import hat.test.annotation.HatTest;
  42 import hat.test.exceptions.HATAssertionError;
  43 import hat.test.exceptions.HATAsserts;
  44 import hat.test.exceptions.HATExpectedPrecisionError;
  45 import jdk.incubator.code.Reflect;
  46 
  47 import java.lang.invoke.MethodHandles;
  48 import java.util.Random;
  49 
  50 import static optkl.ifacemapper.MappableIface.RO;
  51 import static optkl.ifacemapper.MappableIface.RW;
  52 
  53 public class TestMatMul {
  54 
  55     private static final int SIZE = 256;
  56 
  57     @Reflect
  58     public static void matrixMultiplyKernel2D(@RO KernelContext kc, @RO F32Array matrixA, @RO F32Array matrixB, @RW F32Array matrixC, int size) {
  59         if (kc.gix < kc.gsx) {
  60             if (kc.giy < kc.gsy) {
  61                 float acc = 0.0f;
  62                 for (int k = 0; k < size; k++) {
  63                     acc += (matrixA.array(kc.gix * size + k) * matrixB.array(k * size + kc.giy));
  64                 }
  65                 matrixC.array(kc.gix * size + kc.giy, acc);
  66             }
  67         }
  68     }
  69 
  70     @Reflect
  71     public static void matrixMultiplyKernel2DLI(@RO KernelContext kc, @RO F32Array matrixA, @RO F32Array matrixB, @RW F32Array matrixC, int size) {
  72         if (kc.gix < kc.gsx) {
  73             if (kc.giy < kc.gsy) {
  74                 float acc = 0.0f;
  75                 for (int k = 0; k < size; k++) {
  76                     acc += (matrixA.array(kc.giy * size + k) * matrixB.array(k * size + kc.gix));
  77                 }
  78                 matrixC.array(kc.giy * size + kc.gix, acc);
  79             }
  80         }
  81     }
  82 
  83     @Reflect
  84     public static void matrixMultiplyKernel2DLIF16(@RO KernelContext kc, @RO F16Array matrixA, @RO F16Array matrixB, @RW F16Array matrixC, int size) {
  85         if (kc.gix < kc.gsx) {
  86             if (kc.giy < kc.gsy) {
  87                 F16 acc = F16.of(0.0f);
  88                 for (int k = 0; k < size; k++) {
  89                     F16 valA = matrixA.array(kc.giy * size + k);
  90                     F16 valB = matrixB.array(k * size + kc.gix);
  91                     F16 valc = F16.mul(valA, valB);
  92                     acc = F16.add(acc, valc);
  93                 }
  94                 F16 resultC = matrixC.array(kc.giy * size + kc.gix);
  95                 resultC.value(acc.value());
  96             }
  97         }
  98     }
  99 
 100     private interface MyLocalArrayFixedSize extends NonMappableIface {
 101         void array(long index, float value);
 102         float array(long index);
 103 
 104         DeviceSchema<MyLocalArrayFixedSize> schema = DeviceSchema.of(MyLocalArrayFixedSize.class,
 105                 myPrivateArray -> myPrivateArray
 106                         .withArray("array", 256));
 107 
 108         static MyLocalArrayFixedSize create(Accelerator accelerator) {
 109             return null;
 110         }
 111 
 112         static MyLocalArrayFixedSize createLocal() {
 113             return null;
 114         }
 115     }
 116 
 117     @Reflect
 118     public static void matrixMultiplyKernel2DTiling(@RO KernelContext kc, @RO F32Array matrixA, @RO F32Array matrixB, @RW F32Array matrixC, int size) {
 119 
 120         final int tileSize = 16;
 121         MyLocalArrayFixedSize tileA = MyLocalArrayFixedSize.createLocal();
 122         MyLocalArrayFixedSize tileB = MyLocalArrayFixedSize.createLocal();
 123 
 124         int groupIndexX = kc.bix;
 125         int groupIndexY = kc.biy;
 126         int localIdx = kc.lix;
 127         int localIdy = kc.liy;
 128 
 129         // we identify the row and column
 130         int row = groupIndexY * tileSize + localIdy;
 131         int col = groupIndexX * tileSize + localIdx;
 132 
 133         // Compute matrix-vector and accumulate the result over the tiles
 134         float sum = 0.0f;
 135         for (int tile = 0; tile < (size/tileSize); tile++) {
 136             // Copy from global to shared memory
 137             tileA.array((long) localIdy * tileSize + localIdx, matrixA.array((long) row * size + tile * tileSize + localIdx));
 138             tileB.array((long) localIdy * tileSize + localIdx, matrixB.array((tile * tileSize + localIdy) * size + col));
 139 
 140             // Apply a barrier for the local group: we need to guarantee that all threads that belong
 141             // to the same group reach this point before doing the partial reduction
 142             kc.barrier();
 143 
 144             // compute partial reductions over the tile
 145             for (int k = 0; k < tileSize; k++) {
 146                 sum += (tileA.array((long) localIdy * tileSize + k) * tileB.array(k * tileSize + localIdx));
 147             }
 148 
 149             // A new local barrier for all threads that belong to the same group before loading a new tile into
 150             // share memory. With the following barrier, we can ensure that all threads within the same workgroup
 151             // finished the compute for the partial reduction
 152             kc.barrier();
 153         }
 154 
 155         // copy result from shared memory to global memory
 156         matrixC.array((long) row * size + col, sum);
 157     }
 158 
 159     @Reflect
 160     public static float compute(@RO KernelContext kc, @RO F32Array matrixA, @RO F32Array matrixB, int size, int j) {
 161         float acc = 0.0f;
 162         for (int k = 0; k < size; k++) {
 163             acc += (matrixA.array(kc.gix * size + k) * matrixB.array(k * size + j));
 164         }
 165         return acc;
 166     }
 167 
 168     @Reflect
 169     public static void matrixMultiplyKernel1D(@RO KernelContext kc, @RO F32Array matrixA, @RO F32Array matrixB, @RW F32Array matrixC, int size) {
 170         if (kc.gix < kc.gsx) {
 171             for (int j = 0; j < size; j++) {
 172                 float acc = 0.0f;
 173                 for (int k = 0; k < size; k++) {
 174                     acc += (matrixA.array(kc.gix * size + k) * matrixB.array(k * size + j));
 175                 }
 176                 matrixC.array(kc.gix * size + j, acc);
 177             }
 178         }
 179     }
 180 
 181     @Reflect
 182     public static void matrixMultiplyKernel1DWithFunctionCalls(@RO KernelContext kc, @RO F32Array matrixA, @RO F32Array matrixB, @RW F32Array matrixC, int size) {
 183         if (kc.gix < kc.gsx) {
 184             for (int j = 0; j < size; j++) {
 185                 float acc = compute(kc, matrixA, matrixB, size, j);
 186                 matrixC.array(kc.gix * size + j, acc);
 187             }
 188         }
 189     }
 190 
 191     @Reflect
 192     public static void matrixMultiply1D(@RO ComputeContext cc, @RO F32Array matrixA, @RO F32Array matrixB, @RW F32Array matrixC, int globalSize) {
 193         cc.dispatchKernel(NDRange.of1D(globalSize,16),
 194                 kc -> matrixMultiplyKernel1D(kc, matrixA, matrixB, matrixC, globalSize)
 195         );
 196     }
 197 
 198     final static int BLOCK_SIZE = 16;
 199 
 200     @Reflect
 201     public static void matrixMultiply1DWithFunctionCalls(@RO ComputeContext cc, @RO F32Array matrixA, @RO F32Array matrixB, @RW F32Array matrixC, int size) {
 202         cc.dispatchKernel(NDRange.of1D(size),
 203                 kc -> matrixMultiplyKernel1DWithFunctionCalls(kc, matrixA, matrixB, matrixC, size)
 204         );
 205     }
 206 
 207     @Reflect
 208     public static void matrixMultiply2D(@RO ComputeContext cc, @RO F32Array matrixA, @RO F32Array matrixB, @RW F32Array matrixC, int globalSize) {
 209         cc.dispatchKernel(NDRange.of2D(globalSize, globalSize,BLOCK_SIZE, BLOCK_SIZE),
 210                 kc -> matrixMultiplyKernel2D(kc, matrixA, matrixB, matrixC, globalSize)
 211         );
 212     }
 213 
 214     @Reflect
 215     public static void matrixMultiply2DLI(@RO ComputeContext cc, @RO F32Array matrixA, @RO F32Array matrixB, @RW F32Array matrixC, int globalSize) {
 216           cc.dispatchKernel(NDRange.of2D(globalSize, globalSize,BLOCK_SIZE, BLOCK_SIZE),
 217                 kc -> matrixMultiplyKernel2DLI(kc, matrixA, matrixB, matrixC, globalSize)
 218         );
 219     }
 220 
 221     @Reflect
 222     public static void matrixMultiply2DLIF16(@RO ComputeContext cc, @RO F16Array matrixA, @RO F16Array matrixB, @RW F16Array matrixC, int globalSize) {
 223           cc.dispatchKernel(NDRange.of2D(globalSize, globalSize,BLOCK_SIZE, BLOCK_SIZE),
 224                 kc -> matrixMultiplyKernel2DLIF16(kc, matrixA, matrixB, matrixC, globalSize)
 225         );
 226     }
 227 
 228     @Reflect
 229     public static void matrixMultiply2DTiling(@RO ComputeContext cc, @RO F32Array matrixA, @RO F32Array matrixB, @RW F32Array matrixC, int globalSize) {
 230           cc.dispatchKernel(NDRange.of2D(globalSize, globalSize,BLOCK_SIZE, BLOCK_SIZE),
 231                 kc -> matrixMultiplyKernel2DTiling(kc, matrixA, matrixB, matrixC, globalSize)
 232         );
 233     }
 234 
 235     private static void runSequential(F16Array matrixA, F16Array matrixB, F16Array matrixC, final int size) {
 236         for (int i = 0; i < size; i++) {
 237             for (int j = 0; j < size; j++) {
 238                 F16 sum = F16.of(0.0f);
 239                 for (int k = 0; k < size; k++) {
 240                     F16 a = matrixA.array((long) i * size + k);
 241                     F16 b = matrixB.array((long) k * size + j);
 242                     sum = F16.add(sum, F16.mul(a, b));
 243                 }
 244                 matrixC.array((long) i * size + j).value(sum.value());
 245             }
 246         }
 247     }
 248 
 249     private static void runSequential(F32Array matrixA, F32Array matrixB, F32Array matrixC, final int size) {
 250         for (int i = 0; i < size; i++) {
 251             for (int j = 0; j < size; j++) {
 252                 float sum = 0;
 253                 for (int k = 0; k < size; k++) {
 254                     float a = matrixA.array((long) i * size + k);
 255                     float b = matrixB.array((long) k * size + j);
 256                     sum += a * b;
 257                 }
 258                 matrixC.array((long) i * size + j, sum);
 259             }
 260         }
 261     }
 262 
 263     private static void runSequential(F32ArrayPadded matrixA, F32ArrayPadded matrixB, F32ArrayPadded matrixC, final int size) {
 264         for (int i = 0; i < size; i++) {
 265             for (int j = 0; j < size; j++) {
 266                 float sum = 0;
 267                 for (int k = 0; k < size; k++) {
 268                     float a = matrixA.array((long) i * size + k);
 269                     float b = matrixB.array((long) k * size + j);
 270                     sum += a * b;
 271                 }
 272                 matrixC.array((long) i * size + j, sum);
 273             }
 274         }
 275     }
 276 
 277     private static void runSequential(BF16Array matrixA, BF16Array matrixB, BF16Array matrixC, final int size) {
 278         for (int i = 0; i < size; i++) {
 279             for (int j = 0; j < size; j++) {
 280                 BF16 sum = BF16.of(0.0f);
 281                 for (int k = 0; k < size; k++) {
 282                     BF16 a = matrixA.array((long) i * size + k);
 283                     BF16 b = matrixB.array((long) k * size + j);
 284                     sum = BF16.add(sum, BF16.mul(a, b));
 285                 }
 286                 matrixC.array((long) i * size + j).value(sum.value());
 287             }
 288         }
 289     }
 290 
 291     @HatTest
 292     @Reflect
 293     public void testMatrixMultiply1D() {
 294         var lookup = MethodHandles.lookup();
 295         var accelerator = new Accelerator(lookup, Backend.FIRST);
 296 
 297         final int size = SIZE;
 298         var matrixA = F32Array.create(accelerator, size * size);
 299         var matrixB = F32Array.create(accelerator, size * size);
 300 
 301         // Matrix for the results
 302         var matrixC = F32Array.create(accelerator, size * size);
 303         var resultSeq = F32Array.create(accelerator, size * size);
 304 
 305         // Initialize matrices (A and B have the same size)
 306         Random r = new Random(19);
 307 
 308         for (int j = 0; j < matrixA.length(); j++) {
 309             matrixA.array(j, r.nextFloat());
 310             matrixB.array(j, r.nextFloat());
 311         }
 312 
 313         accelerator.compute(cc ->
 314                 TestMatMul.matrixMultiply1D(cc, matrixA, matrixB, matrixC, size));
 315 
 316         // Run Seq for reference
 317         runSequential(matrixA, matrixB, resultSeq, size);
 318 
 319         for (int j = 0; j < size; j++) {
 320             for (int i = 0; i < size; i++) {
 321                 HATAsserts.assertEquals(resultSeq.array(i * size + j), matrixC.array(i * size + j), 0.01f);
 322             }
 323         }
 324     }
 325 
 326     @HatTest
 327     @Reflect
 328     public void testMatrixMultiply1DWithFunctionCalls() {
 329         var lookup = MethodHandles.lookup();
 330         var accelerator = new Accelerator(lookup, Backend.FIRST);
 331 
 332         final int size = SIZE;
 333         var matrixA = F32Array.create(accelerator, size * size);
 334         var matrixB = F32Array.create(accelerator, size * size);
 335 
 336         // Matrix for the results
 337         var matrixC = F32Array.create(accelerator, size * size);
 338         var resultSeq = F32Array.create(accelerator, size * size);
 339 
 340         // Initialize matrices (A and B have the same size)
 341         Random r = new Random(19);
 342 
 343         for (int j = 0; j < matrixA.length(); j++) {
 344             matrixA.array(j, r.nextFloat());
 345             matrixB.array(j, r.nextFloat());
 346         }
 347 
 348         accelerator.compute(cc ->
 349                 TestMatMul.matrixMultiply1DWithFunctionCalls(cc, matrixA, matrixB, matrixC, size));
 350 
 351         // Run Seq for reference
 352         runSequential(matrixA, matrixB, resultSeq, size);
 353 
 354         for (int j = 0; j < size; j++) {
 355             for (int i = 0; i < size; i++) {
 356                 HATAsserts.assertEquals(resultSeq.array(i * size + j), matrixC.array(i * size + j), 0.01f);
 357             }
 358         }
 359     }
 360 
 361 
 362     @HatTest
 363     @Reflect
 364     public void testMatrixMultiply2D() {
 365         var lookup = MethodHandles.lookup();
 366         var accelerator = new Accelerator(lookup, Backend.FIRST);
 367 
 368         final int size = SIZE;
 369         var matrixA = F32Array.create(accelerator, size * size);
 370         var matrixB = F32Array.create(accelerator, size * size);
 371 
 372         // Matrix for the results
 373         var matrixC = F32Array.create(accelerator, size * size);
 374         var resultSeq = F32Array.create(accelerator, size * size);
 375 
 376         // Initialize matrices (A and B have the same size)
 377         Random r = new Random(19);
 378 
 379         for (int j = 0; j < matrixA.length(); j++) {
 380             matrixA.array(j, r.nextFloat());
 381             matrixB.array(j, r.nextFloat());
 382         }
 383 
 384         accelerator.compute(cc ->
 385                 TestMatMul.matrixMultiply2D(cc, matrixA, matrixB, matrixC, size));
 386 
 387         // Run Seq for reference
 388         runSequential(matrixA, matrixB, resultSeq, size);
 389 
 390         for (int j = 0; j < size; j++) {
 391             for (int i = 0; i < size; i++) {
 392                 HATAsserts.assertEquals(resultSeq.array(i * size + j), matrixC.array(i * size + j), 0.01f);
 393             }
 394         }
 395     }
 396 
 397     @HatTest
 398     @Reflect
 399     public void testMatrixMultiply2DLI() {
 400         var lookup = MethodHandles.lookup();
 401         var accelerator = new Accelerator(lookup, Backend.FIRST);
 402 
 403         final int size = SIZE;
 404         var matrixA = F32Array.create(accelerator, size * size);
 405         var matrixB = F32Array.create(accelerator, size * size);
 406 
 407         // Matrix for the results
 408         var matrixC = F32Array.create(accelerator, size * size);
 409         var resultSeq = F32Array.create(accelerator, size * size);
 410 
 411         // Initialize matrices (A and B have the same size)
 412         Random r = new Random(19);
 413 
 414         for (int j = 0; j < matrixA.length(); j++) {
 415             matrixA.array(j, r.nextFloat());
 416             matrixB.array(j, r.nextFloat());
 417         }
 418 
 419         accelerator.compute(cc ->
 420                 TestMatMul.matrixMultiply2DLI(cc, matrixA, matrixB, matrixC, size));
 421 
 422         // Run Seq for reference
 423         runSequential(matrixA, matrixB, resultSeq, size);
 424 
 425         for (int j = 0; j < size; j++) {
 426             for (int i = 0; i < size; i++) {
 427                 HATAsserts.assertEquals(resultSeq.array(i * size + j), matrixC.array(i * size + j), 0.01f);
 428             }
 429         }
 430     }
 431 
 432     @HatTest
 433     @Reflect
 434     public void testMatrixMultiply2DLIF16() {
 435         var lookup = MethodHandles.lookup();
 436         var accelerator = new Accelerator(lookup, Backend.FIRST);
 437 
 438         final int size = SIZE;
 439         var matrixA = F16Array.create(accelerator, size * size);
 440         var matrixB = F16Array.create(accelerator, size * size);
 441 
 442         // Matrix for the results
 443         var matrixC = F16Array.create(accelerator, size * size);
 444         var resultSeq = F16Array.create(accelerator, size * size);
 445 
 446         // Initialize matrices (A and B have the same size)
 447         Random r = new Random(19);
 448 
 449         for (int j = 0; j < matrixA.length(); j++) {
 450             matrixA.array(j).value(F16.floatToF16(r.nextFloat()).value());
 451             matrixB.array(j).value(F16.floatToF16(r.nextFloat()).value());
 452         }
 453 
 454         accelerator.compute(cc ->
 455                 TestMatMul.matrixMultiply2DLIF16(cc, matrixA, matrixB, matrixC, size));
 456 
 457         // Run Seq for reference
 458         runSequential(matrixA, matrixB, resultSeq, size);
 459 
 460         for (int j = 0; j < size; j++) {
 461             for (int i = 0; i < size; i++) {
 462                 try {
 463                     HATAsserts.assertEquals(
 464                             Float.float16ToFloat(resultSeq.array(i * size + j).value()),
 465                             Float.float16ToFloat(matrixC.array(i * size + j).value()),
 466                             0.01f);
 467                 } catch (HATAssertionError hatAssertionError) {
 468                     throw new HATExpectedPrecisionError(hatAssertionError.getMessage());
 469                 }
 470             }
 471         }
 472     }
 473 
 474     @HatTest
 475     @Reflect
 476     public void testMatrixMultiply2DTiling() {
 477         var lookup = MethodHandles.lookup();
 478         var accelerator = new Accelerator(lookup, Backend.FIRST);
 479 
 480         final int size = SIZE;
 481         var matrixA = F32Array.create(accelerator, size * size);
 482         var matrixB = F32Array.create(accelerator, size * size);
 483 
 484         // Matrix for the results
 485         var matrixC = F32Array.create(accelerator, size * size);
 486         var resultSeq = F32Array.create(accelerator, size * size);
 487 
 488         // Initialize matrices (A and B have the same size)
 489         Random r = new Random(19);
 490 
 491         for (int j = 0; j < matrixA.length(); j++) {
 492             matrixA.array(j, r.nextFloat());
 493             matrixB.array(j, r.nextFloat());
 494         }
 495 
 496         accelerator.compute(cc ->
 497                 TestMatMul.matrixMultiply2DTiling(cc, matrixA, matrixB, matrixC, size));
 498 
 499         // Run Seq for reference
 500         runSequential(matrixA, matrixB, resultSeq, size);
 501 
 502         for (int j = 0; j < size; j++) {
 503             for (int i = 0; i < size; i++) {
 504                 HATAsserts.assertEquals(resultSeq.array(i * size + j), matrixC.array(i * size + j), 0.01f);
 505             }
 506         }
 507     }
 508 
 509     private interface SharedMemory extends NonMappableIface {
 510         void array(long index, float value);
 511         float array(long index);
 512         DeviceSchema<SharedMemory> schema = DeviceSchema.of(SharedMemory.class,
 513                 arr -> arr.withArray("array", 1024));
 514         static SharedMemory create(Accelerator accelerator) {
 515             return null;
 516         }
 517         static SharedMemory createLocal() {
 518             return null;
 519         }
 520         default void storeFloat4View(Float4 float4, int index) {
 521         }
 522     }
 523 
 524     private interface PrivateArray extends NonMappableIface {
 525         void array(long index, float value);
 526         float array(long index);
 527         DeviceSchema<PrivateArray> schema = DeviceSchema.of(PrivateArray.class,
 528                 arr -> arr.withArray("array", 16));
 529         static PrivateArray create(Accelerator accelerator) {
 530             return null;
 531         }
 532         static PrivateArray createPrivate() {
 533             return null;
 534         }
 535     }
 536 
 537     private interface FlatPrivate extends NonMappableIface {
 538         void array(long index, float value);
 539         float array(long index);
 540         DeviceSchema<FlatPrivate> schema = DeviceSchema.of(FlatPrivate.class,
 541                 arr -> arr.withArray("array", 4));
 542         static FlatPrivate create(Accelerator accelerator) {
 543             return null;
 544         }
 545         static FlatPrivate createPrivate() {
 546             return null;
 547         }
 548     }
 549 
 550     // Code ported from the HAT example module.
 551     @Reflect
 552     public static void matrixMultiplyKernel2DRegisterTiling(@RO KernelContext kc, @RO F32Array matrixA, @RO F32Array matrixB, @RW F32Array matrixC, int size) {
 553 
 554         // Configuration for the kernel: Keep in mind that if you change the following parameters,
 555         // also change the scheduling (global and local work sizes).
 556         final int BM = 64;
 557         final int BN = 64;
 558         final int BK = 16;
 559         final int TM = 4;
 560         final int TN = 4;
 561 
 562         int bx = kc.bix;
 563         int by = kc.biy;
 564 
 565         int totalResultsBlockTile = BM * BN;
 566         final int numThreadsBlockTile = totalResultsBlockTile / (TM * TN);
 567 
 568         final int linearLocalId = kc.liy * kc.lsx + kc.lix;
 569         final int threadCol = kc.lix;
 570         final int threadRow = kc.liy;
 571 
 572         SharedMemory tileA = SharedMemory.createLocal();
 573         SharedMemory tileB = SharedMemory.createLocal();
 574 
 575         int aFrom = by * BM * size;
 576         int bFrom = bx * BN;
 577         int v = bx * BN;
 578         int cFrom = (by * BM * size) + (v);
 579 
 580         final int innerRowA = linearLocalId / BK;
 581         final int innerColA = linearLocalId % BK;
 582 
 583         final int strideA = numThreadsBlockTile / BK;
 584         final int innerRowB = linearLocalId / BN;
 585         final int innerColB = linearLocalId % BN;
 586 
 587         int strideB = numThreadsBlockTile / BN;
 588 
 589         // Declarations of the arrays in private memory to perform register tiling
 590         PrivateArray threadResults = PrivateArray.createPrivate();
 591         FlatPrivate regM = FlatPrivate.createPrivate();
 592         FlatPrivate regN = FlatPrivate.createPrivate();
 593 
 594         // initialize values
 595         for (int i = 0; i < (TN * TN); i++) {
 596             threadResults.array(i, 0.0f);
 597         }
 598 
 599         // Each thread loops over the tiles
 600         for (int bkIdx = 0; bkIdx < size; bkIdx += BK) {
 601 
 602             // A) Load data into shared memory for array A
 603             for (int loadOffset = 0; loadOffset < BM; loadOffset += strideA) {
 604                 tileA.array((innerRowA + loadOffset) * BK + innerColA,
 605                         matrixA.array(((innerRowA + loadOffset) * size + innerColA) + aFrom));
 606             }
 607 
 608             // B) Load data matrixB into shared memory for array B
 609             for (int loadOffset = 0; loadOffset < BK; loadOffset += strideB) {
 610                 tileB.array((innerRowB + loadOffset) * BN + innerColB,
 611                         matrixB.array(((innerRowB + loadOffset) * size + innerColB) + bFrom));
 612             }
 613             kc.barrier();
 614 
 615             aFrom += (BK);
 616             int f = BK * size;
 617             bFrom += f;
 618 
 619             // Per-thread, we load the data from the shared memory into register for both
 620             // array A and array B (matrix A and B), and then perform the reduction within
 621             // the small region in private memory.
 622             for (int dotIdx = 0; dotIdx < BK; dotIdx++) {
 623                 // block into registers
 624                 for (int i = 0; i < TM; i++) {
 625                     regM.array(i,  tileA.array((threadRow * TM + i) * BK + dotIdx));
 626                 }
 627                 for (int i = 0; i < TN; i++) {
 628                     regN.array(i,  tileB.array(dotIdx * BN + threadCol * TN + i));
 629                 }
 630                 for (int resIdxM = 0; resIdxM < TM; resIdxM++) {
 631                     for (int resIdxN = 0; resIdxN < TN; resIdxN++) {
 632                         float val = regM.array(resIdxM) * regN.array(resIdxN);
 633                         float acc = threadResults.array(resIdxM * TN + resIdxN);
 634                         acc += val;
 635                         threadResults.array((resIdxM * TN + resIdxN), (acc));
 636                     }
 637                 }
 638             }
 639             kc.barrier();
 640         }
 641 
 642         // Finally, we store the results of the reductions for the whole 2D register block into global memory.
 643         // Essentially, each thread compute a small block of TM * TN sub-block size.
 644         for (int resIdxM = 0; resIdxM < TM; resIdxM++) {
 645             for (int resIdxN = 0; resIdxN < TN; resIdxN++) {
 646                 float value = threadResults.array(resIdxM * TN + resIdxN);
 647                 matrixC.array((((threadRow * TM + resIdxM) * size + threadCol * TN + resIdxN) + (cFrom)), value);
 648             }
 649         }
 650     }
 651 
 652     // Code ported from the HAT example module.
 653     @Reflect
 654     public static void matrixMultiplyKernel2DRegisterTilingVectorized(@RO KernelContext kc, @RO F32ArrayPadded matrixA, @RO F32ArrayPadded matrixB, @RW F32ArrayPadded matrixC, int size) {
 655 
 656         // Configuration for the kernel: Keep in mind that if you change the following parameters,
 657         // also change the scheduling (global and local work sizes).
 658       //  final int M = size;
 659         final int N = size;
 660         final int K = size;
 661         final int BM = 64;
 662         final int BN = 64;
 663         final int BK = 16;
 664         final int TM = 4;
 665         final int TN = 4;
 666 
 667         int bx = kc.bix;
 668         int by = kc.biy;
 669 
 670         final int linearLocalId = kc.liy * kc.lsx + kc.lix;
 671         final int threadCol = kc.lix;
 672         final int threadRow = kc.liy;
 673 
 674         SharedMemory tileA = SharedMemory.createLocal();
 675         SharedMemory tileB = SharedMemory.createLocal();
 676 
 677         int aFrom = by * BM * size;
 678         int bFrom = bx * BN;
 679         int v = bx * BN;
 680         int cFrom = (by * BM * size) + (v);
 681 
 682         final int innerRowA = linearLocalId / (BK / 4);
 683         final int innerColA = linearLocalId % (BK / 4);
 684         final int innerRowB = linearLocalId / (BN / 4);
 685         final int innerColB = linearLocalId % (BN / 4);
 686 
 687         // Declarations of the arrays in private memory to perform register tiling
 688         PrivateArray threadResults = PrivateArray.createPrivate();
 689         FlatPrivate regM = FlatPrivate.createPrivate();
 690         FlatPrivate regN = FlatPrivate.createPrivate();
 691 
 692         // initialize values
 693         for (int i = 0; i < (TN * TN); i++) {
 694             threadResults.array(i, 0.0f);
 695         }
 696 
 697         final int extraCols = 0;
 698 
 699         // Each thread loops over the tiles
 700         for (int bkIdx = 0; bkIdx < size; bkIdx += BK) {
 701 
 702             Float4 loadA = matrixA.float4View((innerRowA * K + innerColA * 4) + aFrom);
 703             tileA.array((innerColA * 4 + 0) * BM + innerRowA, loadA.x());
 704             tileA.array((innerColA * 4 + 1) * BM + innerRowA, loadA.y());
 705             tileA.array((innerColA * 4 + 2) * BM + innerRowA, loadA.z());
 706             tileA.array((innerColA * 4 + 3) * BM + innerRowA, loadA.w());
 707 
 708             Float4 loadB = matrixB.float4View((innerRowB * N + innerColB * 4) + bFrom);
 709             tileB.array(innerRowB * (BN + extraCols) + innerColB * 4 + 0, loadB.x());
 710             tileB.array(innerRowB * (BN + extraCols) + innerColB * 4 + 1, loadB.y());
 711             tileB.array(innerRowB * (BN + extraCols) + innerColB * 4 + 2, loadB.z());
 712             tileB.array(innerRowB * (BN + extraCols) + innerColB * 4 + 3, loadB.w());
 713 
 714             kc.barrier();
 715 
 716             aFrom += (BK);
 717             int f = BK * size;
 718             bFrom += f;
 719 
 720             // Per-thread, we load the data from the shared memory into register for both
 721             // array A and array B (matrix A and B), and then perform the reduction within
 722             // the small region in private memory.
 723             for (int dotIdx = 0; dotIdx < BK; dotIdx++) {
 724                 // block into registers
 725                 for (int i = 0; i < TM; i++) {
 726                     regM.array(i,  tileA.array(dotIdx * BM + threadRow * TM + i));
 727                 }
 728                 for (int i = 0; i < TN; i++) {
 729                     regN.array(i,  tileB.array(dotIdx * (BN + extraCols) + threadCol * TN + i));
 730                 }
 731                 for (int resIdxM = 0; resIdxM < TM; resIdxM++) {
 732                     for (int resIdxN = 0; resIdxN < TN; resIdxN++) {
 733                         float val = regM.array(resIdxM) * regN.array(resIdxN);
 734                         float acc = threadResults.array(resIdxM * TN + resIdxN);
 735                         acc += val;
 736                         threadResults.array((resIdxM * TN + resIdxN), (acc));
 737                     }
 738                 }
 739             }
 740             kc.barrier();
 741         }
 742 
 743         // Finally, we store the results of the reductions for the whole 2D register block into global memory.
 744         // Essentially, each thread compute a small block of TM * TN sub-block size.
 745         for (int resIdxM = 0; resIdxM < TM; resIdxM++) {
 746             for (int resIdxN = 0; resIdxN < TN; resIdxN++) {
 747                 float value = threadResults.array(resIdxM * TN + resIdxN);
 748                 matrixC.array((((threadRow * TM + resIdxM) * N + threadCol * TN + resIdxN) + (cFrom)), value);
 749             }
 750         }
 751     }
 752 
 753     @Reflect
 754     public static void matrixMultiply2DRegisterTiling(@RO ComputeContext cc, @RO F32Array matrixA, @RO F32Array matrixB, @RW  F32Array matrixC, final int size) {
 755         cc.dispatchKernel(NDRange.of2D(256, 256,16, 16),
 756                 kc -> matrixMultiplyKernel2DRegisterTiling(kc, matrixA, matrixB, matrixC, size)
 757         );
 758     }
 759 
 760     @Reflect
 761     public static void matrixMultiply2DRegisterTilingVectorized(@RO ComputeContext cc, @RO F32ArrayPadded matrixA, @RO F32ArrayPadded matrixB, @RW  F32ArrayPadded matrixC, final int size) {
 762         cc.dispatchKernel(NDRange.of2D(256, 256,16, 16),
 763                 kc -> matrixMultiplyKernel2DRegisterTilingVectorized(kc, matrixA, matrixB, matrixC, size)
 764         );
 765     }
 766 
 767     @HatTest
 768     @Reflect
 769     public void testMatMul2DRegisterTiling() {
 770         var lookup = MethodHandles.lookup();
 771         var accelerator = new Accelerator(lookup, Backend.FIRST);
 772 
 773         final int size = 1024;
 774         var matrixA = F32Array.create(accelerator, size * size);
 775         var matrixB = F32Array.create(accelerator, size * size);
 776 
 777         // Matrix for the results
 778         var matrixC = F32Array.create(accelerator, size * size);
 779         var resultSeq = F32Array.create(accelerator, size * size);
 780 
 781         // Initialize matrices (A and B have the same size)
 782         Random r = new Random(19);
 783 
 784         for (int j = 0; j < matrixA.length(); j++) {
 785             matrixA.array(j, r.nextFloat());
 786             matrixB.array(j, r.nextFloat());
 787         }
 788 
 789         accelerator.compute(cc ->
 790                 TestMatMul.matrixMultiply2DRegisterTiling(cc, matrixA, matrixB, matrixC, size));
 791 
 792         // Run Seq for reference
 793         runSequential(matrixA, matrixB, resultSeq, size);
 794 
 795         for (int j = 0; j < size; j++) {
 796             for (int i = 0; i < size; i++) {
 797                 HATAsserts.assertEquals(resultSeq.array(i * size + j), matrixC.array(i * size + j), 0.01f);
 798             }
 799         }
 800     }
 801 
 802     @HatTest
 803     @Reflect
 804     public void testMatMul2DRegisterTilingVectorized() {
 805         var lookup = MethodHandles.lookup();
 806         var accelerator = new Accelerator(lookup, Backend.FIRST);
 807 
 808         final int size = 1024;
 809         var matrixA = F32ArrayPadded.create(accelerator, size * size);
 810         var matrixB = F32ArrayPadded.create(accelerator, size * size);
 811 
 812         // Matrix for the results
 813         var matrixC = F32ArrayPadded.create(accelerator, size * size);
 814         var resultSeq = F32ArrayPadded.create(accelerator, size * size);
 815 
 816         // Initialize matrices (A and B have the same size)
 817         Random r = new Random(19);
 818 
 819         for (int j = 0; j < matrixA.length(); j++) {
 820             matrixA.array(j, r.nextFloat());
 821             matrixB.array(j, r.nextFloat());
 822         }
 823 
 824         accelerator.compute(cc ->
 825                 TestMatMul.matrixMultiply2DRegisterTilingVectorized(cc, matrixA, matrixB, matrixC, size));
 826 
 827         // Run Seq for reference
 828         runSequential(matrixA, matrixB, resultSeq, size);
 829 
 830         for (int j = 0; j < size; j++) {
 831             for (int i = 0; i < size; i++) {
 832                 HATAsserts.assertEquals(resultSeq.array(i * size + j), matrixC.array(i * size + j), 0.01f);
 833             }
 834         }
 835     }
 836 
 837     private interface SharedMemoryHalf extends NonMappableIface {
 838         F16 array(int index);
 839 
 840         DeviceSchema<SharedMemoryHalf> schema = DeviceSchema.of(SharedMemoryHalf.class,
 841                 arr -> arr.withArray("array", 1024)
 842                         .withDeps(F16.class, half -> half.withField("value")));
 843 
 844         static SharedMemoryHalf create(Accelerator accelerator) {
 845             return null;
 846         }
 847 
 848         static SharedMemoryHalf createLocal() {
 849             return null;
 850         }
 851     }
 852 
 853     private interface PrivateArrayHalf extends NonMappableIface {
 854         F16 array(int index);
 855 
 856         DeviceSchema<PrivateArrayHalf> schema = DeviceSchema.of(PrivateArrayHalf.class,
 857                 arr -> arr.withArray("array", 16)
 858                         .withDeps(F16.class, half -> half.withField("value")));
 859 
 860         static PrivateArrayHalf create(Accelerator accelerator) {
 861             return null;
 862         }
 863 
 864         static PrivateArrayHalf createPrivate() {
 865             return null;
 866         }
 867     }
 868 
 869     private interface FlatPrivateHalf extends NonMappableIface {
 870         F16 array(int index);
 871 
 872         DeviceSchema<FlatPrivateHalf> schema = DeviceSchema.of(FlatPrivateHalf.class,
 873                 arr -> arr.withArray("array", 4)
 874                         .withDeps(F16.class, half -> half.withField("value")));
 875 
 876         static FlatPrivateHalf create(Accelerator accelerator) {
 877             return null;
 878         }
 879 
 880         static FlatPrivateHalf createPrivate() {
 881             return null;
 882         }
 883     }
 884 
 885     // Taking from the HAT Examples module
 886     @Reflect
 887     public static void matrixMultiplyKernel2DRegisterTilingHalf(@RO KernelContext kc, @RO F16Array matrixA, @RO F16Array matrixB, @RW F16Array matrixC, int size) {
 888         final int BM = 64;
 889         final int BN = 64;
 890         final int BK = 16;
 891         final int TM = 4;
 892         final int TN = 4;
 893 
 894         int bx = kc.bix;
 895         int by = kc.biy;
 896 
 897         int totalResultsBlockTile = BM * BN;
 898         final int numThreadsBlockTile = totalResultsBlockTile / (TM * TN);
 899 
 900         final int linearLocalId = kc.liy * kc.lsx + kc.lix;
 901         final int threadCol = kc.lix;
 902         final int threadRow = kc.liy;
 903 
 904         SharedMemoryHalf tileA = SharedMemoryHalf.createLocal();
 905         SharedMemoryHalf tileB = SharedMemoryHalf.createLocal();
 906 
 907         int aFrom = by * BM * size;
 908         int bFrom = bx * BN;
 909         int v = bx * BN;
 910         int cFrom = (by * BM * size) + (v);
 911 
 912         final int innerRowA = linearLocalId / BK;
 913         final int innerColA = linearLocalId % BK;
 914 
 915         final int strideA = numThreadsBlockTile / BK;
 916         final int innerRowB = linearLocalId / BN;
 917         final int innerColB = linearLocalId % BN;
 918 
 919         int strideB = numThreadsBlockTile / BN;
 920 
 921         PrivateArrayHalf threadResults = PrivateArrayHalf.createPrivate();
 922         FlatPrivateHalf regM = FlatPrivateHalf.createPrivate();
 923         FlatPrivateHalf regN = FlatPrivateHalf.createPrivate();
 924 
 925         for (int i = 0; i < (TN * TN); i++) {
 926             F16 init = F16.of(0.0f);
 927             threadResults.array(i).value(init.value());
 928         }
 929 
 930         for (int bkIdx = 0; bkIdx < size; bkIdx += BK) {
 931             for (int loadOffset = 0; loadOffset < BM; loadOffset += strideA) {
 932                 F16 ha = matrixA.array(((innerRowA + loadOffset) * size + innerColA) + aFrom);
 933                 tileA.array((innerRowA + loadOffset) * BK + innerColA).value(ha.value());
 934             }
 935             for (int loadOffset = 0; loadOffset < BK; loadOffset += strideB) {
 936                 F16 hb = matrixB.array(((innerRowB + loadOffset) * size + innerColB) + bFrom);
 937                 tileB.array((innerRowB + loadOffset) * BN + innerColB).value(hb.value());
 938             }
 939             kc.barrier();
 940 
 941             aFrom += (BK);
 942             int f = BK * size;
 943             bFrom += f;
 944 
 945             for (int dotIdx = 0; dotIdx < BK; dotIdx++) {
 946                 for (int i = 0; i < TM; i++) {
 947                     F16 ha = tileA.array((threadRow * TM + i) * BK + dotIdx);
 948                     regM.array(i).value(ha.value());
 949                 }
 950                 for (int i = 0; i < TN; i++) {
 951                     F16 hb = tileB.array(dotIdx * BN + threadCol * TN + i);
 952                     regN.array(i).value(hb.value());
 953                 }
 954                 for (int resIdxM = 0; resIdxM < TM; resIdxM++) {
 955                     for (int resIdxN = 0; resIdxN < TN; resIdxN++) {
 956                         F16 privA = regM.array(resIdxM);
 957                         F16 privB = regN.array(resIdxN);
 958                         F16 mul = F16.mul(privA, privB);
 959                         F16 acc = threadResults.array(resIdxM * TN + resIdxN);
 960                         acc = F16.add(acc, mul);
 961                         threadResults.array((resIdxM * TN + resIdxN)).value(acc.value());
 962                     }
 963                 }
 964             }
 965             kc.barrier();
 966         }
 967         for (int resIdxM = 0; resIdxM < TM; resIdxM++) {
 968             for (int resIdxN = 0; resIdxN < TN; resIdxN++) {
 969                 F16 result = threadResults.array(resIdxM * TN + resIdxN);
 970                 matrixC.array((((threadRow * TM + resIdxM) * size + threadCol * TN + resIdxN) + (cFrom))).value(result.value());
 971             }
 972         }
 973     }
 974 
 975     private interface SharedMemoryBfloat16 extends NonMappableIface {
 976         BF16 array(int index);
 977 
 978         DeviceSchema<SharedMemoryBfloat16> schema = DeviceSchema.of(SharedMemoryBfloat16.class,
 979                 arr -> arr.withArray("array", 1024)
 980                         .withDeps(BF16.class, half -> half.withField("value")));
 981 
 982         static SharedMemoryBfloat16 create(Accelerator accelerator) {
 983             return null;
 984         }
 985 
 986         static SharedMemoryBfloat16 createLocal() {
 987             return null;
 988         }
 989     }
 990 
 991     private interface PrivateArrayBfloat16 extends NonMappableIface {
 992         BF16 array(int index);
 993 
 994         DeviceSchema<PrivateArrayBfloat16> schema = DeviceSchema.of(PrivateArrayBfloat16.class,
 995                 arr -> arr.withArray("array", 16)
 996                         .withDeps(BF16.class, half -> half.withField("value")));
 997 
 998         static PrivateArrayBfloat16 create(Accelerator accelerator) {
 999             return null;
1000         }
1001 
1002         static PrivateArrayBfloat16 createPrivate() {
1003             return null;
1004         }
1005     }
1006 
1007     private interface FlatPrivateBfloat16 extends NonMappableIface {
1008         BF16 array(int index);
1009 
1010         DeviceSchema<FlatPrivateBfloat16> schema = DeviceSchema.of(FlatPrivateBfloat16.class,
1011                 arr -> arr.withArray("array", 4)
1012                         .withDeps(BF16.class, half -> half.withField("value")));
1013 
1014         static FlatPrivateBfloat16 create(Accelerator accelerator) {
1015             return null;
1016         }
1017 
1018         static FlatPrivateBfloat16 createPrivate() {
1019             return null;
1020         }
1021     }
1022 
1023     @Reflect
1024     public static void matrixMultiplyKernel2DRegisterTilingBFloat16(@RO KernelContext kc, @RO BF16Array matrixA, @RO BF16Array matrixB, @RW BF16Array matrixC, int size) {
1025         final int BM = 64;
1026         final int BN = 64;
1027         final int BK = 16;
1028         final int TM = 4;
1029         final int TN = 4;
1030 
1031         int bx = kc.bix;
1032         int by = kc.biy;
1033 
1034         int totalResultsBlockTile = BM * BN;
1035         final int numThreadsBlockTile = totalResultsBlockTile / (TM * TN);
1036 
1037         final int linearLocalId = kc.liy * kc.lsx + kc.lix;
1038         final int threadCol = kc.lix;
1039         final int threadRow = kc.liy;
1040 
1041         SharedMemoryBfloat16 tileA = SharedMemoryBfloat16.createLocal();
1042         SharedMemoryBfloat16 tileB = SharedMemoryBfloat16.createLocal();
1043 
1044         int aFrom = by * BM * size;
1045         int bFrom = bx * BN;
1046         int v = bx * BN;
1047         int cFrom = (by * BM * size) + (v);
1048 
1049         final int innerRowA = linearLocalId / BK;
1050         final int innerColA = linearLocalId % BK;
1051 
1052         final int strideA = numThreadsBlockTile / BK;
1053         final int innerRowB = linearLocalId / BN;
1054         final int innerColB = linearLocalId % BN;
1055 
1056         int strideB = numThreadsBlockTile / BN;
1057 
1058         PrivateArrayBfloat16 threadResults = PrivateArrayBfloat16.createPrivate();
1059         FlatPrivateBfloat16 regM = FlatPrivateBfloat16.createPrivate();
1060         FlatPrivateBfloat16 regN = FlatPrivateBfloat16.createPrivate();
1061 
1062         for (int i = 0; i < (TN * TN); i++) {
1063             BF16 init = BF16.of(0.0f);
1064             threadResults.array(i).value(init.value());
1065         }
1066 
1067         for (int bkIdx = 0; bkIdx < size; bkIdx += BK) {
1068             for (int loadOffset = 0; loadOffset < BM; loadOffset += strideA) {
1069                 BF16 ha = matrixA.array(((innerRowA + loadOffset) * size + innerColA) + aFrom);
1070                 tileA.array((innerRowA + loadOffset) * BK + innerColA).value(ha.value());
1071             }
1072             for (int loadOffset = 0; loadOffset < BK; loadOffset += strideB) {
1073                 BF16 hb = matrixB.array(((innerRowB + loadOffset) * size + innerColB) + bFrom);
1074                 tileB.array((innerRowB + loadOffset) * BN + innerColB).value(hb.value());
1075             }
1076             kc.barrier();
1077 
1078             aFrom += (BK);
1079             int f = BK * size;
1080             bFrom += f;
1081 
1082             for (int dotIdx = 0; dotIdx < BK; dotIdx++) {
1083                 for (int i = 0; i < TM; i++) {
1084                     BF16 ha = tileA.array((threadRow * TM + i) * BK + dotIdx);
1085                     regM.array(i).value(ha.value());
1086                 }
1087                 for (int i = 0; i < TN; i++) {
1088                     BF16 hb = tileB.array(dotIdx * BN + threadCol * TN + i);
1089                     regN.array(i).value(hb.value());
1090                 }
1091                 for (int resIdxM = 0; resIdxM < TM; resIdxM++) {
1092                     for (int resIdxN = 0; resIdxN < TN; resIdxN++) {
1093                         BF16 privA = regM.array(resIdxM);
1094                         BF16 privB = regN.array(resIdxN);
1095                         BF16 mul = BF16.mul(privA, privB);
1096                         BF16 acc = threadResults.array(resIdxM * TN + resIdxN);
1097                         acc = BF16.add(acc, mul);
1098                         threadResults.array((resIdxM * TN + resIdxN)).value(acc.value());
1099                     }
1100                 }
1101             }
1102             kc.barrier();
1103         }
1104         for (int resIdxM = 0; resIdxM < TM; resIdxM++) {
1105             for (int resIdxN = 0; resIdxN < TN; resIdxN++) {
1106                 BF16 result = threadResults.array(resIdxM * TN + resIdxN);
1107                 matrixC.array((((threadRow * TM + resIdxM) * size + threadCol * TN + resIdxN) + (cFrom))).value(result.value());
1108             }
1109         }
1110     }
1111 
1112     @Reflect
1113     public static void matrixMultiply2DRegisterTilingHalf(@RO ComputeContext cc, @RO F16Array matrixA, @RO F16Array matrixB, @RW F16Array matrixC, int globalSize) {
1114         cc.dispatchKernel(NDRange.of2D(256, 256,16, 16),
1115                 kc -> matrixMultiplyKernel2DRegisterTilingHalf(kc, matrixA, matrixB, matrixC, globalSize)
1116         );
1117     }
1118 
1119     @Reflect
1120     public static void matrixMultiply2DRegisterTilingBFloat16(@RO ComputeContext cc, @RO BF16Array matrixA, @RO BF16Array matrixB, @RW BF16Array matrixC, int globalSize) {
1121         cc.dispatchKernel(NDRange.of2D(256, 256,16, 16),
1122                 kc -> matrixMultiplyKernel2DRegisterTilingBFloat16(kc, matrixA, matrixB, matrixC, globalSize)
1123         );
1124     }
1125 
1126     @HatTest
1127     @Reflect
1128     public void matrixMultiply2DRegisterTilingHalf() {
1129         var lookup = MethodHandles.lookup();
1130         var accelerator = new Accelerator(lookup, Backend.FIRST);
1131 
1132         final int size = 1024;
1133         var matrixA = F16Array.create(accelerator, size * size);
1134         var matrixB = F16Array.create(accelerator, size * size);
1135 
1136         // Matrix for the results
1137         var matrixC = F16Array.create(accelerator, size * size);
1138         var resultSeq = F16Array.create(accelerator, size * size);
1139 
1140         // Initialize matrices (A and B have the same size)
1141         Random r = new Random(19);
1142         for (int j = 0; j < matrixA.length(); j++) {
1143             matrixA.array(j).value(F16.floatToF16(r.nextFloat()).value());
1144             matrixB.array(j).value(F16.floatToF16(r.nextFloat()).value());
1145         }
1146 
1147         accelerator.compute(cc ->
1148                 TestMatMul.matrixMultiply2DRegisterTilingHalf(cc, matrixA, matrixB, matrixC, size));
1149 
1150         // Run Seq for reference
1151         runSequential(matrixA, matrixB, resultSeq, size);
1152 
1153         for (int i = 0; i < size; i++) {
1154             for (int j = 0; j < size; j++) {
1155                 try {
1156                     HATAsserts.assertEquals(F16.f16ToFloat(resultSeq.array(i * size + j)),
1157                                         F16.f16ToFloat(matrixC.array(i * size + j)),
1158                                         0.01f);
1159                 } catch (HATAssertionError hatAssertionError) {
1160                     throw new HATExpectedPrecisionError(hatAssertionError.getMessage());
1161                 }
1162             }
1163         }
1164     }
1165 
1166     @HatTest
1167     @Reflect
1168     public void matrixMultiply2DRegisterTilingBFloat16() {
1169         var lookup = MethodHandles.lookup();
1170         var accelerator = new Accelerator(lookup, Backend.FIRST);
1171 
1172         final int size = 1024;
1173         var matrixA = BF16Array.create(accelerator, size * size);
1174         var matrixB = BF16Array.create(accelerator, size * size);
1175 
1176         // Matrix for the results
1177         var matrixC = BF16Array.create(accelerator, size * size);
1178         var resultSeq = BF16Array.create(accelerator, size * size);
1179 
1180         // Initialize matrices (A and B have the same size)
1181         Random r = new Random(19);
1182         for (int j = 0; j < matrixA.length(); j++) {
1183             matrixA.array(j).value(BF16.float2bfloat16(r.nextFloat()).value());
1184             matrixB.array(j).value(BF16.float2bfloat16(r.nextFloat()).value());
1185         }
1186 
1187         accelerator.compute(cc ->
1188                 TestMatMul.matrixMultiply2DRegisterTilingBFloat16(cc, matrixA, matrixB, matrixC, size));
1189 
1190         // Run Seq for reference
1191         runSequential(matrixA, matrixB, resultSeq, size);
1192 
1193         for (int i = 0; i < size; i++) {
1194             for (int j = 0; j < size; j++) {
1195                 try {
1196                     HATAsserts.assertEquals(BF16.bfloat162float(resultSeq.array(i * size + j)),
1197                             BF16.bfloat162float(matrixC.array(i * size + j)),
1198                             0.01f);
1199                 } catch (HATAssertionError hatAssertionError) {
1200                     throw new HATExpectedPrecisionError(hatAssertionError.getMessage());
1201                 }
1202             }
1203         }
1204     }
1205 }