1 /*
  2  * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 package hat;
 26 
 27 import optkl.util.carriers.ArenaAndLookupCarrier;
 28 import optkl.ifacemapper.BufferTracker;
 29 import hat.callgraph.ComputeCallGraph;
 30 import hat.callgraph.KernelCallGraph;
 31 import optkl.ifacemapper.MappableIface;
 32 import jdk.incubator.code.dialect.core.CoreOp.FuncOp;
 33 import jdk.incubator.code.Reflect;
 34 import jdk.incubator.code.Op;
 35 import jdk.incubator.code.Quoted;
 36 import jdk.incubator.code.dialect.java.JavaOp;
 37 import jdk.incubator.code.dialect.java.MethodRef;
 38 
 39 import java.lang.foreign.Arena;
 40 import java.lang.invoke.MethodHandles;
 41 import java.lang.reflect.Method;
 42 import java.util.HashMap;
 43 import java.util.Map;
 44 import java.util.function.Consumer;
 45 import java.util.Optional;
 46 
 47 import static optkl.OpHelper.Invoke.getTargetInvoke;
 48 import static optkl.OpHelper.Lambda.lambda;
 49 
 50 /**
 51  * A ComputeContext is created by an Accelerator to capture and control compute and kernel
 52  * callgraphs for the work to be performed by the backend.
 53  * <p/>
 54  * The Compute closure is created first, by walking the code model of the entrypoint, then transitively
 55  * visiting all conventional code reachable from this entrypoint.
 56  * <p/>
 57  * Generally all user defined methods reachable from the entrypoint (and the entrypoint itself) must be static methods of the same
 58  * enclosing classes.
 59  * <p/>
 60  * We do allow calls on the ComputeContext itself, and on the mapped interface buffers holding non uniform kernel data.
 61  * <p/>
 62  * Each request to dispatch a kernel discovered in the compute graph, results in a new Kernel call graph
 63  * being created with the dispatched kernel as it's entrypoint.
 64  * <p/>
 65  * When the ComputeContext is finalized, it is passed to the backend via <a href="Backend.computeClosureHandoff(ComputeContext)"></a>
 66  *
 67  * @author Gary Frost
 68  */
 69 public class ComputeContext implements ArenaAndLookupCarrier, BufferTracker {
 70 
 71 
 72     @Override
 73     public Arena arena() {
 74         return accelerator.arena();
 75     }
 76 
 77     @Override
 78     public MethodHandles.Lookup lookup() {
 79         return accelerator.lookup();
 80     }
 81 
 82 
 83     public Config config() {
 84         return accelerator().config();
 85     }
 86 
 87     public void invokeWithArgs(Object[] args) {
 88         computeCallGraph.invokeWithArgs(args);
 89 
 90     }
 91 
 92     public enum WRAPPER {
 93         MUTATE("Mutate"), ACCESS("Access");
 94         final public MethodRef pre;
 95         final public MethodRef post;
 96 
 97         WRAPPER(String name) {
 98             this.pre = MethodRef.method(ComputeContext.class, "pre" + name, void.class, MappableIface.class);
 99             this.post = MethodRef.method(ComputeContext.class, "post" + name, void.class, MappableIface.class);
100         }
101     }
102 
103     private  final Accelerator accelerator;
104     final  public  Accelerator accelerator(){
105         return accelerator;
106     }
107 
108     private  final ComputeCallGraph computeCallGraph;
109     final  public  ComputeCallGraph computeCallGraph(){
110         return computeCallGraph;
111     }
112 
113 
114 
115     /**
116      * Called by the Accelerator when the accelerator is passed a compute entrypoint.
117      * <p>
118      * So given a ComputeClass such as..
119      * <pre>
120      *  public class MyComputeClass {
121      *    @ Reflect
122      *    public static void addDeltaKernel(KernelContext kc, S32Array arrayOfInt, int delta) {
123      *        arrayOfInt.array(kc.x, arrayOfInt.array(kc.x)+delta);
124      *    }
125      *
126      *    @ Reflect
127      *    static public void doSomeWork(final ComputeContext cc, S32Array arrayOfInt) {
128      *        cc.dispatchKernel(KernelContext kc -> addDeltaKernel(kc,arrayOfInt.length(), 5, arrayOfInt);
129      *    }
130      *  }
131      *  </pre>
132      *
133      * @param accelerator
134      * @param computeMethod
135      */
136 
137     protected ComputeContext(Accelerator accelerator, Method computeMethod) {
138         this.accelerator = accelerator;
139         Optional<FuncOp> funcOp =  Op.ofMethod(computeMethod);
140         if (funcOp.isEmpty()) {
141             throw new RuntimeException("Failed to create ComputeCallGraph (did you miss @Reflect annotation?).");
142         }
143         this.computeCallGraph = new ComputeCallGraph(this, computeMethod, funcOp.get());
144         this.accelerator.backend.computeContextHandoff(this);
145     }
146     record KernelCallSite(Quoted<JavaOp.LambdaOp> quoted, JavaOp.LambdaOp lambdaOp, MethodRef methodRef, KernelCallGraph kernelCallGraph) {}
147 
148     private final Map<Op.Location, KernelCallSite> kernelCallSiteCache = new HashMap<>();
149 
150     /** Creating the kernel callsite involves
151          walking the code model of the lambda
152          analysing the callgraph and trsnsforming to HATDielect
153      So we cache the callsite against the location from the lambdaop.
154      */
155     public void dispatchKernel(NDRange ndRange, Kernel kernel) {
156         Quoted<JavaOp.LambdaOp> quoted = Op.ofLambda(kernel).orElseThrow();
157 
158         var location = quoted.op().location();
159 
160         KernelCallSite kernelCallSite;
161         if (kernelCallSiteCache.containsKey(location)) {
162             var oldKernelCallSite = kernelCallSiteCache.get(location);
163             kernelCallSite = new KernelCallSite(quoted, oldKernelCallSite.lambdaOp(), oldKernelCallSite.methodRef(), oldKernelCallSite.kernelCallGraph());
164         } else {
165             kernelCallSite = kernelCallSiteCache.compute(location, (_, _)-> {
166                 JavaOp.LambdaOp lambdaOp = quoted.op();
167                 MethodRef methodRef = getTargetInvoke(this.lookup(), lambdaOp, KernelContext.class).op().invokeReference();
168                 KernelCallGraph kernelCallGraph = computeCallGraph.kernelCallGraphMap.get(methodRef);
169                 if (kernelCallGraph == null) {
170                     throw new RuntimeException("Failed to create KernelCallGraph (did you miss @Reflect annotation?).");
171                 }
172                 return new KernelCallSite(quoted, lambdaOp, methodRef, kernelCallGraph);
173             });
174         }
175         Object[] args = lambda(lookup(),kernelCallSite.lambdaOp).getQuotedCapturedValues(kernelCallSite.quoted, kernelCallSite.kernelCallGraph.callDag.entryPoint.method());
176         KernelContext kernelContext = accelerator.range(ndRange);
177         args[0] = kernelContext;
178         accelerator.backend.dispatchKernel(kernelCallSite.kernelCallGraph, kernelContext, args);
179     }
180 
181 
182     @Override
183     public void preMutate(MappableIface b) {
184         if (accelerator.backend instanceof BufferTracker bufferTracker) {
185             bufferTracker.preMutate(b);
186         }
187     }
188 
189     @Override
190     public void postMutate(MappableIface b) {
191         if (accelerator.backend instanceof BufferTracker bufferTracker) {
192             bufferTracker.postMutate(b);
193         }
194 
195     }
196 
197     @Override
198     public void preAccess(MappableIface b) {
199         if (accelerator.backend instanceof BufferTracker bufferTracker) {
200             bufferTracker.preAccess(b);
201         }
202 
203     }
204 
205     @Override
206     public void postAccess(MappableIface b) {
207         if (accelerator.backend instanceof BufferTracker bufferTracker) {
208             bufferTracker.postAccess(b);
209         }
210     }
211 
212     @Reflect
213     @FunctionalInterface
214     public interface Kernel extends Consumer<KernelContext> { }
215 
216 }