1 /*
2 * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation. Oracle designates this
8 * particular file as subject to the "Classpath" exception as provided
9 * by Oracle in the LICENSE file that accompanied this code.
10 *
11 * This code is distributed in the hope that it will be useful, but WITHOUT
12 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
13 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
14 * version 2 for more details (a copy is included in the LICENSE file that
15 * accompanied this code).
16 *
17 * You should have received a copy of the GNU General Public License version
18 * 2 along with this work; if not, write to the Free Software Foundation,
19 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
20 *
21 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
22 * or visit www.oracle.com if you need additional information or have any
23 * questions.
24 */
25
26 package hat.backend.ffi;
27
28 import hat.NDRange;
29 import hat.Config;
30 import hat.KernelContext;
31 import hat.types.BF16;
32 import hat.types.F16;
33 import jdk.incubator.code.CodeTransformer;
34 import hat.annotations.Kernel;
35 import hat.annotations.Preformatted;
36 import hat.annotations.TypeDef;
37 import hat.buffer.*;
38 import hat.codebuilders.C99HATKernelBuilder;
39 import hat.callgraph.KernelCallGraph;
40 import optkl.codebuilders.ScopedCodeBuilderContext;
41 import hat.device.DeviceSchema;
42 import hat.dialect.HATMemoryVarOp;
43 import optkl.ifacemapper.BoundSchema;
44 import optkl.ifacemapper.Buffer;
45 import optkl.ifacemapper.BufferState;
46 import optkl.ifacemapper.BufferTracker;
47 import optkl.ifacemapper.MappableIface;
48 import optkl.ifacemapper.Schema;
49 import hat.phases.HATFinalDetector;
50 import jdk.incubator.code.TypeElement;
51 import jdk.incubator.code.dialect.java.ClassType;
52
53 import java.lang.foreign.Arena;
54 import java.lang.invoke.MethodHandles;
55 import java.lang.reflect.Field;
56 import java.util.ArrayList;
57 import java.util.Arrays;
58 import java.util.HashMap;
59 import java.util.HashSet;
60 import java.util.LinkedHashSet;
61 import java.util.List;
62 import java.util.Map;
63 import java.util.Objects;
64 import java.util.Set;
65
66 public abstract class C99FFIBackend extends FFIBackend implements BufferTracker {
67 public C99FFIBackend(Arena arena, MethodHandles.Lookup lookup,String libName, Config config) {
68 super(arena,lookup,libName, config);
69 }
70 public static class CompiledKernel {
71 public final C99FFIBackend c99FFIBackend;
72 public final KernelCallGraph kernelCallGraph;
73 public final BackendBridge.CompilationUnitBridge.KernelBridge kernelBridge;
74 public final ArgArray argArray;
75 public final KernelBufferContext kernelBufferContext;
76
77 public CompiledKernel(C99FFIBackend c99FFIBackend, KernelCallGraph kernelCallGraph, BackendBridge.CompilationUnitBridge.KernelBridge kernelBridge, Object[] ndRangeAndArgs) {
78 this.c99FFIBackend = c99FFIBackend;
79 this.kernelCallGraph = kernelCallGraph;
80 this.kernelBridge = kernelBridge;
81 this.kernelBufferContext = KernelBufferContext.createDefault(kernelCallGraph.computeContext.accelerator());
82 ndRangeAndArgs[0] = this.kernelBufferContext;
83 this.argArray = ArgArray.create(kernelCallGraph.computeContext.accelerator(),kernelCallGraph, ndRangeAndArgs);
84 }
85
86 public void dispatch(KernelContext kernelContext, Object[] args) {
87 kernelBufferContext.gsy(1);
88 kernelBufferContext.gsz(1);
89 switch (kernelContext.ndRange.global()) {
90 case NDRange.Global1D global1D -> {
91 kernelBufferContext.gsx(global1D.x());
92 kernelBufferContext.dimensions(global1D.dimension());
93 }
94 case NDRange.Global2D global2D -> {
95 kernelBufferContext.gsx(global2D.x());
96 kernelBufferContext.gsy(global2D.y());
97 kernelBufferContext.dimensions(global2D.dimension());
98 }
99 case NDRange.Global3D global3D -> {
100 kernelBufferContext.gsx(global3D.x());
101 kernelBufferContext.gsy(global3D.y());
102 kernelBufferContext.gsz(global3D.z());
103 kernelBufferContext.dimensions(global3D.dimension());
104 }
105 case null, default -> {
106 throw new IllegalArgumentException("Unknown global range " + kernelContext.ndRange.global().getClass());
107 }
108 }
109 if (kernelContext.ndRange.hasLocal()) {
110 kernelBufferContext.lsy(1);
111 kernelBufferContext.lsz(1);
112 switch (kernelContext.ndRange.local()) {
113 case NDRange.Local1D local1D -> {
114 kernelBufferContext.lsx(local1D.x());
115 kernelBufferContext.dimensions(local1D.dimension());
116 }
117 case NDRange.Local2D local2D -> {
118 kernelBufferContext.lsx(local2D.x());
119 kernelBufferContext.lsy(local2D.y());
120 kernelBufferContext.dimensions(local2D.dimension());
121 }
122 case NDRange.Local3D local3D -> {
123 kernelBufferContext.lsx(local3D.x());
124 kernelBufferContext.lsy(local3D.y());
125 kernelBufferContext.lsz(local3D.z());
126 kernelBufferContext.dimensions(local3D.dimension());
127 }
128 case null, default -> {
129 throw new IllegalArgumentException("Unknown global range " + kernelContext.ndRange.local().getClass());
130 }
131 }
132 } else {
133 kernelBufferContext.lsx(0);
134 kernelBufferContext.lsy(0);
135 kernelBufferContext.lsz(0);
136 }
137 args[0] = this.kernelBufferContext;
138 ArgArray.update(argArray, kernelCallGraph, args);
139 kernelBridge.ndRange(this.argArray);
140 }
141 }
142
143 public Map<KernelCallGraph, CompiledKernel> kernelCallGraphCompiledCodeMap = new HashMap<>();
144
145 private <T extends C99HATKernelBuilder<T>> void generateDeviceTypeStructs(T builder, String toText, Set<String> typedefs) {
146 // From here is text processing
147 String[] split = toText.split(">");
148 // Each item is a data struct
149 for (String s : split) {
150 // curate: remove first character
151 s = s.substring(1);
152 String dsName = s.split(":")[0];
153 if (typedefs.contains(dsName)) {
154 continue;
155 }
156 typedefs.add(dsName);
157 // sanitize dsName
158 dsName = sanitize(dsName);
159 builder.typedefKeyword()
160 .space()
161 .structKeyword()
162 .space()
163 .suffix_s(dsName)
164 .obrace()
165 .nl();
166
167 String[] members = s.split(";");
168
169 int j = 0;
170 builder.in();
171 for (int i = 0; i < members.length; i++) {
172 String member = members[i];
173 String[] field = member.split(":");
174 if (i == 0) {
175 j = 1;
176 }
177 String isArray = field[j++];
178 String type = field[j++];
179 String name = field[j++];
180 String lenValue = "";
181 if (isArray.equals("[")) {
182 lenValue = field[j];
183 }
184 j = 0;
185 if (typedefs.contains(type))
186 type = sanitize(type) + "_t";
187 else
188 type = sanitize(type);
189
190 builder.typeName(type)
191 .space()
192 .identifier(name);
193
194 if (isArray.equals("[")) {
195 builder.space()
196 .osbrace()
197 .identifier(lenValue)
198 .csbrace();
199 }
200 builder.semicolon().nl();
201 }
202 builder.out();
203 builder.cbrace().suffix_t(dsName).semicolon().nl().nl();
204 }
205 }
206
207 public <T extends C99HATKernelBuilder<T>> String createCode(KernelCallGraph kernelCallGraph, T builder, Object... args) {
208 builder.defines().types();
209 var visitedAlready= new HashSet<Schema.IfaceType>();
210 Arrays.stream(args)
211 .filter(arg -> arg instanceof Buffer)
212 .map(arg -> (Buffer) arg)
213 .forEach(ifaceBuffer -> {
214 BoundSchema<?> boundSchema = MappableIface.getBoundSchema(ifaceBuffer);
215 boundSchema.schema().rootIfaceType.visitUniqueTypes( t -> {
216 if (visitedAlready.add(t)) { // true first time we see this type
217 builder.typedef(boundSchema, t);
218 }
219 });
220 });
221
222 var annotation = kernelCallGraph.entrypoint.method().getAnnotation(Kernel.class);
223 if (annotation!=null){
224 var typedef = kernelCallGraph.entrypoint.method().getAnnotation(TypeDef.class);
225 if (typedef!=null){
226 builder.lineComment("Preformatted typedef body from @Typedef annotation");
227 builder.typedefKeyword().space().structKeyword().space().suffix_s(typedef.name()).braceNlIndented(_->
228 builder.preformatted(typedef.body())
229 ).suffix_t(typedef.name()).semicolon().nl();
230 }
231 var preformatted = kernelCallGraph.entrypoint.method().getAnnotation(Preformatted.class);
232 if (preformatted!=null){
233 builder.lineComment("Preformatted text from @Preformatted annotation");
234 builder.preformatted(preformatted.value());
235 }
236 builder.lineComment("Preformatted code body from @Kernel annotation");
237 builder.preformatted(annotation.value());
238 } else {
239 Set<String> typedefs = new HashSet<>();
240
241 // Add HAT reserved types
242 typedefs.add(F16.class.getName());
243 typedefs.add(BF16.class.getName());
244
245 /*
246 I think the kernelCallGraph module op was built before we inserted HATMemoryVarOps
247
248 So we will likely never get any matches from the module op
249
250 List<ClassType> localIFaceList = new ArrayList<>();
251 kernelCallGraph.getModuleOp()
252 .elements()
253 .filter(c -> Objects.requireNonNull(c) instanceof HATMemoryVarOp)
254 .map(c -> (ClassType)((HATMemoryVarOp) c).invokeType())
255 .forEach(localIFaceList::add);
256
257
258
259 However,the sentiment from above was correct as we may have kernel reachable methods that do indeed
260 have these HATMemoryVarOps. I think if we called a method from the entrypoint with Device type accesses
261 we would miss them
262 */
263
264 // Dynamically build the schema for the user data type we are creating within the kernel.
265 // This is because no allocation was done from the host. This is kernel code, and it is reflected
266 // using the code reflection API
267 // 1. Add for struct for iface objects
268 kernelCallGraph.entrypoint.funcOp()
269 .elements()
270 .filter(c -> Objects.requireNonNull(c) instanceof HATMemoryVarOp)
271 .map(c -> (ClassType)((HATMemoryVarOp) c).invokeType())
272 .forEach( classType-> {
273 try {
274 Class<?> clazz = (Class<?>) classType.resolve(kernelCallGraph.lookup());
275 Field schemaField = clazz.getDeclaredField("schema");
276 schemaField.setAccessible(true);
277 var schema = (DeviceSchema<?>)schemaField.get(schemaField);
278 // <1> We are creating text form of DeviceType schema
279 String toText = schema.toText();
280 if (toText != null) {
281 // <2> just to then parse the text from above.
282 // Lets get the model in a cleaner form
283 generateDeviceTypeStructs(builder, toText, typedefs);
284 } else {
285 throw new RuntimeException("[ERROR] Could not find valid device schema ");
286 }
287 } catch (ReflectiveOperationException e) {
288 throw new RuntimeException(e);
289 }
290 });
291
292 var buildContext = new ScopedCodeBuilderContext(kernelCallGraph.lookup(), kernelCallGraph.entrypoint.funcOp());
293
294 kernelCallGraph.getModuleOp().functionTable()
295 .forEach((_, funcOp) -> {
296 // TODO: did we just trash the callgraph sidetables?
297 // Why are we transforming the callgraph here
298 HATFinalDetector finals = new HATFinalDetector(kernelCallGraph);
299 // Update the build context for this method to use the right constants-map
300 buildContext.setFinals(finals.applied(funcOp));
301 builder.nl().kernelMethod(buildContext, funcOp).nl();
302 });
303
304 // Update the constants-map for the main kernel
305 // Why are we doing this here we should not be mutating the kernel callgraph at this point
306 HATFinalDetector hatFinalDetector = new HATFinalDetector(kernelCallGraph);
307 buildContext.setFinals(hatFinalDetector.applied(kernelCallGraph.entrypoint.funcOp()));
308
309 builder.nl().kernelEntrypoint(buildContext).nl();
310
311 if (config().showKernelModel()) {
312 IO.println("Non Lowered");
313 IO.println(kernelCallGraph.entrypoint.funcOp().toText());
314 }
315 if (config().showLoweredKernelModel()) {
316 IO.println("Lowered");
317 IO.println(kernelCallGraph.entrypoint.funcOp().transform(CodeTransformer.LOWERING_TRANSFORMER).toText());
318 }
319 }
320 return builder.toString();
321 }
322
323
324 private String sanitize(String s) {
325 String[] split1 = s.split("\\.");
326 if (split1.length == 1) {
327 return s;
328 }
329 s = split1[split1.length - 1];
330 if (s.split("\\$").length > 1) {
331 s = sanitize(s.split("\\$")[1]);
332 }
333 return s;
334 }
335
336 @Override
337 public void preMutate(MappableIface b) {
338 switch (b.getState()) {
339 case BufferState.NO_STATE:
340 case BufferState.NEW_STATE:
341 case BufferState.HOST_OWNED:
342 case BufferState.DEVICE_VALID_HOST_HAS_COPY: {
343 if (config().showState()) {
344 System.out.println("in preMutate state = " + b.getStateString() + " no action to take");
345 }
346 break;
347 }
348 case BufferState.DEVICE_OWNED: {
349 backendBridge.getBufferFromDeviceIfDirty(b);// calls through FFI and might block when fetching from device
350 if (config().showState()) {
351 System.out.print("in preMutate state = " + b.getStateString() + " we pulled from device ");
352 }
353 b.setState(BufferState.DEVICE_VALID_HOST_HAS_COPY);
354 if (config().showState()) {
355 System.out.println("and switched to " + b.getStateString());
356 }
357 break;
358 }
359 default:
360 throw new IllegalStateException("Not expecting this state ");
361 }
362 }
363
364 @Override
365 public void postMutate(MappableIface b) {
366 if (config().showState()) {
367 System.out.print("in postMutate state = " + b.getStateString() + " no action to take ");
368 }
369 if (b.getState() != BufferState.NEW_STATE) {
370 b.setState(BufferState.HOST_OWNED);
371 }
372 if (config().showState()) {
373 System.out.println("and switched to (or stayed on) " + b.getStateString());
374 }
375 }
376
377 @Override
378 public void preAccess(MappableIface b) {
379 switch (b.getState()) {
380 case BufferState.NO_STATE:
381 case BufferState.NEW_STATE:
382 case BufferState.HOST_OWNED:
383 case BufferState.DEVICE_VALID_HOST_HAS_COPY: {
384 if (config().showState()) {
385 System.out.println("in preAccess state = " + b.getStateString() + " no action to take");
386 }
387 break;
388 }
389 case BufferState.DEVICE_OWNED: {
390 backendBridge.getBufferFromDeviceIfDirty(b);// calls through FFI and might block when fetching from device
391
392 if (config().showState()) {
393 System.out.print("in preAccess state = " + b.getStateString() + " we pulled from device ");
394 }
395 b.setState(BufferState.DEVICE_VALID_HOST_HAS_COPY);
396 if (config().showState()) {
397 System.out.println("and switched to " + b.getStateString());
398 }
399 break;
400 }
401 default:
402 throw new IllegalStateException("Not expecting this state ");
403 }
404 }
405
406
407 @Override
408 public void postAccess(MappableIface b) {
409 if (config().showState()) {
410 System.out.println("in postAccess state = " + b.getStateString());
411 }
412 }
413 }