1 /*
  2  * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 
 26 package hat.backend.ffi;
 27 
 28 import hat.NDRange;
 29 import hat.backend.codebuilders.C99HATKernelBuilder;
 30 import hat.buffer.ArgArray;
 31 import hat.buffer.Buffer;
 32 import hat.buffer.BufferTracker;
 33 import hat.buffer.KernelContext;
 34 import hat.callgraph.KernelCallGraph;
 35 import hat.ifacemapper.BoundSchema;
 36 import hat.ifacemapper.BufferState;
 37 import hat.ifacemapper.Schema;
 38 
 39 import java.util.Arrays;
 40 import java.util.HashMap;
 41 import java.util.LinkedHashSet;
 42 import java.util.Map;
 43 import java.util.Set;
 44 
 45 public abstract class C99FFIBackend extends FFIBackend  implements BufferTracker {
 46 
 47 
 48     public C99FFIBackend(String libName, Config config) {
 49         super(libName, config);
 50     }
 51 
 52     public static class CompiledKernel {
 53         public final C99FFIBackend c99FFIBackend;
 54         public final KernelCallGraph kernelCallGraph;
 55         public final BackendBridge.CompilationUnitBridge.KernelBridge kernelBridge;
 56         public final ArgArray argArray;
 57         public final KernelContext kernelContext;
 58 
 59         public CompiledKernel(C99FFIBackend c99FFIBackend, KernelCallGraph kernelCallGraph, BackendBridge.CompilationUnitBridge.KernelBridge kernelBridge, Object[] ndRangeAndArgs) {
 60             this.c99FFIBackend = c99FFIBackend;
 61             this.kernelCallGraph = kernelCallGraph;
 62             this.kernelBridge = kernelBridge;
 63             this.kernelContext = KernelContext.create(kernelCallGraph.computeContext.accelerator, 0, 0);
 64             ndRangeAndArgs[0] = this.kernelContext;
 65             this.argArray = ArgArray.create(kernelCallGraph.computeContext.accelerator,kernelCallGraph,  ndRangeAndArgs);
 66         }
 67 
 68         public void dispatch(NDRange ndRange, Object[] args) {
 69           //  long ns = System.nanoTime();
 70             kernelContext.maxX(ndRange.kid.maxX);
 71             args[0] = this.kernelContext;
 72             ArgArray.update(argArray,kernelCallGraph, args);
 73             //System.out.println("argupdate  "+((System.nanoTime()-ns)/1000)+" us");
 74            // ns = System.nanoTime();
 75             kernelBridge.ndRange(this.argArray);
 76            // System.out.println("dispatch time "+((System.nanoTime()-ns)/1000)+" us");
 77         }
 78     }
 79 
 80     public Map<KernelCallGraph, CompiledKernel> kernelCallGraphCompiledCodeMap = new HashMap<>();
 81 
 82     public <T extends C99HATKernelBuilder<T>> String createCode(KernelCallGraph kernelCallGraph, T builder, Object[] args, boolean show) {
 83         builder.defines().pragmas().types();
 84         Set<Schema.IfaceType> already = new LinkedHashSet<>();
 85         Arrays.stream(args)
 86                 .filter(arg -> arg instanceof Buffer)
 87                 .map(arg -> (Buffer) arg)
 88                 .forEach(ifaceBuffer -> {
 89                     BoundSchema<?> boundSchema = Buffer.getBoundSchema(ifaceBuffer);
 90                     boundSchema.schema().rootIfaceType.visitTypes(0, t -> {
 91                         if (!already.contains(t)) {
 92                             builder.typedef(boundSchema, t);
 93                             already.add(t);
 94                         }
 95                     });
 96                 });
 97 
 98         // Sorting by rank ensures we don't need forward declarations
 99         kernelCallGraph.kernelReachableResolvedStream().sorted((lhs, rhs) -> rhs.rank - lhs.rank)
100                 .forEach(kernelReachableResolvedMethod -> builder.nl().kernelMethod(kernelReachableResolvedMethod).nl());
101 
102         builder.nl().kernelEntrypoint(kernelCallGraph.entrypoint, args).nl();
103 
104         if (show) {
105             System.out.println("Original");
106             System.out.println(kernelCallGraph.entrypoint.funcOpWrapper().op().toText());
107             System.out.println("Lowered");
108             System.out.println(kernelCallGraph.entrypoint.funcOpWrapper().lower().op().toText());
109         }
110         return builder.toString();
111     }
112 
113 
114     @Override
115     public void preMutate(Buffer b) {
116         switch (b.getState()) {
117             case BufferState.NO_STATE:
118             case BufferState.NEW_STATE:
119             case BufferState.HOST_OWNED:
120             case BufferState.DEVICE_VALID_HOST_HAS_COPY: {
121                 if (config.isSHOW_STATE()) {
122                     System.out.println("in preMutate state = " + b.getStateString() + " no action to take");
123                 }
124                 break;
125             }
126             case BufferState.DEVICE_OWNED: {
127                 backendBridge.getBufferFromDeviceIfDirty(b);// calls through FFI and might block when fetching from device
128                 if (config.isSHOW_STATE()) {
129                     System.out.print("in preMutate state = " + b.getStateString() + " we pulled from device ");
130                 }
131                 b.setState(BufferState.DEVICE_VALID_HOST_HAS_COPY);
132                 if (config.isSHOW_STATE()) {
133                     System.out.println("and switched to " + b.getStateString());
134                 }
135                 break;
136             }
137             default:
138                 throw new IllegalStateException("Not expecting this state ");
139         }
140     }
141 
142     @Override
143     public void postMutate(Buffer b) {
144         if (config.isSHOW_STATE()) {
145             System.out.print("in postMutate state = " + b.getStateString() + " no action to take ");
146         }
147         if (b.getState() != BufferState.NEW_STATE) {
148             b.setState(BufferState.HOST_OWNED);
149         }
150         if (config.isSHOW_STATE()) {
151             System.out.println("and switched to (or stayed on) " + b.getStateString());
152         }
153     }
154 
155     @Override
156     public void preAccess(Buffer b) {
157         switch (b.getState()) {
158             case BufferState.NO_STATE:
159             case BufferState.NEW_STATE:
160             case BufferState.HOST_OWNED:
161             case BufferState.DEVICE_VALID_HOST_HAS_COPY: {
162                 if (config.isSHOW_STATE()) {
163                     System.out.println("in preAccess state = " + b.getStateString() + " no action to take");
164                 }
165                 break;
166             }
167             case BufferState.DEVICE_OWNED: {
168                 backendBridge.getBufferFromDeviceIfDirty(b);// calls through FFI and might block when fetching from device
169 
170                 if (config.isSHOW_STATE()) {
171                     System.out.print("in preAccess state = " + b.getStateString() + " we pulled from device ");
172                 }
173                 b.setState(BufferState.DEVICE_VALID_HOST_HAS_COPY);
174                 if (config.isSHOW_STATE()) {
175                     System.out.println("and switched to " + b.getStateString());
176                 }
177                 break;
178             }
179             default:
180                 throw new IllegalStateException("Not expecting this state ");
181         }
182     }
183 
184 
185     @Override
186     public void postAccess(Buffer b) {
187         if (config.isSHOW_STATE()) {
188             System.out.println("in postAccess state = " + b.getStateString());
189         }
190     }
191 }