1 ## Leverage CodeReflection to expose array view of buffers in kernels
  2 
  3 ----
  4 * [Contents](hat-00.md)
  5 * Build Babylon and HAT
  6     * [Quick Install](hat-01-quick-install.md)
  7     * [Building Babylon with jtreg](hat-01-02-building-babylon.md)
  8     * [Building HAT with jtreg](hat-01-03-building-hat.md)
  9         * [Enabling the NVIDIA CUDA Backend](hat-01-05-building-hat-for-cuda.md)
 10 * [Testing Framework](hat-02-testing-framework.md)
 11 * [Running Examples](hat-03-examples.md)
 12 * [HAT Programming Model](hat-03-programming-model.md)
 13 * Interface Mapping
 14     * [Interface Mapping Overview](hat-04-01-interface-mapping.md)
 15     * [Cascade Interface Mapping](hat-04-02-cascade-interface-mapping.md)
 16 * Development
 17     * [Project Layout](hat-01-01-project-layout.md)
 18 * Implementation Details
 19     * [Walkthrough Of Accelerator.compute()](hat-accelerator-compute.md)
 20     * [How we minimize buffer transfers](hat-minimizing-buffer-transfers.md)
 21 * [Running HAT with Docker on NVIDIA GPUs](hat-07-docker-build-nvidia.md)
 22 ---
 23 
 24 Here is the canonical HAT example
 25 
 26 ```java
 27 import jdk.incubator.code.Reflect;
 28 
 29 @Reflect
 30 public class Square {
 31     @Reflect
 32     public static void kernel(KernelContext kc, S32Arr s32Arr) {
 33         s32Arr.array(kc.x, s32Arr.array(kc.x) * s32Arr.array(kc.x));
 34     }
 35 
 36     @Reflect
 37     public static void compute(ComputeContext cc, S32Arr s32Arr) {
 38         cc.dispatchKernel(s32Arr.length(), kc -> kernel(kc, s32Arr));
 39     }
 40 }
 41 ```
 42 
 43 This code in the kernel has always bothered me.  One downside of using MemorySegment backed buffers,
 44 in HAT is that we have made what in array form, would be simple code, look verbose.
 45 
 46 ```java
 47  @Reflect
 48   public static void kernel(KernelContext kc, S32Arr s32Arr) {
 49     s32Arr.array(kc.x, s32Arr.array(kc.x) * s32Arr.array(kc.x));
 50   }
 51 ```
 52 
 53 But what if we added a method (`int[] arrayView()`) to `S32Arr` to extract a java int array `view` of simple arrays
 54 
 55 Becomes way more readable.
 56 ```java
 57  @Reflect
 58   public static void kernel(KernelContext kc, S32Arr s32Arr) {
 59     int[] arr = s32Arr.arrayView();
 60     arr[kc.x] *= arr[kc.x];
 61   }
 62 ```
 63 IMHO This makes code more readable.
 64 
 65 For the GPU this is fine.  We can (thanks to CodeReflection) prove that the array is indeed just a view
 66 and we just remove all references to `int arr[]` and replace array accessors with get/set accessors
 67 on the original S32Arr.
 68 
 69 But what about Java performance?. Won't it suck because we are copying the array in each kernel ;)
 70 
 71 Well, we can use the same trick, we used for the GPU, we take the transformed model (with array
 72 references removed) and create bytecode from that code we and run it.
 73 
 74 Historically, we just run the original bytecode in the Java MT/Seq backends, but we don't have to.
 75 
 76 This helps also with game of life
 77 ```java
 78   public static void lifePerIdx(int idx, @RO Control control, @RW CellGrid cellGrid) {
 79             int w = cellGrid.width();
 80             int h = cellGrid.height();
 81             int from = control.from();
 82             int to = control.to();
 83             int x = idx % w;
 84             int y = idx / w;
 85             byte cell = cellGrid.cell(idx + from);
 86             if (x > 0 && x < (w - 1) && y > 0 && y < (h - 1)) { // passports please
 87                 int count =
 88                         val(cellGrid, from, w, x - 1, y - 1)
 89                                 + val(cellGrid, from, w, x - 1, y + 0)
 90                                 + val(cellGrid, from, w, x - 1, y + 1)
 91                                 + val(cellGrid, from, w, x + 0, y - 1)
 92                                 + val(cellGrid, from, w, x + 0, y + 1)
 93                                 + val(cellGrid, from, w, x + 1, y + 0)
 94                                 + val(cellGrid, from, w, x + 1, y - 1)
 95                                 + val(cellGrid, from, w, x + 1, y + 1);
 96                 cell = ((count == 3) || ((count == 2) && (cell == ALIVE))) ? ALIVE : DEAD;// B3/S23.
 97             }
 98             cellGrid.cell(idx + to, cell);
 99         }
100 ```
101 
102 This code uses a helper function `val(grid, offset, w, dx, dy)` to extract the neighbours
103 ```java
104  int count =
105         val(cellGrid, from, w, x - 1, y - 1)
106                 + val(cellGrid, from, w, x - 1, y + 0)
107                 + val(cellGrid, from, w, x - 1, y + 1)
108                 + val(cellGrid, from, w, x + 0, y - 1)
109                 + val(cellGrid, from, w, x + 0, y + 1)
110                 + val(cellGrid, from, w, x + 1, y + 0)
111                 + val(cellGrid, from, w, x + 1, y - 1)
112                 + val(cellGrid, from, w, x + 1, y + 1);
113 ```
114 
115 Val is a bit verbose
116 
117 ```java
118   @Reflect
119         public static int val(@RO CellGrid grid, int from, int w, int x, int y) {
120             return grid.cell(((long) y * w) + x + from) & 1;
121         }
122 ```
123 
124 ```java
125   @Reflect
126         public static int val(@RO CellGrid grid, int from, int w, int x, int y) {
127             byte[] bytes = grid.byteView(); // bit view would be nice ;)
128             return bytes[ y * w + x + from] & 1;
129         }
130 ```
131 
132 We could now dispense with `val()` and just write
133 
134 ```java
135  byte[] bytes = grid.byteView();
136  int count =
137         bytes[(y - 1) * w + x - 1 + from]&1
138        +bytes[(y + 0) * w + x - 1 + from]&1
139        +bytes[(y + 1) * w + x - 1 + from]&1
140        +bytes[(y - 1) * w + x + 0 + from]&1
141        +bytes[(y + 1) * w + x + 0 + from]&1
142        +bytes[(y + 0) * w + x + 1 + from]&1
143        +bytes[(y - 1) * w + x + 1 + from]&1
144        +bytes[(y + 1) * w + x + 1 + from]&1 ;
145 ```
146 
147 BTW My inner verilog programmer has always wondered whether shift and oring each bit into a
148 9 bit value, which we use to index live/dead state from a prepopulated 512 array (GPU constant memory)
149 would allow us to sidestep the wave divergent conditional :)
150 
151 ```java
152 byte[] bytes = grid.byteView();
153 int idx = 0;
154 int to =0;
155 byte[] lookup = new byte[]{};
156 int lookupIdx =
157         bytes[(y - 1) * w + x - 1 + from]&1 <<0
158        |bytes[(y + 0) * w + x - 1 + from]&1 <<1
159        |bytes[(y + 1) * w + x - 1 + from]&1 <<2
160        |bytes[(y - 1) * w + x + 0 + from]&1 <<3
161        |bytes[(y - 0) * w + x + 0 + from]&1 <<4 // current cell added
162        |bytes[(y + 1) * w + x + 0 + from]&1 <<5
163        |bytes[(y + 0) * w + x + 1 + from]&1 <<6
164        |bytes[(y - 1) * w + x + 1 + from]&1 <<7
165        |bytes[(y + 1) * w + x + 1 + from]&1 <<8 ;
166 // conditional removed!
167 
168 bytes[idx + to] = lookup[lookupIdx];
169 ```
170 
171 So the task here is to process kernel code models and perform the appropriate analysis
172 (for tracking the primitive arrays origin) and transformations to map the array code back to buffer get/sets
173 
174 ----------
175 The arrayView trick actually leads us to other possibilities.
176 
177 Let's look at current NBody code.
178 
179 ```java
180   @Reflect
181     static public void nbodyKernel(@RO KernelContext kc, @RW Universe universe, float mass, float delT, float espSqr) {
182         float accx = 0.0f;
183         float accy = 0.0f;
184         float accz = 0.0f;
185         Universe.Body me = universe.body(kc.x);
186 
187         for (int i = 0; i < kc.maxX; i++) {
188             Universe.Body otherBody = universe.body(i);
189             float dx = otherBody.x() - me.x();
190             float dy = otherBody.y() - me.y();
191             float dz = otherBody.z() - me.z();
192             float invDist = (float) (1.0f / Math.sqrt(((dx * dx) + (dy * dy) + (dz * dz) + espSqr)));
193             float s = mass * invDist * invDist * invDist;
194             accx = accx + (s * dx);
195             accy = accy + (s * dy);
196             accz = accz + (s * dz);
197         }
198         accx = accx * delT;
199         accy = accy * delT;
200         accz = accz * delT;
201         me.x(me.x() + (me.vx() * delT) + accx * .5f * delT);
202         me.y(me.y() + (me.vy() * delT) + accy * .5f * delT);
203         me.z(me.z() + (me.vz() * delT) + accz * .5f * delT);
204         me.vx(me.vx() + accx);
205         me.vy(me.vy() + accy);
206         me.vz(me.vz() + accz);
207     }
208 ```
209 
210 Contrast, the above code with the OpenCL code using `float4`
211 
212 ```java
213  __kernel void nbody( __global float4 *xyzPos ,__global float4* xyzVel, float mass, float delT, float espSqr ){
214       float4 acc = (0.0, 0.0,0.0,0.0);
215       float4 myPos = xyzPos[get_global_id(0)];
216       float4 myVel = xyzVel[get_global_id(0)];
217       for (int i = 0; i < get_global_size(0); i++) {
218              float4 delta =  xyzPos[i] - myPos;
219              float invDist =  (float) 1.0/sqrt((float)((delta.x * delta.x) + (delta.y * delta.y) + (delta.z * delta.z) + espSqr));
220              float s = mass * invDist * invDist * invDist;
221              acc= acc + (s * delta);
222       }
223       acc = acc*delT;
224       myPos = myPos + (myVel * delT) + (acc * delT)/2;
225       myVel = myVel + acc;
226       xyzPos[get_global_id(0)] = myPos;
227       xyzVel[get_global_id(0)] = myVel;
228 }
229 ```
230 
231 Thanks to interface mapped segments we can approximate `float4` vector type, In fact the
232 existing `Universe` interface mapped segment actually embeds a `float6` kind of
233 object holding the `x,y,z,vx,vy,vz` values for each `body` (so position and velocity)
234 
235 What if we created generalized `float4` interface mapped views `(x,y,z,w)` as placeholders for true vector types.
236 
237 And modified Universe to hold `pos` and `vel` `float4` arrays.
238 
239 So in Java code we would pass a MemorySegment version of a F32x4Arr.
240 
241 Our java code could then start to approximate the OpenCL code.
242 
243 Except for our lack of operator overloading...
244 
245 ```java
246  void nbody(KernelContext kc,  F32x4Arr xyzPos ,F32x4Arr xyzVel, float mass, float delT, float espSqr ){
247       float4 acc = float4.of(0.0,0.0,0.0,0.0);
248       float4[] xyPosArr = xyzPos.float4View();
249 
250       float4 myPos = xyzPosArr[kc.x];
251       float4 myVel = xyzVelArr[kc.x];
252       for (int i = 0; i < kc.max; i++) {
253              float4 delta =  float4.sub(xyzPosArr[i],myPos); // yucky but ok
254              float invDist =  (float) 1.0/sqrt((float)((delta.x * delta.x) + (delta.y * delta.y) + (delta.z * delta.z) + espSqr));
255              float s = mass * invDist * invDist * invDist;
256              acc= float4.add(acc,(s * delta)); // adding scaler to float4 via overloaded add method
257       }
258       acc = float4.mul(acc*delT);  // scaler * vector
259       myPos = float4.add(float4.add(myPos,float4.mul(myVel * delT)) , float4.mul(acc,delT/2));
260       myVel = float4.add(myVel,acc);
261       xyzPos[kc.x] = myPos;
262       xyzVel[kc.x] = myVel;
263 }
264 ```
265 The code is more compact, it's still weird though. Because we can't overload operators.
266 
267 Well we can sort of.
268 
269 What if we allowed a `floatView()` call on `float4` ;) which yields float value to be used as a proxy
270 for the `float4` we fetched it from...
271 
272 So we would leverage the anonymity of `var`
273 
274 `var myVelF4 = myVel.floatView()` // pst var is really a float
275 
276 From the code model we can track the relationship from float views to the original vector...
277 
278 Any action we perform on the 'float' view will be mapped back to calls on the origin, and performed on the actual origin float4.
279 
280 So
281 `myVelF4 + myVelF4`  -> `float4.add(myVel,myVel)` behind the scenes.
282 
283 Yeah, it's a bit crazy. The code would look something like this.  Perhaps this is a 'bridge too far'.
284 
285 ```java
286 void nbody(KernelContext kc,  F32x4Arr xyzPos ,F32x4Arr xyzVel, float mass, float delT, float espSqr ){
287       float4 acc = float4.of(0.0,0.0,0.0,0.0);
288       var accF4 = acc.floatView(); // var is really float
289       float4[] xyPosArr = xyzPos.float4View();
290 
291       float4 myPos = xyzPosArr[kc.x];
292       float4 myVel = xyzVelArr[kc.x];
293       var myPosF4 = myPos.floatView(); // ;)  var is actually a float tied to myPos
294       var myVelF4 = myVel.floatView(); //... myVel
295       for (int i = 0; i < kc.max; i++) {
296              var bodyF4 = xyzPosArr[i].floatView(); // bodyF4 is a float
297              var  deltaF4 = bodyF4 - myPosF4;  // look ma operator overloading ;)
298              float invDist =  (float) 1.0/sqrt((float)((delta.x * delta.x) + (delta.y * delta.y) + (delta.z * delta.z) + espSqr));
299              float s = mass * invDist * invDist * invDist;
300              accF4+=s * deltaF4; // adding scaler to float4 via overloaded add method
301       }
302       accF4 = accF4*delT;  // scalar * vector
303       myPosF4 = myPosF4 + (myVelF4 * delT) + (accF4 * delT)/2;
304       myVelF4 = myVelF4 + accF4;
305       xyzPos[kc.x] = myPos;
306       xyzVel[kc.x] = myVel;
307 }
308 ```
309 
310