1 ## Leverage CodeReflection to expose array view of buffers in kernels
  2 [Back to Index ../](../index.md)
  3 
  4 Here is the canonical HAT example
  5 
  6 ```java
  7 import jdk.incubator.code.Reflect;
  8 
  9 @Reflect
 10 public class Square {
 11     @Reflect
 12     public static void kernel(KernelContext kc, S32Arr s32Arr) {
 13         s32Arr.array(kc.x, s32Arr.array(kc.x) * s32Arr.array(kc.x));
 14     }
 15 
 16     @Reflect
 17     public static void compute(ComputeContext cc, S32Arr s32Arr) {
 18         cc.dispatchKernel(s32Arr.length(), kc -> kernel(kc, s32Arr));
 19     }
 20 }
 21 ```
 22 
 23 This code in the kernel has always bothered me.  One downside of using MemorySegment backed buffers,
 24 in HAT is that we have made what in array form, would be simple code, look verbose.
 25 
 26 ```java
 27  @Reflect
 28   public static void kernel(KernelContext kc, S32Arr s32Arr) {
 29     s32Arr.array(kc.x, s32Arr.array(kc.x) * s32Arr.array(kc.x));
 30   }
 31 ```
 32 
 33 But what if we added a method (`int[] arrayView()`) to `S32Arr` to extract a java int array `view` of simple arrays
 34 
 35 Becomes way more readable.
 36 ```java
 37  @Reflect
 38   public static void kernel(KernelContext kc, S32Arr s32Arr) {
 39     int[] arr = s32Arr.arrayView();
 40     arr[kc.x] *= arr[kc.x];
 41   }
 42 ```
 43 IMHO This makes code more readable.
 44 
 45 For the GPU this is fine.  We can (thanks to CodeReflection) prove that the array is indeed just a view
 46 and we just remove all references to `int arr[]` and replace array accessors with get/set accessors
 47 on the original S32Arr.
 48 
 49 But what about Java performance?. Won't it suck because we are copying the array in each kernel ;)
 50 
 51 Well, we can use the same trick, we used for the GPU, we take the transformed model (with array
 52 references removed) and create bytecode from that code we and run it.
 53 
 54 Historically, we just run the original bytecode in the Java MT/Seq backends, but we don't have to.
 55 
 56 This helps also with game of life
 57 ```java
 58   public static void lifePerIdx(int idx, @RO Control control, @RW CellGrid cellGrid) {
 59             int w = cellGrid.width();
 60             int h = cellGrid.height();
 61             int from = control.from();
 62             int to = control.to();
 63             int x = idx % w;
 64             int y = idx / w;
 65             byte cell = cellGrid.cell(idx + from);
 66             if (x > 0 && x < (w - 1) && y > 0 && y < (h - 1)) { // passports please
 67                 int count =
 68                         val(cellGrid, from, w, x - 1, y - 1)
 69                                 + val(cellGrid, from, w, x - 1, y + 0)
 70                                 + val(cellGrid, from, w, x - 1, y + 1)
 71                                 + val(cellGrid, from, w, x + 0, y - 1)
 72                                 + val(cellGrid, from, w, x + 0, y + 1)
 73                                 + val(cellGrid, from, w, x + 1, y + 0)
 74                                 + val(cellGrid, from, w, x + 1, y - 1)
 75                                 + val(cellGrid, from, w, x + 1, y + 1);
 76                 cell = ((count == 3) || ((count == 2) && (cell == ALIVE))) ? ALIVE : DEAD;// B3/S23.
 77             }
 78             cellGrid.cell(idx + to, cell);
 79         }
 80 ```
 81 
 82 This code uses a helper function `val(grid, offset, w, dx, dy)` to extract the neighbours
 83 ```java
 84  int count =
 85         val(cellGrid, from, w, x - 1, y - 1)
 86                 + val(cellGrid, from, w, x - 1, y + 0)
 87                 + val(cellGrid, from, w, x - 1, y + 1)
 88                 + val(cellGrid, from, w, x + 0, y - 1)
 89                 + val(cellGrid, from, w, x + 0, y + 1)
 90                 + val(cellGrid, from, w, x + 1, y + 0)
 91                 + val(cellGrid, from, w, x + 1, y - 1)
 92                 + val(cellGrid, from, w, x + 1, y + 1);
 93 ```
 94 
 95 Val is a bit verbose
 96 
 97 ```java
 98   @Reflect
 99         public static int val(@RO CellGrid grid, int from, int w, int x, int y) {
100             return grid.cell(((long) y * w) + x + from) & 1;
101         }
102 ```
103 
104 ```java
105   @Reflect
106         public static int val(@RO CellGrid grid, int from, int w, int x, int y) {
107             byte[] bytes = grid.byteView(); // bit view would be nice ;)
108             return bytes[ y * w + x + from] & 1;
109         }
110 ```
111 
112 We could now dispense with `val()` and just write
113 
114 ```java
115  byte[] bytes = grid.byteView();
116  int count =
117         bytes[(y - 1) * w + x - 1 + from]&1
118        +bytes[(y + 0) * w + x - 1 + from]&1
119        +bytes[(y + 1) * w + x - 1 + from]&1
120        +bytes[(y - 1) * w + x + 0 + from]&1
121        +bytes[(y + 1) * w + x + 0 + from]&1
122        +bytes[(y + 0) * w + x + 1 + from]&1
123        +bytes[(y - 1) * w + x + 1 + from]&1
124        +bytes[(y + 1) * w + x + 1 + from]&1 ;
125 ```
126 
127 BTW My inner verilog programmer has always wondered whether shift and oring each bit into a
128 9 bit value, which we use to index live/dead state from a prepopulated 512 array (GPU constant memory)
129 would allow us to sidestep the wave divergent conditional :)
130 
131 ```java
132 byte[] bytes = grid.byteView();
133 int idx = 0;
134 int to =0;
135 byte[] lookup = new byte[]{};
136 int lookupIdx =
137         bytes[(y - 1) * w + x - 1 + from]&1 <<0
138        |bytes[(y + 0) * w + x - 1 + from]&1 <<1
139        |bytes[(y + 1) * w + x - 1 + from]&1 <<2
140        |bytes[(y - 1) * w + x + 0 + from]&1 <<3
141        |bytes[(y - 0) * w + x + 0 + from]&1 <<4 // current cell added
142        |bytes[(y + 1) * w + x + 0 + from]&1 <<5
143        |bytes[(y + 0) * w + x + 1 + from]&1 <<6
144        |bytes[(y - 1) * w + x + 1 + from]&1 <<7
145        |bytes[(y + 1) * w + x + 1 + from]&1 <<8 ;
146 // conditional removed!
147 
148 bytes[idx + to] = lookup[lookupIdx];
149 ```
150 
151 So the task here is to process kernel code models and perform the appropriate analysis
152 (for tracking the primitive arrays origin) and transformations to map the array code back to buffer get/sets
153 
154 ----------
155 The arrayView trick actually leads us to other possibilities.
156 
157 Let's look at current NBody code.
158 
159 ```java
160   @Reflect
161     static public void nbodyKernel(@RO KernelContext kc, @RW Universe universe, float mass, float delT, float espSqr) {
162         float accx = 0.0f;
163         float accy = 0.0f;
164         float accz = 0.0f;
165         Universe.Body me = universe.body(kc.x);
166 
167         for (int i = 0; i < kc.maxX; i++) {
168             Universe.Body otherBody = universe.body(i);
169             float dx = otherBody.x() - me.x();
170             float dy = otherBody.y() - me.y();
171             float dz = otherBody.z() - me.z();
172             float invDist = (float) (1.0f / Math.sqrt(((dx * dx) + (dy * dy) + (dz * dz) + espSqr)));
173             float s = mass * invDist * invDist * invDist;
174             accx = accx + (s * dx);
175             accy = accy + (s * dy);
176             accz = accz + (s * dz);
177         }
178         accx = accx * delT;
179         accy = accy * delT;
180         accz = accz * delT;
181         me.x(me.x() + (me.vx() * delT) + accx * .5f * delT);
182         me.y(me.y() + (me.vy() * delT) + accy * .5f * delT);
183         me.z(me.z() + (me.vz() * delT) + accz * .5f * delT);
184         me.vx(me.vx() + accx);
185         me.vy(me.vy() + accy);
186         me.vz(me.vz() + accz);
187     }
188 ```
189 
190 Contrast, the above code with the OpenCL code using `float4`
191 
192 ```java
193  __kernel void nbody( __global float4 *xyzPos ,__global float4* xyzVel, float mass, float delT, float espSqr ){
194       float4 acc = (0.0, 0.0,0.0,0.0);
195       float4 myPos = xyzPos[get_global_id(0)];
196       float4 myVel = xyzVel[get_global_id(0)];
197       for (int i = 0; i < get_global_size(0); i++) {
198              float4 delta =  xyzPos[i] - myPos;
199              float invDist =  (float) 1.0/sqrt((float)((delta.x * delta.x) + (delta.y * delta.y) + (delta.z * delta.z) + espSqr));
200              float s = mass * invDist * invDist * invDist;
201              acc= acc + (s * delta);
202       }
203       acc = acc*delT;
204       myPos = myPos + (myVel * delT) + (acc * delT)/2;
205       myVel = myVel + acc;
206       xyzPos[get_global_id(0)] = myPos;
207       xyzVel[get_global_id(0)] = myVel;
208 }
209 ```
210 
211 Thanks to interface mapped segments we can approximate `float4` vector type, In fact the
212 existing `Universe` interface mapped segment actually embeds a `float6` kind of
213 object holding the `x,y,z,vx,vy,vz` values for each `body` (so position and velocity)
214 
215 What if we created generalized `float4` interface mapped views `(x,y,z,w)` as placeholders for true vector types.
216 
217 And modified Universe to hold `pos` and `vel` `float4` arrays.
218 
219 So in Java code we would pass a MemorySegment version of a F32x4Arr.
220 
221 Our java code could then start to approximate the OpenCL code.
222 
223 Except for our lack of operator overloading...
224 
225 ```java
226  void nbody(KernelContext kc,  F32x4Arr xyzPos ,F32x4Arr xyzVel, float mass, float delT, float espSqr ){
227       float4 acc = float4.of(0.0,0.0,0.0,0.0);
228       float4[] xyPosArr = xyzPos.float4View();
229 
230       float4 myPos = xyzPosArr[kc.x];
231       float4 myVel = xyzVelArr[kc.x];
232       for (int i = 0; i < kc.max; i++) {
233              float4 delta =  float4.sub(xyzPosArr[i],myPos); // yucky but ok
234              float invDist =  (float) 1.0/sqrt((float)((delta.x * delta.x) + (delta.y * delta.y) + (delta.z * delta.z) + espSqr));
235              float s = mass * invDist * invDist * invDist;
236              acc= float4.add(acc,(s * delta)); // adding scaler to float4 via overloaded add method
237       }
238       acc = float4.mul(acc*delT);  // scaler * vector
239       myPos = float4.add(float4.add(myPos,float4.mul(myVel * delT)) , float4.mul(acc,delT/2));
240       myVel = float4.add(myVel,acc);
241       xyzPos[kc.x] = myPos;
242       xyzVel[kc.x] = myVel;
243 }
244 ```
245 The code is more compact, it's still weird though. Because we can't overload operators.
246 
247 Well we can sort of.
248 
249 What if we allowed a `floatView()` call on `float4` ;) which yields float value to be used as a proxy
250 for the `float4` we fetched it from...
251 
252 So we would leverage the anonymity of `var`
253 
254 `var myVelF4 = myVel.floatView()` // pst var is really a float
255 
256 From the code model we can track the relationship from float views to the original vector...
257 
258 Any action we perform on the 'float' view will be mapped back to calls on the origin, and performed on the actual origin float4.
259 
260 So
261 `myVelF4 + myVelF4`  -> `float4.add(myVel,myVel)` behind the scenes.
262 
263 Yeah, it's a bit crazy. The code would look something like this.  Perhaps this is a 'bridge too far'.
264 
265 ```java
266 void nbody(KernelContext kc,  F32x4Arr xyzPos ,F32x4Arr xyzVel, float mass, float delT, float espSqr ){
267       float4 acc = float4.of(0.0,0.0,0.0,0.0);
268       var accF4 = acc.floatView(); // var is really float
269       float4[] xyPosArr = xyzPos.float4View();
270 
271       float4 myPos = xyzPosArr[kc.x];
272       float4 myVel = xyzVelArr[kc.x];
273       var myPosF4 = myPos.floatView(); // ;)  var is actually a float tied to myPos
274       var myVelF4 = myVel.floatView(); //... myVel
275       for (int i = 0; i < kc.max; i++) {
276              var bodyF4 = xyzPosArr[i].floatView(); // bodyF4 is a float
277              var  deltaF4 = bodyF4 - myPosF4;  // look ma operator overloading ;)
278              float invDist =  (float) 1.0/sqrt((float)((delta.x * delta.x) + (delta.y * delta.y) + (delta.z * delta.z) + espSqr));
279              float s = mass * invDist * invDist * invDist;
280              accF4+=s * deltaF4; // adding scaler to float4 via overloaded add method
281       }
282       accF4 = accF4*delT;  // scalar * vector
283       myPosF4 = myPosF4 + (myVelF4 * delT) + (accF4 * delT)/2;
284       myVelF4 = myVelF4 + accF4;
285       xyzPos[kc.x] = myPos;
286       xyzVel[kc.x] = myVel;
287 }
288 ```
289 
290