1 /*
  2  * Copyright (c) 2023, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 package jdk.internal.foreign.abi.fallback;
 26 
 27 import jdk.internal.foreign.AbstractMemorySegmentImpl;
 28 import jdk.internal.foreign.MemorySessionImpl;
 29 import jdk.internal.foreign.abi.AbstractLinker;
 30 import jdk.internal.foreign.abi.CapturableState;
 31 import jdk.internal.foreign.abi.LinkerOptions;
 32 import jdk.internal.foreign.abi.SharedUtils;
 33 
 34 import java.lang.foreign.AddressLayout;
 35 import java.lang.foreign.Arena;
 36 import java.lang.foreign.FunctionDescriptor;
 37 import java.lang.foreign.GroupLayout;
 38 import java.lang.foreign.MemoryLayout;
 39 import java.lang.foreign.MemorySegment;
 40 import java.lang.foreign.SegmentAllocator;
 41 import java.lang.foreign.ValueLayout;
 42 import java.lang.invoke.MethodHandle;
 43 import java.lang.invoke.MethodHandles;
 44 import java.lang.invoke.MethodType;
 45 import java.lang.ref.Reference;
 46 import java.nio.ByteOrder;
 47 import java.util.ArrayList;
 48 import java.util.List;
 49 import java.util.function.Consumer;
 50 
 51 import static java.lang.foreign.ValueLayout.ADDRESS;
 52 import static java.lang.invoke.MethodHandles.foldArguments;
 53 
 54 public final class FallbackLinker extends AbstractLinker {
 55 
 56     private static final MethodHandle MH_DO_DOWNCALL;
 57     private static final MethodHandle MH_DO_UPCALL;
 58 
 59     static {
 60         try {
 61             MH_DO_DOWNCALL = MethodHandles.lookup().findStatic(FallbackLinker.class, "doDowncall",
 62                     MethodType.methodType(Object.class, SegmentAllocator.class, Object[].class, FallbackLinker.DowncallData.class));
 63             MH_DO_UPCALL = MethodHandles.lookup().findStatic(FallbackLinker.class, "doUpcall",
 64                     MethodType.methodType(void.class, MethodHandle.class, MemorySegment.class, MemorySegment.class, UpcallData.class));
 65         } catch (ReflectiveOperationException e) {
 66             throw new ExceptionInInitializerError(e);
 67         }
 68     }
 69 
 70     public static FallbackLinker getInstance() {
 71         class Holder {
 72             static final FallbackLinker INSTANCE = new FallbackLinker();
 73         }
 74         return Holder.INSTANCE;
 75     }
 76 
 77     public static boolean isSupported() {
 78         return LibFallback.SUPPORTED;
 79     }
 80 
 81     @Override
 82     protected MethodHandle arrangeDowncall(MethodType inferredMethodType, FunctionDescriptor function, LinkerOptions options) {
 83         MemorySegment cif = makeCif(inferredMethodType, function, options, Arena.ofAuto());
 84 
 85         int capturedStateMask = options.capturedCallState()
 86                 .mapToInt(CapturableState::mask)
 87                 .reduce(0, (a, b) -> a | b);
 88         DowncallData invData = new DowncallData(cif, function.returnLayout().orElse(null),
 89                 function.argumentLayouts(), capturedStateMask);
 90 
 91         MethodHandle target = MethodHandles.insertArguments(MH_DO_DOWNCALL, 2, invData);
 92 
 93         int leadingArguments = 1; // address
 94         MethodType type = inferredMethodType.insertParameterTypes(0, SegmentAllocator.class, MemorySegment.class);
 95         if (capturedStateMask != 0) {
 96             leadingArguments++;
 97             type = type.insertParameterTypes(2, MemorySegment.class);
 98         }
 99         target = target.asCollector(1, Object[].class, inferredMethodType.parameterCount() + leadingArguments);
100         target = target.asType(type);
101         target = foldArguments(target, 1, SharedUtils.MH_CHECK_SYMBOL);
102         target = SharedUtils.swapArguments(target, 0, 1); // normalize parameter order
103 
104         return target;
105     }
106 
107     @Override
108     protected UpcallStubFactory arrangeUpcall(MethodType targetType, FunctionDescriptor function, LinkerOptions options) {
109         MemorySegment cif = makeCif(targetType, function, options, Arena.ofAuto());
110 
111         UpcallData invData = new UpcallData(function.returnLayout().orElse(null), function.argumentLayouts(), cif);
112         MethodHandle doUpcallMH = MethodHandles.insertArguments(MH_DO_UPCALL, 3, invData);
113 
114         return (target, scope) -> {
115             target = MethodHandles.insertArguments(doUpcallMH, 0, target);
116             return LibFallback.createClosure(cif, target, scope);
117         };
118     }
119 
120     @Override
121     protected ByteOrder linkerByteOrder() {
122         return ByteOrder.nativeOrder();
123     }
124 
125     private static MemorySegment makeCif(MethodType methodType, FunctionDescriptor function, LinkerOptions options, Arena scope) {
126         FFIABI abi = FFIABI.DEFAULT;
127 
128         MemorySegment argTypes = scope.allocate(function.argumentLayouts().size() * ADDRESS.byteSize());
129         List<MemoryLayout> argLayouts = function.argumentLayouts();
130         for (int i = 0; i < argLayouts.size(); i++) {
131             MemoryLayout layout = argLayouts.get(i);
132             argTypes.setAtIndex(ADDRESS, i, FFIType.toFFIType(layout, abi, scope));
133         }
134 
135         MemorySegment returnType = methodType.returnType() != void.class
136                 ? FFIType.toFFIType(function.returnLayout().orElseThrow(), abi, scope)
137                 : LibFallback.voidType();
138 
139         if (options.isVariadicFunction()) {
140             int numFixedArgs = options.firstVariadicArgIndex();
141             int numTotalArgs = argLayouts.size();
142             return LibFallback.prepCifVar(returnType, numFixedArgs, numTotalArgs, argTypes, abi, scope);
143         } else {
144             return LibFallback.prepCif(returnType, argLayouts.size(), argTypes, abi, scope);
145         }
146     }
147 
148     private record DowncallData(MemorySegment cif, MemoryLayout returnLayout, List<MemoryLayout> argLayouts,
149                                 int capturedStateMask) {}
150 
151     private static Object doDowncall(SegmentAllocator returnAllocator, Object[] args, DowncallData invData) {
152         List<MemorySessionImpl> acquiredSessions = new ArrayList<>();
153         try (Arena arena = Arena.ofConfined()) {
154             int argStart = 0;
155 
156             MemorySegment target = (MemorySegment) args[argStart++];
157             MemorySessionImpl targetImpl = ((AbstractMemorySegmentImpl) target).sessionImpl();
158             targetImpl.acquire0();
159             acquiredSessions.add(targetImpl);
160 
161             MemorySegment capturedState = null;
162             if (invData.capturedStateMask() != 0) {
163                 capturedState = SharedUtils.checkCaptureSegment((MemorySegment) args[argStart++]);
164                 MemorySessionImpl capturedStateImpl = ((AbstractMemorySegmentImpl) capturedState).sessionImpl();
165                 capturedStateImpl.acquire0();
166                 acquiredSessions.add(capturedStateImpl);
167             }
168 
169             List<MemoryLayout> argLayouts = invData.argLayouts();
170             MemorySegment argPtrs = arena.allocate(argLayouts.size() * ADDRESS.byteSize());
171             for (int i = 0; i < argLayouts.size(); i++) {
172                 Object arg = args[argStart + i];
173                 MemoryLayout layout = argLayouts.get(i);
174                 MemorySegment argSeg = arena.allocate(layout);
175                 writeValue(arg, layout, argSeg, addr -> {
176                     MemorySessionImpl sessionImpl = ((AbstractMemorySegmentImpl) addr).sessionImpl();
177                     sessionImpl.acquire0();
178                     acquiredSessions.add(sessionImpl);
179                 });
180                 argPtrs.setAtIndex(ADDRESS, i, argSeg);
181             }
182 
183             MemorySegment retSeg = null;
184             if (invData.returnLayout() != null) {
185                 retSeg = (invData.returnLayout() instanceof GroupLayout ? returnAllocator : arena).allocate(invData.returnLayout);
186             }
187 
188             LibFallback.doDowncall(invData.cif, target, retSeg, argPtrs, capturedState, invData.capturedStateMask());
189 
190             Reference.reachabilityFence(invData.cif());
191 
192             return readValue(retSeg, invData.returnLayout());
193         } finally {
194             for (MemorySessionImpl session : acquiredSessions) {
195                 session.release0();
196             }
197         }
198     }
199 
200     // note that cif is not used, but we store it here to keep it alive
201     private record UpcallData(MemoryLayout returnLayout, List<MemoryLayout> argLayouts, MemorySegment cif) {}
202 
203     private static void doUpcall(MethodHandle target, MemorySegment retPtr, MemorySegment argPtrs, UpcallData data) throws Throwable {
204         List<MemoryLayout> argLayouts = data.argLayouts();
205         int numArgs = argLayouts.size();
206         MemoryLayout retLayout = data.returnLayout();
207         try (Arena upcallArena = Arena.ofConfined()) {
208             MemorySegment argsSeg = argPtrs.reinterpret(numArgs * ADDRESS.byteSize(), upcallArena, null);
209             MemorySegment retSeg = retLayout != null
210                 ? retPtr.reinterpret(retLayout.byteSize(), upcallArena, null)
211                 : null;
212 
213             Object[] args = new Object[numArgs];
214             for (int i = 0; i < numArgs; i++) {
215                 MemoryLayout argLayout = argLayouts.get(i);
216                 MemorySegment argPtr = argsSeg.getAtIndex(ADDRESS, i)
217                         .reinterpret(argLayout.byteSize(), upcallArena, null);
218                 args[i] = readValue(argPtr, argLayout);
219             }
220 
221             Object result = target.invokeWithArguments(args);
222 
223             writeValue(result, data.returnLayout(), retSeg);
224         }
225     }
226 
227     // where
228     private static void writeValue(Object arg, MemoryLayout layout, MemorySegment argSeg) {
229         writeValue(arg, layout, argSeg, addr -> {});
230     }
231 
232     private static void writeValue(Object arg, MemoryLayout layout, MemorySegment argSeg,
233                                    Consumer<MemorySegment> acquireCallback) {
234         if (layout instanceof ValueLayout.OfBoolean bl) {
235             argSeg.set(bl, 0, (Boolean) arg);
236         } else if (layout instanceof ValueLayout.OfByte bl) {
237             argSeg.set(bl, 0, (Byte) arg);
238         } else if (layout instanceof ValueLayout.OfShort sl) {
239             argSeg.set(sl, 0, (Short) arg);
240         } else if (layout instanceof ValueLayout.OfChar cl) {
241             argSeg.set(cl, 0, (Character) arg);
242         } else if (layout instanceof ValueLayout.OfInt il) {
243             argSeg.set(il, 0, (Integer) arg);
244         } else if (layout instanceof ValueLayout.OfLong ll) {
245             argSeg.set(ll, 0, (Long) arg);
246         } else if (layout instanceof ValueLayout.OfFloat fl) {
247             argSeg.set(fl, 0, (Float) arg);
248         } else if (layout instanceof ValueLayout.OfDouble dl) {
249             argSeg.set(dl, 0, (Double) arg);
250         } else if (layout instanceof AddressLayout al) {
251             MemorySegment addrArg = (MemorySegment) arg;
252             acquireCallback.accept(addrArg);
253             argSeg.set(al, 0, addrArg);
254         } else if (layout instanceof GroupLayout) {
255             MemorySegment.copy((MemorySegment) arg, 0, argSeg, 0, argSeg.byteSize()); // by-value struct
256         } else {
257             assert layout == null;
258         }
259     }
260 
261     private static Object readValue(MemorySegment seg, MemoryLayout layout) {
262         if (layout instanceof ValueLayout.OfBoolean bl) {
263             return seg.get(bl, 0);
264         } else if (layout instanceof ValueLayout.OfByte bl) {
265             return seg.get(bl, 0);
266         } else if (layout instanceof ValueLayout.OfShort sl) {
267             return seg.get(sl, 0);
268         } else if (layout instanceof ValueLayout.OfChar cl) {
269             return seg.get(cl, 0);
270         } else if (layout instanceof ValueLayout.OfInt il) {
271             return seg.get(il, 0);
272         } else if (layout instanceof ValueLayout.OfLong ll) {
273             return seg.get(ll, 0);
274         } else if (layout instanceof ValueLayout.OfFloat fl) {
275             return seg.get(fl, 0);
276         } else if (layout instanceof ValueLayout.OfDouble dl) {
277             return seg.get(dl, 0);
278         } else if (layout instanceof AddressLayout al) {
279             return seg.get(al, 0);
280         } else if (layout instanceof GroupLayout) {
281             return seg;
282         }
283         assert layout == null;
284         return null;
285     }
286 }