1 /*
  2  * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 package hat.backend;
 26 
 27 import hat.ComputeContext;
 28 import hat.Config;
 29 import hat.KernelContext;
 30 import hat.callgraph.KernelCallGraph;
 31 import hat.callgraph.KernelEntrypoint;
 32 
 33 import java.lang.reflect.InvocationTargetException;
 34 
 35 import hat.optools.OpTk;
 36 import jdk.incubator.code.Op;
 37 import jdk.incubator.code.bytecode.BytecodeGenerator;
 38 import jdk.incubator.code.dialect.core.CoreOp;
 39 import jdk.incubator.code.interpreter.Interpreter;
 40 
 41 public class DebugBackend extends BackendAdaptor {
 42     public enum HowToRunCompute{REFLECT, BABYLON_INTERPRETER, BABYLON_CLASSFILE}
 43     public HowToRunCompute howToRunCompute=HowToRunCompute.REFLECT;
 44     public enum HowToRunKernel{REFLECT, BABYLON_INTERPRETER, BABYLON_CLASSFILE, LOWER_TO_SSA,LOWER_TO_SSA_AND_MAP_PTRS}
 45     HowToRunKernel howToRunKernel = HowToRunKernel.LOWER_TO_SSA_AND_MAP_PTRS;
 46 
 47     public DebugBackend(){
 48        this(HowToRunCompute.REFLECT, HowToRunKernel.REFLECT);
 49     }
 50 
 51     public DebugBackend(HowToRunCompute howToRunCompute, HowToRunKernel howToRunKernel){
 52         super(Config.fromEnvOrProperty());
 53         this.howToRunCompute = howToRunCompute;
 54         this.howToRunKernel = howToRunKernel;
 55     }
 56 
 57     @Override
 58     public void dispatchCompute(ComputeContext computeContext, Object... args) {
 59         var here = OpTk.CallSite.of(DebugBackend.class,"dispatchCompute");
 60         switch (howToRunCompute){
 61 
 62             case REFLECT: {
 63                 try {
 64                     computeContext.computeCallGraph.entrypoint.method.invoke(null, args);
 65                 } catch (IllegalAccessException | InvocationTargetException e) {
 66                     throw new RuntimeException(e);
 67                 }
 68                 break;
 69             }
 70             case BABYLON_INTERPRETER:{
 71                 if (computeContext.computeCallGraph.entrypoint.lowered == null) {
 72                     computeContext.computeCallGraph.entrypoint.lowered = OpTk.lower(here, computeContext.computeCallGraph.entrypoint.funcOp());
 73                 }
 74                 Interpreter.invoke(computeContext.accelerator.lookup, computeContext.computeCallGraph.entrypoint.lowered, args);
 75                 break;
 76             }
 77             case BABYLON_CLASSFILE:{
 78                 if (computeContext.computeCallGraph.entrypoint.lowered == null) {
 79                     computeContext.computeCallGraph.entrypoint.lowered = OpTk.lower(here, computeContext.computeCallGraph.entrypoint.funcOp());
 80                 }
 81                 try {
 82                     if (computeContext.computeCallGraph.entrypoint.mh == null) {
 83                         computeContext.computeCallGraph.entrypoint.mh = BytecodeGenerator.generate(computeContext.accelerator.lookup, computeContext.computeCallGraph.entrypoint.lowered);
 84                     }
 85                     computeContext.computeCallGraph.entrypoint.mh.invokeWithArguments(args);
 86                 } catch (Throwable e) {
 87                     System.out.println(computeContext.computeCallGraph.entrypoint.lowered.toText());
 88                     throw new RuntimeException(e);
 89                 }
 90                 break;
 91             }
 92         }
 93     }
 94 
 95     @Override
 96     public void dispatchKernel(KernelCallGraph kernelCallGraph, KernelContext kernelContext, Object... args) {
 97 
 98         var here = OpTk.CallSite.of(DebugBackend.class, "dispatchKernel");
 99         switch (howToRunKernel){
100             case REFLECT: {
101                 KernelEntrypoint kernelEntrypoint = kernelCallGraph.entrypoint;
102                 for (kernelContext.gix = 0; kernelContext.gix < kernelContext.gsx; kernelContext.gix++) {
103                     try {
104                         args[0] = kernelContext;
105                         kernelEntrypoint.method.invoke(null, args);
106                     } catch (IllegalAccessException e) {
107                         throw new RuntimeException(e);
108                     } catch (InvocationTargetException e) {
109                         throw new RuntimeException(e);
110                     }
111                 }
112                 break;
113             }
114             case BABYLON_INTERPRETER:{
115                 var lowered = OpTk.lower(here, kernelCallGraph.entrypoint.funcOp());
116                 Interpreter.invoke(kernelCallGraph.computeContext.accelerator.lookup, lowered, args);
117                 break;
118             }
119             case BABYLON_CLASSFILE:{
120                 var lowered = OpTk.lower(here, kernelCallGraph.entrypoint.funcOp());
121                 var mh = BytecodeGenerator.generate(kernelCallGraph.computeContext.accelerator.lookup, lowered);
122                 try {
123                     mh.invokeWithArguments(args);
124                 } catch (Throwable e) {
125                     throw new RuntimeException(e);
126                 }
127                 break;
128             }
129 
130             case LOWER_TO_SSA:{
131                 var highLevelForm = Op.ofMethod(kernelCallGraph.entrypoint.method).orElseThrow();
132 
133 
134                 System.out.println("Initial code model");
135                 System.out.println(highLevelForm.toText());
136                 System.out.println("------------------");
137                 System.out.println("TRANSFORM dispatchKernel"+ DebugBackend.class);
138                 CoreOp.FuncOp loweredForm = OpTk.lower(here, highLevelForm);
139                 System.out.println("Lowered form which maintains original invokes and args");
140                 System.out.println(loweredForm.toText());
141                 System.out.println("-------------- ----");
142                 System.out.println("TRANSFORM dispatchKernel"+DebugBackend.class);
143 
144                 CoreOp.FuncOp ssaInvokeForm = OpTk.SSATransform(here, loweredForm);
145                 System.out.println("SSA form which maintains original invokes and args");
146                 System.out.println(ssaInvokeForm.toText());
147                 System.out.println("------------------");
148 
149             }
150 
151             case LOWER_TO_SSA_AND_MAP_PTRS:{
152                 var highLevelForm = Op.ofMethod(kernelCallGraph.entrypoint.method).orElseThrow();
153                 System.out.println("Initial code model");
154                 System.out.println(highLevelForm.toText());
155                 System.out.println("------------------");
156                 CoreOp.FuncOp loweredForm = OpTk.lower(here, highLevelForm);
157                 System.out.println("Lowered form which maintains original invokes and args");
158                 System.out.println(loweredForm.toText());
159                 System.out.println("-------------- ----");
160                 CoreOp.FuncOp ssaInvokeForm = OpTk.SSATransformLower(here, loweredForm);
161                 System.out.println("SSA form which maintains original invokes and args");
162                 System.out.println(ssaInvokeForm.toText());
163                 System.out.println("------------------");
164             }
165         }
166     }
167 }