1 /*
  2  * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 package hat.callgraph;
 26 
 27 import hat.Accelerator;
 28 import hat.ComputeContext;
 29 import hat.KernelContext;
 30 import hat.buffer.Buffer;
 31 import hat.ifacemapper.MappableIface;
 32 import hat.optools.FuncOpParams;
 33 import hat.optools.OpTk;
 34 import hat.util.StreamMutable;
 35 
 36 import java.lang.invoke.MethodHandles;
 37 import java.lang.reflect.Method;
 38 import jdk.incubator.code.Op;
 39 import jdk.incubator.code.dialect.core.CoreOp;
 40 import jdk.incubator.code.dialect.java.JavaOp;
 41 import jdk.incubator.code.dialect.java.JavaType;
 42 import jdk.incubator.code.dialect.java.MethodRef;
 43 
 44 import java.util.*;
 45 
 46 public class ComputeCallGraph extends CallGraph<ComputeEntrypoint> {
 47 
 48     public final Map<MethodRef, MethodCall> bufferAccessToMethodCallMap = new LinkedHashMap<>();
 49 
 50     ComputeContextMethodCall computeContextMethodCall;
 51 
 52     public interface ComputeReachable {
 53     }
 54 
 55     public abstract static class ComputeReachableResolvedMethodCall extends ResolvedMethodCall implements ComputeReachable {
 56         public ComputeReachableResolvedMethodCall(CallGraph<ComputeEntrypoint> callGraph, MethodRef targetMethodRef, Method method, CoreOp.FuncOp funcOp) {
 57             super(callGraph, targetMethodRef, method, funcOp);
 58         }
 59     }
 60 
 61     public static class ComputeReachableUnresolvedMethodCall extends UnresolvedMethodCall implements ComputeReachable {
 62         ComputeReachableUnresolvedMethodCall(CallGraph<ComputeEntrypoint> callGraph, MethodRef targetMethodRef, Method method) {
 63             super(callGraph, targetMethodRef, method);
 64         }
 65     }
 66 
 67     public static class ComputeReachableIfaceMappedMethodCall extends ComputeReachableUnresolvedMethodCall {
 68         ComputeReachableIfaceMappedMethodCall(CallGraph<ComputeEntrypoint> callGraph, MethodRef targetMethodRef, Method method) {
 69             super(callGraph, targetMethodRef, method);
 70         }
 71     }
 72 
 73     public static class ComputeReachableAcceleratorMethodCall extends ComputeReachableUnresolvedMethodCall {
 74         ComputeReachableAcceleratorMethodCall(CallGraph<ComputeEntrypoint> callGraph, MethodRef targetMethodRef, Method method) {
 75             super(callGraph, targetMethodRef, method);
 76         }
 77     }
 78 
 79     public static class ComputeContextMethodCall extends ComputeReachableUnresolvedMethodCall {
 80         ComputeContextMethodCall(CallGraph<ComputeEntrypoint> callGraph, MethodRef targetMethodRef, Method method) {
 81             super(callGraph, targetMethodRef, method);
 82         }
 83     }
 84 
 85     public static class OtherComputeReachableResolvedMethodCall extends ComputeReachableResolvedMethodCall {
 86         OtherComputeReachableResolvedMethodCall(CallGraph<ComputeEntrypoint> callGraph, MethodRef targetMethodRef, Method method, CoreOp.FuncOp funcOp) {
 87             super(callGraph, targetMethodRef, method, funcOp);
 88         }
 89     }
 90 
 91     static boolean isKernelDispatch(MethodHandles.Lookup lookup,Method calledMethod, CoreOp.FuncOp fow) {
 92         if (fow.body().yieldType().equals(JavaType.VOID)
 93                 && calledMethod.getParameterTypes() instanceof Class<?>[] parameterTypes
 94                 && parameterTypes.length > 1) {
 95                 // We check that the proposed kernel returns void, the first arg is an KernelContext and we have more args
 96                 // We also check that other args are primitive or ifacebuffers  (or atomics?)...
 97                 var firstArgIsKid = StreamMutable.of(false);
 98                 var atLeastOneIfaceBufferParam = StreamMutable.of(false);
 99                 var hasOnlyPrimitiveAndIfaceBufferParams = StreamMutable.of(true);
100                 FuncOpParams paramTable = new FuncOpParams(fow);
101                 paramTable.stream().forEach(paramInfo -> {
102                     if (paramInfo.idx == 0) {
103                         firstArgIsKid.set(parameterTypes[0].isAssignableFrom(KernelContext.class));
104                     } else {
105                         if (paramInfo.isPrimitive()) {
106                             // OK
107                         } else if (OpTk.isAssignable(lookup,paramInfo.javaType, MappableIface.class)){
108                             atLeastOneIfaceBufferParam.set(true);
109                         } else {
110                             hasOnlyPrimitiveAndIfaceBufferParams.set(false);
111                         }
112                     }
113                 });
114                 return true;
115             }
116             return false;
117     }
118 
119     public final Map<MethodRef, KernelCallGraph> kernelCallGraphMap = new HashMap<>();
120 
121 
122     public ComputeCallGraph(ComputeContext computeContext, Method method, CoreOp.FuncOp funcOp) {
123         super(computeContext, new ComputeEntrypoint(null, method, funcOp));
124         entrypoint.callGraph = this;
125         setModuleOp(OpTk.createTransitiveInvokeModule(computeContext.accelerator.lookup, entrypoint.funcOp(), this));
126         //close(entrypoint);
127     }
128 
129     /*
130      * A ResolvedComputeMethodCall (entrypoint or java  method reachable from a compute entrypojnt)  has the following calls
131      * <p>
132      * 1) java calls to compute class static functions
133      *    a) we must have the code model available for these and must be included in the dag
134      * 2) calls to buffer based interface mappings
135      *    a) getters (return non void)
136      *    b) setters (return void)
137      *    c) default helpers with @CodeReflection?
138      * 3) calls to the compute context
139      *    a) kernel dispatches
140      *    b) mapped math functions?
141      *    c) maybe we also handle range creations?
142      * 4) calls through compute context.accelerator;
143      *    a) range creations (maybe compute context should manage ranges?)
144      * 5) References to the dispatched kernels
145      *    a) We must also have the code models for these and must extend the dag to include these.
146      */
147     public void oldUpdateDag(ComputeReachableResolvedMethodCall computeReachableResolvedMethodCall) {
148         MethodHandles.Lookup lookup =  computeReachableResolvedMethodCall.callGraph.computeContext.accelerator.lookup;
149         var here = OpTk.CallSite.of(ComputeCallGraph.class,"updateDag");
150         OpTk.transform(here, computeReachableResolvedMethodCall.funcOp(),(map, op) -> {
151             if (op instanceof JavaOp.InvokeOp invokeOp) {
152                 Class<?> javaRefClass = OpTk.javaRefClassOrThrow(lookup,invokeOp);
153                 Method invokeWrapperCalledMethod = OpTk.methodOrThrow(lookup,invokeOp);
154                 if (Buffer.class.isAssignableFrom(javaRefClass)) {
155                     computeReachableResolvedMethodCall.addCall(methodRefToMethodCallMap.computeIfAbsent(invokeOp.invokeDescriptor(), _ ->
156                             new ComputeReachableIfaceMappedMethodCall(this, invokeOp.invokeDescriptor(), invokeWrapperCalledMethod)
157                     ));
158                 } else if (Accelerator.class.isAssignableFrom(javaRefClass)) {
159                     computeReachableResolvedMethodCall.addCall(methodRefToMethodCallMap.computeIfAbsent(invokeOp.invokeDescriptor(), _ ->
160                             new ComputeReachableAcceleratorMethodCall(this, invokeOp.invokeDescriptor(), invokeWrapperCalledMethod)
161                     ));
162 
163                 } else if (ComputeContext.class.isAssignableFrom(javaRefClass)) {
164                     computeReachableResolvedMethodCall.addCall(methodRefToMethodCallMap.computeIfAbsent(invokeOp.invokeDescriptor(), _ ->
165                             new ComputeContextMethodCall(this, invokeOp.invokeDescriptor(), invokeWrapperCalledMethod)
166                     ));
167                 } else if (entrypoint.method.getDeclaringClass().equals(javaRefClass)) {
168                     Optional<CoreOp.FuncOp> optionalFuncOp = Op.ofMethod(invokeWrapperCalledMethod);
169                     if (optionalFuncOp.isPresent()) {
170                         CoreOp.FuncOp fow = optionalFuncOp.get();//OpWrapper.wrap(computeContext.accelerator.lookup, optionalFuncOp.get());
171                         if (isKernelDispatch(lookup,invokeWrapperCalledMethod, fow)) {
172                             // System.out.println("A kernel reference (not a direct call) to a kernel " + methodRef);
173                             kernelCallGraphMap.computeIfAbsent(invokeOp.invokeDescriptor(), _ ->
174                                     new KernelCallGraph(this, invokeOp.invokeDescriptor(), invokeWrapperCalledMethod, fow)
175                             );
176                         } else {
177                             // System.out.println("A call to a method on the compute class which we have code model for " + methodRef);
178                             computeReachableResolvedMethodCall.addCall(methodRefToMethodCallMap.computeIfAbsent(invokeOp.invokeDescriptor(), _ ->
179                                     new OtherComputeReachableResolvedMethodCall(this, invokeOp.invokeDescriptor(), invokeWrapperCalledMethod, fow)
180                             ));
181                         }
182                     } else {
183                         //  System.out.println("A call to a method on the compute class which we DO NOT have code model for " + methodRef);
184                         computeReachableResolvedMethodCall.addCall(methodRefToMethodCallMap.computeIfAbsent(invokeOp.invokeDescriptor(), _ ->
185                                 new ComputeReachableUnresolvedMethodCall(this, invokeOp.invokeDescriptor(), invokeWrapperCalledMethod)
186                         ));
187                     }
188                 } else {
189                     //TODO what about ifacenestings?
190                     // System.out.println("A call to a method on the compute class which we DO NOT have code model for " + methodRef);
191                     computeReachableResolvedMethodCall.addCall(methodRefToMethodCallMap.computeIfAbsent(invokeOp.invokeDescriptor(), _ ->
192                             new ComputeReachableUnresolvedMethodCall(this, invokeOp.invokeDescriptor(), invokeWrapperCalledMethod)
193                     ));
194                 }
195             }
196             return map;
197         });
198         if (kernelCallGraphMap.isEmpty()) {
199             throw new IllegalStateException("entrypoint compute has no kernel references!");
200         }
201 
202         boolean updated = true;
203         computeReachableResolvedMethodCall.closed = true;
204         while (updated) {
205             updated = false;
206             var unclosed = callStream().filter(m -> !m.closed).findFirst();
207             if (unclosed.isPresent()) {
208                 if (unclosed.get() instanceof ComputeReachableResolvedMethodCall reachableResolvedMethodCall) {
209                     oldUpdateDag(reachableResolvedMethodCall);
210                 } else {
211                     unclosed.get().closed = true;
212                 }
213                 updated = true;
214             }
215         }
216     }
217 
218     @Override
219     public boolean filterCalls(CoreOp.FuncOp f, JavaOp.InvokeOp invokeOp, Method method, MethodRef methodRef, Class<?> javaRefTypeClass) {
220         if (entrypoint.method.getDeclaringClass().equals(OpTk.javaRefClassOrThrow(computeContext.accelerator.lookup,invokeOp))
221                 && isKernelDispatch(computeContext.accelerator.lookup,method, f)) {
222             // TODO this side effect is not good.  we should do this when we construct !
223             kernelCallGraphMap.computeIfAbsent(methodRef, _ ->
224                     new KernelCallGraph(this, methodRef, method, f)
225             );
226         } else if (ComputeContext.class.isAssignableFrom(javaRefTypeClass)) {
227             computeContextMethodCall = new ComputeContextMethodCall(this, methodRef, method);
228         } else if (Buffer.class.isAssignableFrom(javaRefTypeClass)) {
229             bufferAccessToMethodCallMap.computeIfAbsent(methodRef, _ ->
230                     new ComputeReachableIfaceMappedMethodCall(this, methodRef, method)
231             );
232         } else {
233             return false;
234         }
235         return true;
236     }
237 
238 }