1 /*
  2  * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 package hat.backend.ffi;
 26 
 27 
 28 import hat.Config;
 29 import hat.backend.Backend;
 30 import hat.buffer.ArgArray;
 31 import optkl.ifacemapper.MappableIface;
 32 
 33 import java.lang.foreign.Arena;
 34 import java.lang.foreign.MemorySegment;
 35 import java.lang.invoke.MethodHandles;
 36 import java.util.HashMap;
 37 import java.util.Map;
 38 
 39 public abstract class FFIBackendDriver extends Backend {
 40     public boolean isAvailable() {
 41         return ffiLib.available;
 42     }
 43 
 44     public static class BackendBridge {
 45         // CUDA this combines Device+Stream+Context
 46         // OpenCL this combines Platform+Device+Queue+Context
 47         public static class CompilationUnitBridge {
 48             // CUDA calls this a Module
 49             // OpenCL calls this a program
 50             public static class KernelBridge {
 51                 // CUDA calls this a Function
 52                 // OpenCL calls this a Program
 53                 CompilationUnitBridge compilationUnitBridge;
 54                 long handle;
 55                 final FFILib.VoidHandleMethodPtr releaseKernel_MPtr;
 56                 String name;
 57                 final FFILib.LongHandleLongAddressMethodPtr ndrange_MPtr;
 58                 KernelBridge(CompilationUnitBridge compilationUnitBridge, String name, long handle) {
 59                     this.compilationUnitBridge = compilationUnitBridge;
 60                     this.handle = handle;
 61                     this.releaseKernel_MPtr = compilationUnitBridge.backendBridge.ffiLib.voidHandleFunc("releaseKernel");
 62                     this.ndrange_MPtr = compilationUnitBridge.backendBridge.ffiLib.longHandleLongAddressFunc("ndrange");
 63                     this.name = name;
 64                 }
 65                 public void ndRange(ArgArray argArray) {
 66                     this.ndrange_MPtr.invoke(handle, MappableIface.getMemorySegment(argArray));
 67                 }
 68                 void release() {
 69                     releaseKernel_MPtr.invoke(handle);
 70                 }
 71             }
 72 
 73             BackendBridge backendBridge;
 74             String source;
 75             final FFILib.VoidHandleMethodPtr releaseCompilationUnit_MPtr;
 76             final FFILib.BooleanHandleMethodPtr compilationUnitOK_MPtr;
 77             final FFILib.LongHandleIntAddressMethodPtr getKernel_MPtr;
 78             long handle;
 79             Map<String, KernelBridge> kernels = new HashMap<>();
 80 
 81             CompilationUnitBridge(BackendBridge backendBridge, long handle, String source) {
 82                 this.backendBridge = backendBridge;
 83                 this.handle = handle;
 84                 this.source = source;
 85                 this.releaseCompilationUnit_MPtr = backendBridge.ffiLib.voidHandleFunc("releaseCompilationUnit");
 86                 this.compilationUnitOK_MPtr = backendBridge.ffiLib.booleanHandleFunc("compilationUnitOK");
 87                 this.getKernel_MPtr = backendBridge.ffiLib.longHandleIntAddressFunc("getKernel");
 88             }
 89             void release() {
 90                 this.releaseCompilationUnit_MPtr.invoke(handle);
 91             }
 92             boolean ok() {
 93                 return this.compilationUnitOK_MPtr.invoke(handle);
 94             }
 95             public KernelBridge getKernel(String kernelName) {
 96                 return kernels.computeIfAbsent(kernelName, _ ->
 97                         new KernelBridge(this, kernelName,
 98                                 getKernel_MPtr.invoke(handle, kernelName.length(), Arena.global().allocateFrom(kernelName)))
 99                 );
100             }
101         }
102 
103         final FFILib ffiLib;
104         final long handle;
105 
106         final Map<Long, CompilationUnitBridge> compilationUnits = new HashMap<>();
107         final FFILib.LongHandleIntMethodPtr getBackend_MPtr;
108         final FFILib.LongHandleIntAddressMethodPtr compile_MPtr;
109         final FFILib.VoidHandleMethodPtr computeStart_MPtr;
110         final FFILib.VoidHandleMethodPtr computeEnd_MPtr;
111 
112         final FFILib.VoidHandleMethodPtr showDeviceInfo_MPtr;
113         final FFILib.BooleanHandleAddressLongMethodPtr getBufferFromDeviceIfDirty_MPtr;
114         BackendBridge(FFILib ffiLib, Config config) {
115             this.ffiLib = ffiLib;
116             this.getBackend_MPtr = ffiLib.longHandleIntFunc("getBackend");
117             if (this.getBackend_MPtr.mh == null) {
118                 throw new RuntimeException("No getBackend()");
119             }
120             this.handle = getBackend(config.bits());
121             this.compile_MPtr = ffiLib.longHandleIntAddressFunc("compile");
122             this.showDeviceInfo_MPtr = ffiLib.voidHandleFunc("showDeviceInfo");
123             this.computeStart_MPtr = ffiLib.voidHandleFunc("computeStart");
124             this.computeEnd_MPtr = ffiLib.voidHandleFunc("computeEnd");
125             this.getBufferFromDeviceIfDirty_MPtr = ffiLib.booleanHandleAddressLongFunc("getBufferFromDeviceIfDirty");
126         }
127 
128         void release() {}
129 
130         public long getBackend(int configBits) {
131             return getBackend_MPtr.invoke(configBits);
132         }
133 
134         private CompilationUnitBridge compilationUnit(long handle, String source) {
135             return compilationUnits.computeIfAbsent(handle, _ ->
136                     new CompilationUnitBridge(this, handle, source)
137             );
138         }
139 
140         public CompilationUnitBridge compile(String source) {
141             var compilationUnitHandle = compile_MPtr.invoke(handle, source.length(), Arena.global().allocateFrom(source));
142             return compilationUnit(compilationUnitHandle, source);
143         }
144 
145         public MappableIface getBufferFromDeviceIfDirty(MappableIface buffer) {
146             MemorySegment memorySegment = MappableIface.getMemorySegment(buffer);
147             if (!getBufferFromDeviceIfDirty_MPtr.invoke(handle, memorySegment, memorySegment.byteSize())){
148                 throw new IllegalStateException("Failed to get buffer from backend");
149             }
150             return buffer;
151 
152         }
153         public void computeStart() {
154             computeStart_MPtr.invoke(handle);
155         }
156         public void computeEnd() {
157             computeEnd_MPtr.invoke(handle);
158         }
159         public void showDeviceInfo() {
160             showDeviceInfo_MPtr.invoke(handle);
161         }
162     }
163 
164     public final FFILib ffiLib;
165     public final BackendBridge backendBridge;
166 
167     public FFIBackendDriver(Arena arena, MethodHandles.Lookup lookup,String libName, Config config) {
168         super(arena,lookup,config);
169         this.ffiLib = new FFILib(libName);
170         this.backendBridge = new BackendBridge(ffiLib, config);
171     }
172 
173 }