1 /*
  2  * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 package hat.test;
 26 
 27 import hat.Accelerator;
 28 import hat.ComputeContext;
 29 import hat.NDRange;
 30 import hat.KernelContext;
 31 import hat.backend.Backend;
 32 import hat.buffer.*;
 33 import hat.device.DeviceSchema;
 34 import hat.device.DeviceType;
 35 import optkl.ifacemapper.Buffer;
 36 import optkl.ifacemapper.Schema;
 37 import jdk.incubator.code.Reflect;
 38 import hat.test.annotation.HatTest;
 39 import hat.test.exceptions.HATAsserts;
 40 import optkl.ifacemapper.MappableIface.RO;
 41 import optkl.ifacemapper.MappableIface.RW;
 42 import optkl.ifacemapper.MappableIface.WO;
 43 
 44 import java.lang.foreign.ValueLayout;
 45 import java.lang.invoke.MethodHandles;
 46 import java.util.Random;
 47 
 48 import static java.lang.foreign.ValueLayout.JAVA_BYTE;
 49 
 50 public class TestArrayView {
 51 
 52     /*
 53      * simple square kernel example using S32Array's ArrayView
 54      */
 55     @Reflect
 56     public static void squareKernel(@RO  KernelContext kc, @RW S32Array s32Array) {
 57         if (kc.gix < kc.gsx){
 58             int[] arr = s32Array.arrayView();
 59             arr[kc.gix] *= arr[kc.gix];
 60         }
 61     }
 62 
 63     @Reflect
 64     public static void square(@RO ComputeContext cc, @RW S32Array s32Array) {
 65         cc.dispatchKernel(NDRange.of1D(s32Array.length()),
 66                 kc -> squareKernel(kc, s32Array)
 67         );
 68     }
 69 
 70     @HatTest
 71     @Reflect
 72     public static void testSquare() {
 73 
 74         var accelerator = new Accelerator(MethodHandles.lookup(), Backend.FIRST);
 75         var arr = S32Array.create(accelerator, 32);
 76         for (int i = 0; i < arr.length(); i++) {
 77             arr.array(i, i);
 78         }
 79         accelerator.compute(
 80                 cc -> square(cc, arr)
 81         );
 82         for (int i = 0; i < arr.length(); i++) {
 83             HATAsserts.assertEquals(i * i, arr.array(i));
 84         }
 85     }
 86 
 87     /*
 88      * making sure arrayviews aren't reliant on varOps
 89      */
 90     @Reflect
 91     public static void squareKernelNoVarOp(@RO  KernelContext kc, @RW S32Array s32Array) {
 92         if (kc.gix<kc.gsx){
 93             s32Array.arrayView()[kc.gix] *= s32Array.arrayView()[kc.gix];
 94         }
 95     }
 96 
 97     @Reflect
 98     public static void squareNoVarOp(@RO ComputeContext cc, @RW S32Array s32Array) {
 99         cc.dispatchKernel(NDRange.of1D(s32Array.length()),
100                 kc -> squareKernelNoVarOp(kc, s32Array)
101         );
102     }
103 
104     @HatTest
105     @Reflect
106     public static void testSquareNoVarOp() {
107         var accelerator = new Accelerator(MethodHandles.lookup(), Backend.FIRST);
108         var arr = S32Array.create(accelerator, 32);
109         for (int i = 0; i < arr.length(); i++) {
110             arr.array(i, i);
111         }
112         accelerator.compute(
113                 cc -> squareNoVarOp(cc, arr)
114         );
115         for (int i = 0; i < arr.length(); i++) {
116             HATAsserts.assertEquals(i * i, arr.array(i));
117         }
118     }
119 
120     @Reflect
121     public static void square2DKernel(@RO  KernelContext kc, @RW S32Array2D s32Array2D) {
122         if (kc.gix < kc.gsx){
123             int[][] arr = s32Array2D.arrayView();
124             arr[kc.gix][kc.giy] *= arr[kc.gix][kc.giy];
125         }
126     }
127 
128     @Reflect
129     public static void square2D(@RO ComputeContext cc, @RW S32Array2D s32Array2D) {
130         cc.dispatchKernel(NDRange.of1D(s32Array2D.width() * s32Array2D.height()),
131                 kc -> square2DKernel(kc, s32Array2D)
132         );
133     }
134 
135     @HatTest
136     @Reflect
137     public static void testSquare2D() {
138 
139         var accelerator = new Accelerator(MethodHandles.lookup(), Backend.FIRST);//new JavaMultiThreadedBackend());
140         var arr = S32Array2D.create(accelerator, 5, 5);
141         for (int i = 0; i < arr.height(); i++) {
142             for (int j = 0; j < arr.width(); j++) {
143                 arr.set(i, j, i * 5 + j);
144             }
145         }
146         accelerator.compute(
147                 cc -> square2D(cc, arr)
148         );
149         for (int i = 0; i < arr.height(); i++) {
150             for (int j = 0; j < arr.width(); j++) {
151                 HATAsserts.assertEquals((i * 5 + j) * (i * 5 + j), arr.get(i, j));
152             }
153         }
154     }
155 
156     /*
157      * simplified version of Game of Life using ArrayView
158      */
159     public final static byte ALIVE = (byte) 0xff;
160     public final static byte DEAD = 0x00;
161 
162     public interface CellGrid extends Buffer {
163         /*
164          * struct CellGrid{
165          *     int width;
166          *     int height;
167          *     byte[width*height*2] cellArray;
168          *  }
169          */
170         int width();
171 
172         int height();
173 
174         byte array(long idx);
175 
176         void array(long idx, byte b);
177 
178         Schema<CellGrid> schema = Schema.of(CellGrid.class, lifeData -> lifeData
179                 .arrayLen("width", "height").stride(2).array("array")
180         );
181 
182         static CellGrid create(Accelerator accelerator, int width, int height) {
183             return schema.allocate(accelerator, width, height);
184         }
185 
186         ValueLayout valueLayout = JAVA_BYTE;
187 
188         default byte[][] arrayView() {
189             return null;
190         }
191     }
192 
193     public interface Control extends Buffer {
194         /*
195          * struct Control{
196          *     int from;
197          *     int to;
198          *  }
199          */
200         int from();
201 
202         void from(int from);
203 
204         int to();
205 
206         void to(int to);
207 
208         Schema<Control> schema = Schema.of(
209                 Control.class, control ->
210                         control.fields("from", "to"));//, "generation", "requiredFrameRate", "maxGenerations"));
211 
212         static Control create(Accelerator accelerator, CellGrid cellGrid) {
213             var instance = schema.allocate(accelerator);
214             instance.from(cellGrid.width() * cellGrid.height());
215             instance.to(0);
216             return instance;
217         }
218     }
219 
220     public static byte[][] lifeCheck(CellGrid cellGrid) {
221         int w = cellGrid.width();
222         int h = cellGrid.height();
223 
224         byte[][] res = new byte[h][w];
225 
226         for (int y = 0; y < h; y++) {
227             for (int x = 0; x < w; x++) {
228                 int idx = y * w + x;
229                 byte cell = cellGrid.array(idx);
230                 if (x > 0 && x < (w - 1) && y > 0 && y < (h - 1)) { // passports please
231                     int count =
232                             (cellGrid.array((y - 1) * w + (x - 1)) & 1)
233                                     + (cellGrid.array((y + 0) * w + (x - 1)) & 1)
234                                     + (cellGrid.array((y + 1) * w + (x - 1)) & 1)
235                                     + (cellGrid.array((y - 1) * w + (x + 0)) & 1)
236                                     + (cellGrid.array((y + 1) * w + (x + 0)) & 1)
237                                     + (cellGrid.array((y - 1) * w + (x + 1)) & 1)
238                                     + (cellGrid.array((y + 0) * w + (x + 1)) & 1)
239                                     + (cellGrid.array((y + 1) * w + (x + 1))& 1);
240                     cell = ((count == 3) || ((count == 2) && (cell == ALIVE))) ? ALIVE : DEAD;// B3/S23.
241                 }
242                 res[x][y] = cell;
243             }
244         }
245         return res;
246     }
247 
248     public static class Compute {
249         @Reflect
250         // TODO: switch cellGridRes to WO
251         public static void lifePerIdx(int idx, @RO Control control, @RO CellGrid cellGrid, @RW CellGrid cellGridRes) {
252             int w = cellGrid.width();
253             int h = cellGrid.height();
254            // int from = control.from();
255            // int to = control.to();
256             int x = idx % w;
257             int y = idx / w;
258 
259             // byte[] bytes = cellGrid.arrayView();
260             // byte cell = bytes[idx + from];
261             // byte[] lookup = new byte[]{};
262             // if (x > 0 && x < (w - 1) && y > 0 && y < (h - 1)) { // passports please
263             //     int lookupIdx =
264             //             (bytes[(y - 1) * w + x - 1 + from]&1 <<0)
265             //                     |(bytes[(y + 0) * w + x - 1 + from]&1 <<1)
266             //                     |(bytes[(y + 1) * w + x - 1 + from]&1 <<2)
267             //                     |(bytes[(y - 1) * w + x + 0 + from]&1 <<3)
268             //                     |(bytes[(y - 0) * w + x + 0 + from]&1 <<4) // current cell added
269             //                     |(bytes[(y + 1) * w + x + 0 + from]&1 <<5)
270             //                     |(bytes[(y + 0) * w + x + 1 + from]&1 <<6)
271             //                     |(bytes[(y - 1) * w + x + 1 + from]&1 <<7)
272             //                     |(bytes[(y + 1) * w + x + 1 + from]&1 <<8) ;
273             //     // conditional removed!
274             //     bytes[idx + to] = lookup[lookupIdx];
275             // }
276 
277             byte[][] bytes = cellGrid.arrayView();
278             byte cell = bytes[x][y];
279             if (x > 0 && x < (w - 1) && y > 0 && y < (h - 1)) { // passports please
280                 int count =
281                         (bytes[x - 1][y - 1] & 1)
282                                 + (bytes[x - 1][y + 0] & 1)
283                                 + (bytes[x - 1][y + 1] & 1)
284                                 + (bytes[x + 0][y - 1] & 1)
285                                 + (bytes[x + 0][y + 1] & 1)
286                                 + (bytes[x + 1][y - 1] & 1)
287                                 + (bytes[x + 1][y + 0] & 1)
288                                 + (bytes[x + 1][y + 1] & 1);
289                 cell = ((count == 3) || ((count == 2) && (cell == ALIVE))) ? ALIVE : DEAD;// B3/S23.
290             }
291             byte[][] res = cellGridRes.arrayView();
292             res[x][y] = cell;
293         }
294 
295         @Reflect
296         public static void life(@RO KernelContext kc, @RO Control control, @RO CellGrid cellGrid, @RW CellGrid cellGridRes) {
297             if (kc.gix < kc.gsx) {
298                 Compute.lifePerIdx(kc.gix, control, cellGrid, cellGridRes);
299             }
300         }
301 
302         @Reflect
303         static public void compute(final @RO ComputeContext cc, @RO Control ctrl, @RO CellGrid grid, @RW CellGrid gridRes) {
304             int range = grid.width() * grid.height();
305             cc.dispatchKernel(NDRange.of1D(range), kc -> Compute.life(kc, ctrl, grid, gridRes));
306         }
307     }
308 
309     @HatTest
310     @Reflect
311     public static void testLife() {
312         Accelerator accelerator = new Accelerator(MethodHandles.lookup());
313 
314         // int w = 20;
315         // int h = 20;
316         // // We oversize the grid by adding 1 to n,e,w and s
317         // CellGrid cellGrid = CellGrid.create(accelerator, w, h);
318         // CellGrid cellGridRes = CellGrid.create(accelerator, w, h);
319         //
320         // Random rand = new Random();
321         // byte[][] actualGrid = new byte[w][h];
322         // for (int y = 0; y < h; y++) {
323         //     for (int x = 0; x < w; x++) {
324         //         actualGrid[x][y] = rand.nextBoolean() ? ALIVE : DEAD;
325         //     }
326         // }
327 
328         // We oversize the grid by adding 1 to n,e,w and s
329         CellGrid cellGrid = CellGrid.create(accelerator, 17, 17);
330         CellGrid cellGridRes = CellGrid.create(accelerator, 17, 17);
331 
332         byte[][] actualGrid = new byte[][]{
333                 {DEAD,  DEAD,  DEAD,  DEAD,  DEAD,  DEAD,  DEAD,  DEAD,  DEAD,  DEAD,  DEAD,  DEAD,  DEAD,  DEAD,  DEAD,  DEAD,  DEAD},
334                 {DEAD,  DEAD,  DEAD,  DEAD,  DEAD,  DEAD,  DEAD,  DEAD,  DEAD,  DEAD,  DEAD,  DEAD,  DEAD,  DEAD,  DEAD,  DEAD,  DEAD},
335                 {DEAD,  DEAD,  DEAD,  DEAD,  ALIVE, ALIVE, ALIVE, DEAD,  DEAD,  DEAD,  ALIVE, ALIVE, ALIVE, DEAD,  DEAD,  DEAD,  DEAD},
336                 {DEAD,  DEAD,  DEAD,  DEAD,  DEAD,  DEAD,  DEAD,  DEAD,  DEAD,  DEAD,  DEAD,  DEAD,  DEAD,  DEAD,  DEAD,  DEAD,  DEAD},
337                 {DEAD,  DEAD,  ALIVE, DEAD,  DEAD,  DEAD,  DEAD,  ALIVE, DEAD,  ALIVE, DEAD,  DEAD,  DEAD,  DEAD,  ALIVE, DEAD,  DEAD},
338                 {DEAD,  DEAD,  ALIVE, DEAD,  DEAD,  DEAD,  DEAD,  ALIVE, DEAD,  ALIVE, DEAD,  DEAD,  DEAD,  DEAD,  ALIVE, DEAD,  DEAD},
339                 {DEAD,  DEAD,  ALIVE, DEAD,  DEAD,  DEAD,  DEAD,  ALIVE, DEAD,  ALIVE, DEAD,  DEAD,  DEAD,  DEAD,  ALIVE, DEAD,  DEAD},
340                 {DEAD,  DEAD,  DEAD,  DEAD,  ALIVE, ALIVE, ALIVE, DEAD,  DEAD,  DEAD,  ALIVE, ALIVE, ALIVE, DEAD,  DEAD,  DEAD,  DEAD},
341                 {DEAD,  DEAD,  DEAD,  DEAD,  DEAD,  DEAD,  DEAD,  DEAD,  DEAD,  DEAD,  DEAD,  DEAD,  DEAD,  DEAD,  DEAD,  DEAD,  DEAD},
342                 {DEAD,  DEAD,  DEAD,  DEAD,  ALIVE, ALIVE, ALIVE, DEAD,  DEAD,  DEAD,  ALIVE, ALIVE, ALIVE, DEAD,  DEAD,  DEAD,  DEAD},
343                 {DEAD,  DEAD,  ALIVE, DEAD,  DEAD,  DEAD,  DEAD,  ALIVE, DEAD,  ALIVE, DEAD,  DEAD,  DEAD,  DEAD,  ALIVE, DEAD,  DEAD},
344                 {DEAD,  DEAD,  ALIVE, DEAD,  DEAD,  DEAD,  DEAD,  ALIVE, DEAD,  ALIVE, DEAD,  DEAD,  DEAD,  DEAD,  ALIVE, DEAD,  DEAD},
345                 {DEAD,  DEAD,  ALIVE, DEAD,  DEAD,  DEAD,  DEAD,  ALIVE, DEAD,  ALIVE, DEAD,  DEAD,  DEAD,  DEAD,  ALIVE, DEAD,  DEAD},
346                 {DEAD,  DEAD,  DEAD,  DEAD,  DEAD,  DEAD,  DEAD,  DEAD,  DEAD,  DEAD,  DEAD,  DEAD,  DEAD,  DEAD,  DEAD,  DEAD,  DEAD},
347                 {DEAD,  DEAD,  DEAD,  DEAD,  ALIVE, ALIVE, ALIVE, DEAD,  DEAD,  DEAD,  ALIVE, ALIVE, ALIVE, DEAD,  DEAD,  DEAD,  DEAD},
348                 {DEAD,  DEAD,  DEAD,  DEAD,  DEAD,  DEAD,  DEAD,  DEAD,  DEAD,  DEAD,  DEAD,  DEAD,  DEAD,  DEAD,  DEAD,  DEAD,  DEAD},
349                 {DEAD,  DEAD,  DEAD,  DEAD,  DEAD,  DEAD,  DEAD,  DEAD,  DEAD,  DEAD,  DEAD,  DEAD,  DEAD,  DEAD,  DEAD,  DEAD,  DEAD},
350         };
351 
352         // By shifting all cells +1,+1 so we only need to scan 1..width-1, 1..height-1
353         // we don't worry about possibly finding cells in 0,n width,n or n,0 height,n
354         for (int i = 0; i < cellGrid.height(); i++) {
355             for (int j = 0; j < cellGrid.width(); j++) {
356                 cellGrid.array(((long) i * cellGrid.width()) + j, actualGrid[i][j]);
357             }
358         }
359 
360         Control control = Control.create(accelerator, cellGrid);
361 
362         accelerator.compute(cc -> Compute.compute(cc, control, cellGrid, cellGridRes));
363 
364         byte[][] resultGrid = lifeCheck(cellGrid);
365 
366         for (int i = 0; i < cellGrid.height(); i++) {
367             for (int j = 0; j < cellGrid.width(); j++) {
368                 HATAsserts.assertEquals(resultGrid[i][j], cellGridRes.array(((long) i * cellGrid.width()) + j));
369             }
370         }
371     }
372 
373     /*
374      * simplified version of mandel using ArrayView
375      */
376     @Reflect
377     public static int mandelCheck(int i, int j, float width, float height, int[] pallette, float offsetx, float offsety, float scale) {
378         float x = (i * scale - (scale / 2f * width)) / width + offsetx;
379         float y = (j * scale - (scale / 2f * height)) / height + offsety;
380         float zx = x;
381         float zy = y;
382         float new_zx;
383         int colorIdx = 0;
384         while ((colorIdx < pallette.length) && (((zx * zx) + (zy * zy)) < 4f)) {
385             new_zx = ((zx * zx) - (zy * zy)) + x;
386             zy = (2f * zx * zy) + y;
387             zx = new_zx;
388             colorIdx++;
389         }
390         return colorIdx < pallette.length ? pallette[colorIdx] : 0;
391     }
392 
393     @Reflect
394     public static void mandel(@RO KernelContext kc, @RW S32Array2D s32Array2D, @RO S32Array pallette, float offsetx, float offsety, float scale) {
395         if (kc.gix < kc.gsx) {
396             int[] pal = pallette.arrayView();
397             int[][] s32 = s32Array2D.arrayView();
398             float width = s32Array2D.width();
399             float height = s32Array2D.height();
400             float x = ((kc.gix % s32Array2D.width()) * scale - (scale / 2f * width)) / width + offsetx;
401             float y = ((kc.gix / s32Array2D.width()) * scale - (scale / 2f * height)) / height + offsety;
402             float zx = x;
403             float zy = y;
404             float new_zx;
405             int colorIdx = 0;
406             while ((colorIdx < pal.length) && (((zx * zx) + (zy * zy)) < 4f)) {
407                 new_zx = ((zx * zx) - (zy * zy)) + x;
408                 zy = (2f * zx * zy) + y;
409                 zx = new_zx;
410                 colorIdx++;
411             }
412             int color = colorIdx < pal.length ? pal[colorIdx] : 0;
413             s32[kc.gix % s32Array2D.width()][kc.gix / s32Array2D.width()] = color;
414         }
415     }
416 
417 
418     @Reflect
419     static public void compute(final ComputeContext computeContext, S32Array pallete, S32Array2D s32Array2D, float x, float y, float scale) {
420 
421         computeContext.dispatchKernel(
422                 NDRange.of1D(s32Array2D.width()*s32Array2D.height()), //0..S32Array2D.size()
423                 kc -> mandel(kc, s32Array2D, pallete, x, y, scale));
424     }
425 
426     @HatTest
427     @Reflect
428     public static void testMandel() {
429         final int width = 1024;
430         final int height = 1024;
431         final float defaultScale = 3f;
432         final float originX = -1f;
433         final float originY = 0;
434         final int maxIterations = 64;
435 
436         Accelerator accelerator = new Accelerator(MethodHandles.lookup(), Backend.FIRST);
437 
438         S32Array2D s32Array2D = S32Array2D.create(accelerator, width, height);
439 
440         int[] palletteArray = new int[maxIterations];
441 
442         for (int i = 1; i < maxIterations; i++) {
443             palletteArray[i]=(i/8+1);
444         }
445         palletteArray[0]=0;
446         S32Array pallette = S32Array.createFrom(accelerator, palletteArray);
447 
448         accelerator.compute(cc -> compute(cc, pallette, s32Array2D, originX, originY, defaultScale));
449 
450         int subsample = 16;
451         char[] charPallette9 = new char []{' ', '.', ',',':', '-', '+','*', '#', '@', '%'};
452         for (int y = 0; y<height/subsample; y++) {
453             for (int x = 0; x<width/subsample; x++) {
454                 int palletteValue = s32Array2D.get(x*subsample,y*subsample); // so 0->8
455                 int paletteCheck = mandelCheck(x*subsample, y*subsample, width, height, palletteArray, originX, originY, defaultScale);
456                 HATAsserts.assertEquals(paletteCheck, palletteValue);
457             }
458         }
459     }
460 
461     /*
462      * simplified version of BlackScholes using ArrayView
463      */
464     @Reflect
465     public static void blackScholesKernel(@RO KernelContext kc,
466                                           @WO F32Array call,
467                                           @WO F32Array put,
468                                           @RO F32Array sArray,
469                                           @RO F32Array xArray,
470                                           @RO F32Array tArray,
471                                           float r,
472                                           float v) {
473         if (kc.gix<kc.gsx){
474             float[] callArr = call.arrayView();
475             float[] putArr = put.arrayView();
476             float[] sArr = sArray.arrayView();
477             float[] xArr = xArray.arrayView();
478             float[] tArr = tArray.arrayView();
479 
480             float expNegRt = (float) Math.exp(-r * tArr[kc.gix]);
481             float d1 = (float) ((Math.log(sArr[kc.gix] / xArr[kc.gix]) + (r + v * v * .5f) * tArr[kc.gix]) / (v * Math.sqrt(tArr[kc.gix])));
482             float d2 = (float) (d1 - v * Math.sqrt(tArr[kc.gix]));
483             float cnd1 = CND(d1);
484             float cnd2 = CND(d2);
485             float value = sArr[kc.gix] * cnd1 - expNegRt * xArr[kc.gix] * cnd2;
486             callArr[kc.gix] = value;
487             putArr[kc.gix] = expNegRt * xArr[kc.gix] * (1 - cnd2) - sArr[kc.gix] * (1 - cnd1);
488         }
489     }
490 
491     @Reflect
492     public static float CND(float input) {
493         float x = input;
494         if (input < 0f) { // input = Math.abs(input)?
495             x = -input;
496         }
497 
498         float term = 1f / (1f + (0.2316419f * x));
499         float term_pow2 = term * term;
500         float term_pow3 = term_pow2 * term;
501         float term_pow4 = term_pow2 * term_pow2;
502         float term_pow5 = term_pow2 * term_pow3;
503 
504         float part1 = (1f / (float)Math.sqrt(2f * 3.1415926535f)) * (float)Math.exp((-x * x) * 0.5f);
505 
506         float part2 = (0.31938153f * term) +
507                 (-0.356563782f * term_pow2) +
508                 (1.781477937f * term_pow3) +
509                 (-1.821255978f * term_pow4) +
510                 (1.330274429f * term_pow5);
511 
512         if (input >= 0f) {
513             return 1f - part1 * part2;
514         }
515         return part1 * part2;
516 
517     }
518 
519     @Reflect
520     public static void blackScholes(@RO ComputeContext cc, @WO F32Array call, @WO F32Array put, @RO F32Array S, @RO F32Array X, @RO F32Array T, float r, float v) {
521         cc.dispatchKernel(NDRange.of1D(call.length()),
522                 kc -> blackScholesKernel(kc, call, put, S, X, T, r, v)
523         );
524     }
525 
526     static F32Array floatArray(Accelerator accelerator, int size, float low, float high, Random rand) {
527         F32Array array = F32Array.create(accelerator, size);
528         for (int i = 0; i <size; i++) {
529             array.array(i, rand.nextFloat() * (high - low) + low);
530         }
531         return array;
532     }
533 
534     public static void blackScholesKernelSeq(F32Array call, F32Array put, F32Array sArray, F32Array xArray, F32Array tArray, float r, float v) {
535         for (int i = 0; i <call.length() ; i++) {
536             float S = sArray.array(i);
537             float X = xArray.array(i);
538             float T = tArray.array(i);
539             float expNegRt = (float) Math.exp(-r * T);
540             float d1 = (float) ((Math.log(S / X) + (r + v * v * .5f) * T) / (v * Math.sqrt(T)));
541             float d2 = (float) (d1 - v * Math.sqrt(T));
542             float cnd1 = CND(d1);
543             float cnd2 = CND(d2);
544             float value = S * cnd1 - expNegRt * X * cnd2;
545             call.array(i, value);
546             put.array(i, expNegRt * X * (1 - cnd2) - S * (1 - cnd1));
547         }
548     }
549 
550     @HatTest
551     @Reflect
552     public static void testBlackScholes() {
553         int size = 1024;
554         Random rand = new Random();
555         var accelerator = new Accelerator(MethodHandles.lookup(), Backend.FIRST);
556         var call = F32Array.create(accelerator, size);
557         var put = F32Array.create(accelerator, size);
558         for (int i = 0; i < size; i++) {
559             call.array(i, i);
560             put.array(i, i);
561         }
562 
563         var S = floatArray(accelerator, size,1f, 100f, rand);
564         var X = floatArray(accelerator, size,1f, 100f, rand);
565         var T = floatArray(accelerator,size, 0.25f, 10f, rand);
566         float r = 0.02f;
567         float v = 0.30f;
568 
569         accelerator.compute(cc -> blackScholes(cc, call, put, S, X, T, r, v));
570 
571         var seqCall = F32Array.create(accelerator, size);
572         var seqPut = F32Array.create(accelerator, size);
573         for (int i = 0; i < seqCall.length(); i++) {
574             seqCall.array(i, i);
575             seqPut.array(i, i);
576         }
577 
578         blackScholesKernelSeq(seqCall, seqPut, S, X, T, r, v);
579 
580         for (int i = 0; i < call.length(); i++) {
581             HATAsserts.assertEquals(seqCall.array(i), call.array(i), 0.01f);
582             HATAsserts.assertEquals(seqPut.array(i), put.array(i), 0.01f);
583         }
584     }
585 
586     /*
587      * basic test of local and private buffer ArrayViews
588      */
589     private interface SharedMemory extends DeviceType {
590         void array(long index, int value);
591         int array(long index);
592         DeviceSchema<SharedMemory> schema = DeviceSchema.of(SharedMemory.class,
593                 arr -> arr.withArray("array", 1024));
594 
595         static SharedMemory createLocal() { return null; }
596 
597         default int[] localArrayView() {
598             int[] view = new int[1024];
599             for (int i = 0; i < 1024; i++) {
600                 view[i] = this.array(i);
601             }
602             return view;
603         }
604     }
605 
606     public interface PrivateArray extends DeviceType {
607         void array(long index, int value);
608         int array(long index);
609         DeviceSchema<PrivateArray> schema = DeviceSchema.of(PrivateArray.class,
610                 arr -> arr.withArray("array", 16));
611 
612         static PrivateArray createPrivate() { return null; }
613 
614         default int[] privateArrayView() {
615             int[] view = new int[16];
616             for (int i = 0; i < 16; i++) {
617                 view[i] = this.array(i);
618             }
619             return view;
620         }
621     }
622 
623     @Reflect
624     public static void squareKernelWithPrivateAndLocal(@RO  KernelContext kc, @RW S32Array s32Array) {
625         SharedMemory shared = SharedMemory.createLocal();
626         if (kc.gix < kc.gsx){
627             int[] arr = s32Array.arrayView();
628             arr[kc.gix] += arr[kc.gix];
629 
630             PrivateArray priv = PrivateArray.createPrivate();
631             int[] privView = priv.privateArrayView();
632             privView[0] = 1;
633             arr[kc.gix] += privView[0];
634 
635             int[] sharedView = shared.localArrayView();
636             sharedView[0] = 16;
637             kc.barrier();
638             arr[kc.gix] += sharedView[0];
639         }
640     }
641 
642     @Reflect
643     public static void privateAndLocal(@RO ComputeContext cc, @RW S32Array s32Array) {
644         cc.dispatchKernel(NDRange.of1D(s32Array.length()),
645                 kc -> squareKernelWithPrivateAndLocal(kc, s32Array)
646         );
647     }
648 
649     @HatTest
650     @Reflect
651     public static void testPrivateAndLocal() {
652 
653         var accelerator = new Accelerator(MethodHandles.lookup(), Backend.FIRST);//new JavaMultiThreadedBackend());
654         var arr = S32Array.create(accelerator, 32);
655         for (int i = 0; i < arr.length(); i++) {
656             arr.array(i, i);
657         }
658         accelerator.compute(
659                 cc -> privateAndLocal(cc, arr)
660         );
661         for (int i = 0; i < arr.length(); i++) {
662             HATAsserts.assertEquals(2 * i + 17, arr.array(i));
663         }
664     }
665 
666     /*
667      * testing basic DeviceTypes
668      */
669 
670     public interface SharedDeviceType extends DeviceType {
671         void array(long index, int value);
672         int array(long index);
673         DeviceSchema<SharedDeviceType> schema = DeviceSchema.of(SharedDeviceType.class,
674                 arr -> arr.withArray("array", 1024));
675         static SharedDeviceType create(Accelerator accelerator) {
676             return null;
677         }
678         static SharedDeviceType createLocal() {
679             return null;
680         }
681 
682         default int[] localArrayView() {
683             return null;
684         }
685     }
686 
687     public interface PrivateDeviceType extends DeviceType {
688         void array(long index, int value);
689         int array(long index);
690         DeviceSchema<PrivateDeviceType> schema = DeviceSchema.of(PrivateDeviceType.class,
691                 arr -> arr.withArray("array", 32));
692         static PrivateDeviceType create(Accelerator accelerator) {
693             return null;
694         }
695         static PrivateDeviceType createPrivate() {
696             return null;
697         }
698 
699         default int[] privateArrayView() {
700             return null;
701         }
702     }
703 
704     @Reflect
705     public static void kernelBasicDeviceType(@RO  KernelContext kc, @RW S32Array s32Array) {
706         SharedDeviceType shared = SharedDeviceType.createLocal();
707         if (kc.gix < kc.gsx){
708             PrivateDeviceType priv = PrivateDeviceType.createPrivate();
709 
710             int[] arr = s32Array.arrayView();
711             int[] privView = priv.privateArrayView();
712             int[] sharedView = shared.localArrayView();
713 
714             privView[kc.gix] = arr[kc.gix];
715             sharedView[kc.gix] = arr[kc.gix];
716             kc.barrier();
717             arr[kc.gix] = privView[kc.gix] + sharedView[kc.gix];
718         }
719     }
720 
721     @Reflect
722     public static void basicDeviceType(@RO ComputeContext cc, @RW S32Array s32Array) {
723         cc.dispatchKernel(NDRange.of1D(s32Array.length()),
724                 kc -> kernelBasicDeviceType(kc, s32Array)
725         );
726     }
727 
728     @HatTest
729     @Reflect
730     public static void testBasicDeviceType() {
731         var accelerator = new Accelerator(MethodHandles.lookup(), Backend.FIRST);//new JavaMultiThreadedBackend());
732         var arr = S32Array.create(accelerator, 32);
733         for (int i = 0; i < arr.length(); i++) {
734             arr.array(i, i);
735         }
736         accelerator.compute(cc -> basicDeviceType(cc, arr));
737         for (int i = 0; i < arr.length(); i++) {
738             HATAsserts.assertEquals(2 * i, arr.array(i));
739         }
740     }
741 
742     @Reflect
743     public static void squareKernelDeviceType(@RO  KernelContext kc, @RW S32Array s32Array) {
744         SharedDeviceType shared = SharedDeviceType.createLocal();
745         if (kc.gix < kc.gsx){
746             PrivateDeviceType priv = PrivateDeviceType.createPrivate();
747 
748             int[] arr = s32Array.arrayView();
749             int[] privView = priv.privateArrayView();
750             int[] sharedView = shared.localArrayView();
751 
752             privView[kc.gix] = arr[kc.gix];
753             sharedView[privView[kc.gix]] = 16 * privView[kc.gix];
754             kc.barrier();
755             arr[kc.gix] += privView[kc.gix] + sharedView[kc.gix];
756         }
757     }
758 
759     @Reflect
760     public static void deviceType(@RO ComputeContext cc, @RW S32Array s32Array) {
761         cc.dispatchKernel(NDRange.of1D(s32Array.length()),
762                 kc -> squareKernelDeviceType(kc, s32Array)
763         );
764     }
765 
766     @HatTest
767     @Reflect
768     public static void testDeviceType() {
769         var accelerator = new Accelerator(MethodHandles.lookup(), Backend.FIRST);//new JavaMultiThreadedBackend());
770         var arr = S32Array.create(accelerator, 32);
771         for (int i = 0; i < arr.length(); i++) {
772             arr.array(i, i);
773         }
774         accelerator.compute(cc -> deviceType(cc, arr));
775         for (int i = 0; i < arr.length(); i++) {
776             HATAsserts.assertEquals(18 * i, arr.array(i));
777         }
778     }
779 }