1 /*
2 * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation. Oracle designates this
8 * particular file as subject to the "Classpath" exception as provided
9 * by Oracle in the LICENSE file that accompanied this code.
10 *
11 * This code is distributed in the hope that it will be useful, but WITHOUT
12 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
13 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
14 * version 2 for more details (a copy is included in the LICENSE file that
15 * accompanied this code).
16 *
17 * You should have received a copy of the GNU General Public License version
18 * 2 along with this work; if not, write to the Free Software Foundation,
19 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
20 *
21 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
22 * or visit www.oracle.com if you need additional information or have any
23 * questions.
24 */
25 package hat;
26
27 import optkl.util.carriers.ArenaAndLookupCarrier;
28 import optkl.ifacemapper.BufferTracker;
29 import hat.callgraph.ComputeCallGraph;
30 import hat.callgraph.KernelCallGraph;
31 import optkl.ifacemapper.MappableIface;
32 import jdk.incubator.code.dialect.core.CoreOp.FuncOp;
33 import jdk.incubator.code.Reflect;
34 import jdk.incubator.code.Op;
35 import jdk.incubator.code.Quoted;
36 import jdk.incubator.code.dialect.java.JavaOp;
37 import jdk.incubator.code.dialect.java.MethodRef;
38
39 import java.lang.foreign.Arena;
40 import java.lang.invoke.MethodHandles;
41 import java.lang.reflect.Method;
42 import java.util.HashMap;
43 import java.util.Map;
44 import java.util.function.Consumer;
45 import java.util.Optional;
46
47 import static optkl.OpHelper.Invoke.getTargetInvoke;
48 import static optkl.OpHelper.Lambda.lambda;
49
50 /**
51 * A ComputeContext is created by an Accelerator to capture and control compute and kernel
52 * callgraphs for the work to be performed by the backend.
53 * <p/>
54 * The Compute closure is created first, by walking the code model of the entrypoint, then transitively
55 * visiting all conventional code reachable from this entrypoint.
56 * <p/>
57 * Generally all user defined methods reachable from the entrypoint (and the entrypoint itself) must be static methods of the same
58 * enclosing classes.
59 * <p/>
60 * We do allow calls on the ComputeContext itself, and on the mapped interface buffers holding non uniform kernel data.
61 * <p/>
62 * Each request to dispatch a kernel discovered in the compute graph, results in a new Kernel call graph
63 * being created with the dispatched kernel as it's entrypoint.
64 * <p/>
65 * When the ComputeContext is finalized, it is passed to the backend via <a href="Backend.computeClosureHandoff(ComputeContext)"></a>
66 *
67 * @author Gary Frost
68 */
69 public class ComputeContext implements ArenaAndLookupCarrier, BufferTracker {
70
71
72 @Override
73 public Arena arena() {
74 return accelerator.arena();
75 }
76
77 @Override
78 public MethodHandles.Lookup lookup() {
79 return accelerator.lookup();
80 }
81
82
83 public Config config() {
84 return accelerator().config();
85 }
86
87 public void invokeWithArgs(Object[] args) {
88 computeCallGraph.invokeWithArgs(args);
89
90 }
91
92 public enum WRAPPER {
93 MUTATE("Mutate"), ACCESS("Access");
94 final public MethodRef pre;
95 final public MethodRef post;
96
97 WRAPPER(String name) {
98 this.pre = MethodRef.method(ComputeContext.class, "pre" + name, void.class, MappableIface.class);
99 this.post = MethodRef.method(ComputeContext.class, "post" + name, void.class, MappableIface.class);
100 }
101 }
102
103 private final Accelerator accelerator;
104 final public Accelerator accelerator(){
105 return accelerator;
106 }
107
108 private final ComputeCallGraph computeCallGraph;
109 final public ComputeCallGraph computeCallGraph(){
110 return computeCallGraph;
111 }
112
113
114
115 /**
116 * Called by the Accelerator when the accelerator is passed a compute entrypoint.
117 * <p>
118 * So given a ComputeClass such as..
119 * <pre>
120 * public class MyComputeClass {
121 * @ Reflect
122 * public static void addDeltaKernel(KernelContext kc, S32Array arrayOfInt, int delta) {
123 * arrayOfInt.array(kc.x, arrayOfInt.array(kc.x)+delta);
124 * }
125 *
126 * @ Reflect
127 * static public void doSomeWork(final ComputeContext cc, S32Array arrayOfInt) {
128 * cc.dispatchKernel(KernelContext kc -> addDeltaKernel(kc,arrayOfInt.length(), 5, arrayOfInt);
129 * }
130 * }
131 * </pre>
132 *
133 * @param accelerator
134 * @param computeMethod
135 */
136
137 protected ComputeContext(Accelerator accelerator, Method computeMethod) {
138 this.accelerator = accelerator;
139 Optional<FuncOp> funcOp = Op.ofMethod(computeMethod);
140 if (funcOp.isEmpty()) {
141 throw new RuntimeException("Failed to create ComputeCallGraph (did you miss @Reflect annotation?).");
142 }
143 this.computeCallGraph = new ComputeCallGraph(this, computeMethod, funcOp.get());
144 this.accelerator.backend.computeContextHandoff(this);
145 }
146 record KernelCallSite(Quoted<JavaOp.LambdaOp> quoted, JavaOp.LambdaOp lambdaOp, MethodRef methodRef, KernelCallGraph kernelCallGraph) {}
147
148 private final Map<Op.Location, KernelCallSite> kernelCallSiteCache = new HashMap<>();
149
150 /** Creating the kernel callsite involves
151 walking the code model of the lambda
152 analysing the callgraph and trsnsforming to HATDielect
153 So we cache the callsite against the location from the lambdaop.
154 */
155 public void dispatchKernel(NDRange ndRange, Kernel kernel) {
156 Quoted<JavaOp.LambdaOp> quoted = Op.ofLambda(kernel).orElseThrow();
157
158 var location = quoted.op().location();
159
160 KernelCallSite kernelCallSite;
161 if (kernelCallSiteCache.containsKey(location)) {
162 var oldKernelCallSite = kernelCallSiteCache.get(location);
163 kernelCallSite = new KernelCallSite(quoted, oldKernelCallSite.lambdaOp(), oldKernelCallSite.methodRef(), oldKernelCallSite.kernelCallGraph());
164 } else {
165 kernelCallSite = kernelCallSiteCache.compute(location, (_, _)-> {
166 JavaOp.LambdaOp lambdaOp = quoted.op();
167 MethodRef methodRef = getTargetInvoke(this.lookup(), lambdaOp, KernelContext.class).op().invokeReference();
168 KernelCallGraph kernelCallGraph = computeCallGraph.kernelCallGraphMap.get(methodRef);
169 if (kernelCallGraph == null) {
170 throw new RuntimeException("Failed to create KernelCallGraph (did you miss @Reflect annotation?).");
171 }
172 return new KernelCallSite(quoted, lambdaOp, methodRef, kernelCallGraph);
173 });
174 }
175 Object[] args = lambda(lookup(),kernelCallSite.lambdaOp).getQuotedCapturedValues(kernelCallSite.quoted, kernelCallSite.kernelCallGraph.callDag.entryPoint.method());
176 KernelContext kernelContext = accelerator.range(ndRange);
177 args[0] = kernelContext;
178 accelerator.backend.dispatchKernel(kernelCallSite.kernelCallGraph, kernelContext, args);
179 }
180
181
182 @Override
183 public void preMutate(MappableIface b) {
184 if (accelerator.backend instanceof BufferTracker bufferTracker) {
185 bufferTracker.preMutate(b);
186 }
187 }
188
189 @Override
190 public void postMutate(MappableIface b) {
191 if (accelerator.backend instanceof BufferTracker bufferTracker) {
192 bufferTracker.postMutate(b);
193 }
194
195 }
196
197 @Override
198 public void preAccess(MappableIface b) {
199 if (accelerator.backend instanceof BufferTracker bufferTracker) {
200 bufferTracker.preAccess(b);
201 }
202
203 }
204
205 @Override
206 public void postAccess(MappableIface b) {
207 if (accelerator.backend instanceof BufferTracker bufferTracker) {
208 bufferTracker.postAccess(b);
209 }
210 }
211
212 @Reflect
213 @FunctionalInterface
214 public interface Kernel extends Consumer<KernelContext> { }
215
216 }