1 /*
2 * Copyright (c) 2025-2026, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation. Oracle designates this
8 * particular file as subject to the "Classpath" exception as provided
9 * by Oracle in the LICENSE file that accompanied this code.
10 *
11 * This code is distributed in the hope that it will be useful, but WITHOUT
12 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
13 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
14 * version 2 for more details (a copy is included in the LICENSE file that
15 * accompanied this code).
16 *
17 * You should have received a copy of the GNU General Public License version
18 * 2 along with this work; if not, write to the Free Software Foundation,
19 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
20 *
21 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
22 * or visit www.oracle.com if you need additional information or have any
23 * questions.
24 */
25 package hat.test;
26
27 import hat.Accelerator;
28 import hat.ComputeContext;
29 import hat.KernelContext;
30 import hat.NDRange;
31 import hat.backend.Backend;
32 import hat.types.BF16;
33 import hat.buffer.BF16Array;
34 import hat.device.DeviceSchema;
35 import hat.device.NonMappableIface;
36 import hat.test.annotation.HatTest;
37 import hat.test.exceptions.HATAssertionError;
38 import hat.test.exceptions.HATAsserts;
39 import hat.test.exceptions.HATExpectedPrecisionError;
40 import jdk.incubator.code.Reflect;
41 import optkl.ifacemapper.MappableIface.*;
42
43 import java.lang.invoke.MethodHandles;
44 import java.util.Random;
45
46 public class TestBFloat16Type {
47
48 @Reflect
49 public static void kernel_copy(@RO KernelContext kernelContext, @RO BF16Array a, @WO BF16Array b) {
50 if (kernelContext.gix < kernelContext.gsx) {
51 BF16 ha = a.array(kernelContext.gix);
52 b.array(kernelContext.gix).value(ha.value());
53 }
54 }
55
56 @Reflect
57 public static void bf16_02(@RO KernelContext kernelContext, @RO BF16Array a, @RO BF16Array b, @WO BF16Array c) {
58 if (kernelContext.gix < kernelContext.gsx) {
59 BF16 ha = a.array(kernelContext.gix);
60 BF16 hb = b.array(kernelContext.gix);
61 BF16 result = BF16.add(ha, hb);
62 BF16 hc = c.array(kernelContext.gix);
63 hc.value(result.value());
64 }
65 }
66
67 @Reflect
68 public static void bf16_03(@RO KernelContext kernelContext, @RO BF16Array a, @RO BF16Array b, @WO BF16Array c) {
69 if (kernelContext.gix < kernelContext.gsx) {
70 BF16 ha = a.array(kernelContext.gix);
71 BF16 hb = b.array(kernelContext.gix);
72
73 BF16 result = BF16.add(ha, BF16.add(hb, hb));
74 BF16 hC = c.array(kernelContext.gix);
75 hC.value(result.value());
76 }
77 }
78
79 @Reflect
80 public static void bf16_04(@RO KernelContext kernelContext, @RO BF16Array a, @RO BF16Array b, @WO BF16Array c) {
81 if (kernelContext.gix < kernelContext.gsx) {
82 BF16 ha = a.array(kernelContext.gix);
83 BF16 hb = b.array(kernelContext.gix);
84
85 BF16 r1 = BF16.mul(ha, hb);
86 BF16 r2 = BF16.div(ha, hb);
87 BF16 r3 = BF16.sub(ha, hb);
88 BF16 r4 = BF16.add(r1, r2);
89 BF16 r5 = BF16.add(r4, r3);
90 BF16 hC = c.array(kernelContext.gix);
91 hC.value(r5.value());
92 }
93 }
94
95 @Reflect
96 public static void bf16_05(@RO KernelContext kernelContext, @WO BF16Array a) {
97 if (kernelContext.gix < kernelContext.gsx) {
98 BF16 ha = a.array(kernelContext.gix);
99 BF16 initVal = BF16.of( 2.1f);
100 ha.value(initVal.value());
101 }
102 }
103
104 @Reflect
105 public static void bf16_06(@RO KernelContext kernelContext, @WO BF16Array a) {
106 if (kernelContext.gix < kernelContext.gsx) {
107 BF16 initVal = BF16.of(kernelContext.gix);
108 BF16 ha = a.array(kernelContext.gix);
109 ha.value(initVal.value());
110 }
111 }
112
113 @Reflect
114 public static void bf16_08(@RO KernelContext kernelContext, @WO BF16Array a) {
115 if (kernelContext.gix < kernelContext.gsx) {
116 BF16 initVal = BF16.float2bfloat16(kernelContext.gix);
117 BF16 ha = a.array(kernelContext.gix);
118 ha.value(initVal.value());
119 }
120 }
121
122 @Reflect
123 public static void bf16_09(@RO KernelContext kernelContext, @RO BF16Array a, @WO BF16Array b) {
124 if (kernelContext.gix < kernelContext.gsx) {
125 BF16 ha = a.array(kernelContext.gix);
126 float f = BF16.bfloat162float(ha);
127 BF16 result = BF16.float2bfloat16(f);
128 BF16 hb = b.array(kernelContext.gix);
129 hb.value(result.value());
130 }
131 }
132
133 @Reflect
134 public static void bf16_10(@RO KernelContext kernelContext, @WO BF16Array a) {
135 if (kernelContext.gix < kernelContext.gsx) {
136 BF16 ha = a.array(kernelContext.gix);
137 BF16 f16 = BF16.of(1.1f);
138 float f = BF16.bfloat162float(f16);
139 BF16 result = BF16.float2bfloat16(f);
140 ha.value(result.value());
141 }
142 }
143
144 public interface LocalArray extends NonMappableIface {
145 BF16 array(int index);
146 DeviceSchema<LocalArray> schema = DeviceSchema.of(LocalArray.class,
147 builder -> builder.withArray("array", 1024)
148 .withDeps(BF16.class, bfloat16 -> bfloat16.withField("value")));
149
150 static LocalArray create(Accelerator accelerator) {
151 return null;
152 }
153
154 static LocalArray createLocal() {
155 return null;
156 }
157 }
158
159 @Reflect
160 public static void bf16_11(@RO KernelContext kernelContext, @RO BF16Array a, @WO BF16Array b) {
161 LocalArray sm = LocalArray.createLocal();
162 if (kernelContext.gix < kernelContext.gsx) {
163 int lix = kernelContext.lix;
164 BF16 ha = a.array(kernelContext.gix);
165
166 sm.array(lix).value(ha.value());
167 kernelContext.barrier();
168
169 BF16 hb = sm.array(lix);
170 b.array(kernelContext.gix).value(hb.value());
171 }
172 }
173
174 @Reflect
175 public static void bf16_12(@RO KernelContext kernelContext, @RO BF16Array a, @RO BF16Array b, @WO BF16Array c) {
176 // Test the fluent API style
177 if (kernelContext.gix < kernelContext.gsx) {
178 BF16 ha = a.array(kernelContext.gix);
179 BF16 hb = b.array(kernelContext.gix);
180 BF16 result = ha.add(hb);
181 c.array(kernelContext.gix).value(result.value());
182 }
183 }
184
185 @Reflect
186 public static void bf16_13(@RO KernelContext kernelContext, @RO BF16Array a, @RO BF16Array b, @WO BF16Array c) {
187 // Test the fluent API style
188 if (kernelContext.gix < kernelContext.gsx) {
189 BF16 ha = a.array(kernelContext.gix);
190 BF16 hb = b.array(kernelContext.gix);
191 BF16 result = ha.add(hb).sub(hb).mul(ha).div(ha);
192 c.array(kernelContext.gix).value(result.value());
193 }
194 }
195
196 @Reflect
197 public static void bf16_14(@RO KernelContext kernelContext, @RO BF16Array a, @WO BF16Array b) {
198 // Testing mixed float types
199 if (kernelContext.gix < kernelContext.gsx) {
200 BF16 ha = a.array(kernelContext.gix);
201 float myFloat = 32.1f;
202 BF16 result = BF16.add(myFloat, ha);
203 b.array(kernelContext.gix).value(result.value());
204 }
205 }
206
207 public interface PrivateArray extends NonMappableIface {
208 BF16 array(int index);
209 DeviceSchema<PrivateArray> schema = DeviceSchema.of(PrivateArray.class,
210 builder -> builder.withArray("array", 256)
211 .withDeps(BF16.class, bfloat16 -> bfloat16.withField("value")));
212
213 static PrivateArray create(Accelerator accelerator) {
214 return null;
215 }
216
217 static PrivateArray createPrivate() {
218 return null;
219 }
220 }
221
222 @Reflect
223 public static void bf16_15(@RO KernelContext kernelContext, @RO BF16Array a, @WO BF16Array b) {
224 PrivateArray privateArray = PrivateArray.createPrivate();
225 if (kernelContext.gix < kernelContext.gsx) {
226 int lix = kernelContext.lix;
227 BF16 ha = a.array(kernelContext.gix);
228 privateArray.array(lix).value(ha.value());
229 BF16 hb = privateArray.array(lix);
230 b.array(kernelContext.gix).value(hb.value());
231 }
232 }
233
234 @Reflect
235 public static void bf16_16(@RO KernelContext kernelContext, @RW BF16Array a) {
236 BF16 ha = a.array(0);
237 BF16 hre = BF16.add(ha, ha);
238 hre = BF16.add(hre, hre);
239 a.array(0).value(hre.value());
240 }
241
242 @Reflect
243 public static void bf16_17(@RO KernelContext kernelContext, @RW BF16Array a) {
244
245 BF16 ha = a.array(0);
246 PrivateArray privateArray = PrivateArray.createPrivate();
247 privateArray.array(0).value(ha.value());
248
249 // Obtain the value from private memory
250 BF16 acc = privateArray.array(0);
251
252 // compute
253 acc = BF16.add(acc, acc);
254
255 // store the result
256 a.array(0).value(acc.value());
257 }
258
259 @Reflect
260 public static void compute01(@RO ComputeContext computeContext, @RO BF16Array a, @WO BF16Array b) {
261 computeContext.dispatchKernel(NDRange.of1D(a.length()), kernelContext -> TestBFloat16Type.kernel_copy(kernelContext, a, b));
262 }
263
264 @Reflect
265 public static void compute02(@RO ComputeContext computeContext, @RO BF16Array a, @RO BF16Array b, @WO BF16Array c) {
266 computeContext.dispatchKernel(NDRange.of1D(a.length()),
267 kernelContext -> TestBFloat16Type.bf16_02(kernelContext, a, b, c));
268 }
269
270 @Reflect
271 public static void compute03(@RO ComputeContext computeContext, @RO BF16Array a, @RO BF16Array b, @WO BF16Array c) {
272 computeContext.dispatchKernel(NDRange.of1D(a.length()),
273 kernelContext -> TestBFloat16Type.bf16_03(kernelContext, a, b, c));
274 }
275
276 @Reflect
277 public static void compute04(@RO ComputeContext computeContext, @RO BF16Array a, @RO BF16Array b, @WO BF16Array c) {
278 computeContext.dispatchKernel(NDRange.of1D(a.length()),
279 kernelContext -> TestBFloat16Type.bf16_04(kernelContext, a, b, c));
280 }
281
282 @Reflect
283 public static void compute05(@RO ComputeContext computeContext, @WO BF16Array a) {
284 computeContext.dispatchKernel(NDRange.of1D(a.length()), kernelContext -> TestBFloat16Type.bf16_05(kernelContext, a));
285 }
286
287 @Reflect
288 public static void compute06(@RO ComputeContext computeContext, @WO BF16Array a) {
289 computeContext.dispatchKernel(NDRange.of1D(a.length()), kernelContext -> TestBFloat16Type.bf16_06(kernelContext, a));
290 }
291
292 @Reflect
293 public static void compute08(@RO ComputeContext computeContext, @WO BF16Array a) {
294 computeContext.dispatchKernel(NDRange.of1D(a.length()), kernelContext -> TestBFloat16Type.bf16_08(kernelContext, a));
295 }
296
297 @Reflect
298 public static void compute09(@RO ComputeContext computeContext, @RW BF16Array a, @WO BF16Array b) {
299 computeContext.dispatchKernel(NDRange.of1D(a.length()), kernelContext -> TestBFloat16Type.bf16_09(kernelContext, a, b));
300 }
301
302 @Reflect
303 public static void compute10(@RO ComputeContext computeContext, @WO BF16Array a) {
304 computeContext.dispatchKernel(NDRange.of1D(a.length()), kernelContext -> TestBFloat16Type.bf16_10(kernelContext, a));
305 }
306
307 @Reflect
308 public static void compute11(@RO ComputeContext computeContext, @RO BF16Array a, @WO BF16Array b) {
309 computeContext.dispatchKernel(NDRange.of1D(a.length(),16), kernelContext -> TestBFloat16Type.bf16_11(kernelContext, a, b));
310 }
311
312 @Reflect
313 public static void compute12(@RO ComputeContext computeContext, @RO BF16Array a, @RO BF16Array b, @WO BF16Array c) {
314 computeContext.dispatchKernel(NDRange.of1D(a.length()), kernelContext -> TestBFloat16Type.bf16_12(kernelContext, a, b, c));
315 }
316
317 @Reflect
318 public static void compute13(@RO ComputeContext computeContext, @RO BF16Array a, @RO BF16Array b, @WO BF16Array c) {
319 computeContext.dispatchKernel(NDRange.of1D(a.length()), kernelContext -> TestBFloat16Type.bf16_13(kernelContext, a, b, c));
320 }
321
322 @Reflect
323 public static void compute14(@RO ComputeContext computeContext, @RO BF16Array a, @WO BF16Array b) {
324 computeContext.dispatchKernel(NDRange.of1D(a.length()), kernelContext -> TestBFloat16Type.bf16_14(kernelContext, a, b));
325 }
326
327 @Reflect
328 public static void compute15(@RO ComputeContext computeContext, @RO BF16Array a, @WO BF16Array b) {
329 computeContext.dispatchKernel(NDRange.of1D(a.length()), kernelContext -> TestBFloat16Type.bf16_15(kernelContext, a, b));
330 }
331
332 @Reflect
333 public static void compute16(@RO ComputeContext computeContext, @RW BF16Array a) {
334 computeContext.dispatchKernel(NDRange.of1D(1), kernelContext -> TestBFloat16Type.bf16_16(kernelContext, a));
335 }
336
337 @Reflect
338 public static void compute17(@RO ComputeContext computeContext, @RW BF16Array a) {
339 computeContext.dispatchKernel(NDRange.of1D(1), kernelContext -> TestBFloat16Type.bf16_17(kernelContext, a));
340 }
341
342 @HatTest
343 @Reflect
344 public void test_bfloat16_01() {
345 final int size = 256;
346 var accelerator = new Accelerator(MethodHandles.lookup(), Backend.FIRST);
347
348 BF16Array arrayA = BF16Array.create(accelerator, size);
349 BF16Array arrayB = BF16Array.create(accelerator, size);
350 for (int i = 0; i < size; i++) {
351 arrayA.array(i).value(BF16.float2bfloat16(i).value());
352 }
353
354 accelerator.compute(computeContext -> TestBFloat16Type.compute01(computeContext, arrayA, arrayB));
355
356 for (int i = 0; i < size; i++) {
357 BF16 result = arrayB.array(i);
358 HATAsserts.assertEquals((float)i, BF16.bfloat162float(result), 0.001f);
359 }
360 }
361
362 @HatTest
363 @Reflect
364 public void test_bfloat16_02() {
365 final int size = 256;
366 var accelerator = new Accelerator(MethodHandles.lookup(), Backend.FIRST);
367
368 BF16Array arrayA = BF16Array.create(accelerator, size);
369 BF16Array arrayB = BF16Array.create(accelerator, size);
370 BF16Array arrayC = BF16Array.create(accelerator, size);
371
372 Random r = new Random(19);
373 for (int i = 0; i < size; i++) {
374 arrayA.array(i).value(BF16.float2bfloat16(r.nextFloat()).value());
375 arrayA.array(i).value(BF16.float2bfloat16(r.nextFloat()).value());
376 }
377
378 accelerator.compute(computeContext -> TestBFloat16Type.compute02(computeContext, arrayA, arrayB, arrayC));
379
380 for (int i = 0; i < size; i++) {
381 BF16 result = arrayC.array(i);
382 BF16 a = arrayA.array(i);
383 BF16 b = arrayB.array(i);
384 float res = BF16.bfloat162float(a) + BF16.bfloat162float(b);
385 HATAsserts.assertEquals(res, BF16.bfloat162float(result), 0.001f);
386 }
387 }
388 @HatTest
389 @Reflect
390 public void test_bfloat16_03() {
391 var accelerator = new Accelerator(MethodHandles.lookup(), Backend.FIRST);
392
393 final int size = 256;
394 BF16Array arrayA = BF16Array.create(accelerator, size);
395 BF16Array arrayB = BF16Array.create(accelerator, size);
396 BF16Array arrayC = BF16Array.create(accelerator, size);
397
398 Random random = new Random();
399 for (int i = 0; i < arrayA.length(); i++) {
400 arrayA.array(i).value(BF16.float2bfloat16(random.nextFloat()).value());
401 arrayB.array(i).value(BF16.float2bfloat16(random.nextFloat()).value());
402 }
403
404 accelerator.compute(computeContext -> TestBFloat16Type.compute03(computeContext, arrayA, arrayB, arrayC));
405
406 for (int i = 0; i < arrayC.length(); i++) {
407 BF16 val = arrayC.array(i);
408 float fa = BF16.bfloat162float(arrayA.array(i));
409 float fb = BF16.bfloat162float(arrayB.array(i));
410 HATAsserts.assertEquals((fa + fb + fb), BF16.bfloat162float(val), 0.01f);
411 }
412 }
413
414 @HatTest
415 @Reflect
416 public void test_bfloat16_04() {
417 var accelerator = new Accelerator(MethodHandles.lookup(), Backend.FIRST);
418
419 final int size = 256;
420 BF16Array arrayA = BF16Array.create(accelerator, size);
421 BF16Array arrayB = BF16Array.create(accelerator, size);
422 BF16Array arrayC = BF16Array.create(accelerator, size);
423
424 Random random = new Random();
425 for (int i = 0; i < arrayA.length(); i++) {
426 arrayA.array(i).value(BF16.float2bfloat16(random.nextFloat()).value());
427 arrayB.array(i).value(BF16.float2bfloat16(random.nextFloat()).value());
428 }
429
430 accelerator.compute(computeContext -> {
431 TestBFloat16Type.compute04(computeContext, arrayA, arrayB, arrayC);
432 });
433
434 for (int i = 0; i < arrayC.length(); i++) {
435 BF16 gotResult = arrayC.array(i);
436
437 // CPU Computation
438 BF16 ha = arrayA.array(i);
439 BF16 hb = arrayB.array(i);
440 BF16 r1 = BF16.mul(ha, hb);
441 BF16 r2 = BF16.div(ha, hb);
442 BF16 r3 = BF16.sub(ha, hb);
443 BF16 r4 = BF16.add(r1, r2);
444 BF16 r5 = BF16.add(r4, r3);
445
446 HATAsserts.assertEquals(BF16.bfloat162float(r5), BF16.bfloat162float(gotResult), 0.01f);
447 }
448 }
449
450 @HatTest
451 @Reflect
452 public void test_bfloat16_05() {
453 var accelerator = new Accelerator(MethodHandles.lookup(), Backend.FIRST);
454
455 final int size = 16;
456 BF16Array arrayA = BF16Array.create(accelerator, size);
457 for (int i = 0; i < arrayA.length(); i++) {
458 arrayA.array(i).value(BF16.float2bfloat16(0.0f).value());
459 }
460
461 accelerator.compute(computeContext -> {
462 TestBFloat16Type.compute05(computeContext, arrayA);
463 });
464
465 for (int i = 0; i < arrayA.length(); i++) {
466 BF16 val = arrayA.array(i);
467 HATAsserts.assertEquals(2.1f, BF16.bfloat162float(val), 0.01f);
468 }
469 }
470
471 @HatTest
472 @Reflect
473 public void test_bfloat16_06() {
474 var accelerator = new Accelerator(MethodHandles.lookup(), Backend.FIRST);
475
476 final int size = 512;
477 BF16Array arrayA = BF16Array.create(accelerator, size);
478 for (int i = 0; i < arrayA.length(); i++) {
479 arrayA.array(i).value(BF16.float2bfloat16(0.0f).value());
480 }
481
482 accelerator.compute(computeContext -> {
483 TestBFloat16Type.compute06(computeContext, arrayA);
484 });
485
486 for (int i = 0; i < arrayA.length(); i++) {
487 BF16 val = arrayA.array(i);
488 try {
489 HATAsserts.assertEquals(i, BF16.bfloat162float(val), 0.01f);
490 } catch (HATAssertionError hatAssertionError) {
491 throw new HATExpectedPrecisionError(hatAssertionError.getMessage());
492 }
493
494 }
495 }
496
497 @HatTest
498 @Reflect
499 public void test_bfloat16_07() {
500 // Test CPU Implementation of BF16
501 BF16 a = BF16.of(2.5f);
502 BF16 b = BF16.of(3.5f);
503 BF16 c = BF16.add(a, b);
504 HATAsserts.assertEquals((2.5f + 3.5f), BF16.bfloat162float(c), 0.01f);
505
506 BF16 d = BF16.sub(a, b);
507 HATAsserts.assertEquals((2.5f - 3.5f), BF16.bfloat162float(d), 0.01f);
508
509 BF16 e = BF16.mul(a, b);
510 HATAsserts.assertEquals((2.5f * 3.5f), BF16.bfloat162float(e), 0.01f);
511
512 BF16 f = BF16.div(a, b);
513 HATAsserts.assertEquals((2.5f / 3.5f), BF16.bfloat162float(f), 0.01f);
514 }
515
516 @HatTest
517 @Reflect
518 public void test_bfloat16_08() {
519 var accelerator = new Accelerator(MethodHandles.lookup(), Backend.FIRST);
520
521 final int size = 256;
522 BF16Array arrayA = BF16Array.create(accelerator, size);
523 for (int i = 0; i < arrayA.length(); i++) {
524 arrayA.array(i).value(BF16.float2bfloat16(0.0f).value());
525 }
526
527 accelerator.compute(computeContext -> {
528 TestBFloat16Type.compute08(computeContext, arrayA);
529 });
530
531 for (int i = 0; i < arrayA.length(); i++) {
532 BF16 val = arrayA.array(i);
533 HATAsserts.assertEquals(i, BF16.bfloat162float(val), 0.01f);
534 }
535 }
536
537 @HatTest
538 @Reflect
539 public void test_bfloat16_09() {
540 var accelerator = new Accelerator(MethodHandles.lookup(), Backend.FIRST);
541
542 final int size = 16;
543 BF16Array arrayA = BF16Array.create(accelerator, size);
544 BF16Array arrayB = BF16Array.create(accelerator, size);
545
546 Random r = new Random(73);
547 for (int i = 0; i < arrayA.length(); i++) {
548 arrayA.array(i).value(BF16.float2bfloat16(r.nextFloat()).value());
549 }
550
551 accelerator.compute(computeContext -> TestBFloat16Type.compute09(computeContext, arrayA, arrayB));
552
553 for (int i = 0; i < arrayB.length(); i++) {
554 BF16 val = arrayB.array(i);
555 HATAsserts.assertEquals(BF16.bfloat162float(arrayA.array(i)), BF16.bfloat162float(val), 0.01f);
556 }
557 }
558
559 @HatTest
560 @Reflect
561 public void test_bfloat16_10() {
562 var accelerator = new Accelerator(MethodHandles.lookup(), Backend.FIRST);
563 final int size = 256;
564 BF16Array arrayA = BF16Array.create(accelerator, size);
565
566 accelerator.compute(computeContext -> TestBFloat16Type.compute10(computeContext, arrayA));
567
568 for (int i = 0; i < arrayA.length(); i++) {
569 BF16 val = arrayA.array(i);
570 HATAsserts.assertEquals(1.1f, BF16.bfloat162float(val), 0.01f);
571 }
572 }
573
574 @HatTest
575 @Reflect
576 public void test_bfloat16_11() {
577 var accelerator = new Accelerator(MethodHandles.lookup(), Backend.FIRST);
578 final int size = 256;
579 BF16Array arrayA = BF16Array.create(accelerator, size);
580 BF16Array arrayB = BF16Array.create(accelerator, size);
581
582 Random r = new Random(73);
583 for (int i = 0; i < arrayA.length(); i++) {
584 arrayA.array(i).value(BF16.float2bfloat16(r.nextFloat()).value());
585 }
586
587 accelerator.compute(computeContext -> TestBFloat16Type.compute11(computeContext, arrayA, arrayB));
588
589 for (int i = 0; i < arrayB.length(); i++) {
590 BF16 val = arrayB.array(i);
591 HATAsserts.assertEquals(arrayA.array(i).value(), val.value());
592 }
593 }
594
595 @HatTest
596 @Reflect
597 public void test_bfloat16_12() {
598 var accelerator = new Accelerator(MethodHandles.lookup(), Backend.FIRST);
599 final int size = 1024;
600 BF16Array arrayA = BF16Array.create(accelerator, size);
601 BF16Array arrayB = BF16Array.create(accelerator, size);
602 BF16Array arrayC = BF16Array.create(accelerator, size);
603
604 Random r = new Random(73);
605 for (int i = 0; i < arrayA.length(); i++) {
606 arrayA.array(i).value(BF16.float2bfloat16(r.nextFloat()).value());
607 arrayB.array(i).value(BF16.float2bfloat16(r.nextFloat()).value());
608 }
609
610 accelerator.compute(computeContext -> TestBFloat16Type.compute12(computeContext, arrayA, arrayB, arrayC));
611
612 for (int i = 0; i < arrayB.length(); i++) {
613 BF16 result = arrayC.array(i);
614 HATAsserts.assertEquals(BF16.bfloat162float(BF16.add(arrayA.array(i), arrayB.array(i))), BF16.bfloat162float(result), 0.01f);
615 }
616 }
617
618 @HatTest
619 @Reflect
620 public void test_bfloat16_13() {
621 var accelerator = new Accelerator(MethodHandles.lookup(), Backend.FIRST);
622 final int size = 1024;
623 BF16Array arrayA = BF16Array.create(accelerator, size);
624 BF16Array arrayB = BF16Array.create(accelerator, size);
625 BF16Array arrayC = BF16Array.create(accelerator, size);
626
627 Random r = new Random(73);
628 for (int i = 0; i < arrayA.length(); i++) {
629 arrayA.array(i).value(BF16.float2bfloat16(r.nextFloat()).value());
630 arrayB.array(i).value(BF16.float2bfloat16(r.nextFloat()).value());
631 }
632
633 accelerator.compute(computeContext -> TestBFloat16Type.compute13(computeContext, arrayA, arrayB, arrayC));
634
635 for (int i = 0; i < arrayB.length(); i++) {
636 BF16 result = arrayC.array(i);
637 HATAsserts.assertEquals(BF16.bfloat162float(arrayA.array(i)), BF16.bfloat162float(result), 0.01f);
638 }
639 }
640
641 @HatTest
642 @Reflect
643 public void test_bfloat16_14() {
644 // Testing mixed types
645 var accelerator = new Accelerator(MethodHandles.lookup(), Backend.FIRST);
646 final int size = 1024;
647 BF16Array arrayA = BF16Array.create(accelerator, size);
648 BF16Array arrayB = BF16Array.create(accelerator, size);
649
650 Random r = new Random(73);
651 for (int i = 0; i < arrayA.length(); i++) {
652 arrayA.array(i).value(BF16.float2bfloat16(r.nextFloat()).value());
653 }
654
655 accelerator.compute(computeContext -> TestBFloat16Type.compute14(computeContext, arrayA, arrayB));
656
657 for (int i = 0; i < arrayB.length(); i++) {
658 BF16 result = arrayB.array(i);
659 try {
660 HATAsserts.assertEquals(BF16.bfloat162float(arrayA.array(i)) + 32.1f, BF16.bfloat162float(result), 0.1f);
661 } catch (HATAssertionError hatAssertionError) {
662 throw new HATExpectedPrecisionError(hatAssertionError.getMessage());
663 }
664 }
665 }
666
667 @HatTest
668 @Reflect
669 public void test_bfloat16_15() {
670 var accelerator = new Accelerator(MethodHandles.lookup(), Backend.FIRST);
671 final int size = 256;
672 BF16Array arrayA = BF16Array.create(accelerator, size);
673 BF16Array arrayB = BF16Array.create(accelerator, size);
674
675 Random r = new Random(73);
676 for (int i = 0; i < arrayA.length(); i++) {
677 arrayA.array(i).value(BF16.float2bfloat16(r.nextFloat()).value());
678 }
679
680 accelerator.compute(computeContext -> TestBFloat16Type.compute15(computeContext, arrayA, arrayB));
681
682 for (int i = 0; i < arrayB.length(); i++) {
683 BF16 val = arrayB.array(i);
684 HATAsserts.assertEquals(arrayA.array(i).value(), val.value());
685 }
686 }
687
688 // Check accumulators
689 @HatTest
690 @Reflect
691 public void test_bfloat16_16() {
692 var accelerator = new Accelerator(MethodHandles.lookup(), Backend.FIRST);
693 final int size = 1;
694 BF16Array arrayA = BF16Array.create(accelerator, size);
695
696 Random r = new Random(73);
697 arrayA.array(0).value(BF16.float2bfloat16(10).value());
698
699 accelerator.compute(computeContext -> TestBFloat16Type.compute16(computeContext, arrayA));
700
701 BF16 val = arrayA.array(0);
702 HATAsserts.assertEquals(40.0f, BF16.bfloat162float(val), 0.01f);
703 }
704
705 // Check accumulators in private memory
706 @HatTest
707 @Reflect
708 public void test_bfloat16_17() {
709 var accelerator = new Accelerator(MethodHandles.lookup(), Backend.FIRST);
710 final int size = 1;
711 BF16Array arrayA = BF16Array.create(accelerator, size);
712
713 Random r = new Random(73);
714 arrayA.array(0).value(BF16.float2bfloat16(10).value());
715
716 accelerator.compute(computeContext -> TestBFloat16Type.compute17(computeContext, arrayA));
717
718 BF16 val = arrayA.array(0);
719 HATAsserts.assertEquals(20.0f, BF16.bfloat162float(val), 0.01f);
720 }
721
722 }