1 /*
  2  * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 package hat.callgraph;
 26 
 27 import hat.ComputeContext;
 28 import hat.Config;
 29 import hat.KernelContext;
 30 import jdk.incubator.code.CodeTransformer;
 31 import jdk.incubator.code.bytecode.BytecodeGenerator;
 32 import optkl.OpHelper;
 33 import optkl.ifacemapper.MappableIface;
 34 import optkl.FuncOpParams;
 35 
 36 
 37 import java.lang.invoke.MethodHandle;
 38 import java.lang.invoke.MethodHandles;
 39 import java.lang.reflect.Method;
 40 import java.util.HashMap;
 41 import java.util.Map;
 42 import java.util.stream.Stream;
 43 
 44 import jdk.incubator.code.dialect.core.CoreOp;
 45 import jdk.incubator.code.dialect.java.JavaType;
 46 import jdk.incubator.code.dialect.java.MethodRef;
 47 import optkl.util.carriers.LookupCarrier;
 48 
 49 
 50 public class ComputeCallGraph implements LookupCarrier {
 51     @Override public MethodHandles.Lookup lookup(){
 52         return computeContext.lookup();
 53     }
 54 
 55     public static final boolean  showComputeCallDag =Boolean.getBoolean("showComputeCallDag");
 56     public final ComputeContext computeContext;
 57     public final MethodCallDag callDag;
 58     private CoreOp.FuncOp lowered;
 59     private MethodHandle bytecodeGeneratedMethodHandle;
 60 
 61     static boolean isValidKernelDispatch(MethodHandles.Lookup lookup, Method calledMethod, CoreOp.FuncOp funcOp) {
 62         // We check that the proposed kernel returns void, the first arg is an KernelContext and we have more args
 63         // We also check that other args are primitive or ifacebuffers  (or atomics?)...
 64         class Traits{
 65             boolean firstArgKernelContext = false;
 66             boolean atLeastOneIfaceBufferParam=false;
 67             boolean hasOnlyPrimitiveAndIfaceBufferParams=true;
 68             boolean ok(){
 69                 return firstArgKernelContext &&atLeastOneIfaceBufferParam&&hasOnlyPrimitiveAndIfaceBufferParams;
 70             }
 71         }
 72         var traits = new Traits();
 73         if (funcOp.body().yieldType().equals(JavaType.VOID)
 74                 && calledMethod.getParameterTypes() instanceof Class<?>[] parameterTypes
 75                 && parameterTypes.length > 1) {
 76                 FuncOpParams paramTable = new FuncOpParams(funcOp);
 77                 paramTable.stream().forEach(paramInfo -> {
 78                     if (paramInfo.idx == 0) {
 79                         traits.firstArgKernelContext = parameterTypes[0].isAssignableFrom(KernelContext.class);
 80                     } else {
 81                         if (paramInfo.isPrimitive()) {
 82                             // OK
 83                         } else if (OpHelper.isAssignable(lookup,paramInfo.javaType, MappableIface.class)){
 84                             traits.atLeastOneIfaceBufferParam= true;
 85                         } else {
 86                             traits.hasOnlyPrimitiveAndIfaceBufferParams=false;
 87                         }
 88                     }
 89                 });
 90             }
 91             return traits.ok();
 92     }
 93 
 94     public final Map<MethodRef, KernelCallGraph> kernelCallGraphMap = new HashMap<>();
 95 
 96     public ComputeCallGraph(ComputeContext computeContext, Method method, CoreOp.FuncOp entry) {
 97         this.computeContext = computeContext;
 98         this.callDag = new MethodCallDag(lookup(), method,entry,null);
 99         if (showComputeCallDag){
100             this.callDag.view("computeCallDag", n -> n.funcOp().funcName());
101         }
102 
103             callDag.rankOrdered.stream()
104                     .filter(m->m instanceof MethodCallDag.OtherMethodCall &&
105                         this.callDag.entryPoint.method().getDeclaringClass().equals(m.method().getDeclaringClass())
106                                 && isValidKernelDispatch(computeContext.lookup(),m.method(),m.funcOp()))
107                 .forEach(m-> kernelCallGraphMap.computeIfAbsent( m.methodRef(), _ ->
108                     new KernelCallGraph(this, m.method(), m.funcOp())
109             )
110         );
111 
112     }
113 
114     public CoreOp.FuncOp lazyLower(){
115         if (lowered == null) {
116             lowered =callDag.entryPoint.funcOp().transform(CodeTransformer.LOWERING_TRANSFORMER);
117         }
118         return lowered;
119     }
120 
121     public void invokeWithArgs(Object[] args) {
122         try {
123             if (bytecodeGeneratedMethodHandle == null) {
124                 bytecodeGeneratedMethodHandle = BytecodeGenerator.generate(lookup(),lazyLower());
125             }
126             bytecodeGeneratedMethodHandle.invokeWithArguments(args);
127         }catch (Throwable e) {
128             throw new RuntimeException(e);
129         }
130     }
131 }