1 ## Leverage CodeReflection to expose array view of buffers in kernels
  2 
  3 ----
  4 
  5 * [Contents](hat-00.md)
  6 * House Keeping
  7   * [Project Layout](hat-01-01-project-layout.md)
  8   * [Building Babylon](hat-01-02-building-babylon.md)
  9   * [Building HAT](hat-01-03-building-hat.md)
 10     * [Enabling the CUDA Backend](hat-01-05-building-hat-for-cuda.md)
 11 * Programming Model
 12   * [Programming Model](hat-03-programming-model.md)
 13 * Interface Mapping
 14   * [Interface Mapping Overview](hat-04-01-interface-mapping.md)
 15   * [Cascade Interface Mapping](hat-04-02-cascade-interface-mapping.md)
 16 * Implementation Detail
 17   * [Walkthrough Of Accelerator.compute()](hat-accelerator-compute.md)
 18   * [How we minimize buffer transfers](hat-minimizing-buffer-transfers.md)
 19 
 20 ----
 21 
 22 Here is the canonical HAT example
 23 
 24 ```java
 25 import jdk.incubator.code.CodeReflection;
 26 
 27 @CodeReflection
 28 public class Square {
 29   @CodeReflection
 30   public static void kernel(KernelContext kc, S32Arr s32Arr) {
 31     s32Arr.array(kc.x, s32Arr.array(kc.x) * s32Arr.array(kc.x));
 32   }
 33 
 34   @CodeReflection
 35   public static void compute(ComputeContext cc, S32Arr s32Arr) {
 36     cc.dispatchKernel(s32Arr.length(), kc -> kernel(kc, s32Arr));
 37   }
 38 }
 39 ```
 40 
 41 This code in the kernel has always bothered me.  One downside of using MemorySegment backed buffers,
 42 in HAT is that we have made what in array form, would be simple code, look verbose.
 43 
 44 ```java
 45  @CodeReflection
 46   public static void kernel(KernelContext kc, S32Arr s32Arr) {
 47     s32Arr.array(kc.x, s32Arr.array(kc.x) * s32Arr.array(kc.x));
 48   }
 49 ```
 50 
 51 But what if we added a method (`int[] arrayView()`) to `S32Arr` to extract a java int array `view` of simple arrays
 52 
 53 Becomes way more readable.
 54 ```java
 55  @CodeReflection
 56   public static void kernel(KernelContext kc, S32Arr s32Arr) {
 57     int[] arr = s32Arr.arrayView();
 58     arr[kc.x] *= arr[kc.x];
 59   }
 60 ```
 61 IMHO This makes code more readable.
 62 
 63 For the GPU this is fine.  We can (thanks to CodeReflection) prove that the array is indeed just a view
 64 and we just remove all references to `int arr[]` and replace array accessors with get/set accessors
 65 on the original S32Arr.
 66 
 67 But what about Java performance?. Won't it suck because we are copying the array in each kernel ;)
 68 
 69 Well, we can use the same trick, we used for the GPU, we take the transformed model (with array
 70 references removed) and create bytecode from that code we and run it.
 71 
 72 Historically, we just run the original bytecode in the Java MT/Seq backends, but we don't have to.
 73 
 74 This helps also with game of life
 75 ```java
 76   public static void lifePerIdx(int idx, @RO Control control, @RW CellGrid cellGrid) {
 77             int w = cellGrid.width();
 78             int h = cellGrid.height();
 79             int from = control.from();
 80             int to = control.to();
 81             int x = idx % w;
 82             int y = idx / w;
 83             byte cell = cellGrid.cell(idx + from);
 84             if (x > 0 && x < (w - 1) && y > 0 && y < (h - 1)) { // passports please
 85                 int count =
 86                         val(cellGrid, from, w, x - 1, y - 1)
 87                                 + val(cellGrid, from, w, x - 1, y + 0)
 88                                 + val(cellGrid, from, w, x - 1, y + 1)
 89                                 + val(cellGrid, from, w, x + 0, y - 1)
 90                                 + val(cellGrid, from, w, x + 0, y + 1)
 91                                 + val(cellGrid, from, w, x + 1, y + 0)
 92                                 + val(cellGrid, from, w, x + 1, y - 1)
 93                                 + val(cellGrid, from, w, x + 1, y + 1);
 94                 cell = ((count == 3) || ((count == 2) && (cell == ALIVE))) ? ALIVE : DEAD;// B3/S23.
 95             }
 96             cellGrid.cell(idx + to, cell);
 97         }
 98 ```
 99 
100 This code uses a helper function `val(grid, offset, w, dx, dy)` to extract the neighbours
101 ```java
102  int count =
103         val(cellGrid, from, w, x - 1, y - 1)
104                 + val(cellGrid, from, w, x - 1, y + 0)
105                 + val(cellGrid, from, w, x - 1, y + 1)
106                 + val(cellGrid, from, w, x + 0, y - 1)
107                 + val(cellGrid, from, w, x + 0, y + 1)
108                 + val(cellGrid, from, w, x + 1, y + 0)
109                 + val(cellGrid, from, w, x + 1, y - 1)
110                 + val(cellGrid, from, w, x + 1, y + 1);
111 ```
112 
113 Val is a bit verbose
114 
115 ```java
116   @CodeReflection
117         public static int val(@RO CellGrid grid, int from, int w, int x, int y) {
118             return grid.cell(((long) y * w) + x + from) & 1;
119         }
120 ```
121 
122 ```java
123   @CodeReflection
124         public static int val(@RO CellGrid grid, int from, int w, int x, int y) {
125             byte[] bytes = grid.byteView(); // bit view would be nice ;)
126             return bytes[ y * w + x + from] & 1;
127         }
128 ```
129 
130 We could now dispense with `val()` and just write
131 
132 ```java
133  byte[] bytes = grid.byteView();
134  int count =
135         bytes[(y - 1) * w + x - 1 + from]&1
136        +bytes[(y + 0) * w + x - 1 + from]&1
137        +bytes[(y + 1) * w + x - 1 + from]&1
138        +bytes[(y - 1) * w + x + 0 + from]&1
139        +bytes[(y + 1) * w + x + 0 + from]&1
140        +bytes[(y + 0) * w + x + 1 + from]&1
141        +bytes[(y - 1) * w + x + 1 + from]&1
142        +bytes[(y + 1) * w + x + 1 + from]&1 ;
143 ```
144 
145 BTW My inner verilog programmer has always wondered whether shift and oring each bit into a
146 9 bit value, which we use to index live/dead state from a prepopulated 512 array (GPU constant memory)
147 would allow us to sidestep the wave divergent conditional :)
148 
149 ```java
150 byte[] bytes = grid.byteView();
151 int idx = 0;
152 int to =0;
153 byte[] lookup = new byte[]{};
154 int lookupIdx =
155         bytes[(y - 1) * w + x - 1 + from]&1 <<0
156        |bytes[(y + 0) * w + x - 1 + from]&1 <<1
157        |bytes[(y + 1) * w + x - 1 + from]&1 <<2
158        |bytes[(y - 1) * w + x + 0 + from]&1 <<3
159        |bytes[(y - 0) * w + x + 0 + from]&1 <<4 // current cell added
160        |bytes[(y + 1) * w + x + 0 + from]&1 <<5
161        |bytes[(y + 0) * w + x + 1 + from]&1 <<6
162        |bytes[(y - 1) * w + x + 1 + from]&1 <<7
163        |bytes[(y + 1) * w + x + 1 + from]&1 <<8 ;
164 // conditional removed!
165 
166 bytes[idx + to] = lookup[lookupIdx];
167 ```
168 
169 So the task here is to process kernel code models and perform the appropriate analysis
170 (for tracking the primitive arrays origin) and transformations to map the array code back to buffer get/sets
171 
172 ----------
173 The arrayView trick actually leads us to other possibilities.
174 
175 Let's look at current NBody code.
176 
177 ```java
178   @CodeReflection
179     static public void nbodyKernel(@RO KernelContext kc, @RW Universe universe, float mass, float delT, float espSqr) {
180         float accx = 0.0f;
181         float accy = 0.0f;
182         float accz = 0.0f;
183         Universe.Body me = universe.body(kc.x);
184 
185         for (int i = 0; i < kc.maxX; i++) {
186             Universe.Body otherBody = universe.body(i);
187             float dx = otherBody.x() - me.x();
188             float dy = otherBody.y() - me.y();
189             float dz = otherBody.z() - me.z();
190             float invDist = (float) (1.0f / Math.sqrt(((dx * dx) + (dy * dy) + (dz * dz) + espSqr)));
191             float s = mass * invDist * invDist * invDist;
192             accx = accx + (s * dx);
193             accy = accy + (s * dy);
194             accz = accz + (s * dz);
195         }
196         accx = accx * delT;
197         accy = accy * delT;
198         accz = accz * delT;
199         me.x(me.x() + (me.vx() * delT) + accx * .5f * delT);
200         me.y(me.y() + (me.vy() * delT) + accy * .5f * delT);
201         me.z(me.z() + (me.vz() * delT) + accz * .5f * delT);
202         me.vx(me.vx() + accx);
203         me.vy(me.vy() + accy);
204         me.vz(me.vz() + accz);
205     }
206 ```
207 
208 Contrast, the above code with the OpenCL code using `float4`
209 
210 ```java
211  __kernel void nbody( __global float4 *xyzPos ,__global float4* xyzVel, float mass, float delT, float espSqr ){
212       float4 acc = (0.0, 0.0,0.0,0.0);
213       float4 myPos = xyzPos[get_global_id(0)];
214       float4 myVel = xyzVel[get_global_id(0)];
215       for (int i = 0; i < get_global_size(0); i++) {
216              float4 delta =  xyzPos[i] - myPos;
217              float invDist =  (float) 1.0/sqrt((float)((delta.x * delta.x) + (delta.y * delta.y) + (delta.z * delta.z) + espSqr));
218              float s = mass * invDist * invDist * invDist;
219              acc= acc + (s * delta);
220       }
221       acc = acc*delT;
222       myPos = myPos + (myVel * delT) + (acc * delT)/2;
223       myVel = myVel + acc;
224       xyzPos[get_global_id(0)] = myPos;
225       xyzVel[get_global_id(0)] = myVel;
226 }
227 ```
228 
229 Thanks to interface mapped segments we can approximate `float4` vector type, In fact the
230 existing `Universe` interface mapped segment actually embeds a `float6` kind of
231 object holding the `x,y,z,vx,vy,vz` values for each `body` (so position and velocity)
232 
233 What if we created generalized `float4` interface mapped views `(x,y,z,w)` as placeholders for true vector types.
234 
235 And modified Universe to hold `pos` and `vel` `float4` arrays.
236 
237 So in Java code we would pass a MemorySegment version of a F32x4Arr.
238 
239 Our java code could then start to approximate the OpenCL code.
240 
241 Except for our lack of operator overloading...
242 
243 ```java
244  void nbody(KernelContext kc,  F32x4Arr xyzPos ,F32x4Arr xyzVel, float mass, float delT, float espSqr ){
245       float4 acc = float4.of(0.0,0.0,0.0,0.0);
246       float4[] xyPosArr = xyzPos.float4View();
247 
248       float4 myPos = xyzPosArr[kc.x];
249       float4 myVel = xyzVelArr[kc.x];
250       for (int i = 0; i < kc.max; i++) {
251              float4 delta =  float4.sub(xyzPosArr[i],myPos); // yucky but ok
252              float invDist =  (float) 1.0/sqrt((float)((delta.x * delta.x) + (delta.y * delta.y) + (delta.z * delta.z) + espSqr));
253              float s = mass * invDist * invDist * invDist;
254              acc= float4.add(acc,(s * delta)); // adding scaler to float4 via overloaded add method
255       }
256       acc = float4.mul(acc*delT);  // scaler * vector
257       myPos = float4.add(float4.add(myPos,float4.mul(myVel * delT)) , float4.mul(acc,delT/2));
258       myVel = float4.add(myVel,acc);
259       xyzPos[kc.x] = myPos;
260       xyzVel[kc.x] = myVel;
261 }
262 ```
263 The code is more compact, it's still weird though. Because we can't overload operators.
264 
265 Well we can sort of.
266 
267 What if we allowed a `floatView()` call on `float4` ;) which yields float value to be used as a proxy
268 for the `float4` we fetched it from...
269 
270 So we would leverage the anonymity of `var`
271 
272 `var myVelF4 = myVel.floatView()` // pst var is really a float
273 
274 From the code model we can track the relationship from float views to the original vector...
275 
276 Any action we perform on the 'float' view will be mapped back to calls on the origin, and performed on the actual origin float4.
277 
278 So
279 `myVelF4 + myVelF4`  -> `float4.add(myVel,myVel)` behind the scenes.
280 
281 Yeah, it's a bit crazy. The code would look something like this.  Perhaps this is a 'bridge too far'.
282 
283 ```java
284 void nbody(KernelContext kc,  F32x4Arr xyzPos ,F32x4Arr xyzVel, float mass, float delT, float espSqr ){
285       float4 acc = float4.of(0.0,0.0,0.0,0.0);
286       var accF4 = acc.floatView(); // var is really float
287       float4[] xyPosArr = xyzPos.float4View();
288 
289       float4 myPos = xyzPosArr[kc.x];
290       float4 myVel = xyzVelArr[kc.x];
291       var myPosF4 = myPos.floatView(); // ;)  var is actually a float tied to myPos
292       var myVelF4 = myVel.floatView(); //... myVel
293       for (int i = 0; i < kc.max; i++) {
294              var bodyF4 = xyzPosArr[i].floatView(); // bodyF4 is a float
295              var  deltaF4 = bodyF4 - myPosF4;  // look ma operator overloading ;)
296              float invDist =  (float) 1.0/sqrt((float)((delta.x * delta.x) + (delta.y * delta.y) + (delta.z * delta.z) + espSqr));
297              float s = mass * invDist * invDist * invDist;
298              accF4+=s * deltaF4; // adding scaler to float4 via overloaded add method
299       }
300       accF4 = accF4*delT;  // scalar * vector
301       myPosF4 = myPosF4 + (myVelF4 * delT) + (accF4 * delT)/2;
302       myVelF4 = myVelF4 + accF4;
303       xyzPos[kc.x] = myPos;
304       xyzVel[kc.x] = myVel;
305 }
306 ```
307 
308