1 ## Leverage CodeReflection to expose array view of buffers in kernels
2
3 ----
4 * [Contents](hat-00.md)
5 * Build Babylon and HAT
6 * [Quick Install](hat-01-quick-install.md)
7 * [Building Babylon with jtreg](hat-01-02-building-babylon.md)
8 * [Building HAT with jtreg](hat-01-03-building-hat.md)
9 * [Enabling the NVIDIA CUDA Backend](hat-01-05-building-hat-for-cuda.md)
10 * [Testing Framework](hat-02-testing-framework.md)
11 * [Running Examples](hat-03-examples.md)
12 * [HAT Programming Model](hat-03-programming-model.md)
13 * Interface Mapping
14 * [Interface Mapping Overview](hat-04-01-interface-mapping.md)
15 * [Cascade Interface Mapping](hat-04-02-cascade-interface-mapping.md)
16 * Development
17 * [Project Layout](hat-01-01-project-layout.md)
18 * Implementation Details
19 * [Walkthrough Of Accelerator.compute()](hat-accelerator-compute.md)
20 * [How we minimize buffer transfers](hat-minimizing-buffer-transfers.md)
21 * [Running HAT with Docker on NVIDIA GPUs](hat-07-docker-build-nvidia.md)
22 ---
23
24 Here is the canonical HAT example
25
26 ```java
27 import jdk.incubator.code.Reflect;
28
29 @Reflect
30 public class Square {
31 @Reflect
32 public static void kernel(KernelContext kc, S32Arr s32Arr) {
33 s32Arr.array(kc.x, s32Arr.array(kc.x) * s32Arr.array(kc.x));
34 }
35
36 @Reflect
37 public static void compute(ComputeContext cc, S32Arr s32Arr) {
38 cc.dispatchKernel(s32Arr.length(), kc -> kernel(kc, s32Arr));
39 }
40 }
41 ```
42
43 This code in the kernel has always bothered me. One downside of using MemorySegment backed buffers,
44 in HAT is that we have made what in array form, would be simple code, look verbose.
45
46 ```java
47 @Reflect
48 public static void kernel(KernelContext kc, S32Arr s32Arr) {
49 s32Arr.array(kc.x, s32Arr.array(kc.x) * s32Arr.array(kc.x));
50 }
51 ```
52
53 But what if we added a method (`int[] arrayView()`) to `S32Arr` to extract a java int array `view` of simple arrays
54
55 Becomes way more readable.
56 ```java
57 @Reflect
58 public static void kernel(KernelContext kc, S32Arr s32Arr) {
59 int[] arr = s32Arr.arrayView();
60 arr[kc.x] *= arr[kc.x];
61 }
62 ```
63 IMHO This makes code more readable.
64
65 For the GPU this is fine. We can (thanks to CodeReflection) prove that the array is indeed just a view
66 and we just remove all references to `int arr[]` and replace array accessors with get/set accessors
67 on the original S32Arr.
68
69 But what about Java performance?. Won't it suck because we are copying the array in each kernel ;)
70
71 Well, we can use the same trick, we used for the GPU, we take the transformed model (with array
72 references removed) and create bytecode from that code we and run it.
73
74 Historically, we just run the original bytecode in the Java MT/Seq backends, but we don't have to.
75
76 This helps also with game of life
77 ```java
78 public static void lifePerIdx(int idx, @RO Control control, @RW CellGrid cellGrid) {
79 int w = cellGrid.width();
80 int h = cellGrid.height();
81 int from = control.from();
82 int to = control.to();
83 int x = idx % w;
84 int y = idx / w;
85 byte cell = cellGrid.cell(idx + from);
86 if (x > 0 && x < (w - 1) && y > 0 && y < (h - 1)) { // passports please
87 int count =
88 val(cellGrid, from, w, x - 1, y - 1)
89 + val(cellGrid, from, w, x - 1, y + 0)
90 + val(cellGrid, from, w, x - 1, y + 1)
91 + val(cellGrid, from, w, x + 0, y - 1)
92 + val(cellGrid, from, w, x + 0, y + 1)
93 + val(cellGrid, from, w, x + 1, y + 0)
94 + val(cellGrid, from, w, x + 1, y - 1)
95 + val(cellGrid, from, w, x + 1, y + 1);
96 cell = ((count == 3) || ((count == 2) && (cell == ALIVE))) ? ALIVE : DEAD;// B3/S23.
97 }
98 cellGrid.cell(idx + to, cell);
99 }
100 ```
101
102 This code uses a helper function `val(grid, offset, w, dx, dy)` to extract the neighbours
103 ```java
104 int count =
105 val(cellGrid, from, w, x - 1, y - 1)
106 + val(cellGrid, from, w, x - 1, y + 0)
107 + val(cellGrid, from, w, x - 1, y + 1)
108 + val(cellGrid, from, w, x + 0, y - 1)
109 + val(cellGrid, from, w, x + 0, y + 1)
110 + val(cellGrid, from, w, x + 1, y + 0)
111 + val(cellGrid, from, w, x + 1, y - 1)
112 + val(cellGrid, from, w, x + 1, y + 1);
113 ```
114
115 Val is a bit verbose
116
117 ```java
118 @Reflect
119 public static int val(@RO CellGrid grid, int from, int w, int x, int y) {
120 return grid.cell(((long) y * w) + x + from) & 1;
121 }
122 ```
123
124 ```java
125 @Reflect
126 public static int val(@RO CellGrid grid, int from, int w, int x, int y) {
127 byte[] bytes = grid.byteView(); // bit view would be nice ;)
128 return bytes[ y * w + x + from] & 1;
129 }
130 ```
131
132 We could now dispense with `val()` and just write
133
134 ```java
135 byte[] bytes = grid.byteView();
136 int count =
137 bytes[(y - 1) * w + x - 1 + from]&1
138 +bytes[(y + 0) * w + x - 1 + from]&1
139 +bytes[(y + 1) * w + x - 1 + from]&1
140 +bytes[(y - 1) * w + x + 0 + from]&1
141 +bytes[(y + 1) * w + x + 0 + from]&1
142 +bytes[(y + 0) * w + x + 1 + from]&1
143 +bytes[(y - 1) * w + x + 1 + from]&1
144 +bytes[(y + 1) * w + x + 1 + from]&1 ;
145 ```
146
147 BTW My inner verilog programmer has always wondered whether shift and oring each bit into a
148 9 bit value, which we use to index live/dead state from a prepopulated 512 array (GPU constant memory)
149 would allow us to sidestep the wave divergent conditional :)
150
151 ```java
152 byte[] bytes = grid.byteView();
153 int idx = 0;
154 int to =0;
155 byte[] lookup = new byte[]{};
156 int lookupIdx =
157 bytes[(y - 1) * w + x - 1 + from]&1 <<0
158 |bytes[(y + 0) * w + x - 1 + from]&1 <<1
159 |bytes[(y + 1) * w + x - 1 + from]&1 <<2
160 |bytes[(y - 1) * w + x + 0 + from]&1 <<3
161 |bytes[(y - 0) * w + x + 0 + from]&1 <<4 // current cell added
162 |bytes[(y + 1) * w + x + 0 + from]&1 <<5
163 |bytes[(y + 0) * w + x + 1 + from]&1 <<6
164 |bytes[(y - 1) * w + x + 1 + from]&1 <<7
165 |bytes[(y + 1) * w + x + 1 + from]&1 <<8 ;
166 // conditional removed!
167
168 bytes[idx + to] = lookup[lookupIdx];
169 ```
170
171 So the task here is to process kernel code models and perform the appropriate analysis
172 (for tracking the primitive arrays origin) and transformations to map the array code back to buffer get/sets
173
174 ----------
175 The arrayView trick actually leads us to other possibilities.
176
177 Let's look at current NBody code.
178
179 ```java
180 @Reflect
181 static public void nbodyKernel(@RO KernelContext kc, @RW Universe universe, float mass, float delT, float espSqr) {
182 float accx = 0.0f;
183 float accy = 0.0f;
184 float accz = 0.0f;
185 Universe.Body me = universe.body(kc.x);
186
187 for (int i = 0; i < kc.maxX; i++) {
188 Universe.Body otherBody = universe.body(i);
189 float dx = otherBody.x() - me.x();
190 float dy = otherBody.y() - me.y();
191 float dz = otherBody.z() - me.z();
192 float invDist = (float) (1.0f / Math.sqrt(((dx * dx) + (dy * dy) + (dz * dz) + espSqr)));
193 float s = mass * invDist * invDist * invDist;
194 accx = accx + (s * dx);
195 accy = accy + (s * dy);
196 accz = accz + (s * dz);
197 }
198 accx = accx * delT;
199 accy = accy * delT;
200 accz = accz * delT;
201 me.x(me.x() + (me.vx() * delT) + accx * .5f * delT);
202 me.y(me.y() + (me.vy() * delT) + accy * .5f * delT);
203 me.z(me.z() + (me.vz() * delT) + accz * .5f * delT);
204 me.vx(me.vx() + accx);
205 me.vy(me.vy() + accy);
206 me.vz(me.vz() + accz);
207 }
208 ```
209
210 Contrast, the above code with the OpenCL code using `float4`
211
212 ```java
213 __kernel void nbody( __global float4 *xyzPos ,__global float4* xyzVel, float mass, float delT, float espSqr ){
214 float4 acc = (0.0, 0.0,0.0,0.0);
215 float4 myPos = xyzPos[get_global_id(0)];
216 float4 myVel = xyzVel[get_global_id(0)];
217 for (int i = 0; i < get_global_size(0); i++) {
218 float4 delta = xyzPos[i] - myPos;
219 float invDist = (float) 1.0/sqrt((float)((delta.x * delta.x) + (delta.y * delta.y) + (delta.z * delta.z) + espSqr));
220 float s = mass * invDist * invDist * invDist;
221 acc= acc + (s * delta);
222 }
223 acc = acc*delT;
224 myPos = myPos + (myVel * delT) + (acc * delT)/2;
225 myVel = myVel + acc;
226 xyzPos[get_global_id(0)] = myPos;
227 xyzVel[get_global_id(0)] = myVel;
228 }
229 ```
230
231 Thanks to interface mapped segments we can approximate `float4` vector type, In fact the
232 existing `Universe` interface mapped segment actually embeds a `float6` kind of
233 object holding the `x,y,z,vx,vy,vz` values for each `body` (so position and velocity)
234
235 What if we created generalized `float4` interface mapped views `(x,y,z,w)` as placeholders for true vector types.
236
237 And modified Universe to hold `pos` and `vel` `float4` arrays.
238
239 So in Java code we would pass a MemorySegment version of a F32x4Arr.
240
241 Our java code could then start to approximate the OpenCL code.
242
243 Except for our lack of operator overloading...
244
245 ```java
246 void nbody(KernelContext kc, F32x4Arr xyzPos ,F32x4Arr xyzVel, float mass, float delT, float espSqr ){
247 float4 acc = float4.of(0.0,0.0,0.0,0.0);
248 float4[] xyPosArr = xyzPos.float4View();
249
250 float4 myPos = xyzPosArr[kc.x];
251 float4 myVel = xyzVelArr[kc.x];
252 for (int i = 0; i < kc.max; i++) {
253 float4 delta = float4.sub(xyzPosArr[i],myPos); // yucky but ok
254 float invDist = (float) 1.0/sqrt((float)((delta.x * delta.x) + (delta.y * delta.y) + (delta.z * delta.z) + espSqr));
255 float s = mass * invDist * invDist * invDist;
256 acc= float4.add(acc,(s * delta)); // adding scaler to float4 via overloaded add method
257 }
258 acc = float4.mul(acc*delT); // scaler * vector
259 myPos = float4.add(float4.add(myPos,float4.mul(myVel * delT)) , float4.mul(acc,delT/2));
260 myVel = float4.add(myVel,acc);
261 xyzPos[kc.x] = myPos;
262 xyzVel[kc.x] = myVel;
263 }
264 ```
265 The code is more compact, it's still weird though. Because we can't overload operators.
266
267 Well we can sort of.
268
269 What if we allowed a `floatView()` call on `float4` ;) which yields float value to be used as a proxy
270 for the `float4` we fetched it from...
271
272 So we would leverage the anonymity of `var`
273
274 `var myVelF4 = myVel.floatView()` // pst var is really a float
275
276 From the code model we can track the relationship from float views to the original vector...
277
278 Any action we perform on the 'float' view will be mapped back to calls on the origin, and performed on the actual origin float4.
279
280 So
281 `myVelF4 + myVelF4` -> `float4.add(myVel,myVel)` behind the scenes.
282
283 Yeah, it's a bit crazy. The code would look something like this. Perhaps this is a 'bridge too far'.
284
285 ```java
286 void nbody(KernelContext kc, F32x4Arr xyzPos ,F32x4Arr xyzVel, float mass, float delT, float espSqr ){
287 float4 acc = float4.of(0.0,0.0,0.0,0.0);
288 var accF4 = acc.floatView(); // var is really float
289 float4[] xyPosArr = xyzPos.float4View();
290
291 float4 myPos = xyzPosArr[kc.x];
292 float4 myVel = xyzVelArr[kc.x];
293 var myPosF4 = myPos.floatView(); // ;) var is actually a float tied to myPos
294 var myVelF4 = myVel.floatView(); //... myVel
295 for (int i = 0; i < kc.max; i++) {
296 var bodyF4 = xyzPosArr[i].floatView(); // bodyF4 is a float
297 var deltaF4 = bodyF4 - myPosF4; // look ma operator overloading ;)
298 float invDist = (float) 1.0/sqrt((float)((delta.x * delta.x) + (delta.y * delta.y) + (delta.z * delta.z) + espSqr));
299 float s = mass * invDist * invDist * invDist;
300 accF4+=s * deltaF4; // adding scaler to float4 via overloaded add method
301 }
302 accF4 = accF4*delT; // scalar * vector
303 myPosF4 = myPosF4 + (myVelF4 * delT) + (accF4 * delT)/2;
304 myVelF4 = myVelF4 + accF4;
305 xyzPos[kc.x] = myPos;
306 xyzVel[kc.x] = myVel;
307 }
308 ```
309
310