1 ## Leverage CodeReflection to expose array view of buffers in kernels
2 [Back to Index ../](../index.md)
3
4 Here is the canonical HAT example
5
6 ```java
7 import jdk.incubator.code.Reflect;
8
9 @Reflect
10 public class Square {
11 @Reflect
12 public static void kernel(KernelContext kc, S32Arr s32Arr) {
13 s32Arr.array(kc.x, s32Arr.array(kc.x) * s32Arr.array(kc.x));
14 }
15
16 @Reflect
17 public static void compute(ComputeContext cc, S32Arr s32Arr) {
18 cc.dispatchKernel(s32Arr.length(), kc -> kernel(kc, s32Arr));
19 }
20 }
21 ```
22
23 This code in the kernel has always bothered me. One downside of using MemorySegment backed buffers,
24 in HAT is that we have made what in array form, would be simple code, look verbose.
25
26 ```java
27 @Reflect
28 public static void kernel(KernelContext kc, S32Arr s32Arr) {
29 s32Arr.array(kc.x, s32Arr.array(kc.x) * s32Arr.array(kc.x));
30 }
31 ```
32
33 But what if we added a method (`int[] arrayView()`) to `S32Arr` to extract a java int array `view` of simple arrays
34
35 Becomes way more readable.
36 ```java
37 @Reflect
38 public static void kernel(KernelContext kc, S32Arr s32Arr) {
39 int[] arr = s32Arr.arrayView();
40 arr[kc.x] *= arr[kc.x];
41 }
42 ```
43 IMHO This makes code more readable.
44
45 For the GPU this is fine. We can (thanks to CodeReflection) prove that the array is indeed just a view
46 and we just remove all references to `int arr[]` and replace array accessors with get/set accessors
47 on the original S32Arr.
48
49 But what about Java performance?. Won't it suck because we are copying the array in each kernel ;)
50
51 Well, we can use the same trick, we used for the GPU, we take the transformed model (with array
52 references removed) and create bytecode from that code we and run it.
53
54 Historically, we just run the original bytecode in the Java MT/Seq backends, but we don't have to.
55
56 This helps also with game of life
57 ```java
58 public static void lifePerIdx(int idx, @RO Control control, @RW CellGrid cellGrid) {
59 int w = cellGrid.width();
60 int h = cellGrid.height();
61 int from = control.from();
62 int to = control.to();
63 int x = idx % w;
64 int y = idx / w;
65 byte cell = cellGrid.cell(idx + from);
66 if (x > 0 && x < (w - 1) && y > 0 && y < (h - 1)) { // passports please
67 int count =
68 val(cellGrid, from, w, x - 1, y - 1)
69 + val(cellGrid, from, w, x - 1, y + 0)
70 + val(cellGrid, from, w, x - 1, y + 1)
71 + val(cellGrid, from, w, x + 0, y - 1)
72 + val(cellGrid, from, w, x + 0, y + 1)
73 + val(cellGrid, from, w, x + 1, y + 0)
74 + val(cellGrid, from, w, x + 1, y - 1)
75 + val(cellGrid, from, w, x + 1, y + 1);
76 cell = ((count == 3) || ((count == 2) && (cell == ALIVE))) ? ALIVE : DEAD;// B3/S23.
77 }
78 cellGrid.cell(idx + to, cell);
79 }
80 ```
81
82 This code uses a helper function `val(grid, offset, w, dx, dy)` to extract the neighbours
83 ```java
84 int count =
85 val(cellGrid, from, w, x - 1, y - 1)
86 + val(cellGrid, from, w, x - 1, y + 0)
87 + val(cellGrid, from, w, x - 1, y + 1)
88 + val(cellGrid, from, w, x + 0, y - 1)
89 + val(cellGrid, from, w, x + 0, y + 1)
90 + val(cellGrid, from, w, x + 1, y + 0)
91 + val(cellGrid, from, w, x + 1, y - 1)
92 + val(cellGrid, from, w, x + 1, y + 1);
93 ```
94
95 Val is a bit verbose
96
97 ```java
98 @Reflect
99 public static int val(@RO CellGrid grid, int from, int w, int x, int y) {
100 return grid.cell(((long) y * w) + x + from) & 1;
101 }
102 ```
103
104 ```java
105 @Reflect
106 public static int val(@RO CellGrid grid, int from, int w, int x, int y) {
107 byte[] bytes = grid.byteView(); // bit view would be nice ;)
108 return bytes[ y * w + x + from] & 1;
109 }
110 ```
111
112 We could now dispense with `val()` and just write
113
114 ```java
115 byte[] bytes = grid.byteView();
116 int count =
117 bytes[(y - 1) * w + x - 1 + from]&1
118 +bytes[(y + 0) * w + x - 1 + from]&1
119 +bytes[(y + 1) * w + x - 1 + from]&1
120 +bytes[(y - 1) * w + x + 0 + from]&1
121 +bytes[(y + 1) * w + x + 0 + from]&1
122 +bytes[(y + 0) * w + x + 1 + from]&1
123 +bytes[(y - 1) * w + x + 1 + from]&1
124 +bytes[(y + 1) * w + x + 1 + from]&1 ;
125 ```
126
127 BTW My inner verilog programmer has always wondered whether shift and oring each bit into a
128 9 bit value, which we use to index live/dead state from a prepopulated 512 array (GPU constant memory)
129 would allow us to sidestep the wave divergent conditional :)
130
131 ```java
132 byte[] bytes = grid.byteView();
133 int idx = 0;
134 int to =0;
135 byte[] lookup = new byte[]{};
136 int lookupIdx =
137 bytes[(y - 1) * w + x - 1 + from]&1 <<0
138 |bytes[(y + 0) * w + x - 1 + from]&1 <<1
139 |bytes[(y + 1) * w + x - 1 + from]&1 <<2
140 |bytes[(y - 1) * w + x + 0 + from]&1 <<3
141 |bytes[(y - 0) * w + x + 0 + from]&1 <<4 // current cell added
142 |bytes[(y + 1) * w + x + 0 + from]&1 <<5
143 |bytes[(y + 0) * w + x + 1 + from]&1 <<6
144 |bytes[(y - 1) * w + x + 1 + from]&1 <<7
145 |bytes[(y + 1) * w + x + 1 + from]&1 <<8 ;
146 // conditional removed!
147
148 bytes[idx + to] = lookup[lookupIdx];
149 ```
150
151 So the task here is to process kernel code models and perform the appropriate analysis
152 (for tracking the primitive arrays origin) and transformations to map the array code back to buffer get/sets
153
154 ----------
155 The arrayView trick actually leads us to other possibilities.
156
157 Let's look at current NBody code.
158
159 ```java
160 @Reflect
161 static public void nbodyKernel(@RO KernelContext kc, @RW Universe universe, float mass, float delT, float espSqr) {
162 float accx = 0.0f;
163 float accy = 0.0f;
164 float accz = 0.0f;
165 Universe.Body me = universe.body(kc.x);
166
167 for (int i = 0; i < kc.maxX; i++) {
168 Universe.Body otherBody = universe.body(i);
169 float dx = otherBody.x() - me.x();
170 float dy = otherBody.y() - me.y();
171 float dz = otherBody.z() - me.z();
172 float invDist = (float) (1.0f / Math.sqrt(((dx * dx) + (dy * dy) + (dz * dz) + espSqr)));
173 float s = mass * invDist * invDist * invDist;
174 accx = accx + (s * dx);
175 accy = accy + (s * dy);
176 accz = accz + (s * dz);
177 }
178 accx = accx * delT;
179 accy = accy * delT;
180 accz = accz * delT;
181 me.x(me.x() + (me.vx() * delT) + accx * .5f * delT);
182 me.y(me.y() + (me.vy() * delT) + accy * .5f * delT);
183 me.z(me.z() + (me.vz() * delT) + accz * .5f * delT);
184 me.vx(me.vx() + accx);
185 me.vy(me.vy() + accy);
186 me.vz(me.vz() + accz);
187 }
188 ```
189
190 Contrast, the above code with the OpenCL code using `float4`
191
192 ```java
193 __kernel void nbody( __global float4 *xyzPos ,__global float4* xyzVel, float mass, float delT, float espSqr ){
194 float4 acc = (0.0, 0.0,0.0,0.0);
195 float4 myPos = xyzPos[get_global_id(0)];
196 float4 myVel = xyzVel[get_global_id(0)];
197 for (int i = 0; i < get_global_size(0); i++) {
198 float4 delta = xyzPos[i] - myPos;
199 float invDist = (float) 1.0/sqrt((float)((delta.x * delta.x) + (delta.y * delta.y) + (delta.z * delta.z) + espSqr));
200 float s = mass * invDist * invDist * invDist;
201 acc= acc + (s * delta);
202 }
203 acc = acc*delT;
204 myPos = myPos + (myVel * delT) + (acc * delT)/2;
205 myVel = myVel + acc;
206 xyzPos[get_global_id(0)] = myPos;
207 xyzVel[get_global_id(0)] = myVel;
208 }
209 ```
210
211 Thanks to interface mapped segments we can approximate `float4` vector type, In fact the
212 existing `Universe` interface mapped segment actually embeds a `float6` kind of
213 object holding the `x,y,z,vx,vy,vz` values for each `body` (so position and velocity)
214
215 What if we created generalized `float4` interface mapped views `(x,y,z,w)` as placeholders for true vector types.
216
217 And modified Universe to hold `pos` and `vel` `float4` arrays.
218
219 So in Java code we would pass a MemorySegment version of a F32x4Arr.
220
221 Our java code could then start to approximate the OpenCL code.
222
223 Except for our lack of operator overloading...
224
225 ```java
226 void nbody(KernelContext kc, F32x4Arr xyzPos ,F32x4Arr xyzVel, float mass, float delT, float espSqr ){
227 float4 acc = float4.of(0.0,0.0,0.0,0.0);
228 float4[] xyPosArr = xyzPos.float4View();
229
230 float4 myPos = xyzPosArr[kc.x];
231 float4 myVel = xyzVelArr[kc.x];
232 for (int i = 0; i < kc.max; i++) {
233 float4 delta = float4.sub(xyzPosArr[i],myPos); // yucky but ok
234 float invDist = (float) 1.0/sqrt((float)((delta.x * delta.x) + (delta.y * delta.y) + (delta.z * delta.z) + espSqr));
235 float s = mass * invDist * invDist * invDist;
236 acc= float4.add(acc,(s * delta)); // adding scaler to float4 via overloaded add method
237 }
238 acc = float4.mul(acc*delT); // scaler * vector
239 myPos = float4.add(float4.add(myPos,float4.mul(myVel * delT)) , float4.mul(acc,delT/2));
240 myVel = float4.add(myVel,acc);
241 xyzPos[kc.x] = myPos;
242 xyzVel[kc.x] = myVel;
243 }
244 ```
245 The code is more compact, it's still weird though. Because we can't overload operators.
246
247 Well we can sort of.
248
249 What if we allowed a `floatView()` call on `float4` ;) which yields float value to be used as a proxy
250 for the `float4` we fetched it from...
251
252 So we would leverage the anonymity of `var`
253
254 `var myVelF4 = myVel.floatView()` // pst var is really a float
255
256 From the code model we can track the relationship from float views to the original vector...
257
258 Any action we perform on the 'float' view will be mapped back to calls on the origin, and performed on the actual origin float4.
259
260 So
261 `myVelF4 + myVelF4` -> `float4.add(myVel,myVel)` behind the scenes.
262
263 Yeah, it's a bit crazy. The code would look something like this. Perhaps this is a 'bridge too far'.
264
265 ```java
266 void nbody(KernelContext kc, F32x4Arr xyzPos ,F32x4Arr xyzVel, float mass, float delT, float espSqr ){
267 float4 acc = float4.of(0.0,0.0,0.0,0.0);
268 var accF4 = acc.floatView(); // var is really float
269 float4[] xyPosArr = xyzPos.float4View();
270
271 float4 myPos = xyzPosArr[kc.x];
272 float4 myVel = xyzVelArr[kc.x];
273 var myPosF4 = myPos.floatView(); // ;) var is actually a float tied to myPos
274 var myVelF4 = myVel.floatView(); //... myVel
275 for (int i = 0; i < kc.max; i++) {
276 var bodyF4 = xyzPosArr[i].floatView(); // bodyF4 is a float
277 var deltaF4 = bodyF4 - myPosF4; // look ma operator overloading ;)
278 float invDist = (float) 1.0/sqrt((float)((delta.x * delta.x) + (delta.y * delta.y) + (delta.z * delta.z) + espSqr));
279 float s = mass * invDist * invDist * invDist;
280 accF4+=s * deltaF4; // adding scaler to float4 via overloaded add method
281 }
282 accF4 = accF4*delT; // scalar * vector
283 myPosF4 = myPosF4 + (myVelF4 * delT) + (accF4 * delT)/2;
284 myVelF4 = myVelF4 + accF4;
285 xyzPos[kc.x] = myPos;
286 xyzVel[kc.x] = myVel;
287 }
288 ```
289
290