1 /*
2 * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation. Oracle designates this
8 * particular file as subject to the "Classpath" exception as provided
9 * by Oracle in the LICENSE file that accompanied this code.
10 *
11 * This code is distributed in the hope that it will be useful, but WITHOUT
12 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
13 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
14 * version 2 for more details (a copy is included in the LICENSE file that
15 * accompanied this code).
16 *
17 * You should have received a copy of the GNU General Public License version
18 * 2 along with this work; if not, write to the Free Software Foundation,
19 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
20 *
21 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
22 * or visit www.oracle.com if you need additional information or have any
23 * questions.
24 */
25 package hat.callgraph;
26
27 import hat.Accelerator;
28 import hat.ComputeContext;
29 import hat.KernelContext;
30 import hat.buffer.Buffer;
31 import hat.ifacemapper.MappableIface;
32 import hat.optools.FuncOpParams;
33 import hat.optools.OpTk;
34 import hat.util.StreamMutable;
35
36 import java.lang.invoke.MethodHandles;
37 import java.lang.reflect.Method;
38 import jdk.incubator.code.Op;
39 import jdk.incubator.code.dialect.core.CoreOp;
40 import jdk.incubator.code.dialect.java.JavaOp;
41 import jdk.incubator.code.dialect.java.JavaType;
42 import jdk.incubator.code.dialect.java.MethodRef;
43
44 import java.util.*;
45
46 public class ComputeCallGraph extends CallGraph<ComputeEntrypoint> {
47
48 public final Map<MethodRef, MethodCall> bufferAccessToMethodCallMap = new LinkedHashMap<>();
49
50 ComputeContextMethodCall computeContextMethodCall;
51
52 public interface ComputeReachable {
53 }
54
55 public abstract static class ComputeReachableResolvedMethodCall extends ResolvedMethodCall implements ComputeReachable {
56 public ComputeReachableResolvedMethodCall(CallGraph<ComputeEntrypoint> callGraph, MethodRef targetMethodRef, Method method, CoreOp.FuncOp funcOp) {
57 super(callGraph, targetMethodRef, method, funcOp);
58 }
59 }
60
61 public static class ComputeReachableUnresolvedMethodCall extends UnresolvedMethodCall implements ComputeReachable {
62 ComputeReachableUnresolvedMethodCall(CallGraph<ComputeEntrypoint> callGraph, MethodRef targetMethodRef, Method method) {
63 super(callGraph, targetMethodRef, method);
64 }
65 }
66
67 public static class ComputeReachableIfaceMappedMethodCall extends ComputeReachableUnresolvedMethodCall {
68 ComputeReachableIfaceMappedMethodCall(CallGraph<ComputeEntrypoint> callGraph, MethodRef targetMethodRef, Method method) {
69 super(callGraph, targetMethodRef, method);
70 }
71 }
72
73 public static class ComputeReachableAcceleratorMethodCall extends ComputeReachableUnresolvedMethodCall {
74 ComputeReachableAcceleratorMethodCall(CallGraph<ComputeEntrypoint> callGraph, MethodRef targetMethodRef, Method method) {
75 super(callGraph, targetMethodRef, method);
76 }
77 }
78
79 public static class ComputeContextMethodCall extends ComputeReachableUnresolvedMethodCall {
80 ComputeContextMethodCall(CallGraph<ComputeEntrypoint> callGraph, MethodRef targetMethodRef, Method method) {
81 super(callGraph, targetMethodRef, method);
82 }
83 }
84
85 public static class OtherComputeReachableResolvedMethodCall extends ComputeReachableResolvedMethodCall {
86 OtherComputeReachableResolvedMethodCall(CallGraph<ComputeEntrypoint> callGraph, MethodRef targetMethodRef, Method method, CoreOp.FuncOp funcOp) {
87 super(callGraph, targetMethodRef, method, funcOp);
88 }
89 }
90
91 static boolean isKernelDispatch(MethodHandles.Lookup lookup,Method calledMethod, CoreOp.FuncOp fow) {
92 if (fow.body().yieldType().equals(JavaType.VOID)
93 && calledMethod.getParameterTypes() instanceof Class<?>[] parameterTypes
94 && parameterTypes.length > 1) {
95 // We check that the proposed kernel returns void, the first arg is an KernelContext and we have more args
96 // We also check that other args are primitive or ifacebuffers (or atomics?)...
97 var firstArgIsKid = StreamMutable.of(false);
98 var atLeastOneIfaceBufferParam = StreamMutable.of(false);
99 var hasOnlyPrimitiveAndIfaceBufferParams = StreamMutable.of(true);
100 FuncOpParams paramTable = new FuncOpParams(fow);
101 paramTable.stream().forEach(paramInfo -> {
102 if (paramInfo.idx == 0) {
103 firstArgIsKid.set(parameterTypes[0].isAssignableFrom(KernelContext.class));
104 } else {
105 if (paramInfo.isPrimitive()) {
106 // OK
107 } else if (OpTk.isAssignable(lookup,paramInfo.javaType, MappableIface.class)){
108 atLeastOneIfaceBufferParam.set(true);
109 } else {
110 hasOnlyPrimitiveAndIfaceBufferParams.set(false);
111 }
112 }
113 });
114 return true;
115 }
116 return false;
117 }
118
119 public final Map<MethodRef, KernelCallGraph> kernelCallGraphMap = new HashMap<>();
120
121
122 public ComputeCallGraph(ComputeContext computeContext, Method method, CoreOp.FuncOp funcOp) {
123 super(computeContext, new ComputeEntrypoint(null, method, funcOp));
124 entrypoint.callGraph = this;
125 setModuleOp(OpTk.createTransitiveInvokeModule(computeContext.accelerator.lookup, entrypoint.funcOp(), this));
126 //close(entrypoint);
127 }
128
129 /*
130 * A ResolvedComputeMethodCall (entrypoint or java method reachable from a compute entrypojnt) has the following calls
131 * <p>
132 * 1) java calls to compute class static functions
133 * a) we must have the code model available for these and must be included in the dag
134 * 2) calls to buffer based interface mappings
135 * a) getters (return non void)
136 * b) setters (return void)
137 * c) default helpers with @CodeReflection?
138 * 3) calls to the compute context
139 * a) kernel dispatches
140 * b) mapped math functions?
141 * c) maybe we also handle range creations?
142 * 4) calls through compute context.accelerator;
143 * a) range creations (maybe compute context should manage ranges?)
144 * 5) References to the dispatched kernels
145 * a) We must also have the code models for these and must extend the dag to include these.
146 */
147 public void oldUpdateDag(ComputeReachableResolvedMethodCall computeReachableResolvedMethodCall) {
148 MethodHandles.Lookup lookup = computeReachableResolvedMethodCall.callGraph.computeContext.accelerator.lookup;
149 var here = OpTk.CallSite.of(ComputeCallGraph.class,"updateDag");
150 OpTk.transform(here, computeReachableResolvedMethodCall.funcOp(),(map, op) -> {
151 if (op instanceof JavaOp.InvokeOp invokeOp) {
152 Class<?> javaRefClass = OpTk.javaRefClassOrThrow(lookup,invokeOp);
153 Method invokeWrapperCalledMethod = OpTk.methodOrThrow(lookup,invokeOp);
154 if (Buffer.class.isAssignableFrom(javaRefClass)) {
155 computeReachableResolvedMethodCall.addCall(methodRefToMethodCallMap.computeIfAbsent(invokeOp.invokeDescriptor(), _ ->
156 new ComputeReachableIfaceMappedMethodCall(this, invokeOp.invokeDescriptor(), invokeWrapperCalledMethod)
157 ));
158 } else if (Accelerator.class.isAssignableFrom(javaRefClass)) {
159 computeReachableResolvedMethodCall.addCall(methodRefToMethodCallMap.computeIfAbsent(invokeOp.invokeDescriptor(), _ ->
160 new ComputeReachableAcceleratorMethodCall(this, invokeOp.invokeDescriptor(), invokeWrapperCalledMethod)
161 ));
162
163 } else if (ComputeContext.class.isAssignableFrom(javaRefClass)) {
164 computeReachableResolvedMethodCall.addCall(methodRefToMethodCallMap.computeIfAbsent(invokeOp.invokeDescriptor(), _ ->
165 new ComputeContextMethodCall(this, invokeOp.invokeDescriptor(), invokeWrapperCalledMethod)
166 ));
167 } else if (entrypoint.method.getDeclaringClass().equals(javaRefClass)) {
168 Optional<CoreOp.FuncOp> optionalFuncOp = Op.ofMethod(invokeWrapperCalledMethod);
169 if (optionalFuncOp.isPresent()) {
170 CoreOp.FuncOp fow = optionalFuncOp.get();//OpWrapper.wrap(computeContext.accelerator.lookup, optionalFuncOp.get());
171 if (isKernelDispatch(lookup,invokeWrapperCalledMethod, fow)) {
172 // System.out.println("A kernel reference (not a direct call) to a kernel " + methodRef);
173 kernelCallGraphMap.computeIfAbsent(invokeOp.invokeDescriptor(), _ ->
174 new KernelCallGraph(this, invokeOp.invokeDescriptor(), invokeWrapperCalledMethod, fow)
175 );
176 } else {
177 // System.out.println("A call to a method on the compute class which we have code model for " + methodRef);
178 computeReachableResolvedMethodCall.addCall(methodRefToMethodCallMap.computeIfAbsent(invokeOp.invokeDescriptor(), _ ->
179 new OtherComputeReachableResolvedMethodCall(this, invokeOp.invokeDescriptor(), invokeWrapperCalledMethod, fow)
180 ));
181 }
182 } else {
183 // System.out.println("A call to a method on the compute class which we DO NOT have code model for " + methodRef);
184 computeReachableResolvedMethodCall.addCall(methodRefToMethodCallMap.computeIfAbsent(invokeOp.invokeDescriptor(), _ ->
185 new ComputeReachableUnresolvedMethodCall(this, invokeOp.invokeDescriptor(), invokeWrapperCalledMethod)
186 ));
187 }
188 } else {
189 //TODO what about ifacenestings?
190 // System.out.println("A call to a method on the compute class which we DO NOT have code model for " + methodRef);
191 computeReachableResolvedMethodCall.addCall(methodRefToMethodCallMap.computeIfAbsent(invokeOp.invokeDescriptor(), _ ->
192 new ComputeReachableUnresolvedMethodCall(this, invokeOp.invokeDescriptor(), invokeWrapperCalledMethod)
193 ));
194 }
195 }
196 return map;
197 });
198 if (kernelCallGraphMap.isEmpty()) {
199 throw new IllegalStateException("entrypoint compute has no kernel references!");
200 }
201
202 boolean updated = true;
203 computeReachableResolvedMethodCall.closed = true;
204 while (updated) {
205 updated = false;
206 var unclosed = callStream().filter(m -> !m.closed).findFirst();
207 if (unclosed.isPresent()) {
208 if (unclosed.get() instanceof ComputeReachableResolvedMethodCall reachableResolvedMethodCall) {
209 oldUpdateDag(reachableResolvedMethodCall);
210 } else {
211 unclosed.get().closed = true;
212 }
213 updated = true;
214 }
215 }
216 }
217
218 @Override
219 public boolean filterCalls(CoreOp.FuncOp f, JavaOp.InvokeOp invokeOp, Method method, MethodRef methodRef, Class<?> javaRefTypeClass) {
220 if (entrypoint.method.getDeclaringClass().equals(OpTk.javaRefClassOrThrow(computeContext.accelerator.lookup,invokeOp))
221 && isKernelDispatch(computeContext.accelerator.lookup,method, f)) {
222 // TODO this side effect is not good. we should do this when we construct !
223 kernelCallGraphMap.computeIfAbsent(methodRef, _ ->
224 new KernelCallGraph(this, methodRef, method, f)
225 );
226 } else if (ComputeContext.class.isAssignableFrom(javaRefTypeClass)) {
227 computeContextMethodCall = new ComputeContextMethodCall(this, methodRef, method);
228 } else if (Buffer.class.isAssignableFrom(javaRefTypeClass)) {
229 bufferAccessToMethodCallMap.computeIfAbsent(methodRef, _ ->
230 new ComputeReachableIfaceMappedMethodCall(this, methodRef, method)
231 );
232 } else {
233 return false;
234 }
235 return true;
236 }
237
238 }