1 /* 2 * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved. 3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. 4 * 5 * This code is free software; you can redistribute it and/or modify it 6 * under the terms of the GNU General Public License version 2 only, as 7 * published by the Free Software Foundation. Oracle designates this 8 * particular file as subject to the "Classpath" exception as provided 9 * by Oracle in the LICENSE file that accompanied this code. 10 * 11 * This code is distributed in the hope that it will be useful, but WITHOUT 12 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or 13 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License 14 * version 2 for more details (a copy is included in the LICENSE file that 15 * accompanied this code). 16 * 17 * You should have received a copy of the GNU General Public License version 18 * 2 along with this work; if not, write to the Free Software Foundation, 19 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. 20 * 21 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA 22 * or visit www.oracle.com if you need additional information or have any 23 * questions. 24 */ 25 package hat.backend.ffi; 26 27 28 import hat.backend.Backend; 29 import hat.buffer.ArgArray; 30 import hat.buffer.Buffer; 31 32 import java.lang.foreign.Arena; 33 import java.lang.foreign.MemorySegment; 34 import java.util.HashMap; 35 import java.util.Map; 36 37 public abstract class FFIBackendDriver implements Backend { 38 public boolean isAvailable() { 39 return ffiLib.available; 40 } 41 protected final Config config; 42 43 public static class BackendBridge { 44 // CUDA this combines Device+Stream+Context 45 // OpenCL this combines Platform+Device+Queue+Context 46 public static class CompilationUnitBridge { 47 // CUDA calls this a Module 48 // OpenCL calls this a program 49 public static class KernelBridge { 50 // CUDA calls this a Function 51 // OpenCL calls this a Program 52 CompilationUnitBridge compilationUnitBridge; 53 long handle; 54 final FFILib.VoidHandleMethodPtr releaseKernel_MPtr; 55 String name; 56 final FFILib.LongHandleLongAddressMethodPtr ndrange_MPtr; 57 KernelBridge(CompilationUnitBridge compilationUnitBridge, String name, long handle) { 58 this.compilationUnitBridge = compilationUnitBridge; 59 this.handle = handle; 60 this.releaseKernel_MPtr = compilationUnitBridge.backendBridge.ffiLib.voidHandleFunc("releaseKernel"); 61 this.ndrange_MPtr = compilationUnitBridge.backendBridge.ffiLib.longHandleLongAddressFunc("ndrange"); 62 this.name = name; 63 } 64 public void ndRange(ArgArray argArray) { 65 this.ndrange_MPtr.invoke(handle, Buffer.getMemorySegment(argArray)); 66 } 67 void release() { 68 releaseKernel_MPtr.invoke(handle); 69 } 70 } 71 72 BackendBridge backendBridge; 73 String source; 74 final FFILib.VoidHandleMethodPtr releaseCompilationUnit_MPtr; 75 final FFILib.BooleanHandleMethodPtr compilationUnitOK_MPtr; 76 final FFILib.LongHandleIntAddressMethodPtr getKernel_MPtr; 77 long handle; 78 Map<String, KernelBridge> kernels = new HashMap<>(); 79 80 CompilationUnitBridge(BackendBridge backendBridge, long handle, String source) { 81 this.backendBridge = backendBridge; 82 this.handle = handle; 83 this.source = source; 84 this.releaseCompilationUnit_MPtr = backendBridge.ffiLib.voidHandleFunc("releaseCompilationUnit"); 85 this.compilationUnitOK_MPtr = backendBridge.ffiLib.booleanHandleFunc("compilationUnitOK"); 86 this.getKernel_MPtr = backendBridge.ffiLib.longHandleIntAddressFunc("getKernel"); 87 } 88 void release() { 89 this.releaseCompilationUnit_MPtr.invoke(handle); 90 } 91 boolean ok() { 92 return this.compilationUnitOK_MPtr.invoke(handle); 93 } 94 public KernelBridge getKernel(String kernelName) { 95 return kernels.computeIfAbsent(kernelName, _ -> 96 new KernelBridge(this, kernelName, 97 getKernel_MPtr.invoke(handle, kernelName.length(), Arena.global().allocateFrom(kernelName))) 98 ); 99 } 100 } 101 102 final FFILib ffiLib; 103 final long handle; 104 105 final Map<Long, CompilationUnitBridge> compilationUnits = new HashMap<>(); 106 final FFILib.LongHandleIntMethodPtr getBackend_MPtr; 107 final FFILib.LongHandleIntAddressMethodPtr compile_MPtr; 108 final FFILib.VoidHandleMethodPtr computeStart_MPtr; 109 final FFILib.VoidHandleMethodPtr computeEnd_MPtr; 110 111 final FFILib.VoidHandleMethodPtr info_MPtr; 112 final FFILib.BooleanHandleAddressLongMethodPtr getBufferFromDeviceIfDirty_MPtr; 113 BackendBridge(FFILib ffiLib, Config config) { 114 this.ffiLib = ffiLib; 115 this.getBackend_MPtr = ffiLib.longHandleIntFunc("getBackend"); 116 if (this.getBackend_MPtr.mh == null) { 117 throw new RuntimeException("No getBackend()"); 118 } 119 this.handle = getBackend(config.bits()); 120 this.compile_MPtr = ffiLib.longHandleIntAddressFunc("compile"); 121 this.info_MPtr = ffiLib.voidHandleFunc("info"); 122 this.computeStart_MPtr = ffiLib.voidHandleFunc("computeStart"); 123 this.computeEnd_MPtr = ffiLib.voidHandleFunc("computeEnd"); 124 this.getBufferFromDeviceIfDirty_MPtr = ffiLib.booleanHandleAddressLongFunc("getBufferFromDeviceIfDirty"); 125 } 126 127 128 void release() {} 129 130 public long getBackend(int configBits) { 131 return getBackend_MPtr.invoke(configBits); 132 } 133 134 private CompilationUnitBridge compilationUnit(long handle, String source) { 135 return compilationUnits.computeIfAbsent(handle, _ -> 136 new CompilationUnitBridge(this, handle, source) 137 ); 138 } 139 140 public CompilationUnitBridge compile(String source) { 141 var compilationUnitHandle = compile_MPtr.invoke(handle, source.length(), Arena.global().allocateFrom(source)); 142 return compilationUnit(compilationUnitHandle, source); 143 } 144 145 public Buffer getBufferFromDeviceIfDirty(Buffer buffer) { 146 MemorySegment memorySegment = Buffer.getMemorySegment(buffer); 147 if (!getBufferFromDeviceIfDirty_MPtr.invoke(handle, memorySegment, memorySegment.byteSize())){ 148 throw new IllegalStateException("Failed to get buffer from backend"); 149 } 150 return buffer; 151 152 } 153 public void computeStart() { 154 computeStart_MPtr.invoke(handle); 155 } 156 public void computeEnd() { 157 computeEnd_MPtr.invoke(handle); 158 } 159 public void info() { 160 info_MPtr.invoke(handle); 161 } 162 } 163 164 public final FFILib ffiLib; 165 public final BackendBridge backendBridge; 166 167 public FFIBackendDriver(String libName, Config config) { 168 this.ffiLib = new FFILib(libName); 169 this.config = config; 170 this.backendBridge = new BackendBridge(ffiLib, config); 171 } 172 173 }