1 ## Leverage CodeReflection to expose array view of buffers in kernels
2
3 ----
4 * [Contents](hat-00.md)
5 * Build Babylon and HAT
6 * [Quick Install](hat-01-quick-install.md)
7 * [Building Babylon with jtreg](hat-01-02-building-babylon.md)
8 * [Building HAT with jtreg](hat-01-03-building-hat.md)
9 * [Enabling the NVIDIA CUDA Backend](hat-01-05-building-hat-for-cuda.md)
10 * [Testing Framework](hat-02-testing-framework.md)
11 * [Running Examples](hat-03-examples.md)
12 * [HAT Programming Model](hat-03-programming-model.md)
13 * Interface Mapping
14 * [Interface Mapping Overview](hat-04-01-interface-mapping.md)
15 * [Cascade Interface Mapping](hat-04-02-cascade-interface-mapping.md)
16 * Development
17 * [Project Layout](hat-01-01-project-layout.md)
18 * [IntelliJ Code Formatter](hat-development.md)
19 * Implementation Details
20 * [Walkthrough Of Accelerator.compute()](hat-accelerator-compute.md)
21 * [How we minimize buffer transfers](hat-minimizing-buffer-transfers.md)
22 * [Running HAT with Docker on NVIDIA GPUs](hat-07-docker-build-nvidia.md)
23 ---
24
25 Here is the canonical HAT example
26
27 ```java
28 import jdk.incubator.code.Reflect;
29
30 @Reflect
31 public class Square {
32 @Reflect
33 public static void kernel(KernelContext kc, S32Arr s32Arr) {
34 s32Arr.array(kc.x, s32Arr.array(kc.x) * s32Arr.array(kc.x));
35 }
36
37 @Reflect
38 public static void compute(ComputeContext cc, S32Arr s32Arr) {
39 cc.dispatchKernel(s32Arr.length(), kc -> kernel(kc, s32Arr));
40 }
41 }
42 ```
43
44 This code in the kernel has always bothered me. One downside of using MemorySegment backed buffers,
45 in HAT is that we have made what in array form, would be simple code, look verbose.
46
47 ```java
48 @Reflect
49 public static void kernel(KernelContext kc, S32Arr s32Arr) {
50 s32Arr.array(kc.x, s32Arr.array(kc.x) * s32Arr.array(kc.x));
51 }
52 ```
53
54 But what if we added a method (`int[] arrayView()`) to `S32Arr` to extract a java int array `view` of simple arrays
55
56 Becomes way more readable.
57 ```java
58 @Reflect
59 public static void kernel(KernelContext kc, S32Arr s32Arr) {
60 int[] arr = s32Arr.arrayView();
61 arr[kc.x] *= arr[kc.x];
62 }
63 ```
64 IMHO This makes code more readable.
65
66 For the GPU this is fine. We can (thanks to CodeReflection) prove that the array is indeed just a view
67 and we just remove all references to `int arr[]` and replace array accessors with get/set accessors
68 on the original S32Arr.
69
70 But what about Java performance?. Won't it suck because we are copying the array in each kernel ;)
71
72 Well, we can use the same trick, we used for the GPU, we take the transformed model (with array
73 references removed) and create bytecode from that code we and run it.
74
75 Historically, we just run the original bytecode in the Java MT/Seq backends, but we don't have to.
76
77 This helps also with game of life
78 ```java
79 public static void lifePerIdx(int idx, @RO Control control, @RW CellGrid cellGrid) {
80 int w = cellGrid.width();
81 int h = cellGrid.height();
82 int from = control.from();
83 int to = control.to();
84 int x = idx % w;
85 int y = idx / w;
86 byte cell = cellGrid.cell(idx + from);
87 if (x > 0 && x < (w - 1) && y > 0 && y < (h - 1)) { // passports please
88 int count =
89 val(cellGrid, from, w, x - 1, y - 1)
90 + val(cellGrid, from, w, x - 1, y + 0)
91 + val(cellGrid, from, w, x - 1, y + 1)
92 + val(cellGrid, from, w, x + 0, y - 1)
93 + val(cellGrid, from, w, x + 0, y + 1)
94 + val(cellGrid, from, w, x + 1, y + 0)
95 + val(cellGrid, from, w, x + 1, y - 1)
96 + val(cellGrid, from, w, x + 1, y + 1);
97 cell = ((count == 3) || ((count == 2) && (cell == ALIVE))) ? ALIVE : DEAD;// B3/S23.
98 }
99 cellGrid.cell(idx + to, cell);
100 }
101 ```
102
103 This code uses a helper function `val(grid, offset, w, dx, dy)` to extract the neighbours
104 ```java
105 int count =
106 val(cellGrid, from, w, x - 1, y - 1)
107 + val(cellGrid, from, w, x - 1, y + 0)
108 + val(cellGrid, from, w, x - 1, y + 1)
109 + val(cellGrid, from, w, x + 0, y - 1)
110 + val(cellGrid, from, w, x + 0, y + 1)
111 + val(cellGrid, from, w, x + 1, y + 0)
112 + val(cellGrid, from, w, x + 1, y - 1)
113 + val(cellGrid, from, w, x + 1, y + 1);
114 ```
115
116 Val is a bit verbose
117
118 ```java
119 @Reflect
120 public static int val(@RO CellGrid grid, int from, int w, int x, int y) {
121 return grid.cell(((long) y * w) + x + from) & 1;
122 }
123 ```
124
125 ```java
126 @Reflect
127 public static int val(@RO CellGrid grid, int from, int w, int x, int y) {
128 byte[] bytes = grid.byteView(); // bit view would be nice ;)
129 return bytes[ y * w + x + from] & 1;
130 }
131 ```
132
133 We could now dispense with `val()` and just write
134
135 ```java
136 byte[] bytes = grid.byteView();
137 int count =
138 bytes[(y - 1) * w + x - 1 + from]&1
139 +bytes[(y + 0) * w + x - 1 + from]&1
140 +bytes[(y + 1) * w + x - 1 + from]&1
141 +bytes[(y - 1) * w + x + 0 + from]&1
142 +bytes[(y + 1) * w + x + 0 + from]&1
143 +bytes[(y + 0) * w + x + 1 + from]&1
144 +bytes[(y - 1) * w + x + 1 + from]&1
145 +bytes[(y + 1) * w + x + 1 + from]&1 ;
146 ```
147
148 BTW My inner verilog programmer has always wondered whether shift and oring each bit into a
149 9 bit value, which we use to index live/dead state from a prepopulated 512 array (GPU constant memory)
150 would allow us to sidestep the wave divergent conditional :)
151
152 ```java
153 byte[] bytes = grid.byteView();
154 int idx = 0;
155 int to =0;
156 byte[] lookup = new byte[]{};
157 int lookupIdx =
158 bytes[(y - 1) * w + x - 1 + from]&1 <<0
159 |bytes[(y + 0) * w + x - 1 + from]&1 <<1
160 |bytes[(y + 1) * w + x - 1 + from]&1 <<2
161 |bytes[(y - 1) * w + x + 0 + from]&1 <<3
162 |bytes[(y - 0) * w + x + 0 + from]&1 <<4 // current cell added
163 |bytes[(y + 1) * w + x + 0 + from]&1 <<5
164 |bytes[(y + 0) * w + x + 1 + from]&1 <<6
165 |bytes[(y - 1) * w + x + 1 + from]&1 <<7
166 |bytes[(y + 1) * w + x + 1 + from]&1 <<8 ;
167 // conditional removed!
168
169 bytes[idx + to] = lookup[lookupIdx];
170 ```
171
172 So the task here is to process kernel code models and perform the appropriate analysis
173 (for tracking the primitive arrays origin) and transformations to map the array code back to buffer get/sets
174
175 ----------
176 The arrayView trick actually leads us to other possibilities.
177
178 Let's look at current NBody code.
179
180 ```java
181 @Reflect
182 static public void nbodyKernel(@RO KernelContext kc, @RW Universe universe, float mass, float delT, float espSqr) {
183 float accx = 0.0f;
184 float accy = 0.0f;
185 float accz = 0.0f;
186 Universe.Body me = universe.body(kc.x);
187
188 for (int i = 0; i < kc.maxX; i++) {
189 Universe.Body otherBody = universe.body(i);
190 float dx = otherBody.x() - me.x();
191 float dy = otherBody.y() - me.y();
192 float dz = otherBody.z() - me.z();
193 float invDist = (float) (1.0f / Math.sqrt(((dx * dx) + (dy * dy) + (dz * dz) + espSqr)));
194 float s = mass * invDist * invDist * invDist;
195 accx = accx + (s * dx);
196 accy = accy + (s * dy);
197 accz = accz + (s * dz);
198 }
199 accx = accx * delT;
200 accy = accy * delT;
201 accz = accz * delT;
202 me.x(me.x() + (me.vx() * delT) + accx * .5f * delT);
203 me.y(me.y() + (me.vy() * delT) + accy * .5f * delT);
204 me.z(me.z() + (me.vz() * delT) + accz * .5f * delT);
205 me.vx(me.vx() + accx);
206 me.vy(me.vy() + accy);
207 me.vz(me.vz() + accz);
208 }
209 ```
210
211 Contrast, the above code with the OpenCL code using `float4`
212
213 ```java
214 __kernel void nbody( __global float4 *xyzPos ,__global float4* xyzVel, float mass, float delT, float espSqr ){
215 float4 acc = (0.0, 0.0,0.0,0.0);
216 float4 myPos = xyzPos[get_global_id(0)];
217 float4 myVel = xyzVel[get_global_id(0)];
218 for (int i = 0; i < get_global_size(0); i++) {
219 float4 delta = xyzPos[i] - myPos;
220 float invDist = (float) 1.0/sqrt((float)((delta.x * delta.x) + (delta.y * delta.y) + (delta.z * delta.z) + espSqr));
221 float s = mass * invDist * invDist * invDist;
222 acc= acc + (s * delta);
223 }
224 acc = acc*delT;
225 myPos = myPos + (myVel * delT) + (acc * delT)/2;
226 myVel = myVel + acc;
227 xyzPos[get_global_id(0)] = myPos;
228 xyzVel[get_global_id(0)] = myVel;
229 }
230 ```
231
232 Thanks to interface mapped segments we can approximate `float4` vector type, In fact the
233 existing `Universe` interface mapped segment actually embeds a `float6` kind of
234 object holding the `x,y,z,vx,vy,vz` values for each `body` (so position and velocity)
235
236 What if we created generalized `float4` interface mapped views `(x,y,z,w)` as placeholders for true vector types.
237
238 And modified Universe to hold `pos` and `vel` `float4` arrays.
239
240 So in Java code we would pass a MemorySegment version of a F32x4Arr.
241
242 Our java code could then start to approximate the OpenCL code.
243
244 Except for our lack of operator overloading...
245
246 ```java
247 void nbody(KernelContext kc, F32x4Arr xyzPos ,F32x4Arr xyzVel, float mass, float delT, float espSqr ){
248 float4 acc = float4.of(0.0,0.0,0.0,0.0);
249 float4[] xyPosArr = xyzPos.float4View();
250
251 float4 myPos = xyzPosArr[kc.x];
252 float4 myVel = xyzVelArr[kc.x];
253 for (int i = 0; i < kc.max; i++) {
254 float4 delta = float4.sub(xyzPosArr[i],myPos); // yucky but ok
255 float invDist = (float) 1.0/sqrt((float)((delta.x * delta.x) + (delta.y * delta.y) + (delta.z * delta.z) + espSqr));
256 float s = mass * invDist * invDist * invDist;
257 acc= float4.add(acc,(s * delta)); // adding scaler to float4 via overloaded add method
258 }
259 acc = float4.mul(acc*delT); // scaler * vector
260 myPos = float4.add(float4.add(myPos,float4.mul(myVel * delT)) , float4.mul(acc,delT/2));
261 myVel = float4.add(myVel,acc);
262 xyzPos[kc.x] = myPos;
263 xyzVel[kc.x] = myVel;
264 }
265 ```
266 The code is more compact, it's still weird though. Because we can't overload operators.
267
268 Well we can sort of.
269
270 What if we allowed a `floatView()` call on `float4` ;) which yields float value to be used as a proxy
271 for the `float4` we fetched it from...
272
273 So we would leverage the anonymity of `var`
274
275 `var myVelF4 = myVel.floatView()` // pst var is really a float
276
277 From the code model we can track the relationship from float views to the original vector...
278
279 Any action we perform on the 'float' view will be mapped back to calls on the origin, and performed on the actual origin float4.
280
281 So
282 `myVelF4 + myVelF4` -> `float4.add(myVel,myVel)` behind the scenes.
283
284 Yeah, it's a bit crazy. The code would look something like this. Perhaps this is a 'bridge too far'.
285
286 ```java
287 void nbody(KernelContext kc, F32x4Arr xyzPos ,F32x4Arr xyzVel, float mass, float delT, float espSqr ){
288 float4 acc = float4.of(0.0,0.0,0.0,0.0);
289 var accF4 = acc.floatView(); // var is really float
290 float4[] xyPosArr = xyzPos.float4View();
291
292 float4 myPos = xyzPosArr[kc.x];
293 float4 myVel = xyzVelArr[kc.x];
294 var myPosF4 = myPos.floatView(); // ;) var is actually a float tied to myPos
295 var myVelF4 = myVel.floatView(); //... myVel
296 for (int i = 0; i < kc.max; i++) {
297 var bodyF4 = xyzPosArr[i].floatView(); // bodyF4 is a float
298 var deltaF4 = bodyF4 - myPosF4; // look ma operator overloading ;)
299 float invDist = (float) 1.0/sqrt((float)((delta.x * delta.x) + (delta.y * delta.y) + (delta.z * delta.z) + espSqr));
300 float s = mass * invDist * invDist * invDist;
301 accF4+=s * deltaF4; // adding scaler to float4 via overloaded add method
302 }
303 accF4 = accF4*delT; // scalar * vector
304 myPosF4 = myPosF4 + (myVelF4 * delT) + (accF4 * delT)/2;
305 myVelF4 = myVelF4 + accF4;
306 xyzPos[kc.x] = myPos;
307 xyzVel[kc.x] = myVel;
308 }
309 ```
310
311