1 /*
  2  * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 package hat;
 26 
 27 import hat.callgraph.ComputeEntrypoint;
 28 import jdk.incubator.code.bytecode.BytecodeGenerator;
 29 import jdk.incubator.code.interpreter.Interpreter;
 30 import optkl.util.carriers.ArenaAndLookupCarrier;
 31 import optkl.util.carriers.ArenaCarrier;
 32 import optkl.util.carriers.LookupCarrier;
 33 import optkl.ifacemapper.BufferTracker;
 34 import hat.callgraph.ComputeCallGraph;
 35 import hat.callgraph.KernelCallGraph;
 36 import optkl.ifacemapper.MappableIface;
 37 import jdk.incubator.code.dialect.core.CoreOp.FuncOp;
 38 import jdk.incubator.code.Reflect;
 39 import jdk.incubator.code.Op;
 40 import jdk.incubator.code.Quoted;
 41 import jdk.incubator.code.dialect.java.JavaOp;
 42 import jdk.incubator.code.dialect.java.MethodRef;
 43 
 44 import java.lang.foreign.Arena;
 45 import java.lang.invoke.MethodHandles;
 46 import java.lang.reflect.Method;
 47 import java.util.HashMap;
 48 import java.util.Map;
 49 import java.util.function.Consumer;
 50 import java.util.Optional;
 51 
 52 import static optkl.OpHelper.Invoke.getTargetInvoke;
 53 import static optkl.OpHelper.Lambda.lambda;
 54 
 55 /**
 56  * A ComputeContext is created by an Accelerator to capture and control compute and kernel
 57  * callgraphs for the work to be performed by the backend.
 58  * <p/>
 59  * The Compute closure is created first, by walking the code model of the entrypoint, then transitively
 60  * visiting all conventional code reachable from this entrypoint.
 61  * <p/>
 62  * Generally all user defined methods reachable from the entrypoint (and the entrypoint itself) must be static methods of the same
 63  * enclosing classes.
 64  * <p/>
 65  * We do allow calls on the ComputeContext itself, and on the mapped interface buffers holding non uniform kernel data.
 66  * <p/>
 67  * Each request to dispatch a kernel discovered in the compute graph, results in a new Kernel call graph
 68  * being created with the dispatched kernel as it's entrypoint.
 69  * <p/>
 70  * When the ComputeContext is finalized, it is passed to the backend via <a href="Backend.computeClosureHandoff(ComputeContext)"></a>
 71  *
 72  * @author Gary Frost
 73  */
 74 public class ComputeContext implements ArenaAndLookupCarrier, BufferTracker {
 75 
 76 
 77     @Override
 78     public Arena arena() {
 79         return accelerator.arena();
 80     }
 81 
 82     @Override
 83     public MethodHandles.Lookup lookup() {
 84         return accelerator.lookup();
 85     }
 86 
 87     public ComputeEntrypoint computeEntrypoint() {
 88         return computeCallGraph.entrypoint;
 89     }
 90 
 91     public Config config() {
 92         return accelerator().config();
 93     }
 94 
 95     public void invokeWithArgs(Object[] args) {
 96         computeEntrypoint().invokeWithArgs(args);
 97 
 98     }
 99 
100     public void interpretWithArgs(Object[] args) {
101         computeEntrypoint().interpretWithArgs( args);
102     }
103 
104     public enum WRAPPER {
105         MUTATE("Mutate"), ACCESS("Access");
106         final public MethodRef pre;
107         final public MethodRef post;
108 
109         WRAPPER(String name) {
110             this.pre = MethodRef.method(ComputeContext.class, "pre" + name, void.class, MappableIface.class);
111             this.post = MethodRef.method(ComputeContext.class, "post" + name, void.class, MappableIface.class);
112         }
113     }
114 
115     private  final Accelerator accelerator;
116     final  public  Accelerator accelerator(){
117         return accelerator;
118     }
119 
120     private  final ComputeCallGraph computeCallGraph;
121     final  public  ComputeCallGraph computeCallGraph(){
122         return computeCallGraph;
123     }
124 
125 
126 
127     /**
128      * Called by the Accelerator when the accelerator is passed a compute entrypoint.
129      * <p>
130      * So given a ComputeClass such as..
131      * <pre>
132      *  public class MyComputeClass {
133      *    @ Reflect
134      *    public static void addDeltaKernel(KernelContext kc, S32Array arrayOfInt, int delta) {
135      *        arrayOfInt.array(kc.x, arrayOfInt.array(kc.x)+delta);
136      *    }
137      *
138      *    @ Reflect
139      *    static public void doSomeWork(final ComputeContext cc, S32Array arrayOfInt) {
140      *        cc.dispatchKernel(KernelContext kc -> addDeltaKernel(kc,arrayOfInt.length(), 5, arrayOfInt);
141      *    }
142      *  }
143      *  </pre>
144      *
145      * @param accelerator
146      * @param computeMethod
147      */
148 
149     protected ComputeContext(Accelerator accelerator, Method computeMethod) {
150         this.accelerator = accelerator;
151         Optional<FuncOp> funcOp =  Op.ofMethod(computeMethod);
152         if (funcOp.isEmpty()) {
153             throw new RuntimeException("Failed to create ComputeCallGraph (did you miss @Reflect annotation?).");
154         }
155         this.computeCallGraph = new ComputeCallGraph(this, computeMethod, funcOp.get());
156         this.accelerator.backend.computeContextHandoff(this);
157     }
158     record KernelCallSite(Quoted<JavaOp.LambdaOp> quoted, JavaOp.LambdaOp lambdaOp, MethodRef methodRef, KernelCallGraph kernelCallGraph) {}
159 
160     private Map<Op.Location, KernelCallSite> kernelCallSiteCache = new HashMap<>();
161 
162     /** Creating the kernel callsite involves
163          walking the code model of the lambda
164          analysing the callgraph and trsnsforming to HATDielect
165      So we cache the callsite against the location from the lambdaop.
166      */
167     public void dispatchKernel(NDRange<?, ?> ndRange, Kernel kernel) {
168         Quoted<JavaOp.LambdaOp> quoted = Op.ofLambda(kernel).orElseThrow();
169 
170         var location = quoted.op().location();
171 
172         KernelCallSite kernelCallSite;
173         if (kernelCallSiteCache.containsKey(location)) {
174             var oldKernelCallSite = kernelCallSiteCache.get(location);
175             kernelCallSite = new KernelCallSite(quoted, oldKernelCallSite.lambdaOp(), oldKernelCallSite.methodRef(), oldKernelCallSite.kernelCallGraph());
176         } else {
177             kernelCallSite = kernelCallSiteCache.compute(location, (_, _)-> {
178                 JavaOp.LambdaOp lambdaOp = quoted.op();
179                 MethodRef methodRef = getTargetInvoke(this.lookup(), lambdaOp, KernelContext.class).op().invokeDescriptor();
180                 KernelCallGraph kernelCallGraph = computeCallGraph.kernelCallGraphMap.get(methodRef);
181                 if (kernelCallGraph == null) {
182                     throw new RuntimeException("Failed to create KernelCallGraph (did you miss @Reflect annotation?).");
183                 }
184                 return new KernelCallSite(quoted, lambdaOp, methodRef, kernelCallGraph);
185             });
186         }
187         Object[] args = lambda(lookup(),kernelCallSite.lambdaOp).getQuotedCapturedValues(kernelCallSite.quoted, kernelCallSite.kernelCallGraph.entrypoint.method());
188         KernelContext kernelContext = accelerator.range(ndRange);
189         args[0] = kernelContext;
190         accelerator.backend.dispatchKernel(kernelCallSite.kernelCallGraph, kernelContext, args);
191     }
192 
193 
194     @Override
195     public void preMutate(MappableIface b) {
196         if (accelerator.backend instanceof BufferTracker bufferTracker) {
197             bufferTracker.preMutate(b);
198         }
199     }
200 
201     @Override
202     public void postMutate(MappableIface b) {
203         if (accelerator.backend instanceof BufferTracker bufferTracker) {
204             bufferTracker.postMutate(b);
205         }
206 
207     }
208 
209     @Override
210     public void preAccess(MappableIface b) {
211         if (accelerator.backend instanceof BufferTracker bufferTracker) {
212             bufferTracker.preAccess(b);
213         }
214 
215     }
216 
217     @Override
218     public void postAccess(MappableIface b) {
219         if (accelerator.backend instanceof BufferTracker bufferTracker) {
220             bufferTracker.postAccess(b);
221         }
222     }
223 
224     @Reflect
225     @FunctionalInterface
226     public interface Kernel extends Consumer<KernelContext> { }
227 
228 }