1 ## Leverage CodeReflection to expose array view of buffers in kernels
  2 
  3 ----
  4 * [Contents](hat-00.md)
  5 * Build Babylon and HAT
  6     * [Quick Install](hat-01-quick-install.md)
  7     * [Building Babylon with jtreg](hat-01-02-building-babylon.md)
  8     * [Building HAT with jtreg](hat-01-03-building-hat.md)
  9         * [Enabling the NVIDIA CUDA Backend](hat-01-05-building-hat-for-cuda.md)
 10 * [Testing Framework](hat-02-testing-framework.md)
 11 * [Running Examples](hat-03-examples.md)
 12 * [HAT Programming Model](hat-03-programming-model.md)
 13 * Interface Mapping
 14     * [Interface Mapping Overview](hat-04-01-interface-mapping.md)
 15     * [Cascade Interface Mapping](hat-04-02-cascade-interface-mapping.md)
 16 * Development
 17     * [Project Layout](hat-01-01-project-layout.md)
 18     * [IntelliJ Code Formatter](hat-development.md)
 19 * Implementation Details
 20     * [Walkthrough Of Accelerator.compute()](hat-accelerator-compute.md)
 21     * [How we minimize buffer transfers](hat-minimizing-buffer-transfers.md)
 22 * [Running HAT with Docker on NVIDIA GPUs](hat-07-docker-build-nvidia.md)
 23 ---
 24 
 25 Here is the canonical HAT example
 26 
 27 ```java
 28 import jdk.incubator.code.Reflect;
 29 
 30 @Reflect
 31 public class Square {
 32     @Reflect
 33     public static void kernel(KernelContext kc, S32Arr s32Arr) {
 34         s32Arr.array(kc.x, s32Arr.array(kc.x) * s32Arr.array(kc.x));
 35     }
 36 
 37     @Reflect
 38     public static void compute(ComputeContext cc, S32Arr s32Arr) {
 39         cc.dispatchKernel(s32Arr.length(), kc -> kernel(kc, s32Arr));
 40     }
 41 }
 42 ```
 43 
 44 This code in the kernel has always bothered me.  One downside of using MemorySegment backed buffers,
 45 in HAT is that we have made what in array form, would be simple code, look verbose.
 46 
 47 ```java
 48  @Reflect
 49   public static void kernel(KernelContext kc, S32Arr s32Arr) {
 50     s32Arr.array(kc.x, s32Arr.array(kc.x) * s32Arr.array(kc.x));
 51   }
 52 ```
 53 
 54 But what if we added a method (`int[] arrayView()`) to `S32Arr` to extract a java int array `view` of simple arrays
 55 
 56 Becomes way more readable.
 57 ```java
 58  @Reflect
 59   public static void kernel(KernelContext kc, S32Arr s32Arr) {
 60     int[] arr = s32Arr.arrayView();
 61     arr[kc.x] *= arr[kc.x];
 62   }
 63 ```
 64 IMHO This makes code more readable.
 65 
 66 For the GPU this is fine.  We can (thanks to CodeReflection) prove that the array is indeed just a view
 67 and we just remove all references to `int arr[]` and replace array accessors with get/set accessors
 68 on the original S32Arr.
 69 
 70 But what about Java performance?. Won't it suck because we are copying the array in each kernel ;)
 71 
 72 Well, we can use the same trick, we used for the GPU, we take the transformed model (with array
 73 references removed) and create bytecode from that code we and run it.
 74 
 75 Historically, we just run the original bytecode in the Java MT/Seq backends, but we don't have to.
 76 
 77 This helps also with game of life
 78 ```java
 79   public static void lifePerIdx(int idx, @RO Control control, @RW CellGrid cellGrid) {
 80             int w = cellGrid.width();
 81             int h = cellGrid.height();
 82             int from = control.from();
 83             int to = control.to();
 84             int x = idx % w;
 85             int y = idx / w;
 86             byte cell = cellGrid.cell(idx + from);
 87             if (x > 0 && x < (w - 1) && y > 0 && y < (h - 1)) { // passports please
 88                 int count =
 89                         val(cellGrid, from, w, x - 1, y - 1)
 90                                 + val(cellGrid, from, w, x - 1, y + 0)
 91                                 + val(cellGrid, from, w, x - 1, y + 1)
 92                                 + val(cellGrid, from, w, x + 0, y - 1)
 93                                 + val(cellGrid, from, w, x + 0, y + 1)
 94                                 + val(cellGrid, from, w, x + 1, y + 0)
 95                                 + val(cellGrid, from, w, x + 1, y - 1)
 96                                 + val(cellGrid, from, w, x + 1, y + 1);
 97                 cell = ((count == 3) || ((count == 2) && (cell == ALIVE))) ? ALIVE : DEAD;// B3/S23.
 98             }
 99             cellGrid.cell(idx + to, cell);
100         }
101 ```
102 
103 This code uses a helper function `val(grid, offset, w, dx, dy)` to extract the neighbours
104 ```java
105  int count =
106         val(cellGrid, from, w, x - 1, y - 1)
107                 + val(cellGrid, from, w, x - 1, y + 0)
108                 + val(cellGrid, from, w, x - 1, y + 1)
109                 + val(cellGrid, from, w, x + 0, y - 1)
110                 + val(cellGrid, from, w, x + 0, y + 1)
111                 + val(cellGrid, from, w, x + 1, y + 0)
112                 + val(cellGrid, from, w, x + 1, y - 1)
113                 + val(cellGrid, from, w, x + 1, y + 1);
114 ```
115 
116 Val is a bit verbose
117 
118 ```java
119   @Reflect
120         public static int val(@RO CellGrid grid, int from, int w, int x, int y) {
121             return grid.cell(((long) y * w) + x + from) & 1;
122         }
123 ```
124 
125 ```java
126   @Reflect
127         public static int val(@RO CellGrid grid, int from, int w, int x, int y) {
128             byte[] bytes = grid.byteView(); // bit view would be nice ;)
129             return bytes[ y * w + x + from] & 1;
130         }
131 ```
132 
133 We could now dispense with `val()` and just write
134 
135 ```java
136  byte[] bytes = grid.byteView();
137  int count =
138         bytes[(y - 1) * w + x - 1 + from]&1
139        +bytes[(y + 0) * w + x - 1 + from]&1
140        +bytes[(y + 1) * w + x - 1 + from]&1
141        +bytes[(y - 1) * w + x + 0 + from]&1
142        +bytes[(y + 1) * w + x + 0 + from]&1
143        +bytes[(y + 0) * w + x + 1 + from]&1
144        +bytes[(y - 1) * w + x + 1 + from]&1
145        +bytes[(y + 1) * w + x + 1 + from]&1 ;
146 ```
147 
148 BTW My inner verilog programmer has always wondered whether shift and oring each bit into a
149 9 bit value, which we use to index live/dead state from a prepopulated 512 array (GPU constant memory)
150 would allow us to sidestep the wave divergent conditional :)
151 
152 ```java
153 byte[] bytes = grid.byteView();
154 int idx = 0;
155 int to =0;
156 byte[] lookup = new byte[]{};
157 int lookupIdx =
158         bytes[(y - 1) * w + x - 1 + from]&1 <<0
159        |bytes[(y + 0) * w + x - 1 + from]&1 <<1
160        |bytes[(y + 1) * w + x - 1 + from]&1 <<2
161        |bytes[(y - 1) * w + x + 0 + from]&1 <<3
162        |bytes[(y - 0) * w + x + 0 + from]&1 <<4 // current cell added
163        |bytes[(y + 1) * w + x + 0 + from]&1 <<5
164        |bytes[(y + 0) * w + x + 1 + from]&1 <<6
165        |bytes[(y - 1) * w + x + 1 + from]&1 <<7
166        |bytes[(y + 1) * w + x + 1 + from]&1 <<8 ;
167 // conditional removed!
168 
169 bytes[idx + to] = lookup[lookupIdx];
170 ```
171 
172 So the task here is to process kernel code models and perform the appropriate analysis
173 (for tracking the primitive arrays origin) and transformations to map the array code back to buffer get/sets
174 
175 ----------
176 The arrayView trick actually leads us to other possibilities.
177 
178 Let's look at current NBody code.
179 
180 ```java
181   @Reflect
182     static public void nbodyKernel(@RO KernelContext kc, @RW Universe universe, float mass, float delT, float espSqr) {
183         float accx = 0.0f;
184         float accy = 0.0f;
185         float accz = 0.0f;
186         Universe.Body me = universe.body(kc.x);
187 
188         for (int i = 0; i < kc.maxX; i++) {
189             Universe.Body otherBody = universe.body(i);
190             float dx = otherBody.x() - me.x();
191             float dy = otherBody.y() - me.y();
192             float dz = otherBody.z() - me.z();
193             float invDist = (float) (1.0f / Math.sqrt(((dx * dx) + (dy * dy) + (dz * dz) + espSqr)));
194             float s = mass * invDist * invDist * invDist;
195             accx = accx + (s * dx);
196             accy = accy + (s * dy);
197             accz = accz + (s * dz);
198         }
199         accx = accx * delT;
200         accy = accy * delT;
201         accz = accz * delT;
202         me.x(me.x() + (me.vx() * delT) + accx * .5f * delT);
203         me.y(me.y() + (me.vy() * delT) + accy * .5f * delT);
204         me.z(me.z() + (me.vz() * delT) + accz * .5f * delT);
205         me.vx(me.vx() + accx);
206         me.vy(me.vy() + accy);
207         me.vz(me.vz() + accz);
208     }
209 ```
210 
211 Contrast, the above code with the OpenCL code using `float4`
212 
213 ```java
214  __kernel void nbody( __global float4 *xyzPos ,__global float4* xyzVel, float mass, float delT, float espSqr ){
215       float4 acc = (0.0, 0.0,0.0,0.0);
216       float4 myPos = xyzPos[get_global_id(0)];
217       float4 myVel = xyzVel[get_global_id(0)];
218       for (int i = 0; i < get_global_size(0); i++) {
219              float4 delta =  xyzPos[i] - myPos;
220              float invDist =  (float) 1.0/sqrt((float)((delta.x * delta.x) + (delta.y * delta.y) + (delta.z * delta.z) + espSqr));
221              float s = mass * invDist * invDist * invDist;
222              acc= acc + (s * delta);
223       }
224       acc = acc*delT;
225       myPos = myPos + (myVel * delT) + (acc * delT)/2;
226       myVel = myVel + acc;
227       xyzPos[get_global_id(0)] = myPos;
228       xyzVel[get_global_id(0)] = myVel;
229 }
230 ```
231 
232 Thanks to interface mapped segments we can approximate `float4` vector type, In fact the
233 existing `Universe` interface mapped segment actually embeds a `float6` kind of
234 object holding the `x,y,z,vx,vy,vz` values for each `body` (so position and velocity)
235 
236 What if we created generalized `float4` interface mapped views `(x,y,z,w)` as placeholders for true vector types.
237 
238 And modified Universe to hold `pos` and `vel` `float4` arrays.
239 
240 So in Java code we would pass a MemorySegment version of a F32x4Arr.
241 
242 Our java code could then start to approximate the OpenCL code.
243 
244 Except for our lack of operator overloading...
245 
246 ```java
247  void nbody(KernelContext kc,  F32x4Arr xyzPos ,F32x4Arr xyzVel, float mass, float delT, float espSqr ){
248       float4 acc = float4.of(0.0,0.0,0.0,0.0);
249       float4[] xyPosArr = xyzPos.float4View();
250 
251       float4 myPos = xyzPosArr[kc.x];
252       float4 myVel = xyzVelArr[kc.x];
253       for (int i = 0; i < kc.max; i++) {
254              float4 delta =  float4.sub(xyzPosArr[i],myPos); // yucky but ok
255              float invDist =  (float) 1.0/sqrt((float)((delta.x * delta.x) + (delta.y * delta.y) + (delta.z * delta.z) + espSqr));
256              float s = mass * invDist * invDist * invDist;
257              acc= float4.add(acc,(s * delta)); // adding scaler to float4 via overloaded add method
258       }
259       acc = float4.mul(acc*delT);  // scaler * vector
260       myPos = float4.add(float4.add(myPos,float4.mul(myVel * delT)) , float4.mul(acc,delT/2));
261       myVel = float4.add(myVel,acc);
262       xyzPos[kc.x] = myPos;
263       xyzVel[kc.x] = myVel;
264 }
265 ```
266 The code is more compact, it's still weird though. Because we can't overload operators.
267 
268 Well we can sort of.
269 
270 What if we allowed a `floatView()` call on `float4` ;) which yields float value to be used as a proxy
271 for the `float4` we fetched it from...
272 
273 So we would leverage the anonymity of `var`
274 
275 `var myVelF4 = myVel.floatView()` // pst var is really a float
276 
277 From the code model we can track the relationship from float views to the original vector...
278 
279 Any action we perform on the 'float' view will be mapped back to calls on the origin, and performed on the actual origin float4.
280 
281 So
282 `myVelF4 + myVelF4`  -> `float4.add(myVel,myVel)` behind the scenes.
283 
284 Yeah, it's a bit crazy. The code would look something like this.  Perhaps this is a 'bridge too far'.
285 
286 ```java
287 void nbody(KernelContext kc,  F32x4Arr xyzPos ,F32x4Arr xyzVel, float mass, float delT, float espSqr ){
288       float4 acc = float4.of(0.0,0.0,0.0,0.0);
289       var accF4 = acc.floatView(); // var is really float
290       float4[] xyPosArr = xyzPos.float4View();
291 
292       float4 myPos = xyzPosArr[kc.x];
293       float4 myVel = xyzVelArr[kc.x];
294       var myPosF4 = myPos.floatView(); // ;)  var is actually a float tied to myPos
295       var myVelF4 = myVel.floatView(); //... myVel
296       for (int i = 0; i < kc.max; i++) {
297              var bodyF4 = xyzPosArr[i].floatView(); // bodyF4 is a float
298              var  deltaF4 = bodyF4 - myPosF4;  // look ma operator overloading ;)
299              float invDist =  (float) 1.0/sqrt((float)((delta.x * delta.x) + (delta.y * delta.y) + (delta.z * delta.z) + espSqr));
300              float s = mass * invDist * invDist * invDist;
301              accF4+=s * deltaF4; // adding scaler to float4 via overloaded add method
302       }
303       accF4 = accF4*delT;  // scalar * vector
304       myPosF4 = myPosF4 + (myVelF4 * delT) + (accF4 * delT)/2;
305       myVelF4 = myVelF4 + accF4;
306       xyzPos[kc.x] = myPos;
307       xyzVel[kc.x] = myVel;
308 }
309 ```
310 
311