1 /* 2 * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved. 3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. 4 * 5 * This code is free software; you can redistribute it and/or modify it 6 * under the terms of the GNU General Public License version 2 only, as 7 * published by the Free Software Foundation. Oracle designates this 8 * particular file as subject to the "Classpath" exception as provided 9 * by Oracle in the LICENSE file that accompanied this code. 10 * 11 * This code is distributed in the hope that it will be useful, but WITHOUT 12 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or 13 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License 14 * version 2 for more details (a copy is included in the LICENSE file that 15 * accompanied this code). 16 * 17 * You should have received a copy of the GNU General Public License version 18 * 2 along with this work; if not, write to the Free Software Foundation, 19 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. 20 * 21 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA 22 * or visit www.oracle.com if you need additional information or have any 23 * questions. 24 */ 25 package hat.callgraph; 26 27 import hat.Accelerator; 28 import hat.ComputeContext; 29 import hat.KernelContext; 30 import hat.buffer.Buffer; 31 import hat.optools.FuncOpWrapper; 32 import hat.optools.InvokeOpWrapper; 33 import hat.optools.OpWrapper; 34 import hat.util.Result; 35 36 import java.lang.reflect.Method; 37 import jdk.incubator.code.Op; 38 import jdk.incubator.code.op.CoreOp; 39 import jdk.incubator.code.type.JavaType; 40 import jdk.incubator.code.type.MethodRef; 41 import java.util.HashMap; 42 import java.util.Map; 43 import java.util.Optional; 44 import java.util.stream.Stream; 45 46 public class ComputeCallGraph extends CallGraph<ComputeEntrypoint> { 47 48 49 public interface ComputeReachable { 50 } 51 52 public abstract static class ComputeReachableResolvedMethodCall extends ResolvedMethodCall implements ComputeReachable { 53 public ComputeReachableResolvedMethodCall(CallGraph<ComputeEntrypoint> callGraph, MethodRef targetMethodRef, Method method, FuncOpWrapper funcOpWrapper) { 54 super(callGraph, targetMethodRef, method, funcOpWrapper); 55 } 56 } 57 58 public static class ComputeReachableUnresolvedMethodCall extends UnresolvedMethodCall implements ComputeReachable { 59 ComputeReachableUnresolvedMethodCall(CallGraph<ComputeEntrypoint> callGraph, MethodRef targetMethodRef, Method method) { 60 super(callGraph, targetMethodRef, method); 61 } 62 } 63 64 public static class ComputeReachableIfaceMappedMethodCall extends ComputeReachableUnresolvedMethodCall { 65 ComputeReachableIfaceMappedMethodCall(CallGraph<ComputeEntrypoint> callGraph, MethodRef targetMethodRef, Method method) { 66 super(callGraph, targetMethodRef, method); 67 } 68 } 69 70 public static class ComputeReachableAcceleratorMethodCall extends ComputeReachableUnresolvedMethodCall { 71 ComputeReachableAcceleratorMethodCall(CallGraph<ComputeEntrypoint> callGraph, MethodRef targetMethodRef, Method method) { 72 super(callGraph, targetMethodRef, method); 73 } 74 } 75 76 public static class ComputeContextMethodCall extends ComputeReachableUnresolvedMethodCall { 77 ComputeContextMethodCall(CallGraph<ComputeEntrypoint> callGraph, MethodRef targetMethodRef, Method method) { 78 super(callGraph, targetMethodRef, method); 79 } 80 } 81 82 public static class OtherComputeReachableResolvedMethodCall extends ComputeReachableResolvedMethodCall { 83 OtherComputeReachableResolvedMethodCall(CallGraph<ComputeEntrypoint> callGraph, MethodRef targetMethodRef, Method method, FuncOpWrapper funcOpWrapper) { 84 super(callGraph, targetMethodRef, method, funcOpWrapper); 85 } 86 } 87 88 static boolean isKernelDispatch(Method calledMethod, FuncOpWrapper fow) { 89 if (fow.getReturnType().equals(JavaType.VOID)) { 90 if (calledMethod.getParameterTypes() instanceof Class<?>[] parameterTypes && parameterTypes.length > 1) { 91 // We check that the proposed kernel first arg is an KernelContext and 92 // the only other args are primitive or ifacebuffers 93 var firstArgIsKid = new Result<>(false); 94 var atLeastOneIfaceBufferParam = new Result<>(false); 95 var hasOnlyPrimitiveAndIfaceBufferParams = new Result<Boolean>(true); 96 fow.paramTable().stream().forEach(paramInfo -> { 97 if (paramInfo.idx == 0) { 98 firstArgIsKid.of(parameterTypes[0].isAssignableFrom(KernelContext.class)); 99 } else { 100 if (paramInfo.isPrimitive()) { 101 // OK 102 } else if (InvokeOpWrapper.isIface(paramInfo.javaType)) { 103 atLeastOneIfaceBufferParam.of(true); 104 } else { 105 hasOnlyPrimitiveAndIfaceBufferParams.of(false); 106 } 107 } 108 }); 109 return true; 110 } 111 return false; 112 } else { 113 return false; 114 } 115 } 116 117 public final Map<MethodRef, KernelCallGraph> kernelCallGraphMap = new HashMap<>(); 118 119 public Stream<KernelCallGraph> kernelCallGraphStream() { 120 return kernelCallGraphMap.values().stream(); 121 } 122 123 public ComputeCallGraph(ComputeContext computeContext, Method method, FuncOpWrapper funcOpWrapper) { 124 super(computeContext, new ComputeEntrypoint(null, method, funcOpWrapper)); 125 entrypoint.callGraph = this; 126 } 127 128 public void updateDag(ComputeReachableResolvedMethodCall computeReachableResolvedMethodCall) { 129 /* 130 * A ResolvedComputeMethodCall (entrypoint or java method reachable from a compute entrypojnt) has the following calls 131 * <p> 132 * 1) java calls to compute class static functions 133 * a) we must have the code model available for these and must be included in the dag 134 * 2) calls to buffer based interface mappings 135 * a) getters (return non void) 136 * b) setters (return void) 137 * c) default helpers with @CodeReflection? 138 * 3) calls to the compute context 139 * a) kernel dispatches 140 * b) mapped math functions? 141 * c) maybe we also handle range creations? 142 * 4) calls through compute context.accelerator; 143 * a) range creations (maybe compute context should manage ranges?) 144 * 5) References to the dispatched kernels 145 * a) We must also have the code models for these and must extend the dag to include these. 146 */ 147 148 computeReachableResolvedMethodCall.funcOpWrapper().selectCalls((invokeWrapper) -> { 149 MethodRef methodRef = invokeWrapper.methodRef(); 150 Class<?> javaRefClass = invokeWrapper.javaRefClass().orElseThrow(); 151 Method invokeWrapperCalledMethod = invokeWrapper.method(this.computeContext.accelerator.lookup); 152 if (Buffer.class.isAssignableFrom(javaRefClass)) { 153 // System.out.println("iface mapped buffer call -> " + methodRef); 154 computeReachableResolvedMethodCall.addCall(methodRefToMethodCallMap.computeIfAbsent(methodRef, _ -> 155 new ComputeReachableIfaceMappedMethodCall(this, methodRef, invokeWrapperCalledMethod) 156 )); 157 } else if (Accelerator.class.isAssignableFrom(javaRefClass)) { 158 // System.out.println("call on the accelerator (must be through the computeContext) -> " + methodRef); 159 computeReachableResolvedMethodCall.addCall(methodRefToMethodCallMap.computeIfAbsent(methodRef, _ -> 160 new ComputeReachableAcceleratorMethodCall(this, methodRef, invokeWrapperCalledMethod) 161 )); 162 163 } else if (ComputeContext.class.isAssignableFrom(javaRefClass)) { 164 // System.out.println("call on the computecontext -> " + methodRef); 165 computeReachableResolvedMethodCall.addCall(methodRefToMethodCallMap.computeIfAbsent(methodRef, _ -> 166 new ComputeContextMethodCall(this, methodRef, invokeWrapperCalledMethod) 167 )); 168 } else if (entrypoint.method.getDeclaringClass().equals(javaRefClass)) { 169 Optional<CoreOp.FuncOp> optionalFuncOp = Op.ofMethod(invokeWrapperCalledMethod); 170 if (optionalFuncOp.isPresent()) { 171 FuncOpWrapper fow = OpWrapper.wrap(optionalFuncOp.get()); 172 if (isKernelDispatch(invokeWrapperCalledMethod, fow)) { 173 // System.out.println("A kernel reference (not a direct call) to a kernel " + methodRef); 174 kernelCallGraphMap.computeIfAbsent(methodRef, _ -> 175 new KernelCallGraph(this, methodRef, invokeWrapperCalledMethod, fow).close() 176 ); 177 } else { 178 // System.out.println("A call to a method on the compute class which we have code model for " + methodRef); 179 computeReachableResolvedMethodCall.addCall(methodRefToMethodCallMap.computeIfAbsent(methodRef, _ -> 180 new OtherComputeReachableResolvedMethodCall(this, methodRef, invokeWrapperCalledMethod, fow) 181 )); 182 } 183 } else { 184 // System.out.println("A call to a method on the compute class which we DO NOT have code model for " + methodRef); 185 computeReachableResolvedMethodCall.addCall(methodRefToMethodCallMap.computeIfAbsent(methodRef, _ -> 186 new ComputeReachableUnresolvedMethodCall(this, methodRef, invokeWrapperCalledMethod) 187 )); 188 } 189 } else { 190 //TODO what about ifacenestings? 191 // System.out.println("A call to a method on the compute class which we DO NOT have code model for " + methodRef); 192 computeReachableResolvedMethodCall.addCall(methodRefToMethodCallMap.computeIfAbsent(methodRef, _ -> 193 new ComputeReachableUnresolvedMethodCall(this, methodRef, invokeWrapperCalledMethod) 194 )); 195 } 196 }); 197 198 if (kernelCallGraphMap.isEmpty()) { 199 throw new IllegalStateException("entrypoint compute has no kernel references!"); 200 } 201 202 boolean updated = true; 203 computeReachableResolvedMethodCall.closed = true; 204 while (updated) { 205 updated = false; 206 var unclosed = callStream().filter(m -> !m.closed).findFirst(); 207 if (unclosed.isPresent()) { 208 if (unclosed.get() instanceof ComputeReachableResolvedMethodCall reachableResolvedMethodCall) { 209 updateDag(reachableResolvedMethodCall); 210 } else { 211 unclosed.get().closed = true; 212 } 213 updated = true; 214 } 215 } 216 } 217 218 public void close() { 219 updateDag(entrypoint); 220 } 221 }