1 /*
2 * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation. Oracle designates this
8 * particular file as subject to the "Classpath" exception as provided
9 * by Oracle in the LICENSE file that accompanied this code.
10 *
11 * This code is distributed in the hope that it will be useful, but WITHOUT
12 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
13 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
14 * version 2 for more details (a copy is included in the LICENSE file that
15 * accompanied this code).
16 *
17 * You should have received a copy of the GNU General Public License version
18 * 2 along with this work; if not, write to the Free Software Foundation,
19 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
20 *
21 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
22 * or visit www.oracle.com if you need additional information or have any
23 * questions.
24 */
25
26 package hat.backend.ffi;
27
28 import hat.Config;
29 import hat.KernelContext;
30 import hat.NDRange;
31 import hat.annotations.Kernel;
32 import hat.annotations.Preformatted;
33 import hat.annotations.TypeDef;
34 import hat.buffer.ArgArray;
35 import hat.buffer.KernelBufferContext;
36 import hat.callgraph.IfaceDataDag;
37 import hat.callgraph.KernelCallGraph;
38 import hat.callgraph.MethodCallDag;
39 import hat.codebuilders.C99HATKernelBuilder;
40 import hat.codebuilders.C99VecAndMatHandler;
41 import hat.device.DeviceSchema;
42 import hat.device.NonMappableIface;
43 import hat.types.BF16;
44 import hat.types.F16;
45 import jdk.incubator.code.CodeTransformer;
46 import optkl.ifacemapper.BoundSchema;
47 import optkl.ifacemapper.Buffer;
48 import optkl.ifacemapper.BufferState;
49 import optkl.ifacemapper.BufferTracker;
50 import optkl.ifacemapper.MappableIface;
51 import optkl.ifacemapper.Schema;
52
53 import java.lang.foreign.Arena;
54 import java.lang.invoke.MethodHandles;
55 import java.util.Arrays;
56 import java.util.HashMap;
57 import java.util.HashSet;
58 import java.util.List;
59 import java.util.Map;
60 import java.util.Set;
61
62 public abstract class C99FFIBackend extends FFIBackend implements BufferTracker {
63 public C99FFIBackend(Arena arena, MethodHandles.Lookup lookup, String libName, Config config) {
64 super(arena, lookup, libName, config);
65 }
66
67 public static class CompiledKernel {
68 public final C99FFIBackend c99FFIBackend;
69 public final KernelCallGraph kernelCallGraph;
70 public final BackendBridge.CompilationUnitBridge.KernelBridge kernelBridge;
71 public final ArgArray argArray;
72 public final KernelBufferContext kernelBufferContext;
73
74 public CompiledKernel(C99FFIBackend c99FFIBackend, KernelCallGraph kernelCallGraph, BackendBridge.CompilationUnitBridge.KernelBridge kernelBridge, Object[] ndRangeAndArgs) {
75 this.c99FFIBackend = c99FFIBackend;
76 this.kernelCallGraph = kernelCallGraph;
77 this.kernelBridge = kernelBridge;
78 this.kernelBufferContext = KernelBufferContext.createDefault(kernelCallGraph.computeCallGraph.computeContext.accelerator());
79 ndRangeAndArgs[0] = this.kernelBufferContext;
80 this.argArray = ArgArray.create(kernelCallGraph.computeCallGraph.computeContext.accelerator(), kernelCallGraph, ndRangeAndArgs);
81 }
82
83 public void dispatch(KernelContext kernelContext, Object[] args) {
84 // Do we really need this? We never actually read these
85 kernelBufferContext.gsy(1);
86 kernelBufferContext.gsz(1);
87 switch (kernelContext.ndRange.global()) {
88 case NDRange.Global1D global1D -> {
89 kernelBufferContext.gsx(global1D.x());
90 kernelBufferContext.dimensions(global1D.dimension());
91 }
92 case NDRange.Global2D global2D -> {
93 kernelBufferContext.gsx(global2D.x());
94 kernelBufferContext.gsy(global2D.y());
95 kernelBufferContext.dimensions(global2D.dimension());
96 }
97 case NDRange.Global3D global3D -> {
98 kernelBufferContext.gsx(global3D.x());
99 kernelBufferContext.gsy(global3D.y());
100 kernelBufferContext.gsz(global3D.z());
101 kernelBufferContext.dimensions(global3D.dimension());
102 }
103 case null, default -> throw new IllegalArgumentException("Unknown global range " + kernelContext.ndRange.global().getClass());
104 }
105
106 if (kernelContext.ndRange.hasLocal()) {
107 kernelBufferContext.lsy(1);
108 kernelBufferContext.lsz(1);
109 switch (kernelContext.ndRange.local()) {
110 case NDRange.Local1D local1D -> {
111 kernelBufferContext.lsx(local1D.x());
112 kernelBufferContext.dimensions(local1D.dimension());
113 }
114 case NDRange.Local2D local2D -> {
115 kernelBufferContext.lsx(local2D.x());
116 kernelBufferContext.lsy(local2D.y());
117 kernelBufferContext.dimensions(local2D.dimension());
118 }
119 case NDRange.Local3D local3D -> {
120 kernelBufferContext.lsx(local3D.x());
121 kernelBufferContext.lsy(local3D.y());
122 kernelBufferContext.lsz(local3D.z());
123 kernelBufferContext.dimensions(local3D.dimension());
124 }
125 case null, default -> throw new IllegalArgumentException("Unknown global range " + kernelContext.ndRange.local().getClass());
126 }
127 } else {
128 kernelBufferContext.lsx(0);
129 kernelBufferContext.lsy(0);
130 kernelBufferContext.lsz(0);
131 }
132
133 args[0] = this.kernelBufferContext;
134 ArgArray.update(argArray, kernelCallGraph, args);
135 kernelBridge.ndRange(this.argArray);
136 }
137 }
138
139 public Map<KernelCallGraph, CompiledKernel> kernelCallGraphCompiledCodeMap = new HashMap<>();
140
141
142 public <T extends C99HATKernelBuilder<T>> String createCode(KernelCallGraph kernelCallGraph, T builder, Object... args) {
143 builder.defines().types();
144
145 var visitedAlready = new HashSet<Schema.IfaceType>();
146 Arrays.stream(args)
147 .filter(arg -> arg instanceof Buffer)
148 .map(arg -> (Buffer) arg)
149 .forEach(ifaceBuffer -> {
150 BoundSchema<?> boundSchema = MappableIface.getBoundSchema(ifaceBuffer);
151 boundSchema.schema().rootIfaceType.visitUniqueTypes(t -> {
152 if (visitedAlready.add(t)) { // true first time we see this type
153 builder.typedef(boundSchema, t);
154 }
155 });
156 });
157
158
159 var kernelAnnotation = kernelCallGraph.callDag.entryPoint.method().getAnnotation(Kernel.class);
160 if (kernelAnnotation != null) {
161 // If we find a kernelAnnotation we can't trust the data in kernelCallGraph's state.
162 kernelCallGraph.setUsesAtomics(true);
163 kernelCallGraph.accessedFP16Classes.addAll(List.of(F16.class, BF16.class));
164 kernelCallGraph.setUsesBarrier(true);
165
166 var typedefAnnotation = kernelCallGraph.callDag.entryPoint.method().getAnnotation(TypeDef.class);
167 if (typedefAnnotation != null) {
168 builder.lineComment("Preformatted typedef body from @Typedef annotation");
169 builder.typedefStruct(typedefAnnotation.name(), _ -> builder.preformatted(typedefAnnotation.body())).semicolon().nl();
170 }
171 var preformattedAnnotation = kernelCallGraph.callDag.entryPoint.method().getAnnotation(Preformatted.class);
172 if (preformattedAnnotation != null) {
173 builder.lineComment("Preformatted text from @Preformatted annotation");
174 builder.preformatted(preformattedAnnotation.value());
175 }
176 builder.lineComment("Preformatted code body from @Kernel annotation");
177 builder.preformatted(kernelAnnotation.value());
178 } else {
179 Set<Class<?>> typedeffed = new HashSet<>();
180 typedeffed.add(F16.class);
181 typedeffed.add(BF16.class);
182 kernelCallGraph.accessedNonMappableIfaceClasses.stream()
183 .filter(c->!typedeffed.contains(c))
184 .map(c->(Class<NonMappableIface>) c) // why do we need to do this.
185 .forEach(c -> {
186 // We create a dag of iface references rooted at c
187 var ifaceDataDag = new IfaceDataDag<NonMappableIface>(dag -> {
188 var entryPoint = dag.getNode(c);
189 dag.methodsWithIfaceReturnTypes(c).forEach(ifaceInfo ->
190 dag.addEdge(entryPoint, dag.getNode(ifaceInfo.clazz())) // this recurses with each added class
191 );
192 });
193 // Now we can generate typedefs in rankOrder (so inner typedefs first)
194 if (ifaceDataDag.isDag()) {
195 ifaceDataDag.rankOrdered.stream()
196 .filter(ifaceInfo -> !typedeffed.contains(ifaceInfo.clazz()))
197 .forEach(ifaceInfo -> typedeffed.add(
198 DeviceSchema.getDeviceSchemaOrThrow(ifaceInfo.clazz()).typedef(builder).clazz()
199 )
200 );
201 } else {
202 typedeffed.add(DeviceSchema.getDeviceSchemaOrThrow(c).typedef(builder).clazz());
203 }
204 });
205
206 // This is a slight hack for Shader support.
207 if (!kernelCallGraph.accessedVecClasses.isEmpty()) {
208 C99VecAndMatHandler.createVecFunctions(builder);
209 }
210
211 kernelCallGraph.callDag.rankOrdered.stream()
212 .filter(m -> m instanceof MethodCallDag.OtherMethodCall)
213 .forEach(m -> builder.nl().kernelMethod( m.funcOp()).nl());
214
215 builder.nl().kernelEntrypoint().nl();
216
217 if (config().showKernelModel()) {
218 IO.println("Non Lowered");
219 IO.println(kernelCallGraph.callDag.entryPoint.funcOp().toText());
220 }
221 if (config().showLoweredKernelModel()) {
222 IO.println("Lowered");
223 IO.println(kernelCallGraph.callDag.entryPoint.funcOp().transform(CodeTransformer.LOWERING_TRANSFORMER).toText());
224 }
225 }
226 return builder.toString();
227 }
228
229 @Override
230 public void preMutate(MappableIface b) {
231 switch (b.getState()) {
232 case BufferState.NO_STATE:
233 case BufferState.NEW_STATE:
234 case BufferState.HOST_OWNED:
235 case BufferState.DEVICE_VALID_HOST_HAS_COPY: {
236 if (config().showState()) {
237 System.out.println("in preMutate state = " + b.getStateString() + " no action to take");
238 }
239 break;
240 }
241 case BufferState.DEVICE_OWNED: {
242 backendBridge.getBufferFromDeviceIfDirty(b);// calls through FFI and might block when fetching from device
243 if (config().showState()) {
244 System.out.print("in preMutate state = " + b.getStateString() + " we pulled from device ");
245 }
246 b.setState(BufferState.DEVICE_VALID_HOST_HAS_COPY);
247 if (config().showState()) {
248 System.out.println("and switched to " + b.getStateString());
249 }
250 break;
251 }
252 default:
253 throw new IllegalStateException("Not expecting this state ");
254 }
255 }
256
257 @Override
258 public void postMutate(MappableIface b) {
259 if (config().showState()) {
260 System.out.print("in postMutate state = " + b.getStateString() + " no action to take ");
261 }
262 if (b.getState() != BufferState.NEW_STATE) {
263 b.setState(BufferState.HOST_OWNED);
264 }
265 if (config().showState()) {
266 System.out.println("and switched to (or stayed on) " + b.getStateString());
267 }
268 }
269
270 @Override
271 public void preAccess(MappableIface b) {
272 switch (b.getState()) {
273 case BufferState.NO_STATE:
274 case BufferState.NEW_STATE:
275 case BufferState.HOST_OWNED:
276 case BufferState.DEVICE_VALID_HOST_HAS_COPY: {
277 if (config().showState()) {
278 System.out.println("in preAccess state = " + b.getStateString() + " no action to take");
279 }
280 break;
281 }
282 case BufferState.DEVICE_OWNED: {
283 backendBridge.getBufferFromDeviceIfDirty(b);// calls through FFI and might block when fetching from device
284
285 if (config().showState()) {
286 System.out.print("in preAccess state = " + b.getStateString() + " we pulled from device ");
287 }
288 b.setState(BufferState.DEVICE_VALID_HOST_HAS_COPY);
289 if (config().showState()) {
290 System.out.println("and switched to " + b.getStateString());
291 }
292 break;
293 }
294 default:
295 throw new IllegalStateException("Not expecting this state ");
296 }
297 }
298
299
300 @Override
301 public void postAccess(MappableIface b) {
302 if (config().showState()) {
303 System.out.println("in postAccess state = " + b.getStateString());
304 }
305 }
306 }