1 /*
  2  * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 package hat.backend.ffi;
 26 
 27 
 28 import hat.backend.Backend;
 29 import hat.buffer.ArgArray;
 30 import hat.buffer.Buffer;
 31 
 32 import java.lang.foreign.Arena;
 33 import java.lang.foreign.MemorySegment;
 34 import java.util.HashMap;
 35 import java.util.Map;
 36 
 37 public abstract class FFIBackendDriver implements Backend {
 38     public boolean isAvailable() {
 39         return ffiLib.available;
 40     }
 41     protected final Config config;
 42 
 43     public static class BackendBridge {
 44         // CUDA this combines Device+Stream+Context
 45         // OpenCL this combines Platform+Device+Queue+Context
 46         public static class CompilationUnitBridge {
 47             // CUDA calls this a Module
 48             // OpenCL calls this a program
 49             public static class KernelBridge {
 50                 // CUDA calls this a Function
 51                 // OpenCL calls this a Program
 52                 CompilationUnitBridge compilationUnitBridge;
 53                 long handle;
 54                 final FFILib.VoidHandleMethodPtr releaseKernel_MPtr;
 55                 String name;
 56                 final FFILib.LongHandleLongAddressMethodPtr ndrange_MPtr;
 57                 KernelBridge(CompilationUnitBridge compilationUnitBridge, String name, long handle) {
 58                     this.compilationUnitBridge = compilationUnitBridge;
 59                     this.handle = handle;
 60                     this.releaseKernel_MPtr = compilationUnitBridge.backendBridge.ffiLib.voidHandleFunc("releaseKernel");
 61                     this.ndrange_MPtr = compilationUnitBridge.backendBridge.ffiLib.longHandleLongAddressFunc("ndrange");
 62                     this.name = name;
 63                 }
 64                 public void ndRange(ArgArray argArray) {
 65                     this.ndrange_MPtr.invoke(handle, Buffer.getMemorySegment(argArray));
 66                 }
 67                 void release() {
 68                     releaseKernel_MPtr.invoke(handle);
 69                 }
 70             }
 71 
 72             BackendBridge backendBridge;
 73             String source;
 74             final FFILib.VoidHandleMethodPtr releaseCompilationUnit_MPtr;
 75             final FFILib.BooleanHandleMethodPtr compilationUnitOK_MPtr;
 76             final FFILib.LongHandleIntAddressMethodPtr getKernel_MPtr;
 77             long handle;
 78             Map<String, KernelBridge> kernels = new HashMap<>();
 79 
 80             CompilationUnitBridge(BackendBridge backendBridge, long handle, String source) {
 81                 this.backendBridge = backendBridge;
 82                 this.handle = handle;
 83                 this.source = source;
 84                 this.releaseCompilationUnit_MPtr = backendBridge.ffiLib.voidHandleFunc("releaseCompilationUnit");
 85                 this.compilationUnitOK_MPtr = backendBridge.ffiLib.booleanHandleFunc("compilationUnitOK");
 86                 this.getKernel_MPtr = backendBridge.ffiLib.longHandleIntAddressFunc("getKernel");
 87             }
 88             void release() {
 89                 this.releaseCompilationUnit_MPtr.invoke(handle);
 90             }
 91             boolean ok() {
 92                 return this.compilationUnitOK_MPtr.invoke(handle);
 93             }
 94             public KernelBridge getKernel(String kernelName) {
 95                 return kernels.computeIfAbsent(kernelName, _ ->
 96                         new KernelBridge(this, kernelName,
 97                                 getKernel_MPtr.invoke(handle, kernelName.length(), Arena.global().allocateFrom(kernelName)))
 98                 );
 99             }
100         }
101 
102         final FFILib ffiLib;
103         final long handle;
104 
105         final Map<Long, CompilationUnitBridge> compilationUnits = new HashMap<>();
106         final FFILib.LongHandleIntMethodPtr getBackend_MPtr;
107         final FFILib.LongHandleIntAddressMethodPtr compile_MPtr;
108         final FFILib.VoidHandleMethodPtr computeStart_MPtr;
109         final FFILib.VoidHandleMethodPtr computeEnd_MPtr;
110 
111         final FFILib.VoidHandleMethodPtr info_MPtr;
112         final FFILib.BooleanHandleAddressLongMethodPtr getBufferFromDeviceIfDirty_MPtr;
113         BackendBridge(FFILib ffiLib, Config config) {
114             this.ffiLib = ffiLib;
115             this.getBackend_MPtr = ffiLib.longHandleIntFunc("getBackend");
116             if (this.getBackend_MPtr.mh == null) {
117                 throw new RuntimeException("No getBackend()");
118             }
119             this.handle = getBackend(config.bits());
120             this.compile_MPtr = ffiLib.longHandleIntAddressFunc("compile");
121             this.info_MPtr = ffiLib.voidHandleFunc("info");
122             this.computeStart_MPtr = ffiLib.voidHandleFunc("computeStart");
123             this.computeEnd_MPtr = ffiLib.voidHandleFunc("computeEnd");
124             this.getBufferFromDeviceIfDirty_MPtr = ffiLib.booleanHandleAddressLongFunc("getBufferFromDeviceIfDirty");
125         }
126 
127 
128         void release() {}
129 
130         public long getBackend(int configBits) {
131             return getBackend_MPtr.invoke(configBits);
132         }
133 
134         private CompilationUnitBridge compilationUnit(long handle, String source) {
135             return compilationUnits.computeIfAbsent(handle, _ ->
136                     new CompilationUnitBridge(this, handle, source)
137             );
138         }
139 
140         public CompilationUnitBridge compile(String source) {
141             var compilationUnitHandle = compile_MPtr.invoke(handle, source.length(), Arena.global().allocateFrom(source));
142             return compilationUnit(compilationUnitHandle, source);
143         }
144 
145         public Buffer getBufferFromDeviceIfDirty(Buffer buffer) {
146             MemorySegment memorySegment = Buffer.getMemorySegment(buffer);
147             if (!getBufferFromDeviceIfDirty_MPtr.invoke(handle, memorySegment, memorySegment.byteSize())){
148                 throw new IllegalStateException("Failed to get buffer from backend");
149             }
150             return buffer;
151 
152         }
153         public void computeStart() {
154             computeStart_MPtr.invoke(handle);
155         }
156         public void computeEnd() {
157             computeEnd_MPtr.invoke(handle);
158         }
159         public void info() {
160             info_MPtr.invoke(handle);
161         }
162     }
163 
164     public final FFILib ffiLib;
165     public final BackendBridge backendBridge;
166 
167     public FFIBackendDriver(String libName, Config config) {
168         this.ffiLib = new FFILib(libName);
169         this.config = config;
170         this.backendBridge = new BackendBridge(ffiLib, config);
171     }
172 
173 }