1 # What happens when we call accelerator.compute(lambda)
  2 
  3 ----
  4 * [Contents](hat-00.md)
  5 * Build Babylon and HAT
  6     * [Quick Install](hat-01-quick-install.md)
  7     * [Building Babylon with jtreg](hat-01-02-building-babylon.md)
  8     * [Building HAT with jtreg](hat-01-03-building-hat.md)
  9         * [Enabling the NVIDIA CUDA Backend](hat-01-05-building-hat-for-cuda.md)
 10 * [Testing Framework](hat-02-testing-framework.md)
 11 * [Running Examples](hat-03-examples.md)
 12 * [HAT Programming Model](hat-03-programming-model.md)
 13 * Interface Mapping
 14     * [Interface Mapping Overview](hat-04-01-interface-mapping.md)
 15     * [Cascade Interface Mapping](hat-04-02-cascade-interface-mapping.md)
 16 * Development
 17     * [Project Layout](hat-01-01-project-layout.md)
 18     * [IntelliJ Code Formatter](hat-development.md)
 19 * Implementation Details
 20     * [Walkthrough Of Accelerator.compute()](hat-accelerator-compute.md)
 21     * [How we minimize buffer transfers](hat-minimizing-buffer-transfers.md)
 22 * [Running HAT with Docker on NVIDIA GPUs](hat-07-docker-build-nvidia.md)
 23 ---
 24 
 25 # What happens when we call accelerator.compute(lambda)
 26 
 27 # Back to our Squares example.
 28 
 29 So what is going on here?
 30 
 31 ```java
 32   accelerator.compute(
 33      cc -> SquareCompute.square(cc, s32Array)
 34   );
 35 ```
 36 
 37 Recall we have two types of code in our compute class. We have kernels (and kernel reachable methods) and we have
 38 compute entrypoints (and compute reachable methods).
 39 
 40 ```java
 41 public class SquareCompute{
 42     @Reflect public static int square(int v) {
 43         return  v * v;
 44     }
 45 
 46     @Reflect public static void squareKernel(KernelContext kc, S32Array s32Array) {
 47         int value = s32Array.array(kc.x);     // arr[cc.x]
 48         s32Array.array(kc.x, square(value));  // arr[cc.x]=value*value
 49     }
 50 
 51     @Reflect public static void square(ComputeContext cc, S32Array s32Array) {
 52         cc.dispatchKernel(s32Array.length(),
 53                 kc -> squareKernel(kc, s32Array)
 54         );
 55     }
 56 }
 57 ```
 58 
 59 AGAIN.... NOTE that we cannot just call the compute entrypoint or the kernel directly.
 60 
 61 ```java
 62   SquareCompute.square(????, s32Array);  // We can't do this!!!!
 63 ```
 64 
 65 We purposely make it inconvenient (ComputeContext and KernelContext construction is embedded in the framwork) to
 66 mistakenly call the compute entrypoint directly.  Doing so is akin to calling `Thread.run()` directly, rather than
 67 calling `Thread.start()` on a class extending `Thread` and providing an implementation of `Thread.run()`
 68 
 69 Instead we use this pattern
 70 
 71 ```java
 72   accelerator.compute(
 73      cc -> SquareCompute.square(cc, s32Array)
 74   );
 75 ```
 76 
 77 We pass a lambda to `accelerator.compute()` which is used to determine which compute method to invoke.
 78 
 79 ```
 80  User  |  Accelerator  |  Compute  |  Babylon  |        Backend            |
 81                           Context                 Java     C++     Vendor
 82 +----+   +-----------+   +-------+   +-------+   +----+   +---+   +------+
 83 |    |   |           |   |       |   |       |   |    |   |   |   |      |
 84 +----+   +-----------+   +-------+   +-------+   +----+   +---+   +------+
 85     +--------> accelerator.compute(lambda)
 86 
 87 ```
 88 
 89 Incidently, this lambda is never executed by Java JVM ;) instead, the accelerator uses Babylon's Code Reflection
 90 capabilities to extract the model of this lambda to determine the compute entrypoint and it's captured args.
 91 
 92 ```
 93  User  |  Accelerator  |  Compute  |  Babylon  |        Backend            |
 94                           Context                 Java     C++     Vendor
 95 +----+   +-----------+   +-------+   +-------+   +----+   +---+   +------+
 96 |    |   |           |   |       |   |       |   |    |   |   |   |      |
 97 +----+   +-----------+   +-------+   +-------+   +----+   +---+   +------+
 98     +--------> accelerator.compute( cc -> SquareCompute.square(cc, s32Array) )
 99                 ------------------------->
100                     getModelOf(lambda)
101                 <------------------------
102 ```
103 
104 This model describes the call that we want the accelerator to
105 execute or interpret (`SquareCompute.square()`) and the args that were captured from the call site (the `s32Array` buffer).
106 
107 The accelerator uses Babylon again to get the
108 code model of `SquareCompute.square()` builds a ComputeReachableGraph with this method at the root.
109 So the accelerator walks the code model and collects the methods (and code models) of all methods
110 reachable from the entrypoint.
111 
112 In our trivial case, the ComputeReachableGraph has a single root node representing the `SquareCompute.square()`.
113 
114 ```
115  User  |  Accelerator  |  Compute  |  Babylon  |        Backend            |
116                           Context                 Java     C++     Vendor
117 +----+   +-----------+   +-------+   +-------+   +----+   +---+   +------+
118 |    |   |           |   |       |   |       |   |    |   |   |   |      |
119 +----+   +-----------+   +-------+   +-------+   +----+   +---+   +------+
120     +--------> accelerator.compute( cc -> SquareCompute.square(cc, s32Array) )
121                 ------------------------->
122                      getModelOf(lambda)
123                 <------------------------
124                 ------------------------->
125                      getModelOf(SquareCompute.square())
126                 <-------------------------
127           forEachReachable method in SquareCompute.square() {
128                 ------------------------->
129                      getModelOf(method)
130                 <------------------------
131                 add to ComputeReachableGraph
132           }
133 ```
134 
135 The Accelertor then walks through the ComputeReachableGraph to determine which kernels are referenced..
136 
137 For each kernel we extract the kernels entrypoint (again as a Babylon
138 Code Model) and create a KernelReachableGraph for each kernel.  Again by starting
139 at the kernel entrypoint and closing over all reachable methods (and Code Models).
140 
141 We combine the compute and kernel reachable graphs and create an place them in a  `ComputeContext`.
142 
143 This is the first arg that is 'seemingly' passed to the Compute class. Remember the compute
144 entrypoint is just a model of the code we expect to
145 execute. It may never be executed by the JVM.
146 
147 ```
148  User  |  Accelerator  |  Compute  |  Babylon  |        Backend            |
149                           Context                 Java     C++     Vendor
150 +----+   +-----------+   +-------+   +-------+   +----+   +---+   +------+
151 |    |   |           |   |       |   |       |   |    |   |   |   |      |
152 +----+   +-----------+   +-------+   +-------+   +----+   +---+   +------+
153 
154           forEachReachable kernel in ComputeReachableGraph {
155                 ------------------------->
156                       getModelOf(kernel)
157                 <------------------------
158                 add to KernelReachableGraph
159           }
160           ComputeContext = {ComputeReachableGraph + KernelReachableGraph}
161 
162 ```
163 
164 The accelerator passes the ComputeContext to backend (`computeContextHandoff()`), which will typically take
165 the opportunity to inspect/mutate the compute and kernel models and possibly build backend specific representations of
166 kernels and compile them.
167 
168 The ComputeContext and the captured args are then passed to the backend for execution.
169 
170 ```
171  User  |  Accelerator  |  Compute  |  Babylon  |        Backend            |
172                           Context                 Java     C++     Vendor
173 +----+   +-----------+   +-------+   +-------+   +----+   +---+   +------+
174 |    |   |           |   |       |   |       |   |    |   |   |   |      |
175 +----+   +-----------+   +-------+   +-------+   +----+   +---+   +------+
176 
177 
178                 ----------------------------------->
179                     computeContextHandoff(CLWrapComputeContext)
180                                                     ------->
181                                                              ------->
182                                                          compileKernels()
183                                                              <------
184                                                       mutateComputeModels
185                                                     <-------
186                     dispatchCompute(CLWrapComputeContext, args)
187                                                     ------->
188                                                         dispatchCompute(...)
189                                                             --------->
190                                                                {
191                                                                dispatchKernel()
192                                                                ...
193                                                                }
194                                                             <--------
195                                                     <------
196                 <----------------------------------
197 
198 ```
199 
200 ----
201 ### Notes
202 
203 In reality. The Accelerator receives a `Compute`
204 
205 ```java
206     public interface Compute extends Consumer<ComputeContext> {
207     }
208 ```
209 Here is how we extract the 'target' from such a lambda
210 
211 ```java
212     public void compute(Compute compute) {
213     Quoted<JavaOp.LambdaOp> quoted = Op.ofLambda(compute).orElseThrow();
214     JavaOp.LambdaOp lambda = quoted.op();
215     Method method = getTargetInvoke(this.lookup,lambda, ComputeContext.class).resolveMethodOrThrow();
216     // Create (or get cached) a compute context which closes over compute entrypoint and reachable kernels.
217     // The models of all compute and kernel methods are passed to the backend during creation
218     // The backend may well mutate the models.
219     // It will also use this opportunity to generate ISA specific code for the kernels.
220     ComputeContext computeContext = cache.computeIfAbsent(method, (_) -> new ComputeContext(this, method));
221     // Here we get the captured values from the lambda
222     Object[] args = lambda(lookup,lambda).getQuotedCapturedValues( quoted, method);
223     args[0] = computeContext;
224     // now ask the backend to execute
225     backend.dispatchCompute(computeContext, args);
226 }
227 ```