1 /*
  2  * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 
 26 package hat.backend.ffi;
 27 
 28 import hat.Config;
 29 import hat.KernelContext;
 30 import hat.NDRange;
 31 import hat.annotations.Kernel;
 32 import hat.annotations.Preformatted;
 33 import hat.annotations.TypeDef;
 34 import hat.buffer.ArgArray;
 35 import hat.buffer.KernelBufferContext;
 36 import hat.callgraph.IfaceDataDag;
 37 import hat.callgraph.KernelCallGraph;
 38 import hat.callgraph.MethodCallDag;
 39 import hat.codebuilders.C99HATKernelBuilder;
 40 import hat.codebuilders.C99VecAndMatHandler;
 41 import hat.device.DeviceSchema;
 42 import hat.device.NonMappableIface;
 43 import hat.types.BF16;
 44 import hat.types.F16;
 45 import jdk.incubator.code.CodeTransformer;
 46 import optkl.ifacemapper.BoundSchema;
 47 import optkl.ifacemapper.Buffer;
 48 import optkl.ifacemapper.BufferState;
 49 import optkl.ifacemapper.BufferTracker;
 50 import optkl.ifacemapper.MappableIface;
 51 import optkl.ifacemapper.Schema;
 52 
 53 import java.lang.foreign.Arena;
 54 import java.lang.invoke.MethodHandles;
 55 import java.util.Arrays;
 56 import java.util.HashMap;
 57 import java.util.HashSet;
 58 import java.util.List;
 59 import java.util.Map;
 60 import java.util.Set;
 61 
 62 public abstract class C99FFIBackend extends FFIBackend implements BufferTracker {
 63     public C99FFIBackend(Arena arena, MethodHandles.Lookup lookup, String libName, Config config) {
 64         super(arena, lookup, libName, config);
 65     }
 66 
 67     public static class CompiledKernel {
 68         public final C99FFIBackend c99FFIBackend;
 69         public final KernelCallGraph kernelCallGraph;
 70         public final BackendBridge.CompilationUnitBridge.KernelBridge kernelBridge;
 71         public final ArgArray argArray;
 72         public final KernelBufferContext kernelBufferContext;
 73 
 74         public CompiledKernel(C99FFIBackend c99FFIBackend, KernelCallGraph kernelCallGraph, BackendBridge.CompilationUnitBridge.KernelBridge kernelBridge, Object[] ndRangeAndArgs) {
 75             this.c99FFIBackend = c99FFIBackend;
 76             this.kernelCallGraph = kernelCallGraph;
 77             this.kernelBridge = kernelBridge;
 78             this.kernelBufferContext = KernelBufferContext.createDefault(kernelCallGraph.computeCallGraph.computeContext.accelerator());
 79             ndRangeAndArgs[0] = this.kernelBufferContext;
 80             this.argArray = ArgArray.create(kernelCallGraph.computeCallGraph.computeContext.accelerator(), kernelCallGraph, ndRangeAndArgs);
 81         }
 82 
 83         public void dispatch(KernelContext kernelContext, Object[] args) {
 84             // Do we really need this?  We never actually read these
 85             kernelBufferContext.gsy(1);
 86             kernelBufferContext.gsz(1);
 87             switch (kernelContext.ndRange.global()) {
 88                 case NDRange.Global1D global1D -> {
 89                     kernelBufferContext.gsx(global1D.x());
 90                     kernelBufferContext.dimensions(global1D.dimension());
 91                 }
 92                 case NDRange.Global2D global2D -> {
 93                     kernelBufferContext.gsx(global2D.x());
 94                     kernelBufferContext.gsy(global2D.y());
 95                     kernelBufferContext.dimensions(global2D.dimension());
 96                 }
 97                 case NDRange.Global3D global3D -> {
 98                     kernelBufferContext.gsx(global3D.x());
 99                     kernelBufferContext.gsy(global3D.y());
100                     kernelBufferContext.gsz(global3D.z());
101                     kernelBufferContext.dimensions(global3D.dimension());
102                 }
103                 case null, default -> throw new IllegalArgumentException("Unknown global range " + kernelContext.ndRange.global().getClass());
104             }
105 
106             if (kernelContext.ndRange.hasLocal()) {
107                 kernelBufferContext.lsy(1);
108                 kernelBufferContext.lsz(1);
109                 switch (kernelContext.ndRange.local()) {
110                     case NDRange.Local1D local1D -> {
111                         kernelBufferContext.lsx(local1D.x());
112                         kernelBufferContext.dimensions(local1D.dimension());
113                     }
114                     case NDRange.Local2D local2D -> {
115                         kernelBufferContext.lsx(local2D.x());
116                         kernelBufferContext.lsy(local2D.y());
117                         kernelBufferContext.dimensions(local2D.dimension());
118                     }
119                     case NDRange.Local3D local3D -> {
120                         kernelBufferContext.lsx(local3D.x());
121                         kernelBufferContext.lsy(local3D.y());
122                         kernelBufferContext.lsz(local3D.z());
123                         kernelBufferContext.dimensions(local3D.dimension());
124                     }
125                     case null, default -> throw new IllegalArgumentException("Unknown global range " + kernelContext.ndRange.local().getClass());
126                 }
127             } else {
128                 kernelBufferContext.lsx(0);
129                 kernelBufferContext.lsy(0);
130                 kernelBufferContext.lsz(0);
131             }
132 
133             args[0] = this.kernelBufferContext;
134             ArgArray.update(argArray, kernelCallGraph, args);
135             kernelBridge.ndRange(this.argArray);
136         }
137     }
138 
139     public Map<KernelCallGraph, CompiledKernel> kernelCallGraphCompiledCodeMap = new HashMap<>();
140 
141 
142     public <T extends C99HATKernelBuilder<T>> String createCode(KernelCallGraph kernelCallGraph, T builder, Object... args) {
143         builder.defines().types();
144 
145         var visitedAlready = new HashSet<Schema.IfaceType>();
146         Arrays.stream(args)
147                 .filter(arg -> arg instanceof Buffer)
148                 .map(arg -> (Buffer) arg)
149                 .forEach(ifaceBuffer -> {
150                     BoundSchema<?> boundSchema = MappableIface.getBoundSchema(ifaceBuffer);
151                     boundSchema.schema().rootIfaceType.visitUniqueTypes(t -> {
152                         if (visitedAlready.add(t)) { // true first time we see this type
153                             builder.typedef(boundSchema, t);
154                         }
155                     });
156                 });
157 
158 
159         var kernelAnnotation = kernelCallGraph.callDag.entryPoint.method().getAnnotation(Kernel.class);
160         if (kernelAnnotation != null) {
161             // If we find a kernelAnnotation we can't trust the data in kernelCallGraph's state.
162             kernelCallGraph.setUsesAtomics(true);
163             kernelCallGraph.accessedFP16Classes.addAll(List.of(F16.class, BF16.class));
164             kernelCallGraph.setUsesBarrier(true);
165 
166             var typedefAnnotation = kernelCallGraph.callDag.entryPoint.method().getAnnotation(TypeDef.class);
167             if (typedefAnnotation != null) {
168                 builder.lineComment("Preformatted typedef body from @Typedef annotation");
169                 builder.typedefStruct(typedefAnnotation.name(), _ -> builder.preformatted(typedefAnnotation.body())).semicolon().nl();
170             }
171             var preformattedAnnotation = kernelCallGraph.callDag.entryPoint.method().getAnnotation(Preformatted.class);
172             if (preformattedAnnotation != null) {
173                 builder.lineComment("Preformatted text from @Preformatted annotation");
174                 builder.preformatted(preformattedAnnotation.value());
175             }
176             builder.lineComment("Preformatted code body from @Kernel annotation");
177             builder.preformatted(kernelAnnotation.value());
178         } else {
179             Set<Class<?>> typedeffed = new HashSet<>();
180             typedeffed.add(F16.class);
181             typedeffed.add(BF16.class);
182             kernelCallGraph.accessedNonMappableIfaceClasses.stream()
183                     .filter(c->!typedeffed.contains(c))
184                     .map(c->(Class<NonMappableIface>) c) // why do we need to do this.
185                     .forEach(c -> {
186                         // We create a dag of iface references rooted at c
187                         var ifaceDataDag = new IfaceDataDag<NonMappableIface>(dag -> {
188                             var entryPoint = dag.getNode(c);
189                             dag.methodsWithIfaceReturnTypes(c).forEach(ifaceInfo ->
190                                  dag.addEdge(entryPoint, dag.getNode(ifaceInfo.clazz())) // this recurses with each added class
191                             );
192                         });
193                         // Now we can generate typedefs in rankOrder (so inner typedefs first)
194                         if (ifaceDataDag.isDag()) {
195                             ifaceDataDag.rankOrdered.stream()
196                                     .filter(ifaceInfo -> !typedeffed.contains(ifaceInfo.clazz()))
197                                     .forEach(ifaceInfo -> typedeffed.add(
198                                             DeviceSchema.getDeviceSchemaOrThrow(ifaceInfo.clazz()).typedef(builder).clazz()
199                                     )
200                             );
201                         } else  {
202                             typedeffed.add(DeviceSchema.getDeviceSchemaOrThrow(c).typedef(builder).clazz());
203                         }
204                     });
205 
206             // This is a slight hack for Shader support.
207             if (!kernelCallGraph.accessedVecClasses.isEmpty()) {
208                 C99VecAndMatHandler.createVecFunctions(builder);
209             }
210 
211             kernelCallGraph.callDag.rankOrdered.stream()
212                     .filter(m -> m instanceof MethodCallDag.OtherMethodCall)
213                     .forEach(m -> builder.nl().kernelMethod( m.funcOp()).nl());
214 
215             builder.nl().kernelEntrypoint().nl();
216 
217             if (config().showKernelModel()) {
218                 IO.println("Non Lowered");
219                 IO.println(kernelCallGraph.callDag.entryPoint.funcOp().toText());
220             }
221             if (config().showLoweredKernelModel()) {
222                 IO.println("Lowered");
223                 IO.println(kernelCallGraph.callDag.entryPoint.funcOp().transform(CodeTransformer.LOWERING_TRANSFORMER).toText());
224             }
225         }
226         return builder.toString();
227     }
228 
229     @Override
230     public void preMutate(MappableIface b) {
231         switch (b.getState()) {
232             case BufferState.NO_STATE:
233             case BufferState.NEW_STATE:
234             case BufferState.HOST_OWNED:
235             case BufferState.DEVICE_VALID_HOST_HAS_COPY: {
236                 if (config().showState()) {
237                     System.out.println("in preMutate state = " + b.getStateString() + " no action to take");
238                 }
239                 break;
240             }
241             case BufferState.DEVICE_OWNED: {
242                 backendBridge.getBufferFromDeviceIfDirty(b);// calls through FFI and might block when fetching from device
243                 if (config().showState()) {
244                     System.out.print("in preMutate state = " + b.getStateString() + " we pulled from device ");
245                 }
246                 b.setState(BufferState.DEVICE_VALID_HOST_HAS_COPY);
247                 if (config().showState()) {
248                     System.out.println("and switched to " + b.getStateString());
249                 }
250                 break;
251             }
252             default:
253                 throw new IllegalStateException("Not expecting this state ");
254         }
255     }
256 
257     @Override
258     public void postMutate(MappableIface b) {
259         if (config().showState()) {
260             System.out.print("in postMutate state = " + b.getStateString() + " no action to take ");
261         }
262         if (b.getState() != BufferState.NEW_STATE) {
263             b.setState(BufferState.HOST_OWNED);
264         }
265         if (config().showState()) {
266             System.out.println("and switched to (or stayed on) " + b.getStateString());
267         }
268     }
269 
270     @Override
271     public void preAccess(MappableIface b) {
272         switch (b.getState()) {
273             case BufferState.NO_STATE:
274             case BufferState.NEW_STATE:
275             case BufferState.HOST_OWNED:
276             case BufferState.DEVICE_VALID_HOST_HAS_COPY: {
277                 if (config().showState()) {
278                     System.out.println("in preAccess state = " + b.getStateString() + " no action to take");
279                 }
280                 break;
281             }
282             case BufferState.DEVICE_OWNED: {
283                 backendBridge.getBufferFromDeviceIfDirty(b);// calls through FFI and might block when fetching from device
284 
285                 if (config().showState()) {
286                     System.out.print("in preAccess state = " + b.getStateString() + " we pulled from device ");
287                 }
288                 b.setState(BufferState.DEVICE_VALID_HOST_HAS_COPY);
289                 if (config().showState()) {
290                     System.out.println("and switched to " + b.getStateString());
291                 }
292                 break;
293             }
294             default:
295                 throw new IllegalStateException("Not expecting this state ");
296         }
297     }
298 
299 
300     @Override
301     public void postAccess(MappableIface b) {
302         if (config().showState()) {
303             System.out.println("in postAccess state = " + b.getStateString());
304         }
305     }
306 }