1 /*
2 * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation. Oracle designates this
8 * particular file as subject to the "Classpath" exception as provided
9 * by Oracle in the LICENSE file that accompanied this code.
10 *
11 * This code is distributed in the hope that it will be useful, but WITHOUT
12 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
13 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
14 * version 2 for more details (a copy is included in the LICENSE file that
15 * accompanied this code).
16 *
17 * You should have received a copy of the GNU General Public License version
18 * 2 along with this work; if not, write to the Free Software Foundation,
19 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
20 *
21 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
22 * or visit www.oracle.com if you need additional information or have any
23 * questions.
24 */
25 package hat.callgraph;
26
27 import hat.ComputeContext;
28 import hat.Config;
29 import hat.KernelContext;
30 import jdk.incubator.code.CodeTransformer;
31 import jdk.incubator.code.bytecode.BytecodeGenerator;
32 import optkl.OpHelper;
33 import optkl.ifacemapper.MappableIface;
34 import optkl.FuncOpParams;
35
36
37 import java.lang.invoke.MethodHandle;
38 import java.lang.invoke.MethodHandles;
39 import java.lang.reflect.Method;
40 import java.util.HashMap;
41 import java.util.Map;
42 import java.util.stream.Stream;
43
44 import jdk.incubator.code.dialect.core.CoreOp;
45 import jdk.incubator.code.dialect.java.JavaType;
46 import jdk.incubator.code.dialect.java.MethodRef;
47 import optkl.util.carriers.LookupCarrier;
48
49
50 public class ComputeCallGraph implements LookupCarrier {
51 @Override public MethodHandles.Lookup lookup(){
52 return computeContext.lookup();
53 }
54
55 public static final boolean showComputeCallDag =Boolean.getBoolean("showComputeCallDag");
56 public final ComputeContext computeContext;
57 public final MethodCallDag callDag;
58 private CoreOp.FuncOp lowered;
59 private MethodHandle bytecodeGeneratedMethodHandle;
60
61 static boolean isValidKernelDispatch(MethodHandles.Lookup lookup, Method calledMethod, CoreOp.FuncOp funcOp) {
62 // We check that the proposed kernel returns void, the first arg is an KernelContext and we have more args
63 // We also check that other args are primitive or ifacebuffers (or atomics?)...
64 class Traits{
65 boolean firstArgKernelContext = false;
66 boolean atLeastOneIfaceBufferParam=false;
67 boolean hasOnlyPrimitiveAndIfaceBufferParams=true;
68 boolean ok(){
69 return firstArgKernelContext &&atLeastOneIfaceBufferParam&&hasOnlyPrimitiveAndIfaceBufferParams;
70 }
71 }
72 var traits = new Traits();
73 if (funcOp.body().yieldType().equals(JavaType.VOID)
74 && calledMethod.getParameterTypes() instanceof Class<?>[] parameterTypes
75 && parameterTypes.length > 1) {
76 FuncOpParams paramTable = new FuncOpParams(funcOp);
77 paramTable.stream().forEach(paramInfo -> {
78 if (paramInfo.idx == 0) {
79 traits.firstArgKernelContext = parameterTypes[0].isAssignableFrom(KernelContext.class);
80 } else {
81 if (paramInfo.isPrimitive()) {
82 // OK
83 } else if (OpHelper.isAssignable(lookup,paramInfo.javaType, MappableIface.class)){
84 traits.atLeastOneIfaceBufferParam= true;
85 } else {
86 traits.hasOnlyPrimitiveAndIfaceBufferParams=false;
87 }
88 }
89 });
90 }
91 return traits.ok();
92 }
93
94 public final Map<MethodRef, KernelCallGraph> kernelCallGraphMap = new HashMap<>();
95
96 public ComputeCallGraph(ComputeContext computeContext, Method method, CoreOp.FuncOp entry) {
97 this.computeContext = computeContext;
98 this.callDag = new MethodCallDag(lookup(), method,entry,null);
99 if (showComputeCallDag){
100 this.callDag.view("computeCallDag", n -> n.funcOp().funcName());
101 }
102
103 callDag.rankOrdered.stream()
104 .filter(m->m instanceof MethodCallDag.OtherMethodCall &&
105 this.callDag.entryPoint.method().getDeclaringClass().equals(m.method().getDeclaringClass())
106 && isValidKernelDispatch(computeContext.lookup(),m.method(),m.funcOp()))
107 .forEach(m-> kernelCallGraphMap.computeIfAbsent( m.methodRef(), _ ->
108 new KernelCallGraph(this, m.method(), m.funcOp())
109 )
110 );
111
112 }
113
114 public CoreOp.FuncOp lazyLower(){
115 if (lowered == null) {
116 lowered =callDag.entryPoint.funcOp().transform(CodeTransformer.LOWERING_TRANSFORMER);
117 }
118 return lowered;
119 }
120
121 public void invokeWithArgs(Object[] args) {
122 try {
123 if (bytecodeGeneratedMethodHandle == null) {
124 bytecodeGeneratedMethodHandle = BytecodeGenerator.generate(lookup(),lazyLower());
125 }
126 bytecodeGeneratedMethodHandle.invokeWithArguments(args);
127 }catch (Throwable e) {
128 throw new RuntimeException(e);
129 }
130 }
131 }