1 /*
  2  * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 package hat.callgraph;
 26 
 27 import hat.Accelerator;
 28 import hat.ComputeContext;
 29 import hat.KernelContext;
 30 import hat.buffer.Buffer;
 31 import hat.optools.FuncOpWrapper;
 32 import hat.optools.InvokeOpWrapper;
 33 import hat.optools.OpWrapper;
 34 import hat.util.Result;
 35 
 36 import java.lang.reflect.Method;
 37 import jdk.incubator.code.Op;
 38 import jdk.incubator.code.op.CoreOp;
 39 import jdk.incubator.code.type.JavaType;
 40 import jdk.incubator.code.type.MethodRef;
 41 import java.util.HashMap;
 42 import java.util.Map;
 43 import java.util.Optional;
 44 import java.util.stream.Stream;
 45 
 46 public class ComputeCallGraph extends CallGraph<ComputeEntrypoint> {
 47 
 48 
 49     public interface ComputeReachable {
 50     }
 51 
 52     public abstract static class ComputeReachableResolvedMethodCall extends ResolvedMethodCall implements ComputeReachable {
 53         public ComputeReachableResolvedMethodCall(CallGraph<ComputeEntrypoint> callGraph, MethodRef targetMethodRef, Method method, FuncOpWrapper funcOpWrapper) {
 54             super(callGraph, targetMethodRef, method, funcOpWrapper);
 55         }
 56     }
 57 
 58     public static class ComputeReachableUnresolvedMethodCall extends UnresolvedMethodCall implements ComputeReachable {
 59         ComputeReachableUnresolvedMethodCall(CallGraph<ComputeEntrypoint> callGraph, MethodRef targetMethodRef, Method method) {
 60             super(callGraph, targetMethodRef, method);
 61         }
 62     }
 63 
 64     public static class ComputeReachableIfaceMappedMethodCall extends ComputeReachableUnresolvedMethodCall {
 65         ComputeReachableIfaceMappedMethodCall(CallGraph<ComputeEntrypoint> callGraph, MethodRef targetMethodRef, Method method) {
 66             super(callGraph, targetMethodRef, method);
 67         }
 68     }
 69 
 70     public static class ComputeReachableAcceleratorMethodCall extends ComputeReachableUnresolvedMethodCall {
 71         ComputeReachableAcceleratorMethodCall(CallGraph<ComputeEntrypoint> callGraph, MethodRef targetMethodRef, Method method) {
 72             super(callGraph, targetMethodRef, method);
 73         }
 74     }
 75 
 76     public static class ComputeContextMethodCall extends ComputeReachableUnresolvedMethodCall {
 77         ComputeContextMethodCall(CallGraph<ComputeEntrypoint> callGraph, MethodRef targetMethodRef, Method method) {
 78             super(callGraph, targetMethodRef, method);
 79         }
 80     }
 81 
 82     public static class OtherComputeReachableResolvedMethodCall extends ComputeReachableResolvedMethodCall {
 83         OtherComputeReachableResolvedMethodCall(CallGraph<ComputeEntrypoint> callGraph, MethodRef targetMethodRef, Method method, FuncOpWrapper funcOpWrapper) {
 84             super(callGraph, targetMethodRef, method, funcOpWrapper);
 85         }
 86     }
 87 
 88     static boolean isKernelDispatch(Method calledMethod, FuncOpWrapper fow) {
 89         if (fow.getReturnType().equals(JavaType.VOID)) {
 90             if (calledMethod.getParameterTypes() instanceof Class<?>[] parameterTypes && parameterTypes.length > 1) {
 91                 // We check that the proposed kernel first arg is an KernelContext and
 92                 // the only other args are primitive or ifacebuffers
 93                 var firstArgIsKid = new Result<>(false);
 94                 var atLeastOneIfaceBufferParam = new Result<>(false);
 95                 var hasOnlyPrimitiveAndIfaceBufferParams = new Result<Boolean>(true);
 96                 fow.paramTable().stream().forEach(paramInfo -> {
 97                     if (paramInfo.idx == 0) {
 98                         firstArgIsKid.of(parameterTypes[0].isAssignableFrom(KernelContext.class));
 99                     } else {
100                         if (paramInfo.isPrimitive()) {
101                             // OK
102                         } else if (InvokeOpWrapper.isIface(paramInfo.javaType)) {
103                             atLeastOneIfaceBufferParam.of(true);
104                         } else {
105                             hasOnlyPrimitiveAndIfaceBufferParams.of(false);
106                         }
107                     }
108                 });
109                 return true;
110             }
111             return false;
112         } else {
113             return false;
114         }
115     }
116 
117     public final Map<MethodRef, KernelCallGraph> kernelCallGraphMap = new HashMap<>();
118 
119     public Stream<KernelCallGraph> kernelCallGraphStream() {
120         return kernelCallGraphMap.values().stream();
121     }
122 
123     public ComputeCallGraph(ComputeContext computeContext, Method method, FuncOpWrapper funcOpWrapper) {
124         super(computeContext, new ComputeEntrypoint(null, method, funcOpWrapper));
125         entrypoint.callGraph = this;
126     }
127 
128     public void updateDag(ComputeReachableResolvedMethodCall computeReachableResolvedMethodCall) {
129         /*
130          * A ResolvedComputeMethodCall (entrypoint or java  method reachable from a compute entrypojnt)  has the following calls
131          * <p>
132          * 1) java calls to compute class static functions
133          *    a) we must have the code model available for these and must be included in the dag
134          * 2) calls to buffer based interface mappings
135          *    a) getters (return non void)
136          *    b) setters (return void)
137          *    c) default helpers with @CodeReflection?
138          * 3) calls to the compute context
139          *    a) kernel dispatches
140          *    b) mapped math functions?
141          *    c) maybe we also handle range creations?
142          * 4) calls through compute context.accelerator;
143          *    a) range creations (maybe compute context should manage ranges?)
144          * 5) References to the dispatched kernels
145          *    a) We must also have the code models for these and must extend the dag to include these.
146          */
147 
148         computeReachableResolvedMethodCall.funcOpWrapper().selectCalls((invokeWrapper) -> {
149             MethodRef methodRef = invokeWrapper.methodRef();
150             Class<?> javaRefClass = invokeWrapper.javaRefClass().orElseThrow();
151             Method invokeWrapperCalledMethod = invokeWrapper.method(this.computeContext.accelerator.lookup);
152             if (Buffer.class.isAssignableFrom(javaRefClass)) {
153                 // System.out.println("iface mapped buffer call  -> " + methodRef);
154                 computeReachableResolvedMethodCall.addCall(methodRefToMethodCallMap.computeIfAbsent(methodRef, _ ->
155                         new ComputeReachableIfaceMappedMethodCall(this, methodRef, invokeWrapperCalledMethod)
156                 ));
157             } else if (Accelerator.class.isAssignableFrom(javaRefClass)) {
158                 // System.out.println("call on the accelerator (must be through the computeContext) -> " + methodRef);
159                 computeReachableResolvedMethodCall.addCall(methodRefToMethodCallMap.computeIfAbsent(methodRef, _ ->
160                         new ComputeReachableAcceleratorMethodCall(this, methodRef, invokeWrapperCalledMethod)
161                 ));
162 
163             } else if (ComputeContext.class.isAssignableFrom(javaRefClass)) {
164                 // System.out.println("call on the computecontext -> " + methodRef);
165                 computeReachableResolvedMethodCall.addCall(methodRefToMethodCallMap.computeIfAbsent(methodRef, _ ->
166                         new ComputeContextMethodCall(this, methodRef, invokeWrapperCalledMethod)
167                 ));
168             } else if (entrypoint.method.getDeclaringClass().equals(javaRefClass)) {
169                 Optional<CoreOp.FuncOp> optionalFuncOp = Op.ofMethod(invokeWrapperCalledMethod);
170                 if (optionalFuncOp.isPresent()) {
171                     FuncOpWrapper fow = OpWrapper.wrap(optionalFuncOp.get());
172                     if (isKernelDispatch(invokeWrapperCalledMethod, fow)) {
173                         // System.out.println("A kernel reference (not a direct call) to a kernel " + methodRef);
174                         kernelCallGraphMap.computeIfAbsent(methodRef, _ ->
175                                 new KernelCallGraph(this, methodRef, invokeWrapperCalledMethod, fow).close()
176                         );
177                     } else {
178                         // System.out.println("A call to a method on the compute class which we have code model for " + methodRef);
179                         computeReachableResolvedMethodCall.addCall(methodRefToMethodCallMap.computeIfAbsent(methodRef, _ ->
180                                 new OtherComputeReachableResolvedMethodCall(this, methodRef, invokeWrapperCalledMethod, fow)
181                         ));
182                     }
183                 } else {
184                     //  System.out.println("A call to a method on the compute class which we DO NOT have code model for " + methodRef);
185                     computeReachableResolvedMethodCall.addCall(methodRefToMethodCallMap.computeIfAbsent(methodRef, _ ->
186                             new ComputeReachableUnresolvedMethodCall(this, methodRef, invokeWrapperCalledMethod)
187                     ));
188                 }
189             } else {
190                 //TODO what about ifacenestings?
191                 // System.out.println("A call to a method on the compute class which we DO NOT have code model for " + methodRef);
192                 computeReachableResolvedMethodCall.addCall(methodRefToMethodCallMap.computeIfAbsent(methodRef, _ ->
193                         new ComputeReachableUnresolvedMethodCall(this, methodRef, invokeWrapperCalledMethod)
194                 ));
195             }
196         });
197 
198         if (kernelCallGraphMap.isEmpty()) {
199             throw new IllegalStateException("entrypoint compute has no kernel references!");
200         }
201 
202         boolean updated = true;
203         computeReachableResolvedMethodCall.closed = true;
204         while (updated) {
205             updated = false;
206             var unclosed = callStream().filter(m -> !m.closed).findFirst();
207             if (unclosed.isPresent()) {
208                 if (unclosed.get() instanceof ComputeReachableResolvedMethodCall reachableResolvedMethodCall) {
209                     updateDag(reachableResolvedMethodCall);
210                 } else {
211                     unclosed.get().closed = true;
212                 }
213                 updated = true;
214             }
215         }
216     }
217 
218     public void close() {
219         updateDag(entrypoint);
220     }
221 }