1 /*
2 * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation. Oracle designates this
8 * particular file as subject to the "Classpath" exception as provided
9 * by Oracle in the LICENSE file that accompanied this code.
10 *
11 * This code is distributed in the hope that it will be useful, but WITHOUT
12 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
13 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
14 * version 2 for more details (a copy is included in the LICENSE file that
15 * accompanied this code).
16 *
17 * You should have received a copy of the GNU General Public License version
18 * 2 along with this work; if not, write to the Free Software Foundation,
19 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
20 *
21 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
22 * or visit www.oracle.com if you need additional information or have any
23 * questions.
24 */
25 package hat.backend.ffi;
26
27
28 import hat.Config;
29 import hat.backend.Backend;
30 import hat.buffer.ArgArray;
31 import optkl.ifacemapper.MappableIface;
32
33 import java.lang.foreign.Arena;
34 import java.lang.foreign.MemorySegment;
35 import java.lang.invoke.MethodHandles;
36 import java.util.HashMap;
37 import java.util.Map;
38
39 public abstract class FFIBackendDriver extends Backend {
40 public boolean isAvailable() {
41 return ffiLib.available;
42 }
43
44 public static class BackendBridge {
45 // CUDA this combines Device+Stream+Context
46 // OpenCL this combines Platform+Device+Queue+Context
47 public static class CompilationUnitBridge {
48 // CUDA calls this a Module
49 // OpenCL calls this a program
50 public static class KernelBridge {
51 // CUDA calls this a Function
52 // OpenCL calls this a Program
53 CompilationUnitBridge compilationUnitBridge;
54 long handle;
55 final FFILib.VoidHandleMethodPtr releaseKernel_MPtr;
56 String name;
57 final FFILib.LongHandleLongAddressMethodPtr ndrange_MPtr;
58 KernelBridge(CompilationUnitBridge compilationUnitBridge, String name, long handle) {
59 this.compilationUnitBridge = compilationUnitBridge;
60 this.handle = handle;
61 this.releaseKernel_MPtr = compilationUnitBridge.backendBridge.ffiLib.voidHandleFunc("releaseKernel");
62 this.ndrange_MPtr = compilationUnitBridge.backendBridge.ffiLib.longHandleLongAddressFunc("ndrange");
63 this.name = name;
64 }
65 public void ndRange(ArgArray argArray) {
66 this.ndrange_MPtr.invoke(handle, MappableIface.getMemorySegment(argArray));
67 }
68 void release() {
69 releaseKernel_MPtr.invoke(handle);
70 }
71 }
72
73 BackendBridge backendBridge;
74 String source;
75 final FFILib.VoidHandleMethodPtr releaseCompilationUnit_MPtr;
76 final FFILib.BooleanHandleMethodPtr compilationUnitOK_MPtr;
77 final FFILib.LongHandleIntAddressMethodPtr getKernel_MPtr;
78 long handle;
79 Map<String, KernelBridge> kernels = new HashMap<>();
80
81 CompilationUnitBridge(BackendBridge backendBridge, long handle, String source) {
82 this.backendBridge = backendBridge;
83 this.handle = handle;
84 this.source = source;
85 this.releaseCompilationUnit_MPtr = backendBridge.ffiLib.voidHandleFunc("releaseCompilationUnit");
86 this.compilationUnitOK_MPtr = backendBridge.ffiLib.booleanHandleFunc("compilationUnitOK");
87 this.getKernel_MPtr = backendBridge.ffiLib.longHandleIntAddressFunc("getKernel");
88 }
89 void release() {
90 this.releaseCompilationUnit_MPtr.invoke(handle);
91 }
92 boolean ok() {
93 return this.compilationUnitOK_MPtr.invoke(handle);
94 }
95 public KernelBridge getKernel(String kernelName) {
96 return kernels.computeIfAbsent(kernelName, _ ->
97 new KernelBridge(this, kernelName,
98 getKernel_MPtr.invoke(handle, kernelName.length(), Arena.global().allocateFrom(kernelName)))
99 );
100 }
101 }
102
103 final FFILib ffiLib;
104 final long handle;
105
106 final Map<Long, CompilationUnitBridge> compilationUnits = new HashMap<>();
107 final FFILib.LongHandleIntMethodPtr getBackend_MPtr;
108 final FFILib.LongHandleIntAddressMethodPtr compile_MPtr;
109 final FFILib.VoidHandleMethodPtr computeStart_MPtr;
110 final FFILib.VoidHandleMethodPtr computeEnd_MPtr;
111
112 final FFILib.VoidHandleMethodPtr showDeviceInfo_MPtr;
113 final FFILib.BooleanHandleAddressLongMethodPtr getBufferFromDeviceIfDirty_MPtr;
114 BackendBridge(FFILib ffiLib, Config config) {
115 this.ffiLib = ffiLib;
116 this.getBackend_MPtr = ffiLib.longHandleIntFunc("getBackend");
117 if (this.getBackend_MPtr.mh == null) {
118 throw new RuntimeException("No getBackend()");
119 }
120 this.handle = getBackend(config.bits());
121 this.compile_MPtr = ffiLib.longHandleIntAddressFunc("compile");
122 this.showDeviceInfo_MPtr = ffiLib.voidHandleFunc("showDeviceInfo");
123 this.computeStart_MPtr = ffiLib.voidHandleFunc("computeStart");
124 this.computeEnd_MPtr = ffiLib.voidHandleFunc("computeEnd");
125 this.getBufferFromDeviceIfDirty_MPtr = ffiLib.booleanHandleAddressLongFunc("getBufferFromDeviceIfDirty");
126 }
127
128 void release() {}
129
130 public long getBackend(int configBits) {
131 return getBackend_MPtr.invoke(configBits);
132 }
133
134 private CompilationUnitBridge compilationUnit(long handle, String source) {
135 return compilationUnits.computeIfAbsent(handle, _ ->
136 new CompilationUnitBridge(this, handle, source)
137 );
138 }
139
140 public CompilationUnitBridge compile(String source) {
141 var compilationUnitHandle = compile_MPtr.invoke(handle, source.length(), Arena.global().allocateFrom(source));
142 return compilationUnit(compilationUnitHandle, source);
143 }
144
145 public MappableIface getBufferFromDeviceIfDirty(MappableIface buffer) {
146 MemorySegment memorySegment = MappableIface.getMemorySegment(buffer);
147 if (!getBufferFromDeviceIfDirty_MPtr.invoke(handle, memorySegment, memorySegment.byteSize())){
148 throw new IllegalStateException("Failed to get buffer from backend");
149 }
150 return buffer;
151
152 }
153 public void computeStart() {
154 computeStart_MPtr.invoke(handle);
155 }
156 public void computeEnd() {
157 computeEnd_MPtr.invoke(handle);
158 }
159 public void showDeviceInfo() {
160 showDeviceInfo_MPtr.invoke(handle);
161 }
162 }
163
164 public final FFILib ffiLib;
165 public final BackendBridge backendBridge;
166
167 public FFIBackendDriver(Arena arena, MethodHandles.Lookup lookup,String libName, Config config) {
168 super(arena,lookup,config);
169 this.ffiLib = new FFILib(libName);
170 this.backendBridge = new BackendBridge(ffiLib, config);
171 }
172
173 }