1 /*
  2  * Copyright (c) 2025-2026, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 package hat.test;
 26 
 27 import hat.Accelerator;
 28 import hat.ComputeContext;
 29 import hat.KernelContext;
 30 import hat.NDRange;
 31 import hat.backend.Backend;
 32 import hat.types.BF16;
 33 import hat.buffer.BF16Array;
 34 import hat.device.DeviceSchema;
 35 import hat.device.NonMappableIface;
 36 import hat.test.annotation.HatTest;
 37 import hat.test.exceptions.HATAssertionError;
 38 import hat.test.exceptions.HATAsserts;
 39 import hat.test.exceptions.HATExpectedPrecisionError;
 40 import jdk.incubator.code.Reflect;
 41 import optkl.ifacemapper.MappableIface.*;
 42 
 43 import java.lang.invoke.MethodHandles;
 44 import java.util.Random;
 45 
 46 public class TestBFloat16Type {
 47 
 48     @Reflect
 49     public static void kernel_copy(@RO KernelContext kernelContext, @RO BF16Array a, @WO BF16Array b) {
 50         if (kernelContext.gix < kernelContext.gsx) {
 51             BF16 ha = a.array(kernelContext.gix);
 52             b.array(kernelContext.gix).value(ha.value());
 53         }
 54     }
 55 
 56     @Reflect
 57     public static void bf16_02(@RO KernelContext kernelContext, @RO BF16Array a, @RO BF16Array b, @WO BF16Array c) {
 58         if (kernelContext.gix < kernelContext.gsx) {
 59             BF16 ha = a.array(kernelContext.gix);
 60             BF16 hb = b.array(kernelContext.gix);
 61             BF16 result = BF16.add(ha, hb);
 62             BF16 hc = c.array(kernelContext.gix);
 63             hc.value(result.value());
 64         }
 65     }
 66 
 67     @Reflect
 68     public static void bf16_03(@RO KernelContext kernelContext, @RO BF16Array a, @RO BF16Array b, @WO BF16Array c) {
 69         if (kernelContext.gix < kernelContext.gsx) {
 70             BF16 ha = a.array(kernelContext.gix);
 71             BF16 hb = b.array(kernelContext.gix);
 72 
 73             BF16 result = BF16.add(ha, BF16.add(hb, hb));
 74             BF16 hC = c.array(kernelContext.gix);
 75             hC.value(result.value());
 76         }
 77     }
 78 
 79     @Reflect
 80     public static void bf16_04(@RO KernelContext kernelContext, @RO BF16Array a, @RO BF16Array b, @WO BF16Array c) {
 81         if (kernelContext.gix < kernelContext.gsx) {
 82             BF16 ha = a.array(kernelContext.gix);
 83             BF16 hb = b.array(kernelContext.gix);
 84 
 85             BF16 r1 = BF16.mul(ha, hb);
 86             BF16 r2 = BF16.div(ha, hb);
 87             BF16 r3 = BF16.sub(ha, hb);
 88             BF16 r4 = BF16.add(r1, r2);
 89             BF16 r5 = BF16.add(r4, r3);
 90             BF16 hC = c.array(kernelContext.gix);
 91             hC.value(r5.value());
 92         }
 93     }
 94 
 95     @Reflect
 96     public static void bf16_05(@RO KernelContext kernelContext, @WO BF16Array a) {
 97         if (kernelContext.gix < kernelContext.gsx) {
 98             BF16 ha = a.array(kernelContext.gix);
 99             BF16 initVal = BF16.of( 2.1f);
100             ha.value(initVal.value());
101         }
102     }
103 
104     @Reflect
105     public static void bf16_06(@RO KernelContext kernelContext, @WO BF16Array a) {
106         if (kernelContext.gix < kernelContext.gsx) {
107             BF16 initVal = BF16.of(kernelContext.gix);
108             BF16 ha = a.array(kernelContext.gix);
109             ha.value(initVal.value());
110         }
111     }
112 
113     @Reflect
114     public static void bf16_08(@RO KernelContext kernelContext, @WO BF16Array a) {
115         if (kernelContext.gix < kernelContext.gsx) {
116             BF16 initVal = BF16.float2bfloat16(kernelContext.gix);
117             BF16 ha = a.array(kernelContext.gix);
118             ha.value(initVal.value());
119         }
120     }
121 
122     @Reflect
123     public static void bf16_09(@RO KernelContext kernelContext, @RO BF16Array a, @WO BF16Array b) {
124         if (kernelContext.gix < kernelContext.gsx) {
125             BF16 ha = a.array(kernelContext.gix);
126             float f = BF16.bfloat162float(ha);
127             BF16 result = BF16.float2bfloat16(f);
128             BF16 hb = b.array(kernelContext.gix);
129             hb.value(result.value());
130         }
131     }
132 
133     @Reflect
134     public static void bf16_10(@RO KernelContext kernelContext, @WO BF16Array a) {
135         if (kernelContext.gix < kernelContext.gsx) {
136             BF16 ha = a.array(kernelContext.gix);
137             BF16 f16 = BF16.of(1.1f);
138             float f = BF16.bfloat162float(f16);
139             BF16 result = BF16.float2bfloat16(f);
140             ha.value(result.value());
141         }
142     }
143 
144     public interface LocalArray extends NonMappableIface {
145         BF16 array(int index);
146         DeviceSchema<LocalArray> schema = DeviceSchema.of(LocalArray.class,
147                 builder -> builder.withArray("array", 1024)
148                         .withDeps(BF16.class, bfloat16 -> bfloat16.withField("value")));
149 
150         static LocalArray  create(Accelerator accelerator) {
151             return null;
152         }
153 
154         static LocalArray createLocal() {
155             return null;
156         }
157     }
158 
159     @Reflect
160     public static void bf16_11(@RO KernelContext kernelContext, @RO BF16Array a, @WO BF16Array b) {
161         LocalArray sm = LocalArray.createLocal();
162         if (kernelContext.gix < kernelContext.gsx) {
163             int lix = kernelContext.lix;
164             BF16 ha = a.array(kernelContext.gix);
165 
166             sm.array(lix).value(ha.value());
167             kernelContext.barrier();
168 
169             BF16 hb = sm.array(lix);
170             b.array(kernelContext.gix).value(hb.value());
171         }
172     }
173 
174     @Reflect
175     public static void bf16_12(@RO KernelContext kernelContext, @RO BF16Array a, @RO BF16Array b, @WO BF16Array c) {
176         // Test the fluent API style
177         if (kernelContext.gix < kernelContext.gsx) {
178             BF16 ha = a.array(kernelContext.gix);
179             BF16 hb = b.array(kernelContext.gix);
180             BF16 result = ha.add(hb);
181             c.array(kernelContext.gix).value(result.value());
182         }
183     }
184 
185     @Reflect
186     public static void bf16_13(@RO KernelContext kernelContext, @RO BF16Array a, @RO BF16Array b,  @WO BF16Array c) {
187         // Test the fluent API style
188         if (kernelContext.gix < kernelContext.gsx) {
189             BF16 ha = a.array(kernelContext.gix);
190             BF16 hb = b.array(kernelContext.gix);
191             BF16 result = ha.add(hb).sub(hb).mul(ha).div(ha);
192             c.array(kernelContext.gix).value(result.value());
193         }
194     }
195 
196     @Reflect
197     public static void bf16_14(@RO KernelContext kernelContext, @RO BF16Array a, @WO BF16Array b) {
198         // Testing mixed float types
199         if (kernelContext.gix < kernelContext.gsx) {
200             BF16 ha = a.array(kernelContext.gix);
201             float myFloat = 32.1f;
202             BF16 result = BF16.add(myFloat, ha);
203             b.array(kernelContext.gix).value(result.value());
204         }
205     }
206 
207     public interface PrivateArray extends NonMappableIface {
208         BF16 array(int index);
209         DeviceSchema<PrivateArray> schema = DeviceSchema.of(PrivateArray.class,
210                 builder -> builder.withArray("array", 256)
211                         .withDeps(BF16.class, bfloat16 -> bfloat16.withField("value")));
212 
213         static PrivateArray  create(Accelerator accelerator) {
214             return null;
215         }
216 
217         static PrivateArray createPrivate() {
218             return null;
219         }
220     }
221 
222     @Reflect
223     public static void bf16_15(@RO KernelContext kernelContext, @RO BF16Array a, @WO BF16Array b) {
224         PrivateArray privateArray = PrivateArray.createPrivate();
225         if (kernelContext.gix < kernelContext.gsx) {
226             int lix = kernelContext.lix;
227             BF16 ha = a.array(kernelContext.gix);
228             privateArray.array(lix).value(ha.value());
229             BF16 hb = privateArray.array(lix);
230             b.array(kernelContext.gix).value(hb.value());
231         }
232     }
233 
234     @Reflect
235     public static void bf16_16(@RO KernelContext kernelContext, @RW BF16Array a) {
236         BF16 ha = a.array(0);
237         BF16 hre = BF16.add(ha, ha);
238         hre = BF16.add(hre, hre);
239         a.array(0).value(hre.value());
240     }
241 
242     @Reflect
243     public static void bf16_17(@RO KernelContext kernelContext, @RW BF16Array a) {
244 
245         BF16 ha = a.array(0);
246         PrivateArray privateArray = PrivateArray.createPrivate();
247         privateArray.array(0).value(ha.value());
248 
249         // Obtain the value from private memory
250         BF16 acc = privateArray.array(0);
251 
252         // compute
253         acc = BF16.add(acc, acc);
254 
255         // store the result
256         a.array(0).value(acc.value());
257     }
258 
259     @Reflect
260     public static void compute01(@RO ComputeContext computeContext, @RO BF16Array a, @WO BF16Array b) {
261         computeContext.dispatchKernel(NDRange.of1D(a.length()), kernelContext -> TestBFloat16Type.kernel_copy(kernelContext, a, b));
262     }
263 
264     @Reflect
265     public static void compute02(@RO ComputeContext computeContext, @RO BF16Array a, @RO BF16Array b, @WO BF16Array c) {
266         computeContext.dispatchKernel(NDRange.of1D(a.length()),
267                 kernelContext -> TestBFloat16Type.bf16_02(kernelContext, a, b, c));
268     }
269 
270     @Reflect
271     public static void compute03(@RO ComputeContext computeContext, @RO BF16Array a, @RO BF16Array b, @WO BF16Array c) {
272         computeContext.dispatchKernel(NDRange.of1D(a.length()),
273                 kernelContext -> TestBFloat16Type.bf16_03(kernelContext, a, b, c));
274     }
275 
276     @Reflect
277     public static void compute04(@RO ComputeContext computeContext, @RO BF16Array a, @RO BF16Array b, @WO BF16Array c) {
278         computeContext.dispatchKernel(NDRange.of1D(a.length()),
279                 kernelContext -> TestBFloat16Type.bf16_04(kernelContext, a, b, c));
280     }
281 
282     @Reflect
283     public static void compute05(@RO ComputeContext computeContext, @WO BF16Array a) {
284         computeContext.dispatchKernel(NDRange.of1D(a.length()), kernelContext -> TestBFloat16Type.bf16_05(kernelContext, a));
285     }
286 
287     @Reflect
288     public static void compute06(@RO ComputeContext computeContext, @WO BF16Array a) {
289         computeContext.dispatchKernel(NDRange.of1D(a.length()), kernelContext -> TestBFloat16Type.bf16_06(kernelContext, a));
290     }
291 
292     @Reflect
293     public static void compute08(@RO ComputeContext computeContext, @WO BF16Array a) {
294         computeContext.dispatchKernel(NDRange.of1D(a.length()), kernelContext -> TestBFloat16Type.bf16_08(kernelContext, a));
295     }
296 
297     @Reflect
298     public static void compute09(@RO ComputeContext computeContext, @RW BF16Array a, @WO BF16Array b) {
299         computeContext.dispatchKernel(NDRange.of1D(a.length()), kernelContext -> TestBFloat16Type.bf16_09(kernelContext, a, b));
300     }
301 
302     @Reflect
303     public static void compute10(@RO ComputeContext computeContext, @WO BF16Array a) {
304         computeContext.dispatchKernel(NDRange.of1D(a.length()), kernelContext -> TestBFloat16Type.bf16_10(kernelContext, a));
305     }
306 
307     @Reflect
308     public static void compute11(@RO ComputeContext computeContext, @RO BF16Array a, @WO BF16Array b) {
309         computeContext.dispatchKernel(NDRange.of1D(a.length(),16), kernelContext -> TestBFloat16Type.bf16_11(kernelContext, a, b));
310     }
311 
312     @Reflect
313     public static void compute12(@RO ComputeContext computeContext, @RO BF16Array a, @RO BF16Array b, @WO BF16Array c) {
314         computeContext.dispatchKernel(NDRange.of1D(a.length()), kernelContext -> TestBFloat16Type.bf16_12(kernelContext, a, b, c));
315     }
316 
317     @Reflect
318     public static void compute13(@RO ComputeContext computeContext, @RO BF16Array a, @RO BF16Array b, @WO BF16Array c) {
319         computeContext.dispatchKernel(NDRange.of1D(a.length()), kernelContext -> TestBFloat16Type.bf16_13(kernelContext, a, b, c));
320     }
321 
322     @Reflect
323     public static void compute14(@RO ComputeContext computeContext, @RO BF16Array a, @WO BF16Array b) {
324         computeContext.dispatchKernel(NDRange.of1D(a.length()), kernelContext -> TestBFloat16Type.bf16_14(kernelContext, a, b));
325     }
326 
327     @Reflect
328     public static void compute15(@RO ComputeContext computeContext, @RO BF16Array a, @WO BF16Array b) {
329         computeContext.dispatchKernel(NDRange.of1D(a.length()), kernelContext -> TestBFloat16Type.bf16_15(kernelContext, a, b));
330     }
331 
332     @Reflect
333     public static void compute16(@RO ComputeContext computeContext, @RW BF16Array a) {
334         computeContext.dispatchKernel(NDRange.of1D(1), kernelContext -> TestBFloat16Type.bf16_16(kernelContext, a));
335     }
336 
337     @Reflect
338     public static void compute17(@RO ComputeContext computeContext, @RW BF16Array a) {
339         computeContext.dispatchKernel(NDRange.of1D(1), kernelContext -> TestBFloat16Type.bf16_17(kernelContext, a));
340     }
341 
342     @HatTest
343     @Reflect
344     public void test_bfloat16_01() {
345         final int size = 256;
346         var accelerator = new Accelerator(MethodHandles.lookup(), Backend.FIRST);
347 
348         BF16Array arrayA = BF16Array.create(accelerator, size);
349         BF16Array arrayB = BF16Array.create(accelerator, size);
350         for (int i = 0; i < size; i++) {
351             arrayA.array(i).value(BF16.float2bfloat16(i).value());
352         }
353 
354         accelerator.compute(computeContext -> TestBFloat16Type.compute01(computeContext, arrayA, arrayB));
355 
356         for (int i = 0; i < size; i++) {
357             BF16 result = arrayB.array(i);
358             HATAsserts.assertEquals((float)i, BF16.bfloat162float(result), 0.001f);
359         }
360     }
361 
362     @HatTest
363     @Reflect
364     public void test_bfloat16_02() {
365         final int size = 256;
366         var accelerator = new Accelerator(MethodHandles.lookup(), Backend.FIRST);
367 
368         BF16Array arrayA = BF16Array.create(accelerator, size);
369         BF16Array arrayB = BF16Array.create(accelerator, size);
370         BF16Array arrayC = BF16Array.create(accelerator, size);
371 
372         Random r = new Random(19);
373         for (int i = 0; i < size; i++) {
374             arrayA.array(i).value(BF16.float2bfloat16(r.nextFloat()).value());
375             arrayA.array(i).value(BF16.float2bfloat16(r.nextFloat()).value());
376         }
377 
378         accelerator.compute(computeContext -> TestBFloat16Type.compute02(computeContext, arrayA, arrayB, arrayC));
379 
380         for (int i = 0; i < size; i++) {
381             BF16 result = arrayC.array(i);
382             BF16 a = arrayA.array(i);
383             BF16 b = arrayB.array(i);
384             float res = BF16.bfloat162float(a) + BF16.bfloat162float(b);
385             HATAsserts.assertEquals(res, BF16.bfloat162float(result), 0.001f);
386         }
387     }
388     @HatTest
389     @Reflect
390     public void test_bfloat16_03() {
391         var accelerator = new Accelerator(MethodHandles.lookup(), Backend.FIRST);
392 
393         final int size = 256;
394         BF16Array arrayA = BF16Array.create(accelerator, size);
395         BF16Array arrayB = BF16Array.create(accelerator, size);
396         BF16Array arrayC = BF16Array.create(accelerator, size);
397 
398         Random random = new Random();
399         for (int i = 0; i < arrayA.length(); i++) {
400             arrayA.array(i).value(BF16.float2bfloat16(random.nextFloat()).value());
401             arrayB.array(i).value(BF16.float2bfloat16(random.nextFloat()).value());
402         }
403 
404         accelerator.compute(computeContext -> TestBFloat16Type.compute03(computeContext, arrayA, arrayB, arrayC));
405 
406         for (int i = 0; i < arrayC.length(); i++) {
407             BF16 val = arrayC.array(i);
408             float fa = BF16.bfloat162float(arrayA.array(i));
409             float fb = BF16.bfloat162float(arrayB.array(i));
410             HATAsserts.assertEquals((fa + fb + fb), BF16.bfloat162float(val), 0.01f);
411         }
412     }
413 
414     @HatTest
415     @Reflect
416     public void test_bfloat16_04() {
417         var accelerator = new Accelerator(MethodHandles.lookup(), Backend.FIRST);
418 
419         final int size = 256;
420         BF16Array arrayA = BF16Array.create(accelerator, size);
421         BF16Array arrayB = BF16Array.create(accelerator, size);
422         BF16Array arrayC = BF16Array.create(accelerator, size);
423 
424         Random random = new Random();
425         for (int i = 0; i < arrayA.length(); i++) {
426             arrayA.array(i).value(BF16.float2bfloat16(random.nextFloat()).value());
427             arrayB.array(i).value(BF16.float2bfloat16(random.nextFloat()).value());
428         }
429 
430         accelerator.compute(computeContext -> {
431             TestBFloat16Type.compute04(computeContext, arrayA, arrayB, arrayC);
432         });
433 
434         for (int i = 0; i < arrayC.length(); i++) {
435             BF16 gotResult = arrayC.array(i);
436 
437             // CPU Computation
438             BF16 ha = arrayA.array(i);
439             BF16 hb = arrayB.array(i);
440             BF16 r1 = BF16.mul(ha, hb);
441             BF16 r2 = BF16.div(ha, hb);
442             BF16 r3 = BF16.sub(ha, hb);
443             BF16 r4 = BF16.add(r1, r2);
444             BF16 r5 = BF16.add(r4, r3);
445 
446             HATAsserts.assertEquals(BF16.bfloat162float(r5), BF16.bfloat162float(gotResult), 0.01f);
447         }
448     }
449 
450     @HatTest
451     @Reflect
452     public void test_bfloat16_05() {
453         var accelerator = new Accelerator(MethodHandles.lookup(), Backend.FIRST);
454 
455         final int size = 16;
456         BF16Array arrayA = BF16Array.create(accelerator, size);
457         for (int i = 0; i < arrayA.length(); i++) {
458             arrayA.array(i).value(BF16.float2bfloat16(0.0f).value());
459         }
460 
461         accelerator.compute(computeContext -> {
462             TestBFloat16Type.compute05(computeContext, arrayA);
463         });
464 
465         for (int i = 0; i < arrayA.length(); i++) {
466             BF16 val = arrayA.array(i);
467             HATAsserts.assertEquals(2.1f, BF16.bfloat162float(val), 0.01f);
468         }
469     }
470 
471     @HatTest
472     @Reflect
473     public void test_bfloat16_06() {
474         var accelerator = new Accelerator(MethodHandles.lookup(), Backend.FIRST);
475 
476         final int size = 512;
477         BF16Array arrayA = BF16Array.create(accelerator, size);
478         for (int i = 0; i < arrayA.length(); i++) {
479             arrayA.array(i).value(BF16.float2bfloat16(0.0f).value());
480         }
481 
482         accelerator.compute(computeContext -> {
483             TestBFloat16Type.compute06(computeContext, arrayA);
484         });
485 
486         for (int i = 0; i < arrayA.length(); i++) {
487             BF16 val = arrayA.array(i);
488             try {
489                 HATAsserts.assertEquals(i, BF16.bfloat162float(val), 0.01f);
490             } catch (HATAssertionError hatAssertionError) {
491                 throw new HATExpectedPrecisionError(hatAssertionError.getMessage());
492             }
493 
494         }
495     }
496 
497     @HatTest
498     @Reflect
499     public void test_bfloat16_07() {
500         // Test CPU Implementation of BF16
501         BF16 a = BF16.of(2.5f);
502         BF16 b = BF16.of(3.5f);
503         BF16 c = BF16.add(a, b);
504         HATAsserts.assertEquals((2.5f + 3.5f), BF16.bfloat162float(c), 0.01f);
505 
506         BF16 d = BF16.sub(a, b);
507         HATAsserts.assertEquals((2.5f - 3.5f), BF16.bfloat162float(d), 0.01f);
508 
509         BF16 e = BF16.mul(a, b);
510         HATAsserts.assertEquals((2.5f * 3.5f), BF16.bfloat162float(e), 0.01f);
511 
512         BF16 f = BF16.div(a, b);
513         HATAsserts.assertEquals((2.5f / 3.5f), BF16.bfloat162float(f), 0.01f);
514     }
515 
516     @HatTest
517     @Reflect
518     public void test_bfloat16_08() {
519         var accelerator = new Accelerator(MethodHandles.lookup(), Backend.FIRST);
520 
521         final int size = 256;
522         BF16Array arrayA = BF16Array.create(accelerator, size);
523         for (int i = 0; i < arrayA.length(); i++) {
524             arrayA.array(i).value(BF16.float2bfloat16(0.0f).value());
525         }
526 
527         accelerator.compute(computeContext -> {
528             TestBFloat16Type.compute08(computeContext, arrayA);
529         });
530 
531         for (int i = 0; i < arrayA.length(); i++) {
532             BF16 val = arrayA.array(i);
533             HATAsserts.assertEquals(i, BF16.bfloat162float(val), 0.01f);
534         }
535     }
536 
537     @HatTest
538     @Reflect
539     public void test_bfloat16_09() {
540         var accelerator = new Accelerator(MethodHandles.lookup(), Backend.FIRST);
541 
542         final int size = 16;
543         BF16Array arrayA = BF16Array.create(accelerator, size);
544         BF16Array arrayB = BF16Array.create(accelerator, size);
545 
546         Random r = new Random(73);
547         for (int i = 0; i < arrayA.length(); i++) {
548             arrayA.array(i).value(BF16.float2bfloat16(r.nextFloat()).value());
549         }
550 
551         accelerator.compute(computeContext -> TestBFloat16Type.compute09(computeContext, arrayA, arrayB));
552 
553         for (int i = 0; i < arrayB.length(); i++) {
554             BF16 val = arrayB.array(i);
555             HATAsserts.assertEquals(BF16.bfloat162float(arrayA.array(i)), BF16.bfloat162float(val), 0.01f);
556         }
557     }
558 
559     @HatTest
560     @Reflect
561     public void test_bfloat16_10() {
562         var accelerator = new Accelerator(MethodHandles.lookup(), Backend.FIRST);
563         final int size = 256;
564         BF16Array arrayA = BF16Array.create(accelerator, size);
565 
566         accelerator.compute(computeContext -> TestBFloat16Type.compute10(computeContext, arrayA));
567 
568         for (int i = 0; i < arrayA.length(); i++) {
569             BF16 val = arrayA.array(i);
570             HATAsserts.assertEquals(1.1f, BF16.bfloat162float(val), 0.01f);
571         }
572     }
573 
574     @HatTest
575     @Reflect
576     public void test_bfloat16_11() {
577         var accelerator = new Accelerator(MethodHandles.lookup(), Backend.FIRST);
578         final int size = 256;
579         BF16Array arrayA = BF16Array.create(accelerator, size);
580         BF16Array arrayB = BF16Array.create(accelerator, size);
581 
582         Random r = new Random(73);
583         for (int i = 0; i < arrayA.length(); i++) {
584             arrayA.array(i).value(BF16.float2bfloat16(r.nextFloat()).value());
585         }
586 
587         accelerator.compute(computeContext -> TestBFloat16Type.compute11(computeContext, arrayA, arrayB));
588 
589         for (int i = 0; i < arrayB.length(); i++) {
590             BF16 val = arrayB.array(i);
591             HATAsserts.assertEquals(arrayA.array(i).value(), val.value());
592         }
593     }
594 
595     @HatTest
596     @Reflect
597     public void test_bfloat16_12() {
598         var accelerator = new Accelerator(MethodHandles.lookup(), Backend.FIRST);
599         final int size = 1024;
600         BF16Array arrayA = BF16Array.create(accelerator, size);
601         BF16Array arrayB = BF16Array.create(accelerator, size);
602         BF16Array arrayC = BF16Array.create(accelerator, size);
603 
604         Random r = new Random(73);
605         for (int i = 0; i < arrayA.length(); i++) {
606             arrayA.array(i).value(BF16.float2bfloat16(r.nextFloat()).value());
607             arrayB.array(i).value(BF16.float2bfloat16(r.nextFloat()).value());
608         }
609 
610         accelerator.compute(computeContext -> TestBFloat16Type.compute12(computeContext, arrayA, arrayB, arrayC));
611 
612         for (int i = 0; i < arrayB.length(); i++) {
613             BF16 result = arrayC.array(i);
614             HATAsserts.assertEquals(BF16.bfloat162float(BF16.add(arrayA.array(i), arrayB.array(i))), BF16.bfloat162float(result), 0.01f);
615         }
616     }
617 
618     @HatTest
619     @Reflect
620     public void test_bfloat16_13() {
621         var accelerator = new Accelerator(MethodHandles.lookup(), Backend.FIRST);
622         final int size = 1024;
623         BF16Array arrayA = BF16Array.create(accelerator, size);
624         BF16Array arrayB = BF16Array.create(accelerator, size);
625         BF16Array arrayC = BF16Array.create(accelerator, size);
626 
627         Random r = new Random(73);
628         for (int i = 0; i < arrayA.length(); i++) {
629             arrayA.array(i).value(BF16.float2bfloat16(r.nextFloat()).value());
630             arrayB.array(i).value(BF16.float2bfloat16(r.nextFloat()).value());
631         }
632 
633         accelerator.compute(computeContext -> TestBFloat16Type.compute13(computeContext, arrayA, arrayB, arrayC));
634 
635         for (int i = 0; i < arrayB.length(); i++) {
636             BF16 result = arrayC.array(i);
637             HATAsserts.assertEquals(BF16.bfloat162float(arrayA.array(i)), BF16.bfloat162float(result), 0.01f);
638         }
639     }
640 
641     @HatTest
642     @Reflect
643     public void test_bfloat16_14() {
644         // Testing mixed types
645         var accelerator = new Accelerator(MethodHandles.lookup(), Backend.FIRST);
646         final int size = 1024;
647         BF16Array arrayA = BF16Array.create(accelerator, size);
648         BF16Array arrayB = BF16Array.create(accelerator, size);
649 
650         Random r = new Random(73);
651         for (int i = 0; i < arrayA.length(); i++) {
652             arrayA.array(i).value(BF16.float2bfloat16(r.nextFloat()).value());
653         }
654 
655         accelerator.compute(computeContext -> TestBFloat16Type.compute14(computeContext, arrayA, arrayB));
656 
657         for (int i = 0; i < arrayB.length(); i++) {
658             BF16 result = arrayB.array(i);
659             try {
660                 HATAsserts.assertEquals(BF16.bfloat162float(arrayA.array(i)) + 32.1f, BF16.bfloat162float(result), 0.1f);
661             } catch (HATAssertionError hatAssertionError) {
662                 throw new HATExpectedPrecisionError(hatAssertionError.getMessage());
663             }
664         }
665     }
666 
667     @HatTest
668     @Reflect
669     public void test_bfloat16_15() {
670         var accelerator = new Accelerator(MethodHandles.lookup(), Backend.FIRST);
671         final int size = 256;
672         BF16Array arrayA = BF16Array.create(accelerator, size);
673         BF16Array arrayB = BF16Array.create(accelerator, size);
674 
675         Random r = new Random(73);
676         for (int i = 0; i < arrayA.length(); i++) {
677             arrayA.array(i).value(BF16.float2bfloat16(r.nextFloat()).value());
678         }
679 
680         accelerator.compute(computeContext -> TestBFloat16Type.compute15(computeContext, arrayA, arrayB));
681 
682         for (int i = 0; i < arrayB.length(); i++) {
683             BF16 val = arrayB.array(i);
684             HATAsserts.assertEquals(arrayA.array(i).value(), val.value());
685         }
686     }
687 
688     // Check accumulators
689     @HatTest
690     @Reflect
691     public void test_bfloat16_16() {
692         var accelerator = new Accelerator(MethodHandles.lookup(), Backend.FIRST);
693         final int size = 1;
694         BF16Array arrayA = BF16Array.create(accelerator, size);
695 
696         Random r = new Random(73);
697         arrayA.array(0).value(BF16.float2bfloat16(10).value());
698 
699         accelerator.compute(computeContext -> TestBFloat16Type.compute16(computeContext, arrayA));
700 
701         BF16 val = arrayA.array(0);
702         HATAsserts.assertEquals(40.0f, BF16.bfloat162float(val), 0.01f);
703     }
704 
705     // Check accumulators in private memory
706     @HatTest
707     @Reflect
708     public void test_bfloat16_17() {
709         var accelerator = new Accelerator(MethodHandles.lookup(), Backend.FIRST);
710         final int size = 1;
711         BF16Array arrayA = BF16Array.create(accelerator, size);
712 
713         Random r = new Random(73);
714         arrayA.array(0).value(BF16.float2bfloat16(10).value());
715 
716         accelerator.compute(computeContext -> TestBFloat16Type.compute17(computeContext, arrayA));
717 
718         BF16 val = arrayA.array(0);
719         HATAsserts.assertEquals(20.0f, BF16.bfloat162float(val), 0.01f);
720     }
721 
722 }