1 /*
2 * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation. Oracle designates this
8 * particular file as subject to the "Classpath" exception as provided
9 * by Oracle in the LICENSE file that accompanied this code.
10 *
11 * This code is distributed in the hope that it will be useful, but WITHOUT
12 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
13 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
14 * version 2 for more details (a copy is included in the LICENSE file that
15 * accompanied this code).
16 *
17 * You should have received a copy of the GNU General Public License version
18 * 2 along with this work; if not, write to the Free Software Foundation,
19 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
20 *
21 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
22 * or visit www.oracle.com if you need additional information or have any
23 * questions.
24 */
25 package hat;
26
27 import hat.buffer.Buffer;
28 import hat.buffer.BufferAllocator;
29 import hat.buffer.BufferTracker;
30 import hat.callgraph.ComputeCallGraph;
31 import hat.callgraph.KernelCallGraph;
32 import hat.ifacemapper.BoundSchema;
33 import hat.ifacemapper.SegmentMapper;
34 import hat.optools.OpTk;
35 import jdk.incubator.code.Op;
36 import jdk.incubator.code.Quotable;
37 import jdk.incubator.code.Quoted;
38 import jdk.incubator.code.dialect.java.JavaOp;
39 import jdk.incubator.code.dialect.java.MethodRef;
40
41 import java.lang.reflect.Method;
42 import java.util.function.Consumer;
43
44 /**
45 * A ComputeContext is created by an Accelerator to capture and control compute and kernel
46 * callgraphs for the work to be performed by the backend.
47 * <p/>
48 * The Compute closure is created first, by walking the code model of the entrypoint, then transitively
49 * visiting all conventional code reachable from this entrypoint.
50 * <p/>
51 * Generally all user defined methods reachable from the entrypoint (and the entrypoint intself) must be static methods of the same
52 * enclosing classes.
53 * <p/>
54 * We do allow calls on the ComputeContext itself, and on the mapped interface buffers holding non uniform kernel data.
55 * <p/>
56 * Each request to dispatch a kernel discovered in the compute graph, results in a new Kernel call graph
57 * being created with the dispatched kernel as it's entrypoint.
58 * <p/>
59 * When the ComputeContext is finalized, it is passed to the backend via <a href="Backend.computeClosureHandoff(ComputeContext)"></a>
60 *
61 * @author Gary Frost
62 */
63 public class ComputeContext implements BufferAllocator, BufferTracker {
64
65
66 public enum WRAPPER {
67 MUTATE("Mutate"), ACCESS("Access");//, ESCAPE("Escape");
68 final public MethodRef pre;
69 final public MethodRef post;
70
71 WRAPPER(String name) {
72 this.pre = MethodRef.method(ComputeContext.class, "pre" + name, void.class, Buffer.class);
73 this.post = MethodRef.method(ComputeContext.class, "post" + name, void.class, Buffer.class);
74 }
75 }
76
77 public final Accelerator accelerator;
78
79
80 public final ComputeCallGraph computeCallGraph;
81
82 /**
83 * Called by the Accelerator when the accelerator is passed a compute entrypoint.
84 * <p>
85 * So given a ComputeClass such as..
86 * <pre>
87 * public class MyComputeClass {
88 * @ CodeReflection
89 * public static void addDeltaKernel(KernelContext kc, S32Array arrayOfInt, int delta) {
90 * arrayOfInt.array(kc.x, arrayOfInt.array(kc.x)+delta);
91 * }
92 *
93 * @ CodeReflection
94 * static public void doSomeWork(final ComputeContext cc, S32Array arrayOfInt) {
95 * cc.dispatchKernel(KernelContext kc -> addDeltaKernel(kc,arrayOfInt.length(), 5, arrayOfInt);
96 * }
97 * }
98 * </pre>
99 *
100 * @param accelerator
101 * @param computeMethod
102 */
103
104 protected ComputeContext(Accelerator accelerator, Method computeMethod) {
105 this.accelerator = accelerator;
106 this.computeCallGraph = new ComputeCallGraph(this, computeMethod, Op.ofMethod(computeMethod).orElseThrow());
107 this.accelerator.backend.computeContextHandoff(this);
108 }
109
110 public void dispatchKernel(NDRange ndRange, QuotableKernelContextConsumer quotableKernelContextConsumer) {
111 dispatchKernelWithComputeRange(ndRange, quotableKernelContextConsumer);
112 }
113
114 record CallGraph(Quoted quoted, JavaOp.LambdaOp lambdaOp, MethodRef methodRef, KernelCallGraph kernelCallGraph) {}
115
116 private CallGraph getKernelCallGraph(QuotableKernelContextConsumer quotableKernelContextConsumer) {
117 Quoted quoted = Op.ofQuotable(quotableKernelContextConsumer).orElseThrow();
118 JavaOp.LambdaOp lambdaOp = (JavaOp.LambdaOp) quoted.op();
119 MethodRef methodRef = OpTk.getQuotableTargetInvokeOpWrapper( lambdaOp).invokeDescriptor();
120 KernelCallGraph kernelCallGraph = computeCallGraph.kernelCallGraphMap.get(methodRef);
121 if (kernelCallGraph == null){
122 throw new RuntimeException("Failed to create KernelCallGraph (did you miss @CodeReflection annotation?) ");
123 }
124 return new CallGraph(quoted, lambdaOp, methodRef, kernelCallGraph);
125 }
126
127 private void dispatchKernelWithComputeRange(NDRange ndRange, QuotableKernelContextConsumer quotableKernelContextConsumer) {
128 CallGraph cg = getKernelCallGraph(quotableKernelContextConsumer);
129 try {
130 Object[] args = OpTk.getQuotableCapturedValues(cg.lambdaOp,cg.quoted, cg.kernelCallGraph.entrypoint.method);
131 KernelContext kernelContext = accelerator.range(ndRange);
132 args[0] = kernelContext;
133 accelerator.backend.dispatchKernel(cg.kernelCallGraph, kernelContext, args);
134 } catch (Throwable t) {
135 System.out.print("what?" + cg.methodRef + " " + t);
136 t.printStackTrace();
137 throw t;
138 }
139 }
140
141 @Override
142 public void preMutate(Buffer b) {
143 if (accelerator.backend instanceof BufferTracker bufferTracker) {
144 bufferTracker.preMutate(b);
145 }
146 }
147
148 @Override
149 public void postMutate(Buffer b) {
150 if (accelerator.backend instanceof BufferTracker bufferTracker) {
151 bufferTracker.postMutate(b);
152 }
153
154 }
155
156 @Override
157 public void preAccess(Buffer b) {
158 if (accelerator.backend instanceof BufferTracker bufferTracker) {
159 bufferTracker.preAccess(b);
160 }
161
162 }
163
164 @Override
165 public void postAccess(Buffer b) {
166 if (accelerator.backend instanceof BufferTracker bufferTracker) {
167 bufferTracker.postAccess(b);
168 }
169 }
170
171 @Override
172 public <T extends Buffer> T allocate(SegmentMapper<T> segmentMapper, BoundSchema<T> boundSchema) {
173 return accelerator.allocate(segmentMapper, boundSchema);
174 }
175
176 public interface QuotableKernelContextConsumer extends Quotable, Consumer<KernelContext> { }
177
178 }