1 ## Leverage CodeReflection to expose array view of buffers in kernels 2 3 ---- 4 5 * [Contents](hat-00.md) 6 * House Keeping 7 * [Project Layout](hat-01-01-project-layout.md) 8 * [Building Babylon](hat-01-02-building-babylon.md) 9 * [Building HAT](hat-01-03-building-hat.md) 10 * [Enabling the CUDA Backend](hat-01-05-building-hat-for-cuda.md) 11 * Programming Model 12 * [Programming Model](hat-03-programming-model.md) 13 * Interface Mapping 14 * [Interface Mapping Overview](hat-04-01-interface-mapping.md) 15 * [Cascade Interface Mapping](hat-04-02-cascade-interface-mapping.md) 16 * Implementation Detail 17 * [Walkthrough Of Accelerator.compute()](hat-accelerator-compute.md) 18 * [How we minimize buffer transfers](hat-minimizing-buffer-transfers.md) 19 20 ---- 21 22 Here is the canonical HAT example 23 24 ```java 25 import jdk.incubator.code.CodeReflection; 26 27 @CodeReflection 28 public class Square { 29 @CodeReflection 30 public static void kernel(KernelContext kc, S32Arr s32Arr) { 31 s32Arr.array(kc.x, s32Arr.array(kc.x) * s32Arr.array(kc.x)); 32 } 33 34 @CodeReflection 35 public static void compute(ComputeContext cc, S32Arr s32Arr) { 36 cc.dispatchKernel(s32Arr.length(), kc -> kernel(kc, s32Arr)); 37 } 38 } 39 ``` 40 41 This code in the kernel has always bothered me. One downside of using MemorySegment backed buffers, 42 in HAT is that we have made what in array form, would be simple code, look verbose. 43 44 ```java 45 @CodeReflection 46 public static void kernel(KernelContext kc, S32Arr s32Arr) { 47 s32Arr.array(kc.x, s32Arr.array(kc.x) * s32Arr.array(kc.x)); 48 } 49 ``` 50 51 But what if we added a method (`int[] arrayView()`) to `S32Arr` to extract a java int array `view` of simple arrays 52 53 Becomes way more readable. 54 ```java 55 @CodeReflection 56 public static void kernel(KernelContext kc, S32Arr s32Arr) { 57 int[] arr = s32Arr.arrayView(); 58 arr[kc.x] *= arr[kc.x]; 59 } 60 ``` 61 IMHO This makes code more readable. 62 63 For the GPU this is fine. We can (thanks to CodeReflection) prove that the array is indeed just a view 64 and we just remove all references to `int arr[]` and replace array accessors with get/set accessors 65 on the original S32Arr. 66 67 But what about Java performance?. Won't it suck because we are copying the array in each kernel ;) 68 69 Well, we can use the same trick, we used for the GPU, we take the transformed model (with array 70 references removed) and create bytecode from that code we and run it. 71 72 Historically, we just run the original bytecode in the Java MT/Seq backends, but we don't have to. 73 74 This helps also with game of life 75 ```java 76 public static void lifePerIdx(int idx, @RO Control control, @RW CellGrid cellGrid) { 77 int w = cellGrid.width(); 78 int h = cellGrid.height(); 79 int from = control.from(); 80 int to = control.to(); 81 int x = idx % w; 82 int y = idx / w; 83 byte cell = cellGrid.cell(idx + from); 84 if (x > 0 && x < (w - 1) && y > 0 && y < (h - 1)) { // passports please 85 int count = 86 val(cellGrid, from, w, x - 1, y - 1) 87 + val(cellGrid, from, w, x - 1, y + 0) 88 + val(cellGrid, from, w, x - 1, y + 1) 89 + val(cellGrid, from, w, x + 0, y - 1) 90 + val(cellGrid, from, w, x + 0, y + 1) 91 + val(cellGrid, from, w, x + 1, y + 0) 92 + val(cellGrid, from, w, x + 1, y - 1) 93 + val(cellGrid, from, w, x + 1, y + 1); 94 cell = ((count == 3) || ((count == 2) && (cell == ALIVE))) ? ALIVE : DEAD;// B3/S23. 95 } 96 cellGrid.cell(idx + to, cell); 97 } 98 ``` 99 100 This code uses a helper function `val(grid, offset, w, dx, dy)` to extract the neighbours 101 ```java 102 int count = 103 val(cellGrid, from, w, x - 1, y - 1) 104 + val(cellGrid, from, w, x - 1, y + 0) 105 + val(cellGrid, from, w, x - 1, y + 1) 106 + val(cellGrid, from, w, x + 0, y - 1) 107 + val(cellGrid, from, w, x + 0, y + 1) 108 + val(cellGrid, from, w, x + 1, y + 0) 109 + val(cellGrid, from, w, x + 1, y - 1) 110 + val(cellGrid, from, w, x + 1, y + 1); 111 ``` 112 113 Val is a bit verbose 114 115 ```java 116 @CodeReflection 117 public static int val(@RO CellGrid grid, int from, int w, int x, int y) { 118 return grid.cell(((long) y * w) + x + from) & 1; 119 } 120 ``` 121 122 ```java 123 @CodeReflection 124 public static int val(@RO CellGrid grid, int from, int w, int x, int y) { 125 byte[] bytes = grid.byteView(); // bit view would be nice ;) 126 return bytes[ y * w + x + from] & 1; 127 } 128 ``` 129 130 We could now dispense with `val()` and just write 131 132 ```java 133 byte[] bytes = grid.byteView(); 134 int count = 135 bytes[(y - 1) * w + x - 1 + from]&1 136 +bytes[(y + 0) * w + x - 1 + from]&1 137 +bytes[(y + 1) * w + x - 1 + from]&1 138 +bytes[(y - 1) * w + x + 0 + from]&1 139 +bytes[(y + 1) * w + x + 0 + from]&1 140 +bytes[(y + 0) * w + x + 1 + from]&1 141 +bytes[(y - 1) * w + x + 1 + from]&1 142 +bytes[(y + 1) * w + x + 1 + from]&1 ; 143 ``` 144 145 BTW My inner verilog programmer has always wondered whether shift and oring each bit into a 146 9 bit value, which we use to index live/dead state from a prepopulated 512 array (GPU constant memory) 147 would allow us to sidestep the wave divergent conditional :) 148 149 ```java 150 byte[] bytes = grid.byteView(); 151 int idx = 0; 152 int to =0; 153 byte[] lookup = new byte[]{}; 154 int lookupIdx = 155 bytes[(y - 1) * w + x - 1 + from]&1 <<0 156 |bytes[(y + 0) * w + x - 1 + from]&1 <<1 157 |bytes[(y + 1) * w + x - 1 + from]&1 <<2 158 |bytes[(y - 1) * w + x + 0 + from]&1 <<3 159 |bytes[(y - 0) * w + x + 0 + from]&1 <<4 // current cell added 160 |bytes[(y + 1) * w + x + 0 + from]&1 <<5 161 |bytes[(y + 0) * w + x + 1 + from]&1 <<6 162 |bytes[(y - 1) * w + x + 1 + from]&1 <<7 163 |bytes[(y + 1) * w + x + 1 + from]&1 <<8 ; 164 // conditional removed! 165 166 bytes[idx + to] = lookup[lookupIdx]; 167 ``` 168 169 So the task here is to process kernel code models and perform the appropriate analysis 170 (for tracking the primitive arrays origin) and transformations to map the array code back to buffer get/sets 171 172 ---------- 173 The arrayView trick actually leads us to other possibilities. 174 175 Let's look at current NBody code. 176 177 ```java 178 @CodeReflection 179 static public void nbodyKernel(@RO KernelContext kc, @RW Universe universe, float mass, float delT, float espSqr) { 180 float accx = 0.0f; 181 float accy = 0.0f; 182 float accz = 0.0f; 183 Universe.Body me = universe.body(kc.x); 184 185 for (int i = 0; i < kc.maxX; i++) { 186 Universe.Body otherBody = universe.body(i); 187 float dx = otherBody.x() - me.x(); 188 float dy = otherBody.y() - me.y(); 189 float dz = otherBody.z() - me.z(); 190 float invDist = (float) (1.0f / Math.sqrt(((dx * dx) + (dy * dy) + (dz * dz) + espSqr))); 191 float s = mass * invDist * invDist * invDist; 192 accx = accx + (s * dx); 193 accy = accy + (s * dy); 194 accz = accz + (s * dz); 195 } 196 accx = accx * delT; 197 accy = accy * delT; 198 accz = accz * delT; 199 me.x(me.x() + (me.vx() * delT) + accx * .5f * delT); 200 me.y(me.y() + (me.vy() * delT) + accy * .5f * delT); 201 me.z(me.z() + (me.vz() * delT) + accz * .5f * delT); 202 me.vx(me.vx() + accx); 203 me.vy(me.vy() + accy); 204 me.vz(me.vz() + accz); 205 } 206 ``` 207 208 Contrast, the above code with the OpenCL code using `float4` 209 210 ```java 211 __kernel void nbody( __global float4 *xyzPos ,__global float4* xyzVel, float mass, float delT, float espSqr ){ 212 float4 acc = (0.0, 0.0,0.0,0.0); 213 float4 myPos = xyzPos[get_global_id(0)]; 214 float4 myVel = xyzVel[get_global_id(0)]; 215 for (int i = 0; i < get_global_size(0); i++) { 216 float4 delta = xyzPos[i] - myPos; 217 float invDist = (float) 1.0/sqrt((float)((delta.x * delta.x) + (delta.y * delta.y) + (delta.z * delta.z) + espSqr)); 218 float s = mass * invDist * invDist * invDist; 219 acc= acc + (s * delta); 220 } 221 acc = acc*delT; 222 myPos = myPos + (myVel * delT) + (acc * delT)/2; 223 myVel = myVel + acc; 224 xyzPos[get_global_id(0)] = myPos; 225 xyzVel[get_global_id(0)] = myVel; 226 } 227 ``` 228 229 Thanks to interface mapped segments we can approximate `float4` vector type, In fact the 230 existing `Universe` interface mapped segment actually embeds a `float6` kind of 231 object holding the `x,y,z,vx,vy,vz` values for each `body` (so position and velocity) 232 233 What if we created generalized `float4` interface mapped views `(x,y,z,w)` as placeholders for true vector types. 234 235 And modified Universe to hold `pos` and `vel` `float4` arrays. 236 237 So in Java code we would pass a MemorySegment version of a F32x4Arr. 238 239 Our java code could then start to approximate the OpenCL code. 240 241 Except for our lack of operator overloading... 242 243 ```java 244 void nbody(KernelContext kc, F32x4Arr xyzPos ,F32x4Arr xyzVel, float mass, float delT, float espSqr ){ 245 float4 acc = float4.of(0.0,0.0,0.0,0.0); 246 float4[] xyPosArr = xyzPos.float4View(); 247 248 float4 myPos = xyzPosArr[kc.x]; 249 float4 myVel = xyzVelArr[kc.x]; 250 for (int i = 0; i < kc.max; i++) { 251 float4 delta = float4.sub(xyzPosArr[i],myPos); // yucky but ok 252 float invDist = (float) 1.0/sqrt((float)((delta.x * delta.x) + (delta.y * delta.y) + (delta.z * delta.z) + espSqr)); 253 float s = mass * invDist * invDist * invDist; 254 acc= float4.add(acc,(s * delta)); // adding scaler to float4 via overloaded add method 255 } 256 acc = float4.mul(acc*delT); // scaler * vector 257 myPos = float4.add(float4.add(myPos,float4.mul(myVel * delT)) , float4.mul(acc,delT/2)); 258 myVel = float4.add(myVel,acc); 259 xyzPos[kc.x] = myPos; 260 xyzVel[kc.x] = myVel; 261 } 262 ``` 263 The code is more compact, it's still weird though. Because we can't overload operators. 264 265 Well we can sort of. 266 267 What if we allowed a `floatView()` call on `float4` ;) which yields float value to be used as a proxy 268 for the `float4` we fetched it from... 269 270 So we would leverage the anonymity of `var` 271 272 `var myVelF4 = myVel.floatView()` // pst var is really a float 273 274 From the code model we can track the relationship from float views to the original vector... 275 276 Any action we perform on the 'float' view will be mapped back to calls on the origin, and performed on the actual origin float4. 277 278 So 279 `myVelF4 + myVelF4` -> `float4.add(myVel,myVel)` behind the scenes. 280 281 Yeah, it's a bit crazy. The code would look something like this. Perhaps this is a 'bridge too far'. 282 283 ```java 284 void nbody(KernelContext kc, F32x4Arr xyzPos ,F32x4Arr xyzVel, float mass, float delT, float espSqr ){ 285 float4 acc = float4.of(0.0,0.0,0.0,0.0); 286 var accF4 = acc.floatView(); // var is really float 287 float4[] xyPosArr = xyzPos.float4View(); 288 289 float4 myPos = xyzPosArr[kc.x]; 290 float4 myVel = xyzVelArr[kc.x]; 291 var myPosF4 = myPos.floatView(); // ;) var is actually a float tied to myPos 292 var myVelF4 = myVel.floatView(); //... myVel 293 for (int i = 0; i < kc.max; i++) { 294 var bodyF4 = xyzPosArr[i].floatView(); // bodyF4 is a float 295 var deltaF4 = bodyF4 - myPosF4; // look ma operator overloading ;) 296 float invDist = (float) 1.0/sqrt((float)((delta.x * delta.x) + (delta.y * delta.y) + (delta.z * delta.z) + espSqr)); 297 float s = mass * invDist * invDist * invDist; 298 accF4+=s * deltaF4; // adding scaler to float4 via overloaded add method 299 } 300 accF4 = accF4*delT; // scalar * vector 301 myPosF4 = myPosF4 + (myVelF4 * delT) + (accF4 * delT)/2; 302 myVelF4 = myVelF4 + accF4; 303 xyzPos[kc.x] = myPos; 304 xyzVel[kc.x] = myVel; 305 } 306 ``` 307 308