1 #### Leverage CodeReflection to expose array view of buffers in kernels 2 ---- 3 4 * [Contents](hat-00.md) 5 * House Keeping 6 * [Project Layout](hat-01-01-project-layout.md) 7 * [Building Babylon](hat-01-02-building-babylon.md) 8 * [Building HAT](hat-01-03-building-hat.md) 9 * Programming Model 10 * [Programming Model](hat-03-programming-model.md) 11 * Interface Mapping 12 * [Interface Mapping Overview](hat-04-01-interface-mapping.md) 13 * [Cascade Interface Mapping](hat-04-02-cascade-interface-mapping.md) 14 * Implementation Detail 15 * [Walkthrough Of Accelerator.compute()](hat-accelerator-compute.md) 16 * [How we minimize buffer transfers](hat-minimizing-buffer-transfers.md) 17 18 ---- 19 20 Here is the canonical HAT example 21 22 ```java 23 import jdk.incubator.code.CodeReflection; 24 25 @CodeReflection 26 public class Square { 27 @CodeReflection 28 public static void kernel(KernelContext kc, S32Arr s32Arr) { 29 s32Arr.array(kc.x, s32Arr.array(kc.x) * s32Arr.array(kc.x)); 30 } 31 32 @CodeReflection 33 public static void compute(ComputeContext cc, S32Arr s32Arr) { 34 cc.dispatchKernel(s32Arr.length(), kc -> kernel(kc, s32Arr)); 35 } 36 } 37 ``` 38 39 This code in the kernel has always bothered me. One downside of using MemorySegment backed buffers, 40 in HAT is that we have made what in array form, would be simple code, look verbose. 41 42 ```java 43 @CodeReflection 44 public static void kernel(KernelContext kc, S32Arr s32Arr) { 45 s32Arr.array(kc.x, s32Arr.array(kc.x) * s32Arr.array(kc.x)); 46 } 47 ``` 48 49 But what if we added a method (`int[] arrayView()`) to `S32Arr` to extract a java int array `view` of simple arrays 50 51 Becomes way more readable. 52 ```java 53 @CodeReflection 54 public static void kernel(KernelContext kc, S32Arr s32Arr) { 55 int[] arr = s32Arr.arrayView(); 56 arr[kc.x] *= arr[kc.x]; 57 } 58 ``` 59 IMHO This makes code more readable. 60 61 For the GPU this is fine. We can (thanks to CodeReflection) prove that the array is indeed just a view 62 and we just remove all references to `int arr[]` and replace array accessors with get/set accessors 63 on the original S32Arr. 64 65 But what about Java performance?. Won't it suck because we are copying the array in each kernel ;) 66 67 Well, we can use the same trick, we used for the GPU, we take the transformed model (with array 68 references removed) and create bytecode from that code we and run it. 69 70 Historically, we just run the original bytecode in the Java MT/Seq backends, but we don't have to. 71 72 This helps also with game of life 73 ```java 74 public static void lifePerIdx(int idx, @RO Control control, @RW CellGrid cellGrid) { 75 int w = cellGrid.width(); 76 int h = cellGrid.height(); 77 int from = control.from(); 78 int to = control.to(); 79 int x = idx % w; 80 int y = idx / w; 81 byte cell = cellGrid.cell(idx + from); 82 if (x > 0 && x < (w - 1) && y > 0 && y < (h - 1)) { // passports please 83 int count = 84 val(cellGrid, from, w, x - 1, y - 1) 85 + val(cellGrid, from, w, x - 1, y + 0) 86 + val(cellGrid, from, w, x - 1, y + 1) 87 + val(cellGrid, from, w, x + 0, y - 1) 88 + val(cellGrid, from, w, x + 0, y + 1) 89 + val(cellGrid, from, w, x + 1, y + 0) 90 + val(cellGrid, from, w, x + 1, y - 1) 91 + val(cellGrid, from, w, x + 1, y + 1); 92 cell = ((count == 3) || ((count == 2) && (cell == ALIVE))) ? ALIVE : DEAD;// B3/S23. 93 } 94 cellGrid.cell(idx + to, cell); 95 } 96 ``` 97 98 This code uses a helper function `val(grid, offset, w, dx, dy)` to extract the neighbours 99 ```java 100 int count = 101 val(cellGrid, from, w, x - 1, y - 1) 102 + val(cellGrid, from, w, x - 1, y + 0) 103 + val(cellGrid, from, w, x - 1, y + 1) 104 + val(cellGrid, from, w, x + 0, y - 1) 105 + val(cellGrid, from, w, x + 0, y + 1) 106 + val(cellGrid, from, w, x + 1, y + 0) 107 + val(cellGrid, from, w, x + 1, y - 1) 108 + val(cellGrid, from, w, x + 1, y + 1); 109 ``` 110 111 Val is a bit verbose 112 113 ```java 114 @CodeReflection 115 public static int val(@RO CellGrid grid, int from, int w, int x, int y) { 116 return grid.cell(((long) y * w) + x + from) & 1; 117 } 118 ``` 119 120 ```java 121 @CodeReflection 122 public static int val(@RO CellGrid grid, int from, int w, int x, int y) { 123 byte[] bytes = grid.byteView(); // bit view would be nice ;) 124 return bytes[ y * w + x + from] & 1; 125 } 126 ``` 127 128 We could now dispense with `val()` and just write 129 130 ```java 131 byte[] bytes = grid.byteView(); 132 int count = 133 bytes[(y - 1) * w + x - 1 + from]&1 134 +bytes[(y + 0) * w + x - 1 + from]&1 135 +bytes[(y + 1) * w + x - 1 + from]&1 136 +bytes[(y - 1) * w + x + 0 + from]&1 137 +bytes[(y + 1) * w + x + 0 + from]&1 138 +bytes[(y + 0) * w + x + 1 + from]&1 139 +bytes[(y - 1) * w + x + 1 + from]&1 140 +bytes[(y + 1) * w + x + 1 + from]&1 ; 141 ``` 142 143 BTW My inner verilog programmer has always wondered whether shift and oring each bit into a 144 9 bit value, which we use to index live/dead state from a prepopulated 512 array (GPU constant memory) 145 would allow us to sidestep the wave divergent conditional :) 146 147 ```java 148 byte[] bytes = grid.byteView(); 149 int idx = 0; 150 int to =0; 151 byte[] lookup = new byte[]{}; 152 int lookupIdx = 153 bytes[(y - 1) * w + x - 1 + from]&1 <<0 154 |bytes[(y + 0) * w + x - 1 + from]&1 <<1 155 |bytes[(y + 1) * w + x - 1 + from]&1 <<2 156 |bytes[(y - 1) * w + x + 0 + from]&1 <<3 157 |bytes[(y - 0) * w + x + 0 + from]&1 <<4 // current cell added 158 |bytes[(y + 1) * w + x + 0 + from]&1 <<5 159 |bytes[(y + 0) * w + x + 1 + from]&1 <<6 160 |bytes[(y - 1) * w + x + 1 + from]&1 <<7 161 |bytes[(y + 1) * w + x + 1 + from]&1 <<8 ; 162 // conditional removed! 163 164 bytes[idx + to] = lookup[lookupIdx]; 165 ``` 166 167 So the task here is to process kernel code models and perform the appropriate analysis 168 (for tracking the primitive arrays origin) and transformations to map the array code back to buffer get/sets 169 170 ---------- 171 The arrayView trick actually leads us to other possibilities. 172 173 Let's look at current NBody code. 174 175 ```java 176 @CodeReflection 177 static public void nbodyKernel(@RO KernelContext kc, @RW Universe universe, float mass, float delT, float espSqr) { 178 float accx = 0.0f; 179 float accy = 0.0f; 180 float accz = 0.0f; 181 Universe.Body me = universe.body(kc.x); 182 183 for (int i = 0; i < kc.maxX; i++) { 184 Universe.Body otherBody = universe.body(i); 185 float dx = otherBody.x() - me.x(); 186 float dy = otherBody.y() - me.y(); 187 float dz = otherBody.z() - me.z(); 188 float invDist = (float) (1.0f / Math.sqrt(((dx * dx) + (dy * dy) + (dz * dz) + espSqr))); 189 float s = mass * invDist * invDist * invDist; 190 accx = accx + (s * dx); 191 accy = accy + (s * dy); 192 accz = accz + (s * dz); 193 } 194 accx = accx * delT; 195 accy = accy * delT; 196 accz = accz * delT; 197 me.x(me.x() + (me.vx() * delT) + accx * .5f * delT); 198 me.y(me.y() + (me.vy() * delT) + accy * .5f * delT); 199 me.z(me.z() + (me.vz() * delT) + accz * .5f * delT); 200 me.vx(me.vx() + accx); 201 me.vy(me.vy() + accy); 202 me.vz(me.vz() + accz); 203 } 204 ``` 205 206 Contrast, the above code with the OpenCL code using `float4` 207 208 ```java 209 __kernel void nbody( __global float4 *xyzPos ,__global float4* xyzVel, float mass, float delT, float espSqr ){ 210 float4 acc = (0.0, 0.0,0.0,0.0); 211 float4 myPos = xyzPos[get_global_id(0)]; 212 float4 myVel = xyzVel[get_global_id(0)]; 213 for (int i = 0; i < get_global_size(0); i++) { 214 float4 delta = xyzPos[i] - myPos; 215 float invDist = (float) 1.0/sqrt((float)((delta.x * delta.x) + (delta.y * delta.y) + (delta.z * delta.z) + espSqr)); 216 float s = mass * invDist * invDist * invDist; 217 acc= acc + (s * delta); 218 } 219 acc = acc*delT; 220 myPos = myPos + (myVel * delT) + (acc * delT)/2; 221 myVel = myVel + acc; 222 xyzPos[get_global_id(0)] = myPos; 223 xyzVel[get_global_id(0)] = myVel; 224 } 225 ``` 226 227 Thanks to interface mapped segments we can approximate `float4` vector type, In fact the 228 existing `Universe` interface mapped segment actually embeds a `float6` kind of 229 object holding the `x,y,z,vx,vy,vz` values for each `body` (so position and velocity) 230 231 What if we created generalized `float4` interface mapped views `(x,y,z,w)` as placeholders for true vector types. 232 233 And modified Universe to hold `pos` and `vel` `float4` arrays. 234 235 So in Java code we would pass a MemorySegment version of a F32x4Arr. 236 237 Our java code could then start to approximate the OpenCL code. 238 239 Except for our lack of operator overloading... 240 241 ```java 242 void nbody(KernelContext kc, F32x4Arr xyzPos ,F32x4Arr xyzVel, float mass, float delT, float espSqr ){ 243 float4 acc = float4.of(0.0,0.0,0.0,0.0); 244 float4[] xyPosArr = xyzPos.float4View(); 245 246 float4 myPos = xyzPosArr[kc.x]; 247 float4 myVel = xyzVelArr[kc.x]; 248 for (int i = 0; i < kc.max; i++) { 249 float4 delta = float4.sub(xyzPosArr[i],myPos); // yucky but ok 250 float invDist = (float) 1.0/sqrt((float)((delta.x * delta.x) + (delta.y * delta.y) + (delta.z * delta.z) + espSqr)); 251 float s = mass * invDist * invDist * invDist; 252 acc= float4.add(acc,(s * delta)); // adding scaler to float4 via overloaded add method 253 } 254 acc = float4.mul(acc*delT); // scaler * vector 255 myPos = float4.add(float4.add(myPos,float4.mul(myVel * delT)) , float4.mul(acc,delT/2)); 256 myVel = float4.add(myVel,acc); 257 xyzPos[kc.x] = myPos; 258 xyzVel[kc.x] = myVel; 259 } 260 ``` 261 The code is more compact, it's still weird though. Because we can't overload operators. 262 263 Well we can sort of. 264 265 What if we allowed a `floatView()` call on `float4` ;) which yields float value to be used as a proxy 266 for the `float4` we fetched it from... 267 268 So we would leverage the anonymity of `var` 269 270 `var myVelF4 = myVel.floatView()` // pst var is really a float 271 272 From the code model we can track the relationship from float views to the original vector... 273 274 Any action we perform on the 'float' view will be mapped back to calls on the origin, and performed on the actual origin float4. 275 276 So 277 `myVelF4 + myVelF4` -> `float4.add(myVel,myVel)` behind the scenes. 278 279 Yeah, it's a bit crazy. The code would look something like this. Perhaps this is a 'bridge too far'. 280 281 ```java 282 void nbody(KernelContext kc, F32x4Arr xyzPos ,F32x4Arr xyzVel, float mass, float delT, float espSqr ){ 283 float4 acc = float4.of(0.0,0.0,0.0,0.0); 284 var accF4 = acc.floatView(); // var is really float 285 float4[] xyPosArr = xyzPos.float4View(); 286 287 float4 myPos = xyzPosArr[kc.x]; 288 float4 myVel = xyzVelArr[kc.x]; 289 var myPosF4 = myPos.floatView(); // ;) var is actually a float tied to myPos 290 var myVelF4 = myVel.floatView(); //... myVel 291 for (int i = 0; i < kc.max; i++) { 292 var bodyF4 = xyzPosArr[i].floatView(); // bodyF4 is a float 293 var deltaF4 = bodyF4 - myPosF4; // look ma operator overloading ;) 294 float invDist = (float) 1.0/sqrt((float)((delta.x * delta.x) + (delta.y * delta.y) + (delta.z * delta.z) + espSqr)); 295 float s = mass * invDist * invDist * invDist; 296 accF4+=s * deltaF4; // adding scaler to float4 via overloaded add method 297 } 298 accF4 = accF4*delT; // scalar * vector 299 myPosF4 = myPosF4 + (myVelF4 * delT) + (accF4 * delT)/2; 300 myVelF4 = myVelF4 + accF4; 301 xyzPos[kc.x] = myPos; 302 xyzVel[kc.x] = myVel; 303 } 304 ``` 305 306