1 /* 2 * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved. 3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. 4 * 5 * This code is free software; you can redistribute it and/or modify it 6 * under the terms of the GNU General Public License version 2 only, as 7 * published by the Free Software Foundation. Oracle designates this 8 * particular file as subject to the "Classpath" exception as provided 9 * by Oracle in the LICENSE file that accompanied this code. 10 * 11 * This code is distributed in the hope that it will be useful, but WITHOUT 12 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or 13 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License 14 * version 2 for more details (a copy is included in the LICENSE file that 15 * accompanied this code). 16 * 17 * You should have received a copy of the GNU General Public License version 18 * 2 along with this work; if not, write to the Free Software Foundation, 19 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. 20 * 21 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA 22 * or visit www.oracle.com if you need additional information or have any 23 * questions. 24 */ 25 package hat.backend.ffi; 26 27 28 import hat.backend.Backend; 29 import hat.buffer.ArgArray; 30 import hat.buffer.Buffer; 31 32 import java.lang.foreign.Arena; 33 import java.lang.foreign.MemorySegment; 34 import java.util.HashMap; 35 import java.util.Map; 36 37 public abstract class FFIBackendDriver implements Backend { 38 public boolean isAvailable() { 39 return ffiLib.available; 40 } 41 protected final Config config; 42 43 public static class BackendBridge { 44 // CUDA this combines Device+Stream+Context 45 // OpenCL this combines Platform+Device+Queue+Context 46 public static class CompilationUnitBridge { 47 // CUDA calls this a Module 48 // OpenCL calls this a program 49 public static class KernelBridge { 50 // CUDA calls this a Function 51 // OpenCL calls this a Program 52 CompilationUnitBridge compilationUnitBridge; 53 long handle; 54 final FFILib.VoidHandleMethodPtr releaseKernel_MPtr; 55 String name; 56 final FFILib.LongHandleLongAddressMethodPtr ndrange_MPtr; 57 KernelBridge(CompilationUnitBridge compilationUnitBridge, String name, long handle) { 58 this.compilationUnitBridge = compilationUnitBridge; 59 this.handle = handle; 60 this.releaseKernel_MPtr = compilationUnitBridge.backendBridge.ffiLib.voidHandleFunc("releaseKernel"); 61 this.ndrange_MPtr = compilationUnitBridge.backendBridge.ffiLib.longHandleLongAddressFunc("ndrange"); 62 this.name = name; 63 } 64 65 66 67 public void ndRange(ArgArray argArray) { 68 this.ndrange_MPtr.invoke(handle, Buffer.getMemorySegment(argArray)); 69 } 70 void release() { 71 releaseKernel_MPtr.invoke(handle); 72 } 73 } 74 75 BackendBridge backendBridge; 76 String source; 77 final FFILib.VoidHandleMethodPtr releaseCompilationUnit_MPtr; 78 final FFILib.BooleanHandleMethodPtr compilationUnitOK_MPtr; 79 final FFILib.LongHandleIntAddressMethodPtr getKernel_MPtr; 80 81 82 long handle; 83 Map<String, KernelBridge> kernels = new HashMap<>(); 84 85 CompilationUnitBridge(BackendBridge backendBridge, long handle, String source) { 86 this.backendBridge = backendBridge; 87 this.handle = handle; 88 this.source = source; 89 this.releaseCompilationUnit_MPtr = backendBridge.ffiLib.voidHandleFunc("releaseCompilationUnit"); 90 this.compilationUnitOK_MPtr = backendBridge.ffiLib.booleanHandleFunc("compilationUnitOK"); 91 this.getKernel_MPtr = backendBridge.ffiLib.longHandleIntAddressFunc("getKernel"); 92 } 93 94 void release() { 95 this.releaseCompilationUnit_MPtr.invoke(handle); 96 } 97 98 boolean ok() { 99 return this.compilationUnitOK_MPtr.invoke(handle); 100 } 101 102 public KernelBridge getKernel(String kernelName) { 103 KernelBridge kernelBridge = kernels.computeIfAbsent(kernelName, _ -> 104 new KernelBridge(this, kernelName, 105 getKernel_MPtr.invoke(handle, kernelName.length(), Arena.global().allocateFrom(kernelName))) 106 ); 107 return kernelBridge; 108 } 109 110 111 } 112 113 final FFILib ffiLib; 114 final long handle; 115 116 final Map<Long, CompilationUnitBridge> compilationUnits = new HashMap<>(); 117 final FFILib.LongHandleIntMethodPtr getBackend_MPtr; 118 final FFILib.LongHandleIntAddressMethodPtr compile_MPtr; 119 final FFILib.VoidHandleMethodPtr computeStart_MPtr; 120 final FFILib.VoidHandleMethodPtr computeEnd_MPtr; 121 final FFILib.VoidAddressMethodPtr dumpArgArray_MPtr; 122 /* 123 final FFILib.LongIntMethodPtr getBackend_MPtr; 124 getBackend_MPtr = ffiLib.longIntFunc("getBackend"); 125 public long getBackend(int configBits) { 126 return backendBridge.handle = getBackend_MPtr.invoke(configBits); 127 } 128 129 */ 130 final FFILib.VoidHandleMethodPtr info_MPtr; 131 final FFILib.BooleanHandleAddressLongMethodPtr getBufferFromDeviceIfDirty_MPtr; 132 BackendBridge(FFILib ffiLib, Config config) { 133 this.ffiLib = ffiLib; 134 this.getBackend_MPtr = ffiLib.longHandleIntFunc("getBackend"); 135 if (this.getBackend_MPtr.mh == null) { 136 throw new RuntimeException("No getBackend()"); 137 } 138 this.handle = getBackend(config.bits()); 139 this.compile_MPtr = ffiLib.longHandleIntAddressFunc("compile"); 140 this.dumpArgArray_MPtr = ffiLib.voidAddressFunc("dumpArgArray"); 141 this.info_MPtr = ffiLib.voidHandleFunc("info"); 142 this.computeStart_MPtr = ffiLib.voidHandleFunc("computeStart"); 143 this.computeEnd_MPtr = ffiLib.voidHandleFunc("computeEnd"); 144 this.getBufferFromDeviceIfDirty_MPtr = ffiLib.booleanHandleAddressLongFunc("getBufferFromDeviceIfDirty"); 145 } 146 147 148 void release() { 149 150 } 151 152 public long getBackend(int configBits) { 153 return getBackend_MPtr.invoke(configBits); 154 } 155 156 private CompilationUnitBridge compilationUnit(long handle, String source) { 157 return compilationUnits.computeIfAbsent(handle, _ -> 158 new CompilationUnitBridge(this, handle, source) 159 ); 160 } 161 162 public CompilationUnitBridge compile(String source) { 163 var compilationUnitHandle = compile_MPtr.invoke(handle, source.length(), Arena.global().allocateFrom(source)); 164 return compilationUnit(compilationUnitHandle, source); 165 } 166 167 public Buffer getBufferFromDeviceIfDirty(Buffer buffer) { 168 MemorySegment memorySegment = Buffer.getMemorySegment(buffer); 169 boolean ok = getBufferFromDeviceIfDirty_MPtr.invoke(handle, memorySegment, memorySegment.byteSize()); 170 if (!ok) { 171 throw new IllegalStateException("Failed to get buffer from backend"); 172 } 173 return buffer; 174 175 } 176 177 public void computeStart() { 178 computeStart_MPtr.invoke(handle); 179 } 180 181 public void computeEnd() { 182 computeEnd_MPtr.invoke(handle); 183 } 184 185 public void info() { 186 info_MPtr.invoke(handle); 187 } 188 189 public void dumpArgArray(ArgArray argArray) { 190 dumpArgArray_MPtr.invoke(Buffer.getMemorySegment(argArray)); 191 } 192 193 194 } 195 196 197 198 public final FFILib ffiLib; 199 public final BackendBridge backendBridge; 200 201 public FFIBackendDriver(String libName, Config config) { 202 this.ffiLib = new FFILib(libName); 203 this.config = config; 204 this.backendBridge = new BackendBridge(ffiLib, config); 205 206 } 207 208 209 }