1 /*
2 * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation. Oracle designates this
8 * particular file as subject to the "Classpath" exception as provided
9 * by Oracle in the LICENSE file that accompanied this code.
10 *
11 * This code is distributed in the hope that it will be useful, but WITHOUT
12 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
13 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
14 * version 2 for more details (a copy is included in the LICENSE file that
15 * accompanied this code).
16 *
17 * You should have received a copy of the GNU General Public License version
18 * 2 along with this work; if not, write to the Free Software Foundation,
19 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
20 *
21 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
22 * or visit www.oracle.com if you need additional information or have any
23 * questions.
24 */
25
26 package hat.backend.ffi;
27
28 import hat.NDRange;
29 import hat.Config;
30 import hat.KernelContext;
31 import hat.annotations.Kernel;
32 import hat.annotations.Preformatted;
33 import hat.annotations.TypeDef;
34 import hat.buffer.F16;
35 import hat.buffer.KernelBufferContext;
36 import hat.codebuilders.C99HATKernelBuilder;
37 import hat.buffer.ArgArray;
38 import hat.buffer.Buffer;
39 import hat.buffer.BufferTracker;
40 import hat.callgraph.KernelCallGraph;
41 import hat.codebuilders.ScopedCodeBuilderContext;
42 import hat.device.DeviceSchema;
43 import hat.dialect.HATMemoryOp;
44 import hat.ifacemapper.BoundSchema;
45 import hat.ifacemapper.BufferState;
46 import hat.ifacemapper.Schema;
47 import hat.optools.OpTk;
48 import hat.phases.HATFinalDetectionPhase;
49 import jdk.incubator.code.TypeElement;
50 import jdk.incubator.code.dialect.java.ClassType;
51
52 import java.lang.reflect.Field;
53 import java.lang.reflect.Method;
54 import java.util.ArrayList;
55 import java.util.Arrays;
56 import java.util.HashMap;
57 import java.util.HashSet;
58 import java.util.LinkedHashSet;
59 import java.util.List;
60 import java.util.Map;
61 import java.util.Objects;
62 import java.util.Set;
63
64 public abstract class C99FFIBackend extends FFIBackend implements BufferTracker {
65
66 public C99FFIBackend(String libName, Config config) {
67 super(libName, config);
68 }
69
70 public static class CompiledKernel {
71 public final C99FFIBackend c99FFIBackend;
72 public final KernelCallGraph kernelCallGraph;
73 public final BackendBridge.CompilationUnitBridge.KernelBridge kernelBridge;
74 public final ArgArray argArray;
75 public final KernelBufferContext kernelBufferContext;
76
77 public CompiledKernel(C99FFIBackend c99FFIBackend, KernelCallGraph kernelCallGraph, BackendBridge.CompilationUnitBridge.KernelBridge kernelBridge, Object[] ndRangeAndArgs) {
78 this.c99FFIBackend = c99FFIBackend;
79 this.kernelCallGraph = kernelCallGraph;
80 this.kernelBridge = kernelBridge;
81 this.kernelBufferContext = KernelBufferContext.createDefault(kernelCallGraph.computeContext.accelerator);
82 ndRangeAndArgs[0] = this.kernelBufferContext;
83 this.argArray = ArgArray.create(kernelCallGraph.computeContext.accelerator,kernelCallGraph, ndRangeAndArgs);
84 }
85
86 private void setGlobalMesh(NDRange.Global global) {
87 kernelBufferContext.gsy(1);
88 kernelBufferContext.gsz(1);
89 switch (global) {
90 case NDRange.Global1D global1D -> {
91 kernelBufferContext.gsx(global1D.x());
92 kernelBufferContext.dimensions(global1D.dimension());
93 }
94 case NDRange.Global2D global2D -> {
95 kernelBufferContext.gsx(global2D.x());
96 kernelBufferContext.gsy(global2D.y());
97 kernelBufferContext.dimensions(global2D.dimension());
98 }
99 case NDRange.Global3D global3D -> {
100 kernelBufferContext.gsx(global3D.x());
101 kernelBufferContext.gsy(global3D.y());
102 kernelBufferContext.gsz(global3D.z());
103 kernelBufferContext.dimensions(global3D.dimension());
104 }
105 case null, default -> {
106 throw new IllegalArgumentException("Unknown global range " + global.getClass());
107 }
108 }
109 }
110
111 private void setLocalMesh(NDRange.Local local) {
112 kernelBufferContext.lsy(1);
113 kernelBufferContext.lsz(1);
114 switch (local) {
115 case NDRange.Local1D local1D -> {
116 kernelBufferContext.lsx(local1D.x());
117 kernelBufferContext.dimensions(local1D.dimension());
118 }
119 case NDRange.Local2D local2D -> {
120 kernelBufferContext.lsx(local2D.x());
121 kernelBufferContext.lsy(local2D.y());
122 kernelBufferContext.dimensions(local2D.dimension());
123 }
124 case NDRange.Local3D local3D -> {
125 kernelBufferContext.lsx(local3D.x());
126 kernelBufferContext.lsy(local3D.y());
127 kernelBufferContext.lsz(local3D.z());
128 kernelBufferContext.dimensions(local3D.dimension());
129 }
130 case null, default -> {
131 throw new IllegalArgumentException("Unknown global range " + local.getClass());
132 }
133 }
134 }
135
136 private void setDefaultLocalMesh() {
137 kernelBufferContext.lsx(0);
138 kernelBufferContext.lsy(0);
139 kernelBufferContext.lsz(0);
140 }
141
142 private void setupComputeRange(KernelContext kernelContext) {
143 NDRange ndRange = kernelContext.getNDRange();
144 if (!(ndRange instanceof NDRange.Range range)) {
145 throw new IllegalArgumentException("NDRange must be of type NDRange.Range");
146 }
147 boolean isLocalMeshDefined = kernelContext.hasLocalMesh();
148 NDRange.Global global = range.global();
149 setGlobalMesh(global);
150 if (isLocalMeshDefined) {
151 setLocalMesh(range.local());
152 } else {
153 setDefaultLocalMesh();
154 }
155 }
156
157 public void dispatch(KernelContext kernelContext, Object[] args) {
158 setupComputeRange(kernelContext);
159 args[0] = this.kernelBufferContext;
160 ArgArray.update(argArray, kernelCallGraph, args);
161 kernelBridge.ndRange(this.argArray);
162 }
163 }
164
165 public Map<KernelCallGraph, CompiledKernel> kernelCallGraphCompiledCodeMap = new HashMap<>();
166
167 private <T extends C99HATKernelBuilder<T>> void generateDeviceTypeStructs(T builder, String toText, Set<String> typedefs) {
168 // From here is text processing
169 String[] split = toText.split(">");
170 // Each item is a data struct
171 for (String s : split) {
172 // curate: remove first character
173 s = s.substring(1);
174 String dsName = s.split(":")[0];
175 if (typedefs.contains(dsName)) {
176 continue;
177 }
178 typedefs.add(dsName);
179 // sanitize dsName
180 dsName = sanitize(dsName);
181 builder.typedefKeyword()
182 .space()
183 .structKeyword()
184 .space()
185 .suffix_s(dsName)
186 .obrace()
187 .nl();
188
189 String[] members = s.split(";");
190
191 int j = 0;
192 builder.in();
193 for (int i = 0; i < members.length; i++) {
194 String member = members[i];
195 String[] field = member.split(":");
196 if (i == 0) {
197 j = 1;
198 }
199 String isArray = field[j++];
200 String type = field[j++];
201 String name = field[j++];
202 String lenValue = "";
203 if (isArray.equals("[")) {
204 lenValue = field[j];
205 }
206 j = 0;
207 if (typedefs.contains(type))
208 type = sanitize(type) + "_t";
209 else
210 type = sanitize(type);
211
212 builder.typeName(type)
213 .space()
214 .identifier(name);
215
216 if (isArray.equals("[")) {
217 builder.space()
218 .osbrace()
219 .identifier(lenValue)
220 .csbrace();
221 }
222 builder.semicolon().nl();
223 }
224 builder.out();
225 builder.cbrace().suffix_t(dsName).semicolon().nl();
226 }
227 }
228
229 public <T extends C99HATKernelBuilder<T>> String createCode(KernelCallGraph kernelCallGraph, T builder, Object... args) {
230 var here = OpTk.CallSite.of(C99FFIBackend.class, "createCode");
231 builder.defines().types();
232 Set<Schema.IfaceType> already = new LinkedHashSet<>();
233 Arrays.stream(args)
234 .filter(arg -> arg instanceof Buffer)
235 .map(arg -> (Buffer) arg)
236 .forEach(ifaceBuffer -> {
237 BoundSchema<?> boundSchema = Buffer.getBoundSchema(ifaceBuffer);
238 boundSchema.schema().rootIfaceType.visitTypes(0, t -> {
239 if (!already.contains(t)) {
240 builder.typedef(boundSchema, t);
241 already.add(t);
242 }
243 });
244 });
245
246 var annotation = kernelCallGraph.entrypoint.method.getAnnotation(Kernel.class);
247
248 if (annotation!=null){
249 var typedef = kernelCallGraph.entrypoint.method.getAnnotation(TypeDef.class);
250 if (typedef!=null){
251 builder.lineComment("Preformatted typedef body from @Typedef annotation");
252 builder.typedefKeyword().space().structKeyword().space().suffix_s(typedef.name()).braceNlIndented(_->
253 builder.preformatted(typedef.body())
254 ).suffix_t(typedef.name()).semicolon().nl();
255 }
256 var preformatted = kernelCallGraph.entrypoint.method.getAnnotation(Preformatted.class);
257 if (preformatted!=null){
258 builder.lineComment("Preformatted text from @Preformatted annotation");
259 builder.preformatted(preformatted.value());
260 }
261 builder.lineComment("Preformatted code body from @Kernel annotation");
262 builder.preformatted(annotation.value());
263 } else {
264 List<TypeElement> localIFaceList = new ArrayList<>();
265
266 kernelCallGraph.getModuleOp()
267 .elements()
268 .filter(c -> Objects.requireNonNull(c) instanceof HATMemoryOp)
269 .map(c -> ((HATMemoryOp) c).invokeType())
270 .forEach(localIFaceList::add);
271
272 kernelCallGraph.entrypoint.funcOp()
273 .elements()
274 .filter(c -> Objects.requireNonNull(c) instanceof HATMemoryOp)
275 .map(c -> ((HATMemoryOp) c).invokeType())
276 .forEach(localIFaceList::add);
277
278 // Dynamically build the schema for the user data type we are creating within the kernel.
279 // This is because no allocation was done from the host. This is kernel code, and it is reflected
280 // using the code reflection API
281 // 1. Add for struct for iface objects
282 Set<String> typedefs = new HashSet<>();
283
284 // Add HAT reserved types
285 typedefs.add(F16.class.getName());
286
287 for (TypeElement typeElement : localIFaceList) {
288 try {
289 // Approach 1: The first approach support iFace and Buffer types to be used in Local and Private memory
290 // TODO: Once we decide to move towards the DeviceType implementation, we will remove this part
291 Class<?> clazz = (Class<?>) ((ClassType) typeElement).resolve(kernelCallGraph.computeContext.accelerator.lookup);
292 Method method = clazz.getMethod("create", hat.Accelerator.class);
293 method.setAccessible(true);
294 Buffer invoke = (Buffer) method.invoke(null, kernelCallGraph.computeContext.accelerator);
295 if (invoke != null) {
296 // code gen of the struct
297 BoundSchema<?> boundSchema = Buffer.getBoundSchema(invoke);
298 boundSchema.schema().rootIfaceType.visitTypes(0, t -> {
299 if (!already.contains(t)) {
300 builder.typedef(boundSchema, t);
301 already.add(t);
302 }
303 });
304 } else {
305 // new approach for supporting DeviceTypes
306 Field schemaField = clazz.getDeclaredField("schema");
307 schemaField.setAccessible(true);
308 Object schema = schemaField.get(schemaField);
309
310 Class<?> deviceSchemaClass = Class.forName(DeviceSchema.class.getName());
311 Method toTextMethod = deviceSchemaClass.getDeclaredMethod("toText");
312 toTextMethod.setAccessible(true);
313 String toText = (String) toTextMethod.invoke(schema);
314 if (toText != null) {
315 generateDeviceTypeStructs(builder, toText, typedefs);
316 } else {
317 throw new RuntimeException("[ERROR] Could not find valid device schema ");
318 }
319 }
320 } catch (ReflectiveOperationException e) {
321 throw new RuntimeException(e);
322 }
323 }
324
325 ScopedCodeBuilderContext buildContext =
326 new ScopedCodeBuilderContext(kernelCallGraph.entrypoint.callGraph.computeContext.accelerator.lookup,
327 kernelCallGraph.entrypoint.funcOp());
328
329 // Sorting by rank ensures we don't need forward declarations
330 kernelCallGraph.getModuleOp().functionTable()
331 .forEach((_, funcOp) -> {
332 // TODO: did we just trash the callgraph sidetables?
333 HATFinalDetectionPhase finals = new HATFinalDetectionPhase(kernelCallGraph.entrypoint.callGraph.computeContext.accelerator);
334 finals.apply(funcOp);
335
336 // Update the build context for this method to use the right constants-map
337 buildContext.setFinals(finals.getFinalVars());
338 builder.nl().kernelMethod(buildContext, funcOp).nl();
339 });
340
341 // Update the constants-map for the main kernel
342 HATFinalDetectionPhase hatFinalDetectionPhase = new HATFinalDetectionPhase(kernelCallGraph.entrypoint.callGraph.computeContext.accelerator);
343 hatFinalDetectionPhase.apply(kernelCallGraph.entrypoint.funcOp());
344 buildContext.setFinals(hatFinalDetectionPhase.getFinalVars());
345 builder.nl().kernelEntrypoint(buildContext, args).nl();
346
347 if (config().showKernelModel()) {
348 IO.println("Original");
349 IO.println(kernelCallGraph.entrypoint.funcOp().toText());
350 }
351 if (config().showLoweredKernelModel()) {
352 IO.println("Lowered");
353 IO.println(OpTk.lower(here, kernelCallGraph.entrypoint.funcOp()).toText());
354 }
355 }
356 return builder.toString();
357 }
358
359
360 private String sanitize(String s) {
361 String[] split1 = s.split("\\.");
362 if (split1.length == 1) {
363 return s;
364 }
365 s = split1[split1.length - 1];
366 if (s.split("\\$").length > 1) {
367 s = sanitize(s.split("\\$")[1]);
368 }
369 return s;
370 }
371
372 @Override
373 public void preMutate(Buffer b) {
374 switch (b.getState()) {
375 case BufferState.NO_STATE:
376 case BufferState.NEW_STATE:
377 case BufferState.HOST_OWNED:
378 case BufferState.DEVICE_VALID_HOST_HAS_COPY: {
379 if (config().showState()) {
380 System.out.println("in preMutate state = " + b.getStateString() + " no action to take");
381 }
382 break;
383 }
384 case BufferState.DEVICE_OWNED: {
385 backendBridge.getBufferFromDeviceIfDirty(b);// calls through FFI and might block when fetching from device
386 if (config().showState()) {
387 System.out.print("in preMutate state = " + b.getStateString() + " we pulled from device ");
388 }
389 b.setState(BufferState.DEVICE_VALID_HOST_HAS_COPY);
390 if (config().showState()) {
391 System.out.println("and switched to " + b.getStateString());
392 }
393 break;
394 }
395 default:
396 throw new IllegalStateException("Not expecting this state ");
397 }
398 }
399
400 @Override
401 public void postMutate(Buffer b) {
402 if (config().showState()) {
403 System.out.print("in postMutate state = " + b.getStateString() + " no action to take ");
404 }
405 if (b.getState() != BufferState.NEW_STATE) {
406 b.setState(BufferState.HOST_OWNED);
407 }
408 if (config().showState()) {
409 System.out.println("and switched to (or stayed on) " + b.getStateString());
410 }
411 }
412
413 @Override
414 public void preAccess(Buffer b) {
415 switch (b.getState()) {
416 case BufferState.NO_STATE:
417 case BufferState.NEW_STATE:
418 case BufferState.HOST_OWNED:
419 case BufferState.DEVICE_VALID_HOST_HAS_COPY: {
420 if (config().showState()) {
421 System.out.println("in preAccess state = " + b.getStateString() + " no action to take");
422 }
423 break;
424 }
425 case BufferState.DEVICE_OWNED: {
426 backendBridge.getBufferFromDeviceIfDirty(b);// calls through FFI and might block when fetching from device
427
428 if (config().showState()) {
429 System.out.print("in preAccess state = " + b.getStateString() + " we pulled from device ");
430 }
431 b.setState(BufferState.DEVICE_VALID_HOST_HAS_COPY);
432 if (config().showState()) {
433 System.out.println("and switched to " + b.getStateString());
434 }
435 break;
436 }
437 default:
438 throw new IllegalStateException("Not expecting this state ");
439 }
440 }
441
442
443 @Override
444 public void postAccess(Buffer b) {
445 if (config().showState()) {
446 System.out.println("in postAccess state = " + b.getStateString());
447 }
448 }
449 }