1 /*
  2  * Copyright (c) 2023, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 package jdk.internal.foreign.abi.fallback;
 26 
 27 import jdk.internal.foreign.AbstractMemorySegmentImpl;
 28 import jdk.internal.foreign.MemorySessionImpl;
 29 import jdk.internal.foreign.abi.AbstractLinker;
 30 import jdk.internal.foreign.abi.CapturableState;
 31 import jdk.internal.foreign.abi.LinkerOptions;
 32 import jdk.internal.foreign.abi.SharedUtils;
 33 
 34 import java.lang.foreign.AddressLayout;
 35 import java.lang.foreign.Arena;
 36 import java.lang.foreign.FunctionDescriptor;
 37 import java.lang.foreign.GroupLayout;
 38 import java.lang.foreign.MemoryLayout;
 39 import java.lang.foreign.MemorySegment;
 40 import java.lang.foreign.SegmentAllocator;
 41 import java.lang.foreign.ValueLayout;
 42 import java.lang.invoke.MethodHandle;
 43 import java.lang.invoke.MethodHandles;
 44 import java.lang.invoke.MethodType;
 45 import java.lang.ref.Reference;
 46 import java.nio.ByteOrder;
 47 import java.util.ArrayList;
 48 import java.util.HashMap;
 49 import java.util.List;
 50 import java.util.Map;
 51 import java.util.function.Consumer;
 52 
 53 import static java.lang.foreign.ValueLayout.ADDRESS;
 54 import static java.lang.foreign.ValueLayout.JAVA_BOOLEAN;
 55 import static java.lang.foreign.ValueLayout.JAVA_BYTE;
 56 import static java.lang.foreign.ValueLayout.JAVA_CHAR;
 57 import static java.lang.foreign.ValueLayout.JAVA_DOUBLE;
 58 import static java.lang.foreign.ValueLayout.JAVA_FLOAT;
 59 import static java.lang.foreign.ValueLayout.JAVA_INT;
 60 import static java.lang.foreign.ValueLayout.JAVA_LONG;
 61 import static java.lang.foreign.ValueLayout.JAVA_SHORT;
 62 import static java.lang.invoke.MethodHandles.foldArguments;
 63 
 64 public final class FallbackLinker extends AbstractLinker {
 65 
 66     private static final MethodHandle MH_DO_DOWNCALL;
 67     private static final MethodHandle MH_DO_UPCALL;
 68 
 69     static {
 70         try {
 71             MH_DO_DOWNCALL = MethodHandles.lookup().findStatic(FallbackLinker.class, "doDowncall",
 72                     MethodType.methodType(Object.class, SegmentAllocator.class, Object[].class, FallbackLinker.DowncallData.class));
 73             MH_DO_UPCALL = MethodHandles.lookup().findStatic(FallbackLinker.class, "doUpcall",
 74                     MethodType.methodType(void.class, MethodHandle.class, MemorySegment.class, MemorySegment.class, UpcallData.class));
 75         } catch (ReflectiveOperationException e) {
 76             throw new ExceptionInInitializerError(e);
 77         }
 78     }
 79 
 80     public static FallbackLinker getInstance() {
 81         class Holder {
 82             static final FallbackLinker INSTANCE = new FallbackLinker();
 83         }
 84         return Holder.INSTANCE;
 85     }
 86 
 87     public static boolean isSupported() {
 88         return LibFallback.SUPPORTED;
 89     }
 90 
 91     @Override
 92     protected MethodHandle arrangeDowncall(MethodType inferredMethodType, FunctionDescriptor function, LinkerOptions options) {
 93         MemorySegment cif = makeCif(inferredMethodType, function, options, Arena.ofAuto());
 94 
 95         int capturedStateMask = options.capturedCallState()
 96                 .mapToInt(CapturableState::mask)
 97                 .reduce(0, (a, b) -> a | b);
 98         DowncallData invData = new DowncallData(cif, function.returnLayout().orElse(null),
 99                 function.argumentLayouts(), capturedStateMask);
100 
101         MethodHandle target = MethodHandles.insertArguments(MH_DO_DOWNCALL, 2, invData);
102 
103         int leadingArguments = 1; // address
104         MethodType type = inferredMethodType.insertParameterTypes(0, SegmentAllocator.class, MemorySegment.class);
105         if (capturedStateMask != 0) {
106             leadingArguments++;
107             type = type.insertParameterTypes(2, MemorySegment.class);
108         }
109         target = target.asCollector(1, Object[].class, inferredMethodType.parameterCount() + leadingArguments);
110         target = target.asType(type);
111         target = foldArguments(target, 1, SharedUtils.MH_CHECK_SYMBOL);
112         target = SharedUtils.swapArguments(target, 0, 1); // normalize parameter order
113 
114         return target;
115     }
116 
117     @Override
118     protected UpcallStubFactory arrangeUpcall(MethodType targetType, FunctionDescriptor function, LinkerOptions options) {
119         MemorySegment cif = makeCif(targetType, function, options, Arena.ofAuto());
120 
121         UpcallData invData = new UpcallData(function.returnLayout().orElse(null), function.argumentLayouts(), cif);
122         MethodHandle doUpcallMH = MethodHandles.insertArguments(MH_DO_UPCALL, 3, invData);
123 
124         return (target, scope) -> {
125             target = MethodHandles.insertArguments(doUpcallMH, 0, target);
126             return LibFallback.createClosure(cif, target, scope);
127         };
128     }
129 
130     @Override
131     protected ByteOrder linkerByteOrder() {
132         return ByteOrder.nativeOrder();
133     }
134 
135     private static MemorySegment makeCif(MethodType methodType, FunctionDescriptor function, LinkerOptions options, Arena scope) {
136         FFIABI abi = FFIABI.DEFAULT;
137 
138         MemorySegment argTypes = scope.allocate(function.argumentLayouts().size() * ADDRESS.byteSize());
139         List<MemoryLayout> argLayouts = function.argumentLayouts();
140         for (int i = 0; i < argLayouts.size(); i++) {
141             MemoryLayout layout = argLayouts.get(i);
142             argTypes.setAtIndex(ADDRESS, i, FFIType.toFFIType(layout, abi, scope));
143         }
144 
145         MemorySegment returnType = methodType.returnType() != void.class
146                 ? FFIType.toFFIType(function.returnLayout().orElseThrow(), abi, scope)
147                 : LibFallback.voidType();
148 
149         if (options.isVariadicFunction()) {
150             int numFixedArgs = options.firstVariadicArgIndex();
151             int numTotalArgs = argLayouts.size();
152             return LibFallback.prepCifVar(returnType, numFixedArgs, numTotalArgs, argTypes, abi, scope);
153         } else {
154             return LibFallback.prepCif(returnType, argLayouts.size(), argTypes, abi, scope);
155         }
156     }
157 
158     private record DowncallData(MemorySegment cif, MemoryLayout returnLayout, List<MemoryLayout> argLayouts,
159                                 int capturedStateMask) {}
160 
161     private static Object doDowncall(SegmentAllocator returnAllocator, Object[] args, DowncallData invData) {
162         List<MemorySessionImpl> acquiredSessions = new ArrayList<>();
163         try (Arena arena = Arena.ofConfined()) {
164             int argStart = 0;
165 
166             MemorySegment target = (MemorySegment) args[argStart++];
167             MemorySessionImpl targetImpl = ((AbstractMemorySegmentImpl) target).sessionImpl();
168             targetImpl.acquire0();
169             acquiredSessions.add(targetImpl);
170 
171             MemorySegment capturedState = null;
172             if (invData.capturedStateMask() != 0) {
173                 capturedState = SharedUtils.checkCaptureSegment((MemorySegment) args[argStart++]);
174                 MemorySessionImpl capturedStateImpl = ((AbstractMemorySegmentImpl) capturedState).sessionImpl();
175                 capturedStateImpl.acquire0();
176                 acquiredSessions.add(capturedStateImpl);
177             }
178 
179             List<MemoryLayout> argLayouts = invData.argLayouts();
180             MemorySegment argPtrs = arena.allocate(argLayouts.size() * ADDRESS.byteSize());
181             for (int i = 0; i < argLayouts.size(); i++) {
182                 Object arg = args[argStart + i];
183                 MemoryLayout layout = argLayouts.get(i);
184                 MemorySegment argSeg = arena.allocate(layout);
185                 writeValue(arg, layout, argSeg, addr -> {
186                     MemorySessionImpl sessionImpl = ((AbstractMemorySegmentImpl) addr).sessionImpl();
187                     sessionImpl.acquire0();
188                     acquiredSessions.add(sessionImpl);
189                 });
190                 argPtrs.setAtIndex(ADDRESS, i, argSeg);
191             }
192 
193             MemorySegment retSeg = null;
194             if (invData.returnLayout() != null) {
195                 retSeg = (invData.returnLayout() instanceof GroupLayout ? returnAllocator : arena).allocate(invData.returnLayout);
196             }
197 
198             LibFallback.doDowncall(invData.cif, target, retSeg, argPtrs, capturedState, invData.capturedStateMask());
199 
200             Reference.reachabilityFence(invData.cif());
201 
202             return readValue(retSeg, invData.returnLayout());
203         } finally {
204             for (MemorySessionImpl session : acquiredSessions) {
205                 session.release0();
206             }
207         }
208     }
209 
210     // note that cif is not used, but we store it here to keep it alive
211     private record UpcallData(MemoryLayout returnLayout, List<MemoryLayout> argLayouts, MemorySegment cif) {}
212 
213     private static void doUpcall(MethodHandle target, MemorySegment retPtr, MemorySegment argPtrs, UpcallData data) throws Throwable {
214         List<MemoryLayout> argLayouts = data.argLayouts();
215         int numArgs = argLayouts.size();
216         MemoryLayout retLayout = data.returnLayout();
217         try (Arena upcallArena = Arena.ofConfined()) {
218             MemorySegment argsSeg = argPtrs.reinterpret(numArgs * ADDRESS.byteSize(), upcallArena, null);
219             MemorySegment retSeg = retLayout != null
220                 ? retPtr.reinterpret(retLayout.byteSize(), upcallArena, null)
221                 : null;
222 
223             Object[] args = new Object[numArgs];
224             for (int i = 0; i < numArgs; i++) {
225                 MemoryLayout argLayout = argLayouts.get(i);
226                 MemorySegment argPtr = argsSeg.getAtIndex(ADDRESS, i)
227                         .reinterpret(argLayout.byteSize(), upcallArena, null);
228                 args[i] = readValue(argPtr, argLayout);
229             }
230 
231             Object result = target.invokeWithArguments(args);
232 
233             writeValue(result, data.returnLayout(), retSeg);
234         }
235     }
236 
237     // where
238     private static void writeValue(Object arg, MemoryLayout layout, MemorySegment argSeg) {
239         writeValue(arg, layout, argSeg, addr -> {});
240     }
241 
242     private static void writeValue(Object arg, MemoryLayout layout, MemorySegment argSeg,
243                                    Consumer<MemorySegment> acquireCallback) {
244         switch (layout) {
245             case ValueLayout.OfBoolean bl -> argSeg.set(bl, 0, (Boolean) arg);
246             case ValueLayout.OfByte    bl -> argSeg.set(bl, 0, (Byte) arg);
247             case ValueLayout.OfShort   sl -> argSeg.set(sl, 0, (Short) arg);
248             case ValueLayout.OfChar    cl -> argSeg.set(cl, 0, (Character) arg);
249             case ValueLayout.OfInt     il -> argSeg.set(il, 0, (Integer) arg);
250             case ValueLayout.OfLong    ll -> argSeg.set(ll, 0, (Long) arg);
251             case ValueLayout.OfFloat   fl -> argSeg.set(fl, 0, (Float) arg);
252             case ValueLayout.OfDouble  dl -> argSeg.set(dl, 0, (Double) arg);
253             case AddressLayout         al -> {
254                 MemorySegment addrArg = (MemorySegment) arg;
255                 acquireCallback.accept(addrArg);
256                 argSeg.set(al, 0, addrArg);
257             }
258             case GroupLayout           __ ->
259                     MemorySegment.copy((MemorySegment) arg, 0, argSeg, 0, argSeg.byteSize()); // by-value struct
260             case null, default -> {
261                 assert layout == null;
262             }
263         }
264     }
265 
266     private static Object readValue(MemorySegment seg, MemoryLayout layout) {
267         if (layout instanceof ValueLayout.OfBoolean bl) {
268             return seg.get(bl, 0);
269         } else if (layout instanceof ValueLayout.OfByte bl) {
270             return seg.get(bl, 0);
271         } else if (layout instanceof ValueLayout.OfShort sl) {
272             return seg.get(sl, 0);
273         } else if (layout instanceof ValueLayout.OfChar cl) {
274             return seg.get(cl, 0);
275         } else if (layout instanceof ValueLayout.OfInt il) {
276             return seg.get(il, 0);
277         } else if (layout instanceof ValueLayout.OfLong ll) {
278             return seg.get(ll, 0);
279         } else if (layout instanceof ValueLayout.OfFloat fl) {
280             return seg.get(fl, 0);
281         } else if (layout instanceof ValueLayout.OfDouble dl) {
282             return seg.get(dl, 0);
283         } else if (layout instanceof AddressLayout al) {
284             return seg.get(al, 0);
285         } else if (layout instanceof GroupLayout) {
286             return seg;
287         }
288         assert layout == null;
289         return null;
290     }
291 
292     @Override
293     public Map<String, MemoryLayout> canonicalLayouts() {
294         return CANONICAL_LAYOUTS;
295     }
296 
297     static final Map<String, MemoryLayout> CANONICAL_LAYOUTS = new HashMap<>();
298 
299     static {
300         CANONICAL_LAYOUTS.put("bool", JAVA_BOOLEAN);
301         CANONICAL_LAYOUTS.put("char", JAVA_BYTE);
302         CANONICAL_LAYOUTS.put("float", JAVA_FLOAT);
303         CANONICAL_LAYOUTS.put("double", JAVA_DOUBLE);
304         CANONICAL_LAYOUTS.put("long long", JAVA_LONG);
305         CANONICAL_LAYOUTS.put("void*", ADDRESS);
306         // platform-dependent sizes
307         CANONICAL_LAYOUTS.put("size_t", FFIType.SIZE_T);
308         CANONICAL_LAYOUTS.put("short", FFIType.layoutFor(LibFallback.shortSize()));
309         CANONICAL_LAYOUTS.put("int", FFIType.layoutFor(LibFallback.intSize()));
310         CANONICAL_LAYOUTS.put("long", FFIType.layoutFor(LibFallback.longSize()));
311         int wchar_size = LibFallback.wcharSize();
312         if (wchar_size == 2) {
313             // prefer JAVA_CHAR
314             CANONICAL_LAYOUTS.put("wchar_t", JAVA_CHAR);
315         } else {
316             CANONICAL_LAYOUTS.put("wchar_t", FFIType.layoutFor(wchar_size));
317         }
318         // JNI types
319         CANONICAL_LAYOUTS.put("jboolean", JAVA_BOOLEAN);
320         CANONICAL_LAYOUTS.put("jchar", JAVA_CHAR);
321         CANONICAL_LAYOUTS.put("jbyte", JAVA_BYTE);
322         CANONICAL_LAYOUTS.put("jshort", JAVA_SHORT);
323         CANONICAL_LAYOUTS.put("jint", JAVA_INT);
324         CANONICAL_LAYOUTS.put("jlong", JAVA_LONG);
325         CANONICAL_LAYOUTS.put("jfloat", JAVA_FLOAT);
326         CANONICAL_LAYOUTS.put("jdouble", JAVA_DOUBLE);
327     }
328 }