1 # What happens when we call accelerator.compute(lambda) 2 3 ---- 4 5 * [Contents](hat-00.md) 6 * House Keeping 7 * [Project Layout](hat-01-01-project-layout.md) 8 * [Building Babylon](hat-01-02-building-babylon.md) 9 * [Building HAT](hat-01-03-building-hat.md) 10 * Programming Model 11 * [Programming Model](hat-03-programming-model.md) 12 * Interface Mapping 13 * [Interface Mapping Overview](hat-04-01-interface-mapping.md) 14 * [Cascade Interface Mapping](hat-04-02-cascade-interface-mapping.md) 15 * Implementation Detail 16 * [Walkthrough Of Accelerator.compute()](hat-accelerator-compute.md) 17 * [How we minimize buffer transfers](hat-minimizing-buffer-transfers.md) 18 19 ---- 20 21 # What happens when we call accelerator.compute(lambda) 22 23 # Back to our Squares example. 24 25 So what is going on here? 26 27 ```java 28 accelerator.compute( 29 cc -> SquareCompute.square(cc, s32Array) 30 ); 31 ``` 32 33 Recall we have two types of code in our compute class. We have kernels (and kernel reachable methods) and we have 34 compute entrypoints (and compute reachable methods). 35 36 ```java 37 public class SquareCompute{ 38 @CodeReflection public static int square(int v) { 39 return v * v; 40 } 41 42 @CodeReflection public static void squareKernel(KernelContext kc, S32Array s32Array) { 43 int value = s32Array.array(kc.x); // arr[cc.x] 44 s32Array.array(kc.x, square(value)); // arr[cc.x]=value*value 45 } 46 47 @CodeReflection public static void square(ComputeContext cc, S32Array s32Array) { 48 cc.dispatchKernel(s32Array.length(), 49 kc -> squareKernel(kc, s32Array) 50 ); 51 } 52 } 53 ``` 54 55 AGAIN.... NOTE that we cannot just call the compute entrypoint or the kernel directly. 56 57 ```java 58 SquareCompute.square(????, s32Array); // We can't do this!!!! 59 ``` 60 61 We purposely make it inconvenient (ComputeContext and KernelContext construction is embedded in the framwork) to 62 mistakenly call the compute entrypoint directly. Doing so is akin to calling `Thread.run()` directly, rather than 63 calling `Thread.start()` on a class extending `Thread` and providing an implementation of `Thread.run()` 64 65 Instead we use this pattern 66 67 ```java 68 accelerator.compute( 69 cc -> SquareCompute.square(cc, s32Array) 70 ); 71 ``` 72 73 We pass a lambda to `accelerator.compute()` which is used to determine which compute method to invoke. 74 75 ``` 76 User | Accelerator | Compute | Babylon | Backend | 77 Context Java C++ Vendor 78 +----+ +-----------+ +-------+ +-------+ +----+ +---+ +------+ 79 | | | | | | | | | | | | | | 80 +----+ +-----------+ +-------+ +-------+ +----+ +---+ +------+ 81 +--------> accelerator.compute(lambda) 82 83 ``` 84 85 Incidently, this lambda is never executed by Java JVM ;) instead, the accelerator uses Babylon's Code Reflection 86 capabilities to extract the model of this lambda to determine the compute entrypoint and it's captured args. 87 88 ``` 89 User | Accelerator | Compute | Babylon | Backend | 90 Context Java C++ Vendor 91 +----+ +-----------+ +-------+ +-------+ +----+ +---+ +------+ 92 | | | | | | | | | | | | | | 93 +----+ +-----------+ +-------+ +-------+ +----+ +---+ +------+ 94 +--------> accelerator.compute( cc -> SquareCompute.square(cc, s32Array) ) 95 -------------------------> 96 getModelOf(lambda) 97 <------------------------ 98 ``` 99 100 This model describes the call that we want the accelerator to 101 execute or interpret (`SquareCompute.square()`) and the args that were captured from the call site (the `s32Array` buffer). 102 103 The accelerator uses Babylon again to get the 104 code model of `SquareCompute.square()` builds a ComputeReachableGraph with this method at the root. 105 So the accelerator walks the code model and collects the methods (and code models) of all methods 106 reachable from the entrypoint. 107 108 In our trivial case, the ComputeReachableGraph has a single root node representing the `SquareCompute.square()`. 109 110 ``` 111 User | Accelerator | Compute | Babylon | Backend | 112 Context Java C++ Vendor 113 +----+ +-----------+ +-------+ +-------+ +----+ +---+ +------+ 114 | | | | | | | | | | | | | | 115 +----+ +-----------+ +-------+ +-------+ +----+ +---+ +------+ 116 +--------> accelerator.compute( cc -> SquareCompute.square(cc, s32Array) ) 117 -------------------------> 118 getModelOf(lambda) 119 <------------------------ 120 -------------------------> 121 getModelOf(SquareCompute.square()) 122 <------------------------- 123 forEachReachable method in SquareCompute.square() { 124 -------------------------> 125 getModelOf(method) 126 <------------------------ 127 add to ComputeReachableGraph 128 } 129 ``` 130 131 The Accelertor then walks through the ComputeReachableGraph to determine which kernels are referenced.. 132 133 For each kernel we extract the kernels entrypoint (again as a Babylon 134 Code Model) and create a KernelReachableGraph for each kernel. Again by starting 135 at the kernel entrypoint and closing over all reachable methods (and Code Models). 136 137 We combine the compute and kernel reachable graphs and create an place them in a `ComputeContext`. 138 139 This is the first arg that is 'seemingly' passed to the Compute class. Remember the compute 140 entrypoint is just a model of the code we expect to 141 execute. It may never be executed by the JVM. 142 143 ``` 144 User | Accelerator | Compute | Babylon | Backend | 145 Context Java C++ Vendor 146 +----+ +-----------+ +-------+ +-------+ +----+ +---+ +------+ 147 | | | | | | | | | | | | | | 148 +----+ +-----------+ +-------+ +-------+ +----+ +---+ +------+ 149 150 forEachReachable kernel in ComputeReachableGraph { 151 -------------------------> 152 getModelOf(kernel) 153 <------------------------ 154 add to KernelReachableGraph 155 } 156 ComputeContext = {ComputeReachableGraph + KernelReachableGraph} 157 158 ``` 159 160 The accelerator passes the ComputeContext to backend (`computeContextHandoff()`), which will typically take 161 the opportunity to inspect/mutate the compute and kernel models and possibly build backend specific representations of 162 kernels and compile them. 163 164 The ComputeContext and the captured args are then passed to the backend for execution. 165 166 ``` 167 User | Accelerator | Compute | Babylon | Backend | 168 Context Java C++ Vendor 169 +----+ +-----------+ +-------+ +-------+ +----+ +---+ +------+ 170 | | | | | | | | | | | | | | 171 +----+ +-----------+ +-------+ +-------+ +----+ +---+ +------+ 172 173 174 -----------------------------------> 175 computeContextHandoff(CLWrapComputeContext) 176 -------> 177 -------> 178 compileKernels() 179 <------ 180 mutateComputeModels 181 <------- 182 dispatchCompute(CLWrapComputeContext, args) 183 -------> 184 dispatchCompute(...) 185 ---------> 186 { 187 dispatchKernel() 188 ... 189 } 190 <-------- 191 <------ 192 <---------------------------------- 193 194 ``` 195 196 ---- 197 ### Notes 198 199 In reality. The Accelerator receives a `QuotableComputeContextConsumer` 200 201 ```java 202 public interface QuotableComputeContextConsumer 203 extends Quotable, 204 Consumer<ComputeContext> { 205 } 206 ``` 207 Here is how we extract the 'target' from such a lambda 208 209 ```java 210 public void compute(QuotableComputeContextConsumer qccc) { 211 Quoted quoted = Op.ofQuotable(qccc).orElseThrow(); 212 LambdaOpWrapper lambda = OpTools.wrap((CoreOps.LambdaOp)quoted.op()); 213 214 Method method = lambda.getQuotableComputeContextTargetMethod(); 215 216 // Get from the cache or create a compute context which closes over compute entryppint 217 // and reachable kernels. 218 // The models of all compute and kernel methods are passed to the backend during creation 219 // The backend may well mutate the models. 220 // It will also use this opportunity to generate ISA specific code for the kernels. 221 222 ComputeContext = this.cache.computeIfAbsent(method, (_) -> 223 new ComputeContext(this/*Accelerator*/, method) 224 ); 225 226 // Here we get the captured args from the Quotable and 'jam' in the CLWrapComputeContext in slot[0] 227 Object[] args = lambda.getQuotableComputeContextArgs(quoted, method, CLWrapComputeContext); 228 this.compute(CLWrapComputeContext, args); 229 } 230 ```