1 /*
2 * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation. Oracle designates this
8 * particular file as subject to the "Classpath" exception as provided
9 * by Oracle in the LICENSE file that accompanied this code.
10 *
11 * This code is distributed in the hope that it will be useful, but WITHOUT
12 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
13 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
14 * version 2 for more details (a copy is included in the LICENSE file that
15 * accompanied this code).
16 *
17 * You should have received a copy of the GNU General Public License version
18 * 2 along with this work; if not, write to the Free Software Foundation,
19 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
20 *
21 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
22 * or visit www.oracle.com if you need additional information or have any
23 * questions.
24 */
25 package hat.backend.ffi;
26
27
28 import hat.Config;
29 import hat.backend.Backend;
30 import hat.buffer.ArgArray;
31 import hat.buffer.Buffer;
32
33 import java.lang.foreign.Arena;
34 import java.lang.foreign.MemorySegment;
35 import java.util.HashMap;
36 import java.util.Map;
37
38 public abstract class FFIBackendDriver extends Backend {
39 public boolean isAvailable() {
40 return ffiLib.available;
41 }
42
43 public static class BackendBridge {
44 // CUDA this combines Device+Stream+Context
45 // OpenCL this combines Platform+Device+Queue+Context
46 public static class CompilationUnitBridge {
47 // CUDA calls this a Module
48 // OpenCL calls this a program
49 public static class KernelBridge {
50 // CUDA calls this a Function
51 // OpenCL calls this a Program
52 CompilationUnitBridge compilationUnitBridge;
53 long handle;
54 final FFILib.VoidHandleMethodPtr releaseKernel_MPtr;
55 String name;
56 final FFILib.LongHandleLongAddressMethodPtr ndrange_MPtr;
57 KernelBridge(CompilationUnitBridge compilationUnitBridge, String name, long handle) {
58 this.compilationUnitBridge = compilationUnitBridge;
59 this.handle = handle;
60 this.releaseKernel_MPtr = compilationUnitBridge.backendBridge.ffiLib.voidHandleFunc("releaseKernel");
61 this.ndrange_MPtr = compilationUnitBridge.backendBridge.ffiLib.longHandleLongAddressFunc("ndrange");
62 this.name = name;
63 }
64 public void ndRange(ArgArray argArray) {
65 this.ndrange_MPtr.invoke(handle, Buffer.getMemorySegment(argArray));
66 }
67 void release() {
68 releaseKernel_MPtr.invoke(handle);
69 }
70 }
71
72 BackendBridge backendBridge;
73 String source;
74 final FFILib.VoidHandleMethodPtr releaseCompilationUnit_MPtr;
75 final FFILib.BooleanHandleMethodPtr compilationUnitOK_MPtr;
76 final FFILib.LongHandleIntAddressMethodPtr getKernel_MPtr;
77 long handle;
78 Map<String, KernelBridge> kernels = new HashMap<>();
79
80 CompilationUnitBridge(BackendBridge backendBridge, long handle, String source) {
81 this.backendBridge = backendBridge;
82 this.handle = handle;
83 this.source = source;
84 this.releaseCompilationUnit_MPtr = backendBridge.ffiLib.voidHandleFunc("releaseCompilationUnit");
85 this.compilationUnitOK_MPtr = backendBridge.ffiLib.booleanHandleFunc("compilationUnitOK");
86 this.getKernel_MPtr = backendBridge.ffiLib.longHandleIntAddressFunc("getKernel");
87 }
88 void release() {
89 this.releaseCompilationUnit_MPtr.invoke(handle);
90 }
91 boolean ok() {
92 return this.compilationUnitOK_MPtr.invoke(handle);
93 }
94 public KernelBridge getKernel(String kernelName) {
95 return kernels.computeIfAbsent(kernelName, _ ->
96 new KernelBridge(this, kernelName,
97 getKernel_MPtr.invoke(handle, kernelName.length(), Arena.global().allocateFrom(kernelName)))
98 );
99 }
100 }
101
102 final FFILib ffiLib;
103 final long handle;
104
105 final Map<Long, CompilationUnitBridge> compilationUnits = new HashMap<>();
106 final FFILib.LongHandleIntMethodPtr getBackend_MPtr;
107 final FFILib.LongHandleIntAddressMethodPtr compile_MPtr;
108 final FFILib.VoidHandleMethodPtr computeStart_MPtr;
109 final FFILib.VoidHandleMethodPtr computeEnd_MPtr;
110
111 final FFILib.VoidHandleMethodPtr showDeviceInfo_MPtr;
112 final FFILib.BooleanHandleAddressLongMethodPtr getBufferFromDeviceIfDirty_MPtr;
113 BackendBridge(FFILib ffiLib, Config config) {
114 this.ffiLib = ffiLib;
115 this.getBackend_MPtr = ffiLib.longHandleIntFunc("getBackend");
116 if (this.getBackend_MPtr.mh == null) {
117 throw new RuntimeException("No getBackend()");
118 }
119 this.handle = getBackend(config.bits());
120 this.compile_MPtr = ffiLib.longHandleIntAddressFunc("compile");
121 this.showDeviceInfo_MPtr = ffiLib.voidHandleFunc("showDeviceInfo");
122 this.computeStart_MPtr = ffiLib.voidHandleFunc("computeStart");
123 this.computeEnd_MPtr = ffiLib.voidHandleFunc("computeEnd");
124 this.getBufferFromDeviceIfDirty_MPtr = ffiLib.booleanHandleAddressLongFunc("getBufferFromDeviceIfDirty");
125 }
126
127 void release() {}
128
129 public long getBackend(int configBits) {
130 return getBackend_MPtr.invoke(configBits);
131 }
132
133 private CompilationUnitBridge compilationUnit(long handle, String source) {
134 return compilationUnits.computeIfAbsent(handle, _ ->
135 new CompilationUnitBridge(this, handle, source)
136 );
137 }
138
139 public CompilationUnitBridge compile(String source) {
140 var compilationUnitHandle = compile_MPtr.invoke(handle, source.length(), Arena.global().allocateFrom(source));
141 return compilationUnit(compilationUnitHandle, source);
142 }
143
144 public Buffer getBufferFromDeviceIfDirty(Buffer buffer) {
145 MemorySegment memorySegment = Buffer.getMemorySegment(buffer);
146 if (!getBufferFromDeviceIfDirty_MPtr.invoke(handle, memorySegment, memorySegment.byteSize())){
147 throw new IllegalStateException("Failed to get buffer from backend");
148 }
149 return buffer;
150
151 }
152 public void computeStart() {
153 computeStart_MPtr.invoke(handle);
154 }
155 public void computeEnd() {
156 computeEnd_MPtr.invoke(handle);
157 }
158 public void showDeviceInfo() {
159 showDeviceInfo_MPtr.invoke(handle);
160 }
161 }
162
163 public final FFILib ffiLib;
164 public final BackendBridge backendBridge;
165
166 public FFIBackendDriver(String libName, Config config) {
167 super(config);
168 this.ffiLib = new FFILib(libName);
169 this.backendBridge = new BackendBridge(ffiLib, config);
170 }
171
172 }