1 /*
  2  * Copyright (c) 2020, 2021, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 package jdk.internal.foreign.abi;
 26 
 27 import jdk.incubator.foreign.Addressable;
 28 import jdk.incubator.foreign.MemoryAddress;
 29 import jdk.incubator.foreign.MemoryLayouts;
 30 import jdk.incubator.foreign.MemorySegment;
 31 import jdk.incubator.foreign.ResourceScope;
 32 import jdk.incubator.foreign.SegmentAllocator;
 33 import jdk.internal.access.JavaLangInvokeAccess;
 34 import jdk.internal.access.SharedSecrets;
 35 import jdk.internal.invoke.NativeEntryPoint;
 36 import jdk.internal.invoke.VMStorageProxy;
 37 import sun.security.action.GetPropertyAction;
 38 
 39 import java.lang.invoke.MethodHandle;
 40 import java.lang.invoke.MethodHandles;
 41 import java.lang.invoke.MethodType;
 42 import java.lang.invoke.VarHandle;
 43 import java.lang.ref.Reference;
 44 import java.util.Arrays;
 45 import java.util.List;
 46 import java.util.Map;
 47 import java.util.concurrent.ConcurrentHashMap;
 48 import java.util.stream.Stream;
 49 
 50 import static java.lang.invoke.MethodHandles.collectArguments;
 51 import static java.lang.invoke.MethodHandles.dropArguments;
 52 import static java.lang.invoke.MethodHandles.filterArguments;
 53 import static java.lang.invoke.MethodHandles.identity;
 54 import static java.lang.invoke.MethodHandles.insertArguments;
 55 import static java.lang.invoke.MethodType.methodType;
 56 import static sun.security.action.GetBooleanAction.privilegedGetProperty;
 57 
 58 /**
 59  * This class implements native call invocation through a so called 'universal adapter'. A universal adapter takes
 60  * an array of longs together with a call 'recipe', which is used to move the arguments in the right places as
 61  * expected by the system ABI.
 62  */
 63 public class ProgrammableInvoker {
 64     private static final boolean DEBUG =
 65         privilegedGetProperty("jdk.internal.foreign.ProgrammableInvoker.DEBUG");
 66     private static final boolean USE_SPEC = Boolean.parseBoolean(
 67         GetPropertyAction.privilegedGetProperty("jdk.internal.foreign.ProgrammableInvoker.USE_SPEC", "true"));
 68     private static final boolean USE_INTRINSICS = Boolean.parseBoolean(
 69         GetPropertyAction.privilegedGetProperty("jdk.internal.foreign.ProgrammableInvoker.USE_INTRINSICS", "true"));
 70 
 71     private static final JavaLangInvokeAccess JLIA = SharedSecrets.getJavaLangInvokeAccess();
 72 
 73     private static final VarHandle VH_LONG = MemoryLayouts.JAVA_LONG.varHandle(long.class);
 74 
 75     private static final MethodHandle MH_INVOKE_MOVES;
 76     private static final MethodHandle MH_INVOKE_INTERP_BINDINGS;
 77     private static final MethodHandle MH_ADDR_TO_LONG;
 78     private static final MethodHandle MH_WRAP_ALLOCATOR;
 79 
 80     private static final Map<ABIDescriptor, Long> adapterStubs = new ConcurrentHashMap<>();
 81 
 82     private static final MethodHandle EMPTY_OBJECT_ARRAY_HANDLE = MethodHandles.constant(Object[].class, new Object[0]);
 83 
 84     static {
 85         try {
 86             MethodHandles.Lookup lookup = MethodHandles.lookup();
 87             MH_INVOKE_MOVES = lookup.findVirtual(ProgrammableInvoker.class, "invokeMoves",
 88                     methodType(Object.class, long.class, Object[].class, Binding.VMStore[].class, Binding.VMLoad[].class));
 89             MH_INVOKE_INTERP_BINDINGS = lookup.findVirtual(ProgrammableInvoker.class, "invokeInterpBindings",
 90                     methodType(Object.class, Addressable.class, SegmentAllocator.class, Object[].class, MethodHandle.class, Map.class, Map.class));
 91             MH_WRAP_ALLOCATOR = lookup.findStatic(Binding.Context.class, "ofAllocator",
 92                     methodType(Binding.Context.class, SegmentAllocator.class));
 93             MH_ADDR_TO_LONG = lookup.findStatic(ProgrammableInvoker.class, "unboxTargetAddress", methodType(long.class, Addressable.class));
 94         } catch (ReflectiveOperationException e) {
 95             throw new RuntimeException(e);
 96         }
 97     }
 98 
 99     private final ABIDescriptor abi;
100     private final BufferLayout layout;
101     private final long stackArgsBytes;
102 
103     private final CallingSequence callingSequence;
104 
105     private final long stubAddress;
106 
107     private final long bufferCopySize;
108 
109     public ProgrammableInvoker(ABIDescriptor abi, CallingSequence callingSequence) {
110         this.abi = abi;
111         this.layout = BufferLayout.of(abi);
112         this.stubAddress = adapterStubs.computeIfAbsent(abi, key -> generateAdapter(key, layout));
113 
114         this.callingSequence = callingSequence;
115 
116         this.stackArgsBytes = argMoveBindingsStream(callingSequence)
117                 .map(Binding.VMStore::storage)
118                 .filter(s -> abi.arch.isStackType(s.type()))
119                 .count()
120                 * abi.arch.typeSize(abi.arch.stackType());
121 
122         this.bufferCopySize = SharedUtils.bufferCopySize(callingSequence);
123     }
124 
125     public MethodHandle getBoundMethodHandle() {
126         Binding.VMStore[] argMoves = argMoveBindingsStream(callingSequence).toArray(Binding.VMStore[]::new);
127         Class<?>[] argMoveTypes = Arrays.stream(argMoves).map(Binding.VMStore::type).toArray(Class<?>[]::new);
128 
129         Binding.VMLoad[] retMoves = retMoveBindings(callingSequence);
130         Class<?> returnType = retMoves.length == 0
131                 ? void.class
132                 : retMoves.length == 1
133                     ? retMoves[0].type()
134                     : Object[].class;
135 
136         MethodType leafType = methodType(returnType, argMoveTypes);
137         MethodType leafTypeWithAddress = leafType.insertParameterTypes(0, long.class);
138 
139         MethodHandle handle = insertArguments(MH_INVOKE_MOVES.bindTo(this), 2, argMoves, retMoves);
140         MethodHandle collector = makeCollectorHandle(leafType);
141         handle = collectArguments(handle, 1, collector);
142         handle = handle.asType(leafTypeWithAddress);
143 
144         boolean isSimple = !(retMoves.length > 1);
145         boolean usesStackArgs = stackArgsBytes != 0;
146         if (USE_INTRINSICS && isSimple && !usesStackArgs) {
147             NativeEntryPoint nep = NativeEntryPoint.make(
148                 "native_call",
149                 abi,
150                 toStorageArray(argMoves),
151                 toStorageArray(retMoves),
152                 !callingSequence.isTrivial(),
153                 leafTypeWithAddress
154             );
155 
156             handle = JLIA.nativeMethodHandle(nep, handle);
157         }
158         handle = filterArguments(handle, 0, MH_ADDR_TO_LONG);
159 
160         if (USE_SPEC && isSimple) {
161             handle = specialize(handle);
162          } else {
163             Map<VMStorage, Integer> argIndexMap = SharedUtils.indexMap(argMoves);
164             Map<VMStorage, Integer> retIndexMap = SharedUtils.indexMap(retMoves);
165 
166             handle = insertArguments(MH_INVOKE_INTERP_BINDINGS.bindTo(this), 3, handle, argIndexMap, retIndexMap);
167             MethodHandle collectorInterp = makeCollectorHandle(callingSequence.methodType());
168             handle = collectArguments(handle, 2, collectorInterp);
169             handle = handle.asType(handle.type().changeReturnType(callingSequence.methodType().returnType()));
170          }
171 
172         return handle;
173     }
174 
175     private static long unboxTargetAddress(Addressable addr) {
176         MemoryAddress ma = SharedUtils.checkSymbol(addr);
177         return ma.toRawLongValue();
178     }
179 
180     // Funnel from type to Object[]
181     private static MethodHandle makeCollectorHandle(MethodType type) {
182         return type.parameterCount() == 0
183             ? EMPTY_OBJECT_ARRAY_HANDLE
184             : identity(Object[].class)
185                 .asCollector(Object[].class, type.parameterCount())
186                 .asType(type.changeReturnType(Object[].class));
187     }
188 
189     private Stream<Binding.VMStore> argMoveBindingsStream(CallingSequence callingSequence) {
190         return callingSequence.argumentBindings()
191                 .filter(Binding.VMStore.class::isInstance)
192                 .map(Binding.VMStore.class::cast);
193     }
194 
195     private Binding.VMLoad[] retMoveBindings(CallingSequence callingSequence) {
196         return callingSequence.returnBindings().stream()
197                 .filter(Binding.VMLoad.class::isInstance)
198                 .map(Binding.VMLoad.class::cast)
199                 .toArray(Binding.VMLoad[]::new);
200     }
201 
202 
203     private VMStorageProxy[] toStorageArray(Binding.Move[] moves) {
204         return Arrays.stream(moves).map(Binding.Move::storage).toArray(VMStorage[]::new);
205     }
206 
207     private MethodHandle specialize(MethodHandle leafHandle) {
208         MethodType highLevelType = callingSequence.methodType();
209 
210         int argInsertPos = 1;
211         int argContextPos = 1;
212 
213         MethodHandle specializedHandle = dropArguments(leafHandle, argContextPos, Binding.Context.class);
214 
215         for (int i = 0; i < highLevelType.parameterCount(); i++) {
216             List<Binding> bindings = callingSequence.argumentBindings(i);
217             argInsertPos += bindings.stream().filter(Binding.VMStore.class::isInstance).count() + 1;
218             // We interpret the bindings in reverse since we have to construct a MethodHandle from the bottom up
219             for (int j = bindings.size() - 1; j >= 0; j--) {
220                 Binding binding = bindings.get(j);
221                 if (binding.tag() == Binding.Tag.VM_STORE) {
222                     argInsertPos--;
223                 } else {
224                     specializedHandle = binding.specialize(specializedHandle, argInsertPos, argContextPos);
225                 }
226             }
227         }
228 
229         if (highLevelType.returnType() != void.class) {
230             MethodHandle returnFilter = identity(highLevelType.returnType());
231             int retContextPos = 0;
232             int retInsertPos = 1;
233             returnFilter = dropArguments(returnFilter, retContextPos, Binding.Context.class);
234             List<Binding> bindings = callingSequence.returnBindings();
235             for (int j = bindings.size() - 1; j >= 0; j--) {
236                 Binding binding = bindings.get(j);
237                 returnFilter = binding.specialize(returnFilter, retInsertPos, retContextPos);
238             }
239             returnFilter = MethodHandles.filterArguments(returnFilter, retContextPos, MH_WRAP_ALLOCATOR);
240             // (SegmentAllocator, Addressable, Context, ...) -> ...
241             specializedHandle = MethodHandles.collectArguments(returnFilter, retInsertPos, specializedHandle);
242             // (Addressable, SegmentAllocator, Context, ...) -> ...
243             specializedHandle = SharedUtils.swapArguments(specializedHandle, 0, 1); // normalize parameter order
244         } else {
245             specializedHandle = MethodHandles.dropArguments(specializedHandle, 1, SegmentAllocator.class);
246         }
247 
248         // now bind the internal context parameter
249 
250         argContextPos++; // skip over the return SegmentAllocator (inserted by the above code)
251         specializedHandle = SharedUtils.wrapWithAllocator(specializedHandle, argContextPos, bufferCopySize, false);
252         return specializedHandle;
253     }
254 
255     /**
256      * Does a native invocation by moving primitive values from the arg array into an intermediate buffer
257      * and calling the assembly stub that forwards arguments from the buffer to the target function
258      *
259      * @param args an array of primitive values to be copied in to the buffer
260      * @param argBindings Binding.Move values describing how arguments should be copied
261      * @param returnBindings Binding.Move values describing how return values should be copied
262      * @return null, a single primitive value, or an Object[] of primitive values
263      */
264     Object invokeMoves(long addr, Object[] args, Binding.VMStore[] argBindings, Binding.VMLoad[] returnBindings) {
265         MemorySegment stackArgsSeg = null;
266         try (ResourceScope scope = ResourceScope.newConfinedScope()) {
267             MemorySegment argBuffer = MemorySegment.allocateNative(layout.size, 64, scope);
268             if (stackArgsBytes > 0) {
269                 stackArgsSeg = MemorySegment.allocateNative(stackArgsBytes, 8, scope);
270             }
271 
272             VH_LONG.set(argBuffer.asSlice(layout.arguments_next_pc), addr);
273             VH_LONG.set(argBuffer.asSlice(layout.stack_args_bytes), stackArgsBytes);
274             VH_LONG.set(argBuffer.asSlice(layout.stack_args), stackArgsSeg == null ? 0L : stackArgsSeg.address().toRawLongValue());
275 
276             for (int i = 0; i < argBindings.length; i++) {
277                 Binding.VMStore binding = argBindings[i];
278                 VMStorage storage = binding.storage();
279                 MemorySegment ptr = abi.arch.isStackType(storage.type())
280                     ? stackArgsSeg.asSlice(storage.index() * abi.arch.typeSize(abi.arch.stackType()))
281                     : argBuffer.asSlice(layout.argOffset(storage));
282                 SharedUtils.writeOverSized(ptr, binding.type(), args[i]);
283             }
284 
285             if (DEBUG) {
286                 System.err.println("Buffer state before:");
287                 layout.dump(abi.arch, argBuffer, System.err);
288             }
289 
290             invokeNative(stubAddress, argBuffer.address().toRawLongValue());
291 
292             if (DEBUG) {
293                 System.err.println("Buffer state after:");
294                 layout.dump(abi.arch, argBuffer, System.err);
295             }
296 
297             if (returnBindings.length == 0) {
298                 return null;
299             } else if (returnBindings.length == 1) {
300                 Binding.VMLoad move = returnBindings[0];
301                 VMStorage storage = move.storage();
302                 return SharedUtils.read(argBuffer.asSlice(layout.retOffset(storage)), move.type());
303             } else { // length > 1
304                 Object[] returns = new Object[returnBindings.length];
305                 for (int i = 0; i < returnBindings.length; i++) {
306                     Binding.VMLoad move = returnBindings[i];
307                     VMStorage storage = move.storage();
308                     returns[i] = SharedUtils.read(argBuffer.asSlice(layout.retOffset(storage)), move.type());
309                 }
310                 return returns;
311             }
312         }
313     }
314 
315     Object invokeInterpBindings(Addressable address, SegmentAllocator allocator, Object[] args, MethodHandle leaf,
316                                 Map<VMStorage, Integer> argIndexMap,
317                                 Map<VMStorage, Integer> retIndexMap) throws Throwable {
318         Binding.Context unboxContext = bufferCopySize != 0
319                 ? Binding.Context.ofBoundedAllocator(bufferCopySize)
320                 : Binding.Context.DUMMY;
321         try (unboxContext) {
322             // do argument processing, get Object[] as result
323             Object[] leafArgs = new Object[leaf.type().parameterCount()];
324             leafArgs[0] = address; // addr
325             for (int i = 0; i < args.length; i++) {
326                 Object arg = args[i];
327                 BindingInterpreter.unbox(arg, callingSequence.argumentBindings(i),
328                         (storage, type, value) -> {
329                             leafArgs[argIndexMap.get(storage) + 1] = value; // +1 to skip addr
330                         }, unboxContext);
331             }
332 
333             // call leaf
334             Object o = leaf.invokeWithArguments(leafArgs);
335             // make sure arguments are reachable during the call
336             // technically we only need to do all Addressable parameters here
337             Reference.reachabilityFence(address);
338             Reference.reachabilityFence(args);
339 
340             // return value processing
341             if (o == null) {
342                 return null;
343             } else if (o instanceof Object[]) {
344                 Object[] oArr = (Object[]) o;
345                 return BindingInterpreter.box(callingSequence.returnBindings(),
346                         (storage, type) -> oArr[retIndexMap.get(storage)], Binding.Context.ofAllocator(allocator));
347             } else {
348                 return BindingInterpreter.box(callingSequence.returnBindings(), (storage, type) -> o,
349                         Binding.Context.ofAllocator(allocator));
350             }
351         }
352     }
353 
354     //natives
355 
356     static native void invokeNative(long adapterStub, long buff);
357     static native long generateAdapter(ABIDescriptor abi, BufferLayout layout);
358 
359     private static native void registerNatives();
360     static {
361         registerNatives();
362     }
363 }
364