1 /*
  2  * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 package hat.backend.ffi;
 26 
 27 
 28 import hat.backend.Backend;
 29 import hat.buffer.ArgArray;
 30 import hat.buffer.Buffer;
 31 
 32 import java.lang.foreign.Arena;
 33 import java.lang.foreign.MemorySegment;
 34 import java.util.HashMap;
 35 import java.util.Map;
 36 
 37 public abstract class FFIBackendDriver implements Backend {
 38     public boolean isAvailable() {
 39         return ffiLib.available;
 40     }
 41 protected final Config config;
 42 
 43     public static class BackendBridge {
 44         // CUDA this combines Device+Stream+Context
 45         // OpenCL this combines Platform+Device+Queue+Context
 46         public static class CompilationUnitBridge {
 47             // CUDA calls this a Module
 48             // OpenCL calls this a program
 49             public static class KernelBridge {
 50                 // CUDA calls this a Function
 51                 // OpenCL calls this a Program
 52                 CompilationUnitBridge compilationUnitBridge;
 53                 long handle;
 54                 final FFILib.VoidHandleMethodPtr releaseKernel_MPtr;
 55                 String name;
 56                 final FFILib.LongHandleLongAddressMethodPtr ndrange_MPtr;
 57                 KernelBridge(CompilationUnitBridge compilationUnitBridge, String name, long handle) {
 58                     this.compilationUnitBridge = compilationUnitBridge;
 59                     this.handle = handle;
 60                     this.releaseKernel_MPtr = compilationUnitBridge.backendBridge.ffiLib.voidHandleFunc("releaseKernel");
 61                     this.ndrange_MPtr = compilationUnitBridge.backendBridge.ffiLib.longHandleLongAddressFunc("ndrange");
 62                     this.name = name;
 63                 }
 64 
 65 
 66 
 67                 public void ndRange(ArgArray argArray) {
 68                     this.ndrange_MPtr.invoke(handle, Buffer.getMemorySegment(argArray));
 69                 }
 70                 void release() {
 71                     releaseKernel_MPtr.invoke(handle);
 72                 }
 73             }
 74 
 75             BackendBridge backendBridge;
 76             String source;
 77             final FFILib.VoidHandleMethodPtr releaseCompilationUnit_MPtr;
 78             final FFILib.BooleanHandleMethodPtr compilationUnitOK_MPtr;
 79             final FFILib.LongHandleIntAddressMethodPtr getKernel_MPtr;
 80 
 81 
 82             long handle;
 83             Map<String, KernelBridge> kernels = new HashMap<>();
 84 
 85             CompilationUnitBridge(BackendBridge backendBridge, long handle, String source) {
 86                 this.backendBridge = backendBridge;
 87                 this.handle = handle;
 88                 this.source = source;
 89                 this.releaseCompilationUnit_MPtr = backendBridge.ffiLib.voidHandleFunc("releaseCompilationUnit");
 90                 this.compilationUnitOK_MPtr = backendBridge.ffiLib.booleanHandleFunc("compilationUnitOK");
 91                 this.getKernel_MPtr = backendBridge.ffiLib.longHandleIntAddressFunc("getKernel");
 92             }
 93 
 94             void release() {
 95                 this.releaseCompilationUnit_MPtr.invoke(handle);
 96             }
 97 
 98             boolean ok() {
 99                 return this.compilationUnitOK_MPtr.invoke(handle);
100             }
101 
102             public KernelBridge getKernel(String kernelName) {
103                 KernelBridge kernelBridge = kernels.computeIfAbsent(kernelName, _ ->
104                         new KernelBridge(this, kernelName,
105                                 getKernel_MPtr.invoke(handle, kernelName.length(), Arena.global().allocateFrom(kernelName)))
106                 );
107                 return kernelBridge;
108             }
109 
110 
111         }
112 
113         final FFILib ffiLib;
114         final long handle;
115 
116         final Map<Long, CompilationUnitBridge> compilationUnits = new HashMap<>();
117         final FFILib.LongHandleIntMethodPtr getBackend_MPtr;
118         final FFILib.LongHandleIntAddressMethodPtr compile_MPtr;
119         final FFILib.VoidHandleMethodPtr computeStart_MPtr;
120         final FFILib.VoidHandleMethodPtr computeEnd_MPtr;
121         final FFILib.VoidAddressMethodPtr dumpArgArray_MPtr;
122 /*
123   final FFILib.LongIntMethodPtr getBackend_MPtr;
124     getBackend_MPtr = ffiLib.longIntFunc("getBackend");
125     public long getBackend(int configBits) {
126         return backendBridge.handle = getBackend_MPtr.invoke(configBits);
127     }
128 
129  */
130         final FFILib.VoidHandleMethodPtr info_MPtr;
131         final FFILib.BooleanHandleAddressLongMethodPtr getBufferFromDeviceIfDirty_MPtr;
132         BackendBridge(FFILib ffiLib, Config config) {
133             this.ffiLib = ffiLib;
134             this.getBackend_MPtr = ffiLib.longHandleIntFunc("getBackend");
135             if (this.getBackend_MPtr.mh == null) {
136                 throw new RuntimeException("No getBackend()");
137             }
138             this.handle = getBackend(config.bits());
139             this.compile_MPtr = ffiLib.longHandleIntAddressFunc("compile");
140             this.dumpArgArray_MPtr = ffiLib.voidAddressFunc("dumpArgArray");
141             this.info_MPtr = ffiLib.voidHandleFunc("info");
142             this.computeStart_MPtr = ffiLib.voidHandleFunc("computeStart");
143             this.computeEnd_MPtr = ffiLib.voidHandleFunc("computeEnd");
144             this.getBufferFromDeviceIfDirty_MPtr = ffiLib.booleanHandleAddressLongFunc("getBufferFromDeviceIfDirty");
145         }
146 
147 
148         void release() {
149 
150         }
151 
152         public long getBackend(int configBits) {
153             return getBackend_MPtr.invoke(configBits);
154         }
155 
156         private CompilationUnitBridge compilationUnit(long handle, String source) {
157             return compilationUnits.computeIfAbsent(handle, _ ->
158                     new CompilationUnitBridge(this, handle, source)
159             );
160         }
161 
162         public CompilationUnitBridge compile(String source) {
163             var compilationUnitHandle = compile_MPtr.invoke(handle, source.length(), Arena.global().allocateFrom(source));
164             return compilationUnit(compilationUnitHandle, source);
165         }
166 
167         public Buffer getBufferFromDeviceIfDirty(Buffer buffer) {
168             MemorySegment memorySegment = Buffer.getMemorySegment(buffer);
169             boolean ok = getBufferFromDeviceIfDirty_MPtr.invoke(handle, memorySegment, memorySegment.byteSize());
170             if (!ok) {
171                 throw new IllegalStateException("Failed to get buffer from backend");
172             }
173             return buffer;
174 
175         }
176 
177         public void computeStart() {
178             computeStart_MPtr.invoke(handle);
179         }
180 
181         public void computeEnd() {
182             computeEnd_MPtr.invoke(handle);
183         }
184 
185         public void info() {
186             info_MPtr.invoke(handle);
187         }
188 
189         public void dumpArgArray(ArgArray argArray) {
190             dumpArgArray_MPtr.invoke(Buffer.getMemorySegment(argArray));
191         }
192 
193 
194     }
195 
196 
197 
198     public final FFILib ffiLib;
199     public final BackendBridge backendBridge;
200 
201     public FFIBackendDriver(String libName, Config config) {
202         this.ffiLib = new FFILib(libName);
203         this.config = config;
204         this.backendBridge = new BackendBridge(ffiLib, config);
205 
206     }
207 
208 
209 }