1 /*
  2  * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 
 26 package hat.backend.ffi;
 27 
 28 import hat.NDRange;
 29 import hat.Config;
 30 import hat.KernelContext;
 31 import hat.types.BF16;
 32 import hat.types.F16;
 33 import jdk.incubator.code.CodeTransformer;
 34 import hat.annotations.Kernel;
 35 import hat.annotations.Preformatted;
 36 import hat.annotations.TypeDef;
 37 import hat.buffer.*;
 38 import hat.codebuilders.C99HATKernelBuilder;
 39 import hat.callgraph.KernelCallGraph;
 40 import optkl.codebuilders.ScopedCodeBuilderContext;
 41 import hat.device.DeviceSchema;
 42 import hat.dialect.HATMemoryVarOp;
 43 import optkl.ifacemapper.BoundSchema;
 44 import optkl.ifacemapper.Buffer;
 45 import optkl.ifacemapper.BufferState;
 46 import optkl.ifacemapper.BufferTracker;
 47 import optkl.ifacemapper.MappableIface;
 48 import optkl.ifacemapper.Schema;
 49 import hat.phases.HATFinalDetector;
 50 import jdk.incubator.code.TypeElement;
 51 import jdk.incubator.code.dialect.java.ClassType;
 52 
 53 import java.lang.foreign.Arena;
 54 import java.lang.invoke.MethodHandles;
 55 import java.lang.reflect.Field;
 56 import java.util.ArrayList;
 57 import java.util.Arrays;
 58 import java.util.HashMap;
 59 import java.util.HashSet;
 60 import java.util.LinkedHashSet;
 61 import java.util.List;
 62 import java.util.Map;
 63 import java.util.Objects;
 64 import java.util.Set;
 65 
 66 public abstract class C99FFIBackend extends FFIBackend  implements BufferTracker {
 67     public C99FFIBackend(Arena arena, MethodHandles.Lookup lookup,String libName, Config config) {
 68         super(arena,lookup,libName, config);
 69     }
 70     public static class CompiledKernel {
 71         public final C99FFIBackend c99FFIBackend;
 72         public final KernelCallGraph kernelCallGraph;
 73         public final BackendBridge.CompilationUnitBridge.KernelBridge kernelBridge;
 74         public final ArgArray argArray;
 75         public final KernelBufferContext kernelBufferContext;
 76 
 77         public CompiledKernel(C99FFIBackend c99FFIBackend, KernelCallGraph kernelCallGraph, BackendBridge.CompilationUnitBridge.KernelBridge kernelBridge, Object[] ndRangeAndArgs) {
 78             this.c99FFIBackend = c99FFIBackend;
 79             this.kernelCallGraph = kernelCallGraph;
 80             this.kernelBridge = kernelBridge;
 81             this.kernelBufferContext = KernelBufferContext.createDefault(kernelCallGraph.computeContext.accelerator());
 82             ndRangeAndArgs[0] = this.kernelBufferContext;
 83             this.argArray = ArgArray.create(kernelCallGraph.computeContext.accelerator(),kernelCallGraph,  ndRangeAndArgs);
 84         }
 85 
 86         public void dispatch(KernelContext kernelContext, Object[] args) {
 87             kernelBufferContext.gsy(1);
 88             kernelBufferContext.gsz(1);
 89             switch (kernelContext.ndRange.global()) {
 90                 case NDRange.Global1D global1D -> {
 91                     kernelBufferContext.gsx(global1D.x());
 92                     kernelBufferContext.dimensions(global1D.dimension());
 93                 }
 94                 case NDRange.Global2D global2D -> {
 95                     kernelBufferContext.gsx(global2D.x());
 96                     kernelBufferContext.gsy(global2D.y());
 97                     kernelBufferContext.dimensions(global2D.dimension());
 98                 }
 99                 case NDRange.Global3D global3D -> {
100                     kernelBufferContext.gsx(global3D.x());
101                     kernelBufferContext.gsy(global3D.y());
102                     kernelBufferContext.gsz(global3D.z());
103                     kernelBufferContext.dimensions(global3D.dimension());
104                 }
105                 case null, default -> {
106                     throw new IllegalArgumentException("Unknown global range " + kernelContext.ndRange.global().getClass());
107                 }
108             }
109             if (kernelContext.ndRange.hasLocal()) {
110                 kernelBufferContext.lsy(1);
111                 kernelBufferContext.lsz(1);
112                 switch (kernelContext.ndRange.local()) {
113                     case NDRange.Local1D local1D -> {
114                         kernelBufferContext.lsx(local1D.x());
115                         kernelBufferContext.dimensions(local1D.dimension());
116                     }
117                     case NDRange.Local2D local2D -> {
118                         kernelBufferContext.lsx(local2D.x());
119                         kernelBufferContext.lsy(local2D.y());
120                         kernelBufferContext.dimensions(local2D.dimension());
121                     }
122                     case NDRange.Local3D local3D -> {
123                         kernelBufferContext.lsx(local3D.x());
124                         kernelBufferContext.lsy(local3D.y());
125                         kernelBufferContext.lsz(local3D.z());
126                         kernelBufferContext.dimensions(local3D.dimension());
127                     }
128                     case null, default -> {
129                         throw new IllegalArgumentException("Unknown global range " + kernelContext.ndRange.local().getClass());
130                     }
131                 }
132             } else {
133                 kernelBufferContext.lsx(0);
134                 kernelBufferContext.lsy(0);
135                 kernelBufferContext.lsz(0);
136             }
137             args[0] = this.kernelBufferContext;
138             ArgArray.update(argArray, kernelCallGraph, args);
139             kernelBridge.ndRange(this.argArray);
140         }
141     }
142 
143     public Map<KernelCallGraph, CompiledKernel> kernelCallGraphCompiledCodeMap = new HashMap<>();
144 
145     private <T extends C99HATKernelBuilder<T>> void generateDeviceTypeStructs(T builder, String toText, Set<String> typedefs) {
146         // From here is text processing
147         String[] split = toText.split(">");
148         // Each item is a data struct
149         for (String s : split) {
150             // curate: remove first character
151             s = s.substring(1);
152             String dsName = s.split(":")[0];
153             if (typedefs.contains(dsName)) {
154                 continue;
155             }
156             typedefs.add(dsName);
157             // sanitize dsName
158             dsName = sanitize(dsName);
159             builder.typedefKeyword()
160                     .space()
161                     .structKeyword()
162                     .space()
163                     .suffix_s(dsName)
164                     .obrace()
165                     .nl();
166 
167             String[] members = s.split(";");
168 
169             int j = 0;
170             builder.in();
171             for (int i = 0; i < members.length; i++) {
172                 String member = members[i];
173                 String[] field = member.split(":");
174                 if (i == 0) {
175                     j = 1;
176                 }
177                 String isArray = field[j++];
178                 String type = field[j++];
179                 String name = field[j++];
180                 String lenValue = "";
181                 if (isArray.equals("[")) {
182                     lenValue = field[j];
183                 }
184                 j = 0;
185                 if (typedefs.contains(type))
186                     type = sanitize(type) + "_t";
187                 else
188                     type = sanitize(type);
189 
190                 builder.typeName(type)
191                         .space()
192                         .identifier(name);
193 
194                 if (isArray.equals("[")) {
195                     builder.space()
196                             .osbrace()
197                             .identifier(lenValue)
198                             .csbrace();
199                 }
200                 builder.semicolon().nl();
201             }
202             builder.out();
203             builder.cbrace().suffix_t(dsName).semicolon().nl().nl();
204         }
205     }
206 
207     public <T extends C99HATKernelBuilder<T>> String createCode(KernelCallGraph kernelCallGraph, T builder, Object... args) {
208         builder.defines().types();
209         var visitedAlready=  new HashSet<Schema.IfaceType>();
210         Arrays.stream(args)
211                 .filter(arg -> arg instanceof Buffer)
212                 .map(arg -> (Buffer) arg)
213                 .forEach(ifaceBuffer -> {
214                     BoundSchema<?> boundSchema = MappableIface.getBoundSchema(ifaceBuffer);
215                     boundSchema.schema().rootIfaceType.visitUniqueTypes( t -> {
216                         if (visitedAlready.add(t)) { // true first time we see this type
217                             builder.typedef(boundSchema, t);
218                         }
219                     });
220                 });
221 
222         var annotation = kernelCallGraph.entrypoint.method().getAnnotation(Kernel.class);
223         if (annotation!=null){
224             var typedef = kernelCallGraph.entrypoint.method().getAnnotation(TypeDef.class);
225             if (typedef!=null){
226                 builder.lineComment("Preformatted typedef body from @Typedef annotation");
227                 builder.typedefKeyword().space().structKeyword().space().suffix_s(typedef.name()).braceNlIndented(_->
228                         builder.preformatted(typedef.body())
229                 ).suffix_t(typedef.name()).semicolon().nl();
230             }
231             var preformatted = kernelCallGraph.entrypoint.method().getAnnotation(Preformatted.class);
232             if (preformatted!=null){
233                 builder.lineComment("Preformatted text from @Preformatted annotation");
234                 builder.preformatted(preformatted.value());
235             }
236             builder.lineComment("Preformatted code body from @Kernel annotation");
237             builder.preformatted(annotation.value());
238         } else {
239             Set<String> typedefs = new HashSet<>();
240 
241             // Add HAT reserved types
242             typedefs.add(F16.class.getName());
243             typedefs.add(BF16.class.getName());
244 
245             /*
246              I think the kernelCallGraph module op was built before we inserted HATMemoryVarOps
247 
248              So we will likely never get any matches from the module op
249 
250              List<ClassType> localIFaceList = new ArrayList<>();
251              kernelCallGraph.getModuleOp()
252                     .elements()
253                     .filter(c -> Objects.requireNonNull(c) instanceof HATMemoryVarOp)
254                     .map(c -> (ClassType)((HATMemoryVarOp) c).invokeType())
255                     .forEach(localIFaceList::add);
256 
257 
258 
259              However,the sentiment from above was correct as we may have kernel reachable methods that do indeed
260              have these HATMemoryVarOps.  I think if we called a method from the entrypoint with Device type accesses
261              we would miss them
262              */
263 
264             // Dynamically build the schema for the user data type we are creating within the kernel.
265             // This is because no allocation was done from the host. This is kernel code, and it is reflected
266             // using the code reflection API
267             // 1. Add for struct for iface objects
268             kernelCallGraph.entrypoint.funcOp()
269                     .elements()
270                     .filter(c -> Objects.requireNonNull(c) instanceof HATMemoryVarOp)
271                     .map(c -> (ClassType)((HATMemoryVarOp) c).invokeType())
272                     .forEach( classType-> {
273                          try {
274                              Class<?> clazz = (Class<?>) classType.resolve(kernelCallGraph.lookup());
275                              Field schemaField = clazz.getDeclaredField("schema");
276                              schemaField.setAccessible(true);
277                              var schema = (DeviceSchema<?>)schemaField.get(schemaField);
278                              // <1> We are creating text form of DeviceType schema
279                              String toText = schema.toText();
280                              if (toText != null) {
281                                  // <2> just to then parse the text from above.
282                                  // Lets get the model in a cleaner form
283                                  generateDeviceTypeStructs(builder, toText, typedefs);
284                              } else {
285                                  throw new RuntimeException("[ERROR] Could not find valid device schema ");
286                              }
287                          } catch (ReflectiveOperationException e) {
288                              throw new RuntimeException(e);
289                          }
290             });
291 
292             var buildContext = new ScopedCodeBuilderContext(kernelCallGraph.lookup(), kernelCallGraph.entrypoint.funcOp());
293 
294             kernelCallGraph.getModuleOp().functionTable()
295                     .forEach((_, funcOp) -> {
296                         // TODO: did we just trash the callgraph sidetables?
297                         //  Why are we transforming the callgraph here
298                         HATFinalDetector finals = new HATFinalDetector(kernelCallGraph);
299                         // Update the build context for this method to use the right constants-map
300                         buildContext.setFinals(finals.applied(funcOp));
301                         builder.nl().kernelMethod(buildContext, funcOp).nl();
302                     });
303 
304             // Update the constants-map for the main kernel
305             // Why are we doing this here we should not be mutating the kernel callgraph at this point
306             HATFinalDetector hatFinalDetector = new HATFinalDetector(kernelCallGraph);
307             buildContext.setFinals(hatFinalDetector.applied(kernelCallGraph.entrypoint.funcOp()));
308 
309             builder.nl().kernelEntrypoint(buildContext).nl();
310 
311             if (config().showKernelModel()) {
312                 IO.println("Non Lowered");
313                 IO.println(kernelCallGraph.entrypoint.funcOp().toText());
314             }
315             if (config().showLoweredKernelModel()) {
316                 IO.println("Lowered");
317                 IO.println(kernelCallGraph.entrypoint.funcOp().transform(CodeTransformer.LOWERING_TRANSFORMER).toText());
318             }
319         }
320         return builder.toString();
321     }
322 
323 
324     private String sanitize(String s) {
325         String[] split1 = s.split("\\.");
326         if (split1.length == 1) {
327             return s;
328         }
329         s = split1[split1.length - 1];
330         if (s.split("\\$").length > 1) {
331             s = sanitize(s.split("\\$")[1]);
332         }
333         return s;
334     }
335 
336     @Override
337     public void preMutate(MappableIface b) {
338         switch (b.getState()) {
339             case BufferState.NO_STATE:
340             case BufferState.NEW_STATE:
341             case BufferState.HOST_OWNED:
342             case BufferState.DEVICE_VALID_HOST_HAS_COPY: {
343                 if (config().showState()) {
344                     System.out.println("in preMutate state = " + b.getStateString() + " no action to take");
345                 }
346                 break;
347             }
348             case BufferState.DEVICE_OWNED: {
349                 backendBridge.getBufferFromDeviceIfDirty(b);// calls through FFI and might block when fetching from device
350                 if (config().showState()) {
351                     System.out.print("in preMutate state = " + b.getStateString() + " we pulled from device ");
352                 }
353                 b.setState(BufferState.DEVICE_VALID_HOST_HAS_COPY);
354                 if (config().showState()) {
355                     System.out.println("and switched to " + b.getStateString());
356                 }
357                 break;
358             }
359             default:
360                 throw new IllegalStateException("Not expecting this state ");
361         }
362     }
363 
364     @Override
365     public void postMutate(MappableIface b) {
366         if (config().showState()) {
367             System.out.print("in postMutate state = " + b.getStateString() + " no action to take ");
368         }
369         if (b.getState() != BufferState.NEW_STATE) {
370             b.setState(BufferState.HOST_OWNED);
371         }
372         if (config().showState()) {
373             System.out.println("and switched to (or stayed on) " + b.getStateString());
374         }
375     }
376 
377     @Override
378     public void preAccess(MappableIface b) {
379         switch (b.getState()) {
380             case BufferState.NO_STATE:
381             case BufferState.NEW_STATE:
382             case BufferState.HOST_OWNED:
383             case BufferState.DEVICE_VALID_HOST_HAS_COPY: {
384                 if (config().showState()) {
385                     System.out.println("in preAccess state = " + b.getStateString() + " no action to take");
386                 }
387                 break;
388             }
389             case BufferState.DEVICE_OWNED: {
390                 backendBridge.getBufferFromDeviceIfDirty(b);// calls through FFI and might block when fetching from device
391 
392                 if (config().showState()) {
393                     System.out.print("in preAccess state = " + b.getStateString() + " we pulled from device ");
394                 }
395                 b.setState(BufferState.DEVICE_VALID_HOST_HAS_COPY);
396                 if (config().showState()) {
397                     System.out.println("and switched to " + b.getStateString());
398                 }
399                 break;
400             }
401             default:
402                 throw new IllegalStateException("Not expecting this state ");
403         }
404     }
405 
406 
407     @Override
408     public void postAccess(MappableIface b) {
409         if (config().showState()) {
410             System.out.println("in postAccess state = " + b.getStateString());
411         }
412     }
413 }