1 ## Leverage CodeReflection to expose array view of buffers in kernels
2
3 ----
4
5 * [Contents](hat-00.md)
6 * House Keeping
7 * [Project Layout](hat-01-01-project-layout.md)
8 * [Building Babylon](hat-01-02-building-babylon.md)
9 * [Building HAT](hat-01-03-building-hat.md)
10 * [Enabling the CUDA Backend](hat-01-05-building-hat-for-cuda.md)
11 * Programming Model
12 * [Programming Model](hat-03-programming-model.md)
13 * Interface Mapping
14 * [Interface Mapping Overview](hat-04-01-interface-mapping.md)
15 * [Cascade Interface Mapping](hat-04-02-cascade-interface-mapping.md)
16 * Implementation Detail
17 * [Walkthrough Of Accelerator.compute()](hat-accelerator-compute.md)
18 * [How we minimize buffer transfers](hat-minimizing-buffer-transfers.md)
19
20 ----
21
22 Here is the canonical HAT example
23
24 ```java
25 import jdk.incubator.code.CodeReflection;
26
27 @CodeReflection
28 public class Square {
29 @CodeReflection
30 public static void kernel(KernelContext kc, S32Arr s32Arr) {
31 s32Arr.array(kc.x, s32Arr.array(kc.x) * s32Arr.array(kc.x));
32 }
33
34 @CodeReflection
35 public static void compute(ComputeContext cc, S32Arr s32Arr) {
36 cc.dispatchKernel(s32Arr.length(), kc -> kernel(kc, s32Arr));
37 }
38 }
39 ```
40
41 This code in the kernel has always bothered me. One downside of using MemorySegment backed buffers,
42 in HAT is that we have made what in array form, would be simple code, look verbose.
43
44 ```java
45 @CodeReflection
46 public static void kernel(KernelContext kc, S32Arr s32Arr) {
47 s32Arr.array(kc.x, s32Arr.array(kc.x) * s32Arr.array(kc.x));
48 }
49 ```
50
51 But what if we added a method (`int[] arrayView()`) to `S32Arr` to extract a java int array `view` of simple arrays
52
53 Becomes way more readable.
54 ```java
55 @CodeReflection
56 public static void kernel(KernelContext kc, S32Arr s32Arr) {
57 int[] arr = s32Arr.arrayView();
58 arr[kc.x] *= arr[kc.x];
59 }
60 ```
61 IMHO This makes code more readable.
62
63 For the GPU this is fine. We can (thanks to CodeReflection) prove that the array is indeed just a view
64 and we just remove all references to `int arr[]` and replace array accessors with get/set accessors
65 on the original S32Arr.
66
67 But what about Java performance?. Won't it suck because we are copying the array in each kernel ;)
68
69 Well, we can use the same trick, we used for the GPU, we take the transformed model (with array
70 references removed) and create bytecode from that code we and run it.
71
72 Historically, we just run the original bytecode in the Java MT/Seq backends, but we don't have to.
73
74 This helps also with game of life
75 ```java
76 public static void lifePerIdx(int idx, @RO Control control, @RW CellGrid cellGrid) {
77 int w = cellGrid.width();
78 int h = cellGrid.height();
79 int from = control.from();
80 int to = control.to();
81 int x = idx % w;
82 int y = idx / w;
83 byte cell = cellGrid.cell(idx + from);
84 if (x > 0 && x < (w - 1) && y > 0 && y < (h - 1)) { // passports please
85 int count =
86 val(cellGrid, from, w, x - 1, y - 1)
87 + val(cellGrid, from, w, x - 1, y + 0)
88 + val(cellGrid, from, w, x - 1, y + 1)
89 + val(cellGrid, from, w, x + 0, y - 1)
90 + val(cellGrid, from, w, x + 0, y + 1)
91 + val(cellGrid, from, w, x + 1, y + 0)
92 + val(cellGrid, from, w, x + 1, y - 1)
93 + val(cellGrid, from, w, x + 1, y + 1);
94 cell = ((count == 3) || ((count == 2) && (cell == ALIVE))) ? ALIVE : DEAD;// B3/S23.
95 }
96 cellGrid.cell(idx + to, cell);
97 }
98 ```
99
100 This code uses a helper function `val(grid, offset, w, dx, dy)` to extract the neighbours
101 ```java
102 int count =
103 val(cellGrid, from, w, x - 1, y - 1)
104 + val(cellGrid, from, w, x - 1, y + 0)
105 + val(cellGrid, from, w, x - 1, y + 1)
106 + val(cellGrid, from, w, x + 0, y - 1)
107 + val(cellGrid, from, w, x + 0, y + 1)
108 + val(cellGrid, from, w, x + 1, y + 0)
109 + val(cellGrid, from, w, x + 1, y - 1)
110 + val(cellGrid, from, w, x + 1, y + 1);
111 ```
112
113 Val is a bit verbose
114
115 ```java
116 @CodeReflection
117 public static int val(@RO CellGrid grid, int from, int w, int x, int y) {
118 return grid.cell(((long) y * w) + x + from) & 1;
119 }
120 ```
121
122 ```java
123 @CodeReflection
124 public static int val(@RO CellGrid grid, int from, int w, int x, int y) {
125 byte[] bytes = grid.byteView(); // bit view would be nice ;)
126 return bytes[ y * w + x + from] & 1;
127 }
128 ```
129
130 We could now dispense with `val()` and just write
131
132 ```java
133 byte[] bytes = grid.byteView();
134 int count =
135 bytes[(y - 1) * w + x - 1 + from]&1
136 +bytes[(y + 0) * w + x - 1 + from]&1
137 +bytes[(y + 1) * w + x - 1 + from]&1
138 +bytes[(y - 1) * w + x + 0 + from]&1
139 +bytes[(y + 1) * w + x + 0 + from]&1
140 +bytes[(y + 0) * w + x + 1 + from]&1
141 +bytes[(y - 1) * w + x + 1 + from]&1
142 +bytes[(y + 1) * w + x + 1 + from]&1 ;
143 ```
144
145 BTW My inner verilog programmer has always wondered whether shift and oring each bit into a
146 9 bit value, which we use to index live/dead state from a prepopulated 512 array (GPU constant memory)
147 would allow us to sidestep the wave divergent conditional :)
148
149 ```java
150 byte[] bytes = grid.byteView();
151 int idx = 0;
152 int to =0;
153 byte[] lookup = new byte[]{};
154 int lookupIdx =
155 bytes[(y - 1) * w + x - 1 + from]&1 <<0
156 |bytes[(y + 0) * w + x - 1 + from]&1 <<1
157 |bytes[(y + 1) * w + x - 1 + from]&1 <<2
158 |bytes[(y - 1) * w + x + 0 + from]&1 <<3
159 |bytes[(y - 0) * w + x + 0 + from]&1 <<4 // current cell added
160 |bytes[(y + 1) * w + x + 0 + from]&1 <<5
161 |bytes[(y + 0) * w + x + 1 + from]&1 <<6
162 |bytes[(y - 1) * w + x + 1 + from]&1 <<7
163 |bytes[(y + 1) * w + x + 1 + from]&1 <<8 ;
164 // conditional removed!
165
166 bytes[idx + to] = lookup[lookupIdx];
167 ```
168
169 So the task here is to process kernel code models and perform the appropriate analysis
170 (for tracking the primitive arrays origin) and transformations to map the array code back to buffer get/sets
171
172 ----------
173 The arrayView trick actually leads us to other possibilities.
174
175 Let's look at current NBody code.
176
177 ```java
178 @CodeReflection
179 static public void nbodyKernel(@RO KernelContext kc, @RW Universe universe, float mass, float delT, float espSqr) {
180 float accx = 0.0f;
181 float accy = 0.0f;
182 float accz = 0.0f;
183 Universe.Body me = universe.body(kc.x);
184
185 for (int i = 0; i < kc.maxX; i++) {
186 Universe.Body otherBody = universe.body(i);
187 float dx = otherBody.x() - me.x();
188 float dy = otherBody.y() - me.y();
189 float dz = otherBody.z() - me.z();
190 float invDist = (float) (1.0f / Math.sqrt(((dx * dx) + (dy * dy) + (dz * dz) + espSqr)));
191 float s = mass * invDist * invDist * invDist;
192 accx = accx + (s * dx);
193 accy = accy + (s * dy);
194 accz = accz + (s * dz);
195 }
196 accx = accx * delT;
197 accy = accy * delT;
198 accz = accz * delT;
199 me.x(me.x() + (me.vx() * delT) + accx * .5f * delT);
200 me.y(me.y() + (me.vy() * delT) + accy * .5f * delT);
201 me.z(me.z() + (me.vz() * delT) + accz * .5f * delT);
202 me.vx(me.vx() + accx);
203 me.vy(me.vy() + accy);
204 me.vz(me.vz() + accz);
205 }
206 ```
207
208 Contrast, the above code with the OpenCL code using `float4`
209
210 ```java
211 __kernel void nbody( __global float4 *xyzPos ,__global float4* xyzVel, float mass, float delT, float espSqr ){
212 float4 acc = (0.0, 0.0,0.0,0.0);
213 float4 myPos = xyzPos[get_global_id(0)];
214 float4 myVel = xyzVel[get_global_id(0)];
215 for (int i = 0; i < get_global_size(0); i++) {
216 float4 delta = xyzPos[i] - myPos;
217 float invDist = (float) 1.0/sqrt((float)((delta.x * delta.x) + (delta.y * delta.y) + (delta.z * delta.z) + espSqr));
218 float s = mass * invDist * invDist * invDist;
219 acc= acc + (s * delta);
220 }
221 acc = acc*delT;
222 myPos = myPos + (myVel * delT) + (acc * delT)/2;
223 myVel = myVel + acc;
224 xyzPos[get_global_id(0)] = myPos;
225 xyzVel[get_global_id(0)] = myVel;
226 }
227 ```
228
229 Thanks to interface mapped segments we can approximate `float4` vector type, In fact the
230 existing `Universe` interface mapped segment actually embeds a `float6` kind of
231 object holding the `x,y,z,vx,vy,vz` values for each `body` (so position and velocity)
232
233 What if we created generalized `float4` interface mapped views `(x,y,z,w)` as placeholders for true vector types.
234
235 And modified Universe to hold `pos` and `vel` `float4` arrays.
236
237 So in Java code we would pass a MemorySegment version of a F32x4Arr.
238
239 Our java code could then start to approximate the OpenCL code.
240
241 Except for our lack of operator overloading...
242
243 ```java
244 void nbody(KernelContext kc, F32x4Arr xyzPos ,F32x4Arr xyzVel, float mass, float delT, float espSqr ){
245 float4 acc = float4.of(0.0,0.0,0.0,0.0);
246 float4[] xyPosArr = xyzPos.float4View();
247
248 float4 myPos = xyzPosArr[kc.x];
249 float4 myVel = xyzVelArr[kc.x];
250 for (int i = 0; i < kc.max; i++) {
251 float4 delta = float4.sub(xyzPosArr[i],myPos); // yucky but ok
252 float invDist = (float) 1.0/sqrt((float)((delta.x * delta.x) + (delta.y * delta.y) + (delta.z * delta.z) + espSqr));
253 float s = mass * invDist * invDist * invDist;
254 acc= float4.add(acc,(s * delta)); // adding scaler to float4 via overloaded add method
255 }
256 acc = float4.mul(acc*delT); // scaler * vector
257 myPos = float4.add(float4.add(myPos,float4.mul(myVel * delT)) , float4.mul(acc,delT/2));
258 myVel = float4.add(myVel,acc);
259 xyzPos[kc.x] = myPos;
260 xyzVel[kc.x] = myVel;
261 }
262 ```
263 The code is more compact, it's still weird though. Because we can't overload operators.
264
265 Well we can sort of.
266
267 What if we allowed a `floatView()` call on `float4` ;) which yields float value to be used as a proxy
268 for the `float4` we fetched it from...
269
270 So we would leverage the anonymity of `var`
271
272 `var myVelF4 = myVel.floatView()` // pst var is really a float
273
274 From the code model we can track the relationship from float views to the original vector...
275
276 Any action we perform on the 'float' view will be mapped back to calls on the origin, and performed on the actual origin float4.
277
278 So
279 `myVelF4 + myVelF4` -> `float4.add(myVel,myVel)` behind the scenes.
280
281 Yeah, it's a bit crazy. The code would look something like this. Perhaps this is a 'bridge too far'.
282
283 ```java
284 void nbody(KernelContext kc, F32x4Arr xyzPos ,F32x4Arr xyzVel, float mass, float delT, float espSqr ){
285 float4 acc = float4.of(0.0,0.0,0.0,0.0);
286 var accF4 = acc.floatView(); // var is really float
287 float4[] xyPosArr = xyzPos.float4View();
288
289 float4 myPos = xyzPosArr[kc.x];
290 float4 myVel = xyzVelArr[kc.x];
291 var myPosF4 = myPos.floatView(); // ;) var is actually a float tied to myPos
292 var myVelF4 = myVel.floatView(); //... myVel
293 for (int i = 0; i < kc.max; i++) {
294 var bodyF4 = xyzPosArr[i].floatView(); // bodyF4 is a float
295 var deltaF4 = bodyF4 - myPosF4; // look ma operator overloading ;)
296 float invDist = (float) 1.0/sqrt((float)((delta.x * delta.x) + (delta.y * delta.y) + (delta.z * delta.z) + espSqr));
297 float s = mass * invDist * invDist * invDist;
298 accF4+=s * deltaF4; // adding scaler to float4 via overloaded add method
299 }
300 accF4 = accF4*delT; // scalar * vector
301 myPosF4 = myPosF4 + (myVelF4 * delT) + (accF4 * delT)/2;
302 myVelF4 = myVelF4 + accF4;
303 xyzPos[kc.x] = myPos;
304 xyzVel[kc.x] = myVel;
305 }
306 ```
307
308