1 /*
2 * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation. Oracle designates this
8 * particular file as subject to the "Classpath" exception as provided
9 * by Oracle in the LICENSE file that accompanied this code.
10 *
11 * This code is distributed in the hope that it will be useful, but WITHOUT
12 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
13 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
14 * version 2 for more details (a copy is included in the LICENSE file that
15 * accompanied this code).
16 *
17 * You should have received a copy of the GNU General Public License version
18 * 2 along with this work; if not, write to the Free Software Foundation,
19 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
20 *
21 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
22 * or visit www.oracle.com if you need additional information or have any
23 * questions.
24 */
25 package hat;
26
27 import hat.callgraph.ComputeEntrypoint;
28 import jdk.incubator.code.bytecode.BytecodeGenerator;
29 import jdk.incubator.code.interpreter.Interpreter;
30 import optkl.util.carriers.ArenaAndLookupCarrier;
31 import optkl.util.carriers.ArenaCarrier;
32 import optkl.util.carriers.LookupCarrier;
33 import optkl.ifacemapper.BufferTracker;
34 import hat.callgraph.ComputeCallGraph;
35 import hat.callgraph.KernelCallGraph;
36 import optkl.ifacemapper.MappableIface;
37 import jdk.incubator.code.dialect.core.CoreOp.FuncOp;
38 import jdk.incubator.code.Reflect;
39 import jdk.incubator.code.Op;
40 import jdk.incubator.code.Quoted;
41 import jdk.incubator.code.dialect.java.JavaOp;
42 import jdk.incubator.code.dialect.java.MethodRef;
43
44 import java.lang.foreign.Arena;
45 import java.lang.invoke.MethodHandles;
46 import java.lang.reflect.Method;
47 import java.util.HashMap;
48 import java.util.Map;
49 import java.util.function.Consumer;
50 import java.util.Optional;
51
52 import static optkl.OpHelper.Invoke.getTargetInvoke;
53 import static optkl.OpHelper.Lambda.lambda;
54
55 /**
56 * A ComputeContext is created by an Accelerator to capture and control compute and kernel
57 * callgraphs for the work to be performed by the backend.
58 * <p/>
59 * The Compute closure is created first, by walking the code model of the entrypoint, then transitively
60 * visiting all conventional code reachable from this entrypoint.
61 * <p/>
62 * Generally all user defined methods reachable from the entrypoint (and the entrypoint itself) must be static methods of the same
63 * enclosing classes.
64 * <p/>
65 * We do allow calls on the ComputeContext itself, and on the mapped interface buffers holding non uniform kernel data.
66 * <p/>
67 * Each request to dispatch a kernel discovered in the compute graph, results in a new Kernel call graph
68 * being created with the dispatched kernel as it's entrypoint.
69 * <p/>
70 * When the ComputeContext is finalized, it is passed to the backend via <a href="Backend.computeClosureHandoff(ComputeContext)"></a>
71 *
72 * @author Gary Frost
73 */
74 public class ComputeContext implements ArenaAndLookupCarrier, BufferTracker {
75
76
77 @Override
78 public Arena arena() {
79 return accelerator.arena();
80 }
81
82 @Override
83 public MethodHandles.Lookup lookup() {
84 return accelerator.lookup();
85 }
86
87 public ComputeEntrypoint computeEntrypoint() {
88 return computeCallGraph.entrypoint;
89 }
90
91 public Config config() {
92 return accelerator().config();
93 }
94
95 public void invokeWithArgs(Object[] args) {
96 computeEntrypoint().invokeWithArgs(args);
97
98 }
99
100 public void interpretWithArgs(Object[] args) {
101 computeEntrypoint().interpretWithArgs( args);
102 }
103
104 public enum WRAPPER {
105 MUTATE("Mutate"), ACCESS("Access");
106 final public MethodRef pre;
107 final public MethodRef post;
108
109 WRAPPER(String name) {
110 this.pre = MethodRef.method(ComputeContext.class, "pre" + name, void.class, MappableIface.class);
111 this.post = MethodRef.method(ComputeContext.class, "post" + name, void.class, MappableIface.class);
112 }
113 }
114
115 private final Accelerator accelerator;
116 final public Accelerator accelerator(){
117 return accelerator;
118 }
119
120 private final ComputeCallGraph computeCallGraph;
121 final public ComputeCallGraph computeCallGraph(){
122 return computeCallGraph;
123 }
124
125
126
127 /**
128 * Called by the Accelerator when the accelerator is passed a compute entrypoint.
129 * <p>
130 * So given a ComputeClass such as..
131 * <pre>
132 * public class MyComputeClass {
133 * @ Reflect
134 * public static void addDeltaKernel(KernelContext kc, S32Array arrayOfInt, int delta) {
135 * arrayOfInt.array(kc.x, arrayOfInt.array(kc.x)+delta);
136 * }
137 *
138 * @ Reflect
139 * static public void doSomeWork(final ComputeContext cc, S32Array arrayOfInt) {
140 * cc.dispatchKernel(KernelContext kc -> addDeltaKernel(kc,arrayOfInt.length(), 5, arrayOfInt);
141 * }
142 * }
143 * </pre>
144 *
145 * @param accelerator
146 * @param computeMethod
147 */
148
149 protected ComputeContext(Accelerator accelerator, Method computeMethod) {
150 this.accelerator = accelerator;
151 Optional<FuncOp> funcOp = Op.ofMethod(computeMethod);
152 if (funcOp.isEmpty()) {
153 throw new RuntimeException("Failed to create ComputeCallGraph (did you miss @Reflect annotation?).");
154 }
155 this.computeCallGraph = new ComputeCallGraph(this, computeMethod, funcOp.get());
156 this.accelerator.backend.computeContextHandoff(this);
157 }
158 record KernelCallSite(Quoted<JavaOp.LambdaOp> quoted, JavaOp.LambdaOp lambdaOp, MethodRef methodRef, KernelCallGraph kernelCallGraph) {}
159
160 private Map<Op.Location, KernelCallSite> kernelCallSiteCache = new HashMap<>();
161
162 /** Creating the kernel callsite involves
163 walking the code model of the lambda
164 analysing the callgraph and trsnsforming to HATDielect
165 So we cache the callsite against the location from the lambdaop.
166 */
167 public void dispatchKernel(NDRange<?, ?> ndRange, Kernel kernel) {
168 Quoted<JavaOp.LambdaOp> quoted = Op.ofLambda(kernel).orElseThrow();
169
170 var location = quoted.op().location();
171
172 KernelCallSite kernelCallSite;
173 if (kernelCallSiteCache.containsKey(location)) {
174 var oldKernelCallSite = kernelCallSiteCache.get(location);
175 kernelCallSite = new KernelCallSite(quoted, oldKernelCallSite.lambdaOp(), oldKernelCallSite.methodRef(), oldKernelCallSite.kernelCallGraph());
176 } else {
177 kernelCallSite = kernelCallSiteCache.compute(location, (_, _)-> {
178 JavaOp.LambdaOp lambdaOp = quoted.op();
179 MethodRef methodRef = getTargetInvoke(this.lookup(), lambdaOp, KernelContext.class).op().invokeDescriptor();
180 KernelCallGraph kernelCallGraph = computeCallGraph.kernelCallGraphMap.get(methodRef);
181 if (kernelCallGraph == null) {
182 throw new RuntimeException("Failed to create KernelCallGraph (did you miss @Reflect annotation?).");
183 }
184 return new KernelCallSite(quoted, lambdaOp, methodRef, kernelCallGraph);
185 });
186 }
187 Object[] args = lambda(lookup(),kernelCallSite.lambdaOp).getQuotedCapturedValues(kernelCallSite.quoted, kernelCallSite.kernelCallGraph.entrypoint.method());
188 KernelContext kernelContext = accelerator.range(ndRange);
189 args[0] = kernelContext;
190 accelerator.backend.dispatchKernel(kernelCallSite.kernelCallGraph, kernelContext, args);
191 }
192
193
194 @Override
195 public void preMutate(MappableIface b) {
196 if (accelerator.backend instanceof BufferTracker bufferTracker) {
197 bufferTracker.preMutate(b);
198 }
199 }
200
201 @Override
202 public void postMutate(MappableIface b) {
203 if (accelerator.backend instanceof BufferTracker bufferTracker) {
204 bufferTracker.postMutate(b);
205 }
206
207 }
208
209 @Override
210 public void preAccess(MappableIface b) {
211 if (accelerator.backend instanceof BufferTracker bufferTracker) {
212 bufferTracker.preAccess(b);
213 }
214
215 }
216
217 @Override
218 public void postAccess(MappableIface b) {
219 if (accelerator.backend instanceof BufferTracker bufferTracker) {
220 bufferTracker.postAccess(b);
221 }
222 }
223
224 @Reflect
225 @FunctionalInterface
226 public interface Kernel extends Consumer<KernelContext> { }
227
228 }