1 /*
  2  * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 package hat;
 26 
 27 import hat.buffer.Buffer;
 28 import hat.buffer.BufferAllocator;
 29 import hat.buffer.BufferTracker;
 30 import hat.callgraph.ComputeCallGraph;
 31 import hat.callgraph.KernelCallGraph;
 32 import hat.ifacemapper.BoundSchema;
 33 import hat.ifacemapper.SegmentMapper;
 34 import hat.optools.OpTk;
 35 import jdk.incubator.code.Op;
 36 import jdk.incubator.code.Quotable;
 37 import jdk.incubator.code.Quoted;
 38 import jdk.incubator.code.dialect.java.JavaOp;
 39 import jdk.incubator.code.dialect.java.MethodRef;
 40 
 41 import java.lang.reflect.Method;
 42 import java.util.function.Consumer;
 43 
 44 /**
 45  * A ComputeContext is created by an Accelerator to capture and control compute and kernel
 46  * callgraphs for the work to be performed by the backend.
 47  * <p/>
 48  * The Compute closure is created first, by walking the code model of the entrypoint, then transitively
 49  * visiting all conventional code reachable from this entrypoint.
 50  * <p/>
 51  * Generally all user defined methods reachable from the entrypoint (and the entrypoint intself) must be static methods of the same
 52  * enclosing classes.
 53  * <p/>
 54  * We do allow calls on the ComputeContext itself, and on the mapped interface buffers holding non uniform kernel data.
 55  * <p/>
 56  * Each request to dispatch a kernel discovered in the compute graph, results in a new Kernel call graph
 57  * being created with the dispatched kernel as it's entrypoint.
 58  * <p/>
 59  * When the ComputeContext is finalized, it is passed to the backend via <a href="Backend.computeClosureHandoff(ComputeContext)"></a>
 60  *
 61  * @author Gary Frost
 62  */
 63 public class ComputeContext implements BufferAllocator, BufferTracker {
 64 
 65 
 66     public enum WRAPPER {
 67         MUTATE("Mutate"), ACCESS("Access");//, ESCAPE("Escape");
 68         final public MethodRef pre;
 69         final public MethodRef post;
 70 
 71         WRAPPER(String name) {
 72             this.pre = MethodRef.method(ComputeContext.class, "pre" + name, void.class, Buffer.class);
 73             this.post = MethodRef.method(ComputeContext.class, "post" + name, void.class, Buffer.class);
 74         }
 75     }
 76 
 77     public final Accelerator accelerator;
 78 
 79 
 80     public final ComputeCallGraph computeCallGraph;
 81 
 82     /**
 83      * Called by the Accelerator when the accelerator is passed a compute entrypoint.
 84      * <p>
 85      * So given a ComputeClass such as..
 86      * <pre>
 87      *  public class MyComputeClass {
 88      *    @ CodeReflection
 89      *    public static void addDeltaKernel(KernelContext kc, S32Array arrayOfInt, int delta) {
 90      *        arrayOfInt.array(kc.x, arrayOfInt.array(kc.x)+delta);
 91      *    }
 92      *
 93      *    @ CodeReflection
 94      *    static public void doSomeWork(final ComputeContext cc, S32Array arrayOfInt) {
 95      *        cc.dispatchKernel(KernelContext kc -> addDeltaKernel(kc,arrayOfInt.length(), 5, arrayOfInt);
 96      *    }
 97      *  }
 98      *  </pre>
 99      *
100      * @param accelerator
101      * @param computeMethod
102      */
103 
104     protected ComputeContext(Accelerator accelerator, Method computeMethod) {
105         this.accelerator = accelerator;
106         this.computeCallGraph = new ComputeCallGraph(this, computeMethod, Op.ofMethod(computeMethod).orElseThrow());
107         this.accelerator.backend.computeContextHandoff(this);
108     }
109 
110     public void dispatchKernel(NDRange ndRange, QuotableKernelContextConsumer quotableKernelContextConsumer) {
111         dispatchKernelWithComputeRange(ndRange, quotableKernelContextConsumer);
112     }
113 
114     record CallGraph(Quoted quoted, JavaOp.LambdaOp lambdaOp, MethodRef methodRef, KernelCallGraph kernelCallGraph) {}
115 
116     private CallGraph getKernelCallGraph(QuotableKernelContextConsumer quotableKernelContextConsumer) {
117         Quoted quoted = Op.ofQuotable(quotableKernelContextConsumer).orElseThrow();
118         JavaOp.LambdaOp lambdaOp = (JavaOp.LambdaOp) quoted.op();
119         MethodRef methodRef = OpTk.getQuotableTargetInvokeOpWrapper( lambdaOp).invokeDescriptor();
120         KernelCallGraph kernelCallGraph = computeCallGraph.kernelCallGraphMap.get(methodRef);
121         if (kernelCallGraph == null){
122             throw new RuntimeException("Failed to create KernelCallGraph (did you miss @CodeReflection annotation?) ");
123         }
124         return new CallGraph(quoted, lambdaOp, methodRef, kernelCallGraph);
125     }
126 
127     private void dispatchKernelWithComputeRange(NDRange ndRange, QuotableKernelContextConsumer quotableKernelContextConsumer) {
128         CallGraph cg = getKernelCallGraph(quotableKernelContextConsumer);
129         try {
130             Object[] args = OpTk.getQuotableCapturedValues(cg.lambdaOp,cg.quoted, cg.kernelCallGraph.entrypoint.method);
131             KernelContext kernelContext = accelerator.range(ndRange);
132             args[0] = kernelContext;
133             accelerator.backend.dispatchKernel(cg.kernelCallGraph, kernelContext, args);
134         } catch (Throwable t) {
135             System.out.print("what?" + cg.methodRef + " " + t);
136             t.printStackTrace();
137             throw t;
138         }
139     }
140 
141     @Override
142     public void preMutate(Buffer b) {
143         if (accelerator.backend instanceof BufferTracker bufferTracker) {
144             bufferTracker.preMutate(b);
145         }
146     }
147 
148     @Override
149     public void postMutate(Buffer b) {
150         if (accelerator.backend instanceof BufferTracker bufferTracker) {
151             bufferTracker.postMutate(b);
152         }
153 
154     }
155 
156     @Override
157     public void preAccess(Buffer b) {
158         if (accelerator.backend instanceof BufferTracker bufferTracker) {
159             bufferTracker.preAccess(b);
160         }
161 
162     }
163 
164     @Override
165     public void postAccess(Buffer b) {
166         if (accelerator.backend instanceof BufferTracker bufferTracker) {
167             bufferTracker.postAccess(b);
168         }
169     }
170 
171     @Override
172     public <T extends Buffer> T allocate(SegmentMapper<T> segmentMapper, BoundSchema<T> boundSchema) {
173         return accelerator.allocate(segmentMapper, boundSchema);
174     }
175 
176     public interface QuotableKernelContextConsumer extends Quotable, Consumer<KernelContext> { }
177 
178 }