1 /*
  2  * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 
 26 package hat.backend.ffi;
 27 
 28 import hat.NDRange;
 29 import hat.Config;
 30 import hat.KernelContext;
 31 import hat.annotations.Kernel;
 32 import hat.annotations.Preformatted;
 33 import hat.annotations.TypeDef;
 34 import hat.buffer.F16;
 35 import hat.buffer.KernelBufferContext;
 36 import hat.codebuilders.C99HATKernelBuilder;
 37 import hat.buffer.ArgArray;
 38 import hat.buffer.Buffer;
 39 import hat.buffer.BufferTracker;
 40 import hat.callgraph.KernelCallGraph;
 41 import hat.codebuilders.ScopedCodeBuilderContext;
 42 import hat.device.DeviceSchema;
 43 import hat.dialect.HATMemoryOp;
 44 import hat.ifacemapper.BoundSchema;
 45 import hat.ifacemapper.BufferState;
 46 import hat.ifacemapper.Schema;
 47 import hat.optools.OpTk;
 48 import hat.phases.HATFinalDetectionPhase;
 49 import jdk.incubator.code.TypeElement;
 50 import jdk.incubator.code.dialect.java.ClassType;
 51 
 52 import java.lang.reflect.Field;
 53 import java.lang.reflect.Method;
 54 import java.util.ArrayList;
 55 import java.util.Arrays;
 56 import java.util.HashMap;
 57 import java.util.HashSet;
 58 import java.util.LinkedHashSet;
 59 import java.util.List;
 60 import java.util.Map;
 61 import java.util.Objects;
 62 import java.util.Set;
 63 
 64 public abstract class C99FFIBackend extends FFIBackend  implements BufferTracker {
 65 
 66     public C99FFIBackend(String libName, Config config) {
 67         super(libName, config);
 68     }
 69 
 70     public static class CompiledKernel {
 71         public final C99FFIBackend c99FFIBackend;
 72         public final KernelCallGraph kernelCallGraph;
 73         public final BackendBridge.CompilationUnitBridge.KernelBridge kernelBridge;
 74         public final ArgArray argArray;
 75         public final KernelBufferContext kernelBufferContext;
 76 
 77         public CompiledKernel(C99FFIBackend c99FFIBackend, KernelCallGraph kernelCallGraph, BackendBridge.CompilationUnitBridge.KernelBridge kernelBridge, Object[] ndRangeAndArgs) {
 78             this.c99FFIBackend = c99FFIBackend;
 79             this.kernelCallGraph = kernelCallGraph;
 80             this.kernelBridge = kernelBridge;
 81             this.kernelBufferContext = KernelBufferContext.createDefault(kernelCallGraph.computeContext.accelerator);
 82             ndRangeAndArgs[0] = this.kernelBufferContext;
 83             this.argArray = ArgArray.create(kernelCallGraph.computeContext.accelerator,kernelCallGraph,  ndRangeAndArgs);
 84         }
 85 
 86         private void setGlobalMesh(NDRange.Global global) {
 87             kernelBufferContext.gsy(1);
 88             kernelBufferContext.gsz(1);
 89             switch (global) {
 90                 case NDRange.Global1D global1D -> {
 91                     kernelBufferContext.gsx(global1D.x());
 92                     kernelBufferContext.dimensions(global1D.dimension());
 93                 }
 94                 case NDRange.Global2D global2D -> {
 95                     kernelBufferContext.gsx(global2D.x());
 96                     kernelBufferContext.gsy(global2D.y());
 97                     kernelBufferContext.dimensions(global2D.dimension());
 98                 }
 99                 case NDRange.Global3D global3D -> {
100                     kernelBufferContext.gsx(global3D.x());
101                     kernelBufferContext.gsy(global3D.y());
102                     kernelBufferContext.gsz(global3D.z());
103                     kernelBufferContext.dimensions(global3D.dimension());
104                 }
105                 case null, default -> {
106                     throw new IllegalArgumentException("Unknown global range " + global.getClass());
107                 }
108             }
109         }
110 
111         private void setLocalMesh(NDRange.Local local) {
112             kernelBufferContext.lsy(1);
113             kernelBufferContext.lsz(1);
114             switch (local) {
115                 case NDRange.Local1D local1D -> {
116                     kernelBufferContext.lsx(local1D.x());
117                     kernelBufferContext.dimensions(local1D.dimension());
118                 }
119                 case NDRange.Local2D local2D -> {
120                     kernelBufferContext.lsx(local2D.x());
121                     kernelBufferContext.lsy(local2D.y());
122                     kernelBufferContext.dimensions(local2D.dimension());
123                 }
124                 case NDRange.Local3D local3D -> {
125                     kernelBufferContext.lsx(local3D.x());
126                     kernelBufferContext.lsy(local3D.y());
127                     kernelBufferContext.lsz(local3D.z());
128                     kernelBufferContext.dimensions(local3D.dimension());
129                 }
130                 case null, default -> {
131                     throw new IllegalArgumentException("Unknown global range " + local.getClass());
132                 }
133             }
134         }
135 
136         private void setDefaultLocalMesh() {
137             kernelBufferContext.lsx(0);
138             kernelBufferContext.lsy(0);
139             kernelBufferContext.lsz(0);
140         }
141 
142         private void setupComputeRange(KernelContext kernelContext) {
143             NDRange ndRange = kernelContext.getNDRange();
144             if (!(ndRange instanceof NDRange.Range range)) {
145                 throw new IllegalArgumentException("NDRange must be of type NDRange.Range");
146             }
147             boolean isLocalMeshDefined = kernelContext.hasLocalMesh();
148             NDRange.Global global = range.global();
149             setGlobalMesh(global);
150             if (isLocalMeshDefined) {
151                 setLocalMesh(range.local());
152             } else {
153                 setDefaultLocalMesh();
154             }
155         }
156 
157         public void dispatch(KernelContext kernelContext, Object[] args) {
158             setupComputeRange(kernelContext);
159             args[0] = this.kernelBufferContext;
160             ArgArray.update(argArray, kernelCallGraph, args);
161             kernelBridge.ndRange(this.argArray);
162         }
163     }
164 
165     public Map<KernelCallGraph, CompiledKernel> kernelCallGraphCompiledCodeMap = new HashMap<>();
166 
167     private <T extends C99HATKernelBuilder<T>> void generateDeviceTypeStructs(T builder, String toText, Set<String> typedefs) {
168         // From here is text processing
169         String[] split = toText.split(">");
170         // Each item is a data struct
171         for (String s : split) {
172             // curate: remove first character
173             s = s.substring(1);
174             String dsName = s.split(":")[0];
175             if (typedefs.contains(dsName)) {
176                 continue;
177             }
178             typedefs.add(dsName);
179             // sanitize dsName
180             dsName = sanitize(dsName);
181             builder.typedefKeyword()
182                     .space()
183                     .structKeyword()
184                     .space()
185                     .suffix_s(dsName)
186                     .obrace()
187                     .nl();
188 
189             String[] members = s.split(";");
190 
191             int j = 0;
192             builder.in();
193             for (int i = 0; i < members.length; i++) {
194                 String member = members[i];
195                 String[] field = member.split(":");
196                 if (i == 0) {
197                     j = 1;
198                 }
199                 String isArray = field[j++];
200                 String type = field[j++];
201                 String name = field[j++];
202                 String lenValue = "";
203                 if (isArray.equals("[")) {
204                     lenValue = field[j];
205                 }
206                 j = 0;
207                 if (typedefs.contains(type))
208                     type = sanitize(type) + "_t";
209                 else
210                     type = sanitize(type);
211 
212                 builder.typeName(type)
213                         .space()
214                         .identifier(name);
215 
216                 if (isArray.equals("[")) {
217                     builder.space()
218                             .osbrace()
219                             .identifier(lenValue)
220                             .csbrace();
221                 }
222                 builder.semicolon().nl();
223             }
224             builder.out();
225             builder.cbrace().suffix_t(dsName).semicolon().nl();
226         }
227     }
228 
229     public <T extends C99HATKernelBuilder<T>> String createCode(KernelCallGraph kernelCallGraph, T builder, Object... args) {
230         var here = OpTk.CallSite.of(C99FFIBackend.class, "createCode");
231         builder.defines().types();
232         Set<Schema.IfaceType> already = new LinkedHashSet<>();
233         Arrays.stream(args)
234                 .filter(arg -> arg instanceof Buffer)
235                 .map(arg -> (Buffer) arg)
236                 .forEach(ifaceBuffer -> {
237                     BoundSchema<?> boundSchema = Buffer.getBoundSchema(ifaceBuffer);
238                     boundSchema.schema().rootIfaceType.visitTypes(0, t -> {
239                         if (!already.contains(t)) {
240                             builder.typedef(boundSchema, t);
241                             already.add(t);
242                         }
243                     });
244                 });
245 
246         var annotation = kernelCallGraph.entrypoint.method.getAnnotation(Kernel.class);
247 
248         if (annotation!=null){
249             var typedef = kernelCallGraph.entrypoint.method.getAnnotation(TypeDef.class);
250             if (typedef!=null){
251                 builder.lineComment("Preformatted typedef body from @Typedef annotation");
252                 builder.typedefKeyword().space().structKeyword().space().suffix_s(typedef.name()).braceNlIndented(_->
253                         builder.preformatted(typedef.body())
254                 ).suffix_t(typedef.name()).semicolon().nl();
255             }
256             var preformatted = kernelCallGraph.entrypoint.method.getAnnotation(Preformatted.class);
257             if (preformatted!=null){
258                 builder.lineComment("Preformatted text from @Preformatted annotation");
259                 builder.preformatted(preformatted.value());
260             }
261             builder.lineComment("Preformatted code body from @Kernel annotation");
262             builder.preformatted(annotation.value());
263         } else {
264             List<TypeElement> localIFaceList = new ArrayList<>();
265 
266             kernelCallGraph.getModuleOp()
267                     .elements()
268                     .filter(c -> Objects.requireNonNull(c) instanceof HATMemoryOp)
269                     .map(c -> ((HATMemoryOp) c).invokeType())
270                     .forEach(localIFaceList::add);
271 
272             kernelCallGraph.entrypoint.funcOp()
273                     .elements()
274                     .filter(c -> Objects.requireNonNull(c) instanceof HATMemoryOp)
275                     .map(c -> ((HATMemoryOp) c).invokeType())
276                     .forEach(localIFaceList::add);
277 
278             // Dynamically build the schema for the user data type we are creating within the kernel.
279             // This is because no allocation was done from the host. This is kernel code, and it is reflected
280             // using the code reflection API
281             // 1. Add for struct for iface objects
282             Set<String> typedefs = new HashSet<>();
283 
284             // Add HAT reserved types
285             typedefs.add(F16.class.getName());
286 
287             for (TypeElement typeElement : localIFaceList) {
288                 try {
289                     // Approach 1: The first approach support iFace and Buffer types to be used in Local and Private memory
290                     // TODO: Once we decide to move towards the DeviceType implementation, we will remove this part
291                     Class<?> clazz = (Class<?>) ((ClassType) typeElement).resolve(kernelCallGraph.computeContext.accelerator.lookup);
292                     Method method = clazz.getMethod("create", hat.Accelerator.class);
293                     method.setAccessible(true);
294                     Buffer invoke = (Buffer) method.invoke(null, kernelCallGraph.computeContext.accelerator);
295                     if (invoke != null) {
296                         // code gen of the struct
297                         BoundSchema<?> boundSchema = Buffer.getBoundSchema(invoke);
298                         boundSchema.schema().rootIfaceType.visitTypes(0, t -> {
299                             if (!already.contains(t)) {
300                                 builder.typedef(boundSchema, t);
301                                 already.add(t);
302                             }
303                         });
304                     } else {
305                         // new approach for supporting DeviceTypes
306                         Field schemaField = clazz.getDeclaredField("schema");
307                         schemaField.setAccessible(true);
308                         Object schema = schemaField.get(schemaField);
309 
310                         Class<?> deviceSchemaClass = Class.forName(DeviceSchema.class.getName());
311                         Method toTextMethod = deviceSchemaClass.getDeclaredMethod("toText");
312                         toTextMethod.setAccessible(true);
313                         String toText = (String) toTextMethod.invoke(schema);
314                         if (toText != null) {
315                             generateDeviceTypeStructs(builder, toText, typedefs);
316                         } else {
317                             throw new RuntimeException("[ERROR] Could not find valid device schema ");
318                         }
319                     }
320                 } catch (ReflectiveOperationException e) {
321                     throw new RuntimeException(e);
322                 }
323             }
324 
325             ScopedCodeBuilderContext buildContext =
326                     new ScopedCodeBuilderContext(kernelCallGraph.entrypoint.callGraph.computeContext.accelerator.lookup,
327                             kernelCallGraph.entrypoint.funcOp());
328 
329             // Sorting by rank ensures we don't need forward declarations
330             kernelCallGraph.getModuleOp().functionTable()
331                     .forEach((_, funcOp) -> {
332                         // TODO: did we just trash the callgraph sidetables?
333                         HATFinalDetectionPhase finals = new HATFinalDetectionPhase(kernelCallGraph.entrypoint.callGraph.computeContext.accelerator);
334                         finals.apply(funcOp);
335 
336                         // Update the build context for this method to use the right constants-map
337                         buildContext.setFinals(finals.getFinalVars());
338                         builder.nl().kernelMethod(buildContext, funcOp).nl();
339                     });
340 
341             // Update the constants-map for the main kernel
342             HATFinalDetectionPhase hatFinalDetectionPhase = new HATFinalDetectionPhase(kernelCallGraph.entrypoint.callGraph.computeContext.accelerator);
343             hatFinalDetectionPhase.apply(kernelCallGraph.entrypoint.funcOp());
344             buildContext.setFinals(hatFinalDetectionPhase.getFinalVars());
345             builder.nl().kernelEntrypoint(buildContext, args).nl();
346 
347             if (config().showKernelModel()) {
348                 IO.println("Original");
349                 IO.println(kernelCallGraph.entrypoint.funcOp().toText());
350             }
351             if (config().showLoweredKernelModel()) {
352                 IO.println("Lowered");
353                 IO.println(OpTk.lower(here, kernelCallGraph.entrypoint.funcOp()).toText());
354             }
355         }
356         return builder.toString();
357     }
358 
359 
360     private String sanitize(String s) {
361         String[] split1 = s.split("\\.");
362         if (split1.length == 1) {
363             return s;
364         }
365         s = split1[split1.length - 1];
366         if (s.split("\\$").length > 1) {
367             s = sanitize(s.split("\\$")[1]);
368         }
369         return s;
370     }
371 
372     @Override
373     public void preMutate(Buffer b) {
374         switch (b.getState()) {
375             case BufferState.NO_STATE:
376             case BufferState.NEW_STATE:
377             case BufferState.HOST_OWNED:
378             case BufferState.DEVICE_VALID_HOST_HAS_COPY: {
379                 if (config().showState()) {
380                     System.out.println("in preMutate state = " + b.getStateString() + " no action to take");
381                 }
382                 break;
383             }
384             case BufferState.DEVICE_OWNED: {
385                 backendBridge.getBufferFromDeviceIfDirty(b);// calls through FFI and might block when fetching from device
386                 if (config().showState()) {
387                     System.out.print("in preMutate state = " + b.getStateString() + " we pulled from device ");
388                 }
389                 b.setState(BufferState.DEVICE_VALID_HOST_HAS_COPY);
390                 if (config().showState()) {
391                     System.out.println("and switched to " + b.getStateString());
392                 }
393                 break;
394             }
395             default:
396                 throw new IllegalStateException("Not expecting this state ");
397         }
398     }
399 
400     @Override
401     public void postMutate(Buffer b) {
402         if (config().showState()) {
403             System.out.print("in postMutate state = " + b.getStateString() + " no action to take ");
404         }
405         if (b.getState() != BufferState.NEW_STATE) {
406             b.setState(BufferState.HOST_OWNED);
407         }
408         if (config().showState()) {
409             System.out.println("and switched to (or stayed on) " + b.getStateString());
410         }
411     }
412 
413     @Override
414     public void preAccess(Buffer b) {
415         switch (b.getState()) {
416             case BufferState.NO_STATE:
417             case BufferState.NEW_STATE:
418             case BufferState.HOST_OWNED:
419             case BufferState.DEVICE_VALID_HOST_HAS_COPY: {
420                 if (config().showState()) {
421                     System.out.println("in preAccess state = " + b.getStateString() + " no action to take");
422                 }
423                 break;
424             }
425             case BufferState.DEVICE_OWNED: {
426                 backendBridge.getBufferFromDeviceIfDirty(b);// calls through FFI and might block when fetching from device
427 
428                 if (config().showState()) {
429                     System.out.print("in preAccess state = " + b.getStateString() + " we pulled from device ");
430                 }
431                 b.setState(BufferState.DEVICE_VALID_HOST_HAS_COPY);
432                 if (config().showState()) {
433                     System.out.println("and switched to " + b.getStateString());
434                 }
435                 break;
436             }
437             default:
438                 throw new IllegalStateException("Not expecting this state ");
439         }
440     }
441 
442 
443     @Override
444     public void postAccess(Buffer b) {
445         if (config().showState()) {
446             System.out.println("in postAccess state = " + b.getStateString());
447         }
448     }
449 }