1 # What happens when we call accelerator.compute(lambda) 2 3 ---- 4 5 * [Contents](hat-00.md) 6 * House Keeping 7 * [Project Layout](hat-01-01-project-layout.md) 8 * [Building Babylon](hat-01-02-building-babylon.md) 9 * [Building HAT](hat-01-03-building-hat.md) 10 * [Enabling the CUDA Backend](hat-01-05-building-hat-for-cuda.md) 11 * Programming Model 12 * [Programming Model](hat-03-programming-model.md) 13 * Interface Mapping 14 * [Interface Mapping Overview](hat-04-01-interface-mapping.md) 15 * [Cascade Interface Mapping](hat-04-02-cascade-interface-mapping.md) 16 * Implementation Detail 17 * [Walkthrough Of Accelerator.compute()](hat-accelerator-compute.md) 18 * [How we minimize buffer transfers](hat-minimizing-buffer-transfers.md) 19 20 ---- 21 22 # What happens when we call accelerator.compute(lambda) 23 24 # Back to our Squares example. 25 26 So what is going on here? 27 28 ```java 29 accelerator.compute( 30 cc -> SquareCompute.square(cc, s32Array) 31 ); 32 ``` 33 34 Recall we have two types of code in our compute class. We have kernels (and kernel reachable methods) and we have 35 compute entrypoints (and compute reachable methods). 36 37 ```java 38 public class SquareCompute{ 39 @CodeReflection public static int square(int v) { 40 return v * v; 41 } 42 43 @CodeReflection public static void squareKernel(KernelContext kc, S32Array s32Array) { 44 int value = s32Array.array(kc.x); // arr[cc.x] 45 s32Array.array(kc.x, square(value)); // arr[cc.x]=value*value 46 } 47 48 @CodeReflection public static void square(ComputeContext cc, S32Array s32Array) { 49 cc.dispatchKernel(s32Array.length(), 50 kc -> squareKernel(kc, s32Array) 51 ); 52 } 53 } 54 ``` 55 56 AGAIN.... NOTE that we cannot just call the compute entrypoint or the kernel directly. 57 58 ```java 59 SquareCompute.square(????, s32Array); // We can't do this!!!! 60 ``` 61 62 We purposely make it inconvenient (ComputeContext and KernelContext construction is embedded in the framwork) to 63 mistakenly call the compute entrypoint directly. Doing so is akin to calling `Thread.run()` directly, rather than 64 calling `Thread.start()` on a class extending `Thread` and providing an implementation of `Thread.run()` 65 66 Instead we use this pattern 67 68 ```java 69 accelerator.compute( 70 cc -> SquareCompute.square(cc, s32Array) 71 ); 72 ``` 73 74 We pass a lambda to `accelerator.compute()` which is used to determine which compute method to invoke. 75 76 ``` 77 User | Accelerator | Compute | Babylon | Backend | 78 Context Java C++ Vendor 79 +----+ +-----------+ +-------+ +-------+ +----+ +---+ +------+ 80 | | | | | | | | | | | | | | 81 +----+ +-----------+ +-------+ +-------+ +----+ +---+ +------+ 82 +--------> accelerator.compute(lambda) 83 84 ``` 85 86 Incidently, this lambda is never executed by Java JVM ;) instead, the accelerator uses Babylon's Code Reflection 87 capabilities to extract the model of this lambda to determine the compute entrypoint and it's captured args. 88 89 ``` 90 User | Accelerator | Compute | Babylon | Backend | 91 Context Java C++ Vendor 92 +----+ +-----------+ +-------+ +-------+ +----+ +---+ +------+ 93 | | | | | | | | | | | | | | 94 +----+ +-----------+ +-------+ +-------+ +----+ +---+ +------+ 95 +--------> accelerator.compute( cc -> SquareCompute.square(cc, s32Array) ) 96 -------------------------> 97 getModelOf(lambda) 98 <------------------------ 99 ``` 100 101 This model describes the call that we want the accelerator to 102 execute or interpret (`SquareCompute.square()`) and the args that were captured from the call site (the `s32Array` buffer). 103 104 The accelerator uses Babylon again to get the 105 code model of `SquareCompute.square()` builds a ComputeReachableGraph with this method at the root. 106 So the accelerator walks the code model and collects the methods (and code models) of all methods 107 reachable from the entrypoint. 108 109 In our trivial case, the ComputeReachableGraph has a single root node representing the `SquareCompute.square()`. 110 111 ``` 112 User | Accelerator | Compute | Babylon | Backend | 113 Context Java C++ Vendor 114 +----+ +-----------+ +-------+ +-------+ +----+ +---+ +------+ 115 | | | | | | | | | | | | | | 116 +----+ +-----------+ +-------+ +-------+ +----+ +---+ +------+ 117 +--------> accelerator.compute( cc -> SquareCompute.square(cc, s32Array) ) 118 -------------------------> 119 getModelOf(lambda) 120 <------------------------ 121 -------------------------> 122 getModelOf(SquareCompute.square()) 123 <------------------------- 124 forEachReachable method in SquareCompute.square() { 125 -------------------------> 126 getModelOf(method) 127 <------------------------ 128 add to ComputeReachableGraph 129 } 130 ``` 131 132 The Accelertor then walks through the ComputeReachableGraph to determine which kernels are referenced.. 133 134 For each kernel we extract the kernels entrypoint (again as a Babylon 135 Code Model) and create a KernelReachableGraph for each kernel. Again by starting 136 at the kernel entrypoint and closing over all reachable methods (and Code Models). 137 138 We combine the compute and kernel reachable graphs and create an place them in a `ComputeContext`. 139 140 This is the first arg that is 'seemingly' passed to the Compute class. Remember the compute 141 entrypoint is just a model of the code we expect to 142 execute. It may never be executed by the JVM. 143 144 ``` 145 User | Accelerator | Compute | Babylon | Backend | 146 Context Java C++ Vendor 147 +----+ +-----------+ +-------+ +-------+ +----+ +---+ +------+ 148 | | | | | | | | | | | | | | 149 +----+ +-----------+ +-------+ +-------+ +----+ +---+ +------+ 150 151 forEachReachable kernel in ComputeReachableGraph { 152 -------------------------> 153 getModelOf(kernel) 154 <------------------------ 155 add to KernelReachableGraph 156 } 157 ComputeContext = {ComputeReachableGraph + KernelReachableGraph} 158 159 ``` 160 161 The accelerator passes the ComputeContext to backend (`computeContextHandoff()`), which will typically take 162 the opportunity to inspect/mutate the compute and kernel models and possibly build backend specific representations of 163 kernels and compile them. 164 165 The ComputeContext and the captured args are then passed to the backend for execution. 166 167 ``` 168 User | Accelerator | Compute | Babylon | Backend | 169 Context Java C++ Vendor 170 +----+ +-----------+ +-------+ +-------+ +----+ +---+ +------+ 171 | | | | | | | | | | | | | | 172 +----+ +-----------+ +-------+ +-------+ +----+ +---+ +------+ 173 174 175 -----------------------------------> 176 computeContextHandoff(CLWrapComputeContext) 177 -------> 178 -------> 179 compileKernels() 180 <------ 181 mutateComputeModels 182 <------- 183 dispatchCompute(CLWrapComputeContext, args) 184 -------> 185 dispatchCompute(...) 186 ---------> 187 { 188 dispatchKernel() 189 ... 190 } 191 <-------- 192 <------ 193 <---------------------------------- 194 195 ``` 196 197 ---- 198 ### Notes 199 200 In reality. The Accelerator receives a `QuotableComputeContextConsumer` 201 202 ```java 203 public interface QuotableComputeContextConsumer 204 extends Quotable, 205 Consumer<ComputeContext> { 206 } 207 ``` 208 Here is how we extract the 'target' from such a lambda 209 210 ```java 211 public void compute(QuotableComputeContextConsumer qccc) { 212 Quoted quoted = Op.ofQuotable(qccc).orElseThrow(); 213 LambdaOpWrapper lambda = OpTools.wrap((CoreOps.LambdaOp)quoted.op()); 214 215 Method method = lambda.getQuotableComputeContextTargetMethod(); 216 217 // Get from the cache or create a compute context which closes over compute entryppint 218 // and reachable kernels. 219 // The models of all compute and kernel methods are passed to the backend during creation 220 // The backend may well mutate the models. 221 // It will also use this opportunity to generate ISA specific code for the kernels. 222 223 ComputeContext = this.cache.computeIfAbsent(method, (_) -> 224 new ComputeContext(this/*Accelerator*/, method) 225 ); 226 227 // Here we get the captured args from the Quotable and 'jam' in the CLWrapComputeContext in slot[0] 228 Object[] args = lambda.getQuotableComputeContextArgs(quoted, method, CLWrapComputeContext); 229 this.compute(CLWrapComputeContext, args); 230 } 231 ```