1 #### Leverage CodeReflection to expose array view of buffers in kernels
  2 ----
  3 
  4 * [Contents](hat-00.md)
  5 * House Keeping
  6     * [Project Layout](hat-01-01-project-layout.md)
  7     * [Building Babylon](hat-01-02-building-babylon.md)
  8     * [Building HAT](hat-01-03-building-hat.md)
  9 * Programming Model
 10     * [Programming Model](hat-03-programming-model.md)
 11 * Interface Mapping
 12     * [Interface Mapping Overview](hat-04-01-interface-mapping.md)
 13     * [Cascade Interface Mapping](hat-04-02-cascade-interface-mapping.md)
 14 * Implementation Detail
 15     * [Walkthrough Of Accelerator.compute()](hat-accelerator-compute.md)
 16     * [How we minimize buffer transfers](hat-minimizing-buffer-transfers.md)
 17 
 18 ----
 19 
 20 Here is the canonical HAT example
 21 
 22 ```java
 23 import jdk.incubator.code.CodeReflection;
 24 
 25 @CodeReflection
 26 public class Square {
 27   @CodeReflection
 28   public static void kernel(KernelContext kc, S32Arr s32Arr) {
 29     s32Arr.array(kc.x, s32Arr.array(kc.x) * s32Arr.array(kc.x));
 30   }
 31 
 32   @CodeReflection
 33   public static void compute(ComputeContext cc, S32Arr s32Arr) {
 34     cc.dispatchKernel(s32Arr.length(), kc -> kernel(kc, s32Arr));
 35   }
 36 }
 37 ```
 38 
 39 This code in the kernel has always bothered me.  One downside of using MemorySegment backed buffers,
 40 in HAT is that we have made what in array form, would be simple code, look verbose.
 41 
 42 ```java
 43  @CodeReflection
 44   public static void kernel(KernelContext kc, S32Arr s32Arr) {
 45     s32Arr.array(kc.x, s32Arr.array(kc.x) * s32Arr.array(kc.x));
 46   }
 47 ```
 48 
 49 But what if we added a method (`int[] arrayView()`) to `S32Arr` to extract a java int array `view` of simple arrays
 50 
 51 Becomes way more readable.
 52 ```java
 53  @CodeReflection
 54   public static void kernel(KernelContext kc, S32Arr s32Arr) {
 55     int[] arr = s32Arr.arrayView();
 56     arr[kc.x] *= arr[kc.x];
 57   }
 58 ```
 59 IMHO This makes code more readable.
 60 
 61 For the GPU this is fine.  We can (thanks to CodeReflection) prove that the array is indeed just a view
 62 and we just remove all references to `int arr[]` and replace array accessors with get/set accessors
 63 on the original S32Arr.
 64 
 65 But what about Java performance?. Won't it suck because we are copying the array in each kernel ;)
 66 
 67 Well, we can use the same trick, we used for the GPU, we take the transformed model (with array
 68 references removed) and create bytecode from that code we and run it.
 69 
 70 Historically, we just run the original bytecode in the Java MT/Seq backends, but we don't have to.
 71 
 72 This helps also with game of life
 73 ```java
 74   public static void lifePerIdx(int idx, @RO Control control, @RW CellGrid cellGrid) {
 75             int w = cellGrid.width();
 76             int h = cellGrid.height();
 77             int from = control.from();
 78             int to = control.to();
 79             int x = idx % w;
 80             int y = idx / w;
 81             byte cell = cellGrid.cell(idx + from);
 82             if (x > 0 && x < (w - 1) && y > 0 && y < (h - 1)) { // passports please
 83                 int count =
 84                         val(cellGrid, from, w, x - 1, y - 1)
 85                                 + val(cellGrid, from, w, x - 1, y + 0)
 86                                 + val(cellGrid, from, w, x - 1, y + 1)
 87                                 + val(cellGrid, from, w, x + 0, y - 1)
 88                                 + val(cellGrid, from, w, x + 0, y + 1)
 89                                 + val(cellGrid, from, w, x + 1, y + 0)
 90                                 + val(cellGrid, from, w, x + 1, y - 1)
 91                                 + val(cellGrid, from, w, x + 1, y + 1);
 92                 cell = ((count == 3) || ((count == 2) && (cell == ALIVE))) ? ALIVE : DEAD;// B3/S23.
 93             }
 94             cellGrid.cell(idx + to, cell);
 95         }
 96 ```
 97 
 98 This code uses a helper function `val(grid, offset, w, dx, dy)` to extract the neighbours
 99 ```java
100  int count =
101         val(cellGrid, from, w, x - 1, y - 1)
102                 + val(cellGrid, from, w, x - 1, y + 0)
103                 + val(cellGrid, from, w, x - 1, y + 1)
104                 + val(cellGrid, from, w, x + 0, y - 1)
105                 + val(cellGrid, from, w, x + 0, y + 1)
106                 + val(cellGrid, from, w, x + 1, y + 0)
107                 + val(cellGrid, from, w, x + 1, y - 1)
108                 + val(cellGrid, from, w, x + 1, y + 1);
109 ```
110 
111 Val is a bit verbose
112 
113 ```java
114   @CodeReflection
115         public static int val(@RO CellGrid grid, int from, int w, int x, int y) {
116             return grid.cell(((long) y * w) + x + from) & 1;
117         }
118 ```
119 
120 ```java
121   @CodeReflection
122         public static int val(@RO CellGrid grid, int from, int w, int x, int y) {
123             byte[] bytes = grid.byteView(); // bit view would be nice ;)
124             return bytes[ y * w + x + from] & 1;
125         }
126 ```
127 
128 We could now dispense with `val()` and just write
129 
130 ```java
131  byte[] bytes = grid.byteView();
132  int count =
133         bytes[(y - 1) * w + x - 1 + from]&1
134        +bytes[(y + 0) * w + x - 1 + from]&1
135        +bytes[(y + 1) * w + x - 1 + from]&1
136        +bytes[(y - 1) * w + x + 0 + from]&1
137        +bytes[(y + 1) * w + x + 0 + from]&1
138        +bytes[(y + 0) * w + x + 1 + from]&1
139        +bytes[(y - 1) * w + x + 1 + from]&1
140        +bytes[(y + 1) * w + x + 1 + from]&1 ;
141 ```
142 
143 BTW My inner verilog programmer has always wondered whether shift and oring each bit into a
144 9 bit value, which we use to index live/dead state from a prepopulated 512 array (GPU constant memory)
145 would allow us to sidestep the wave divergent conditional :)
146 
147 ```java
148 byte[] bytes = grid.byteView();
149 int idx = 0;
150 int to =0;
151 byte[] lookup = new byte[]{};
152 int lookupIdx =
153         bytes[(y - 1) * w + x - 1 + from]&1 <<0
154        |bytes[(y + 0) * w + x - 1 + from]&1 <<1
155        |bytes[(y + 1) * w + x - 1 + from]&1 <<2
156        |bytes[(y - 1) * w + x + 0 + from]&1 <<3
157        |bytes[(y - 0) * w + x + 0 + from]&1 <<4 // current cell added
158        |bytes[(y + 1) * w + x + 0 + from]&1 <<5
159        |bytes[(y + 0) * w + x + 1 + from]&1 <<6
160        |bytes[(y - 1) * w + x + 1 + from]&1 <<7
161        |bytes[(y + 1) * w + x + 1 + from]&1 <<8 ;
162 // conditional removed!
163 
164 bytes[idx + to] = lookup[lookupIdx];
165 ```
166 
167 So the task here is to process kernel code models and perform the appropriate analysis
168 (for tracking the primitive arrays origin) and transformations to map the array code back to buffer get/sets
169 
170 ----------
171 The arrayView trick actually leads us to other possibilities.
172 
173 Let's look at current NBody code.
174 
175 ```java
176   @CodeReflection
177     static public void nbodyKernel(@RO KernelContext kc, @RW Universe universe, float mass, float delT, float espSqr) {
178         float accx = 0.0f;
179         float accy = 0.0f;
180         float accz = 0.0f;
181         Universe.Body me = universe.body(kc.x);
182 
183         for (int i = 0; i < kc.maxX; i++) {
184             Universe.Body otherBody = universe.body(i);
185             float dx = otherBody.x() - me.x();
186             float dy = otherBody.y() - me.y();
187             float dz = otherBody.z() - me.z();
188             float invDist = (float) (1.0f / Math.sqrt(((dx * dx) + (dy * dy) + (dz * dz) + espSqr)));
189             float s = mass * invDist * invDist * invDist;
190             accx = accx + (s * dx);
191             accy = accy + (s * dy);
192             accz = accz + (s * dz);
193         }
194         accx = accx * delT;
195         accy = accy * delT;
196         accz = accz * delT;
197         me.x(me.x() + (me.vx() * delT) + accx * .5f * delT);
198         me.y(me.y() + (me.vy() * delT) + accy * .5f * delT);
199         me.z(me.z() + (me.vz() * delT) + accz * .5f * delT);
200         me.vx(me.vx() + accx);
201         me.vy(me.vy() + accy);
202         me.vz(me.vz() + accz);
203     }
204 ```
205 
206 Contrast, the above code with the OpenCL code using `float4`
207 
208 ```java
209  __kernel void nbody( __global float4 *xyzPos ,__global float4* xyzVel, float mass, float delT, float espSqr ){
210       float4 acc = (0.0, 0.0,0.0,0.0);
211       float4 myPos = xyzPos[get_global_id(0)];
212       float4 myVel = xyzVel[get_global_id(0)];
213       for (int i = 0; i < get_global_size(0); i++) {
214              float4 delta =  xyzPos[i] - myPos;
215              float invDist =  (float) 1.0/sqrt((float)((delta.x * delta.x) + (delta.y * delta.y) + (delta.z * delta.z) + espSqr));
216              float s = mass * invDist * invDist * invDist;
217              acc= acc + (s * delta);
218       }
219       acc = acc*delT;
220       myPos = myPos + (myVel * delT) + (acc * delT)/2;
221       myVel = myVel + acc;
222       xyzPos[get_global_id(0)] = myPos;
223       xyzVel[get_global_id(0)] = myVel;
224 }
225 ```
226 
227 Thanks to interface mapped segments we can approximate `float4` vector type, In fact the
228 existing `Universe` interface mapped segment actually embeds a `float6` kind of
229 object holding the `x,y,z,vx,vy,vz` values for each `body` (so position and velocity)
230 
231 What if we created generalized `float4` interface mapped views `(x,y,z,w)` as placeholders for true vector types.
232 
233 And modified Universe to hold `pos` and `vel` `float4` arrays.
234 
235 So in Java code we would pass a MemorySegment version of a F32x4Arr.
236 
237 Our java code could then start to approximate the OpenCL code.
238 
239 Except for our lack of operator overloading...
240 
241 ```java
242  void nbody(KernelContext kc,  F32x4Arr xyzPos ,F32x4Arr xyzVel, float mass, float delT, float espSqr ){
243       float4 acc = float4.of(0.0,0.0,0.0,0.0);
244       float4[] xyPosArr = xyzPos.float4View();
245 
246       float4 myPos = xyzPosArr[kc.x];
247       float4 myVel = xyzVelArr[kc.x];
248       for (int i = 0; i < kc.max; i++) {
249              float4 delta =  float4.sub(xyzPosArr[i],myPos); // yucky but ok
250              float invDist =  (float) 1.0/sqrt((float)((delta.x * delta.x) + (delta.y * delta.y) + (delta.z * delta.z) + espSqr));
251              float s = mass * invDist * invDist * invDist;
252              acc= float4.add(acc,(s * delta)); // adding scaler to float4 via overloaded add method
253       }
254       acc = float4.mul(acc*delT);  // scaler * vector
255       myPos = float4.add(float4.add(myPos,float4.mul(myVel * delT)) , float4.mul(acc,delT/2));
256       myVel = float4.add(myVel,acc);
257       xyzPos[kc.x] = myPos;
258       xyzVel[kc.x] = myVel;
259 }
260 ```
261 The code is more compact, it's still weird though. Because we can't overload operators.
262 
263 Well we can sort of.
264 
265 What if we allowed a `floatView()` call on `float4` ;) which yields float value to be used as a proxy
266 for the `float4` we fetched it from...
267 
268 So we would leverage the anonymity of `var`
269 
270 `var myVelF4 = myVel.floatView()` // pst var is really a float
271 
272 From the code model we can track the relationship from float views to the original vector...
273 
274 Any action we perform on the 'float' view will be mapped back to calls on the origin, and performed on the actual origin float4.
275 
276 So
277 `myVelF4 + myVelF4`  -> `float4.add(myVel,myVel)` behind the scenes.
278 
279 Yeah, it's a bit crazy. The code would look something like this.  Perhaps this is a 'bridge too far'.
280 
281 ```java
282 void nbody(KernelContext kc,  F32x4Arr xyzPos ,F32x4Arr xyzVel, float mass, float delT, float espSqr ){
283       float4 acc = float4.of(0.0,0.0,0.0,0.0);
284       var accF4 = acc.floatView(); // var is really float
285       float4[] xyPosArr = xyzPos.float4View();
286 
287       float4 myPos = xyzPosArr[kc.x];
288       float4 myVel = xyzVelArr[kc.x];
289       var myPosF4 = myPos.floatView(); // ;)  var is actually a float tied to myPos
290       var myVelF4 = myVel.floatView(); //... myVel
291       for (int i = 0; i < kc.max; i++) {
292              var bodyF4 = xyzPosArr[i].floatView(); // bodyF4 is a float
293              var  deltaF4 = bodyF4 - myPosF4;  // look ma operator overloading ;)
294              float invDist =  (float) 1.0/sqrt((float)((delta.x * delta.x) + (delta.y * delta.y) + (delta.z * delta.z) + espSqr));
295              float s = mass * invDist * invDist * invDist;
296              accF4+=s * deltaF4; // adding scaler to float4 via overloaded add method
297       }
298       accF4 = accF4*delT;  // scalar * vector
299       myPosF4 = myPosF4 + (myVelF4 * delT) + (accF4 * delT)/2;
300       myVelF4 = myVelF4 + accF4;
301       xyzPos[kc.x] = myPos;
302       xyzVel[kc.x] = myVel;
303 }
304 ```
305 
306