1 /*
  2  * Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 
 26 package jdk.internal.foreign.abi;
 27 
 28 import jdk.incubator.foreign.MemoryAddress;
 29 import jdk.incubator.foreign.MemorySegment;
 30 import jdk.incubator.foreign.NativeSymbol;
 31 import jdk.incubator.foreign.ResourceScope;
 32 import jdk.incubator.foreign.ValueLayout;
 33 import jdk.internal.foreign.MemoryAddressImpl;
 34 import sun.security.action.GetPropertyAction;
 35 
 36 import java.lang.invoke.MethodHandle;
 37 import java.lang.invoke.MethodHandles;
 38 import java.lang.invoke.MethodType;
 39 import java.lang.invoke.VarHandle;
 40 import java.util.Arrays;
 41 import java.util.List;
 42 import java.util.Map;
 43 import java.util.Objects;
 44 import java.util.stream.Stream;
 45 
 46 import static java.lang.invoke.MethodHandles.dropArguments;
 47 import static java.lang.invoke.MethodHandles.exactInvoker;
 48 import static java.lang.invoke.MethodHandles.filterReturnValue;
 49 import static java.lang.invoke.MethodHandles.identity;
 50 import static java.lang.invoke.MethodHandles.insertArguments;
 51 import static java.lang.invoke.MethodHandles.lookup;
 52 import static java.lang.invoke.MethodType.methodType;
 53 import static jdk.internal.foreign.abi.SharedUtils.mergeArguments;
 54 import static sun.security.action.GetBooleanAction.privilegedGetProperty;
 55 
 56 /**
 57  * This class implements upcall invocation from native code through a so called 'universal adapter'. A universal upcall adapter
 58  * takes an array of storage pointers, which describes the state of the CPU at the time of the upcall. This can be used
 59  * by the Java code to fetch the upcall arguments and to store the results to the desired location, as per system ABI.
 60  */
 61 public class ProgrammableUpcallHandler {
 62     private static final boolean DEBUG =
 63         privilegedGetProperty("jdk.internal.foreign.ProgrammableUpcallHandler.DEBUG");
 64     private static final boolean USE_SPEC = Boolean.parseBoolean(
 65         GetPropertyAction.privilegedGetProperty("jdk.internal.foreign.ProgrammableUpcallHandler.USE_SPEC", "true"));
 66     private static final boolean USE_INTRINSICS = Boolean.parseBoolean(
 67         GetPropertyAction.privilegedGetProperty("jdk.internal.foreign.ProgrammableUpcallHandler.USE_INTRINSICS", "true"));
 68 
 69     private static final VarHandle VH_LONG = ValueLayout.JAVA_LONG.varHandle();
 70 
 71     private static final MethodHandle MH_invokeMoves;
 72     private static final MethodHandle MH_invokeInterpBindings;
 73 
 74     static {
 75         try {
 76             MethodHandles.Lookup lookup = lookup();
 77             MH_invokeMoves = lookup.findStatic(ProgrammableUpcallHandler.class, "invokeMoves",
 78                     methodType(void.class, MemoryAddress.class, MethodHandle.class,
 79                                Binding.VMLoad[].class, Binding.VMStore[].class, ABIDescriptor.class, BufferLayout.class));
 80             MH_invokeInterpBindings = lookup.findStatic(ProgrammableUpcallHandler.class, "invokeInterpBindings",
 81                     methodType(Object.class, Object[].class, MethodHandle.class, Map.class, Map.class,
 82                             CallingSequence.class, long.class));
 83         } catch (ReflectiveOperationException e) {
 84             throw new InternalError(e);
 85         }
 86     }
 87 
 88     public static NativeSymbol make(ABIDescriptor abi, MethodHandle target, CallingSequence callingSequence, ResourceScope scope) {
 89         Binding.VMLoad[] argMoves = argMoveBindings(callingSequence);
 90         Binding.VMStore[] retMoves = retMoveBindings(callingSequence);
 91 
 92         boolean isSimple = !(retMoves.length > 1);
 93 
 94         Class<?> llReturn = !isSimple
 95             ? Object[].class
 96             : retMoves.length == 1
 97                 ? retMoves[0].type()
 98                 : void.class;
 99         Class<?>[] llParams = Arrays.stream(argMoves).map(Binding.Move::type).toArray(Class<?>[]::new);
100         MethodType llType = MethodType.methodType(llReturn, llParams);
101 
102         MethodHandle doBindings;
103         long bufferCopySize = SharedUtils.bufferCopySize(callingSequence);
104         if (USE_SPEC && isSimple) {
105             doBindings = specializedBindingHandle(target, callingSequence, llReturn, bufferCopySize);
106             assert doBindings.type() == llType;
107         } else {
108             Map<VMStorage, Integer> argIndices = SharedUtils.indexMap(argMoves);
109             Map<VMStorage, Integer> retIndices = SharedUtils.indexMap(retMoves);
110             target = target.asSpreader(Object[].class, callingSequence.methodType().parameterCount());
111             doBindings = insertArguments(MH_invokeInterpBindings, 1, target, argIndices, retIndices, callingSequence,
112                     bufferCopySize);
113             doBindings = doBindings.asCollector(Object[].class, llType.parameterCount());
114             doBindings = doBindings.asType(llType);
115         }
116 
117         long entryPoint;
118         if (USE_INTRINSICS && isSimple && supportsOptimizedUpcalls()) {
119             checkPrimitive(doBindings.type());
120             doBindings = insertArguments(exactInvoker(doBindings.type()), 0, doBindings);
121             VMStorage[] args = Arrays.stream(argMoves).map(Binding.Move::storage).toArray(VMStorage[]::new);
122             VMStorage[] rets = Arrays.stream(retMoves).map(Binding.Move::storage).toArray(VMStorage[]::new);
123             CallRegs conv = new CallRegs(args, rets);
124             entryPoint = allocateOptimizedUpcallStub(doBindings, abi, conv);
125         } else {
126             BufferLayout layout = BufferLayout.of(abi);
127             MethodHandle doBindingsErased = doBindings.asSpreader(Object[].class, doBindings.type().parameterCount());
128             MethodHandle invokeMoves = insertArguments(MH_invokeMoves, 1, doBindingsErased, argMoves, retMoves, abi, layout);
129             entryPoint = allocateUpcallStub(invokeMoves, abi, layout);
130         }
131         return UpcallStubs.makeUpcall(entryPoint, scope);
132     }
133 
134     private static void checkPrimitive(MethodType type) {
135         if (!type.returnType().isPrimitive()
136                 || type.parameterList().stream().anyMatch(p -> !p.isPrimitive()))
137             throw new IllegalArgumentException("MethodHandle type must be primitive: " + type);
138     }
139 
140     private static Stream<Binding.VMLoad> argMoveBindingsStream(CallingSequence callingSequence) {
141         return callingSequence.argumentBindings()
142                 .filter(Binding.VMLoad.class::isInstance)
143                 .map(Binding.VMLoad.class::cast);
144     }
145 
146     private static Binding.VMLoad[] argMoveBindings(CallingSequence callingSequence) {
147         return argMoveBindingsStream(callingSequence)
148                 .toArray(Binding.VMLoad[]::new);
149     }
150 
151     private static Binding.VMStore[] retMoveBindings(CallingSequence callingSequence) {
152         return callingSequence.returnBindings().stream()
153                 .filter(Binding.VMStore.class::isInstance)
154                 .map(Binding.VMStore.class::cast)
155                 .toArray(Binding.VMStore[]::new);
156     }
157 
158     private static MethodHandle specializedBindingHandle(MethodHandle target, CallingSequence callingSequence,
159                                                          Class<?> llReturn, long bufferCopySize) {
160         MethodType highLevelType = callingSequence.methodType();
161 
162         MethodHandle specializedHandle = target; // initial
163 
164         int argAllocatorPos = 0;
165         int argInsertPos = 1;
166         specializedHandle = dropArguments(specializedHandle, argAllocatorPos, Binding.Context.class);
167         for (int i = 0; i < highLevelType.parameterCount(); i++) {
168             MethodHandle filter = identity(highLevelType.parameterType(i));
169             int filterAllocatorPos = 0;
170             int filterInsertPos = 1; // +1 for allocator
171             filter = dropArguments(filter, filterAllocatorPos, Binding.Context.class);
172 
173             List<Binding> bindings = callingSequence.argumentBindings(i);
174             for (int j = bindings.size() - 1; j >= 0; j--) {
175                 Binding binding = bindings.get(j);
176                 filter = binding.specialize(filter, filterInsertPos, filterAllocatorPos);
177             }
178             specializedHandle = MethodHandles.collectArguments(specializedHandle, argInsertPos, filter);
179             specializedHandle = mergeArguments(specializedHandle, argAllocatorPos, argInsertPos + filterAllocatorPos);
180             argInsertPos += filter.type().parameterCount() - 1; // -1 for allocator
181         }
182 
183         if (llReturn != void.class) {
184             int retAllocatorPos = -1; // assumed not needed
185             int retInsertPos = 0;
186             MethodHandle filter = identity(llReturn);
187             List<Binding> bindings = callingSequence.returnBindings();
188             for (int j = bindings.size() - 1; j >= 0; j--) {
189                 Binding binding = bindings.get(j);
190                 filter = binding.specialize(filter, retInsertPos, retAllocatorPos);
191             }
192             specializedHandle = filterReturnValue(specializedHandle, filter);
193         }
194 
195         specializedHandle = SharedUtils.wrapWithAllocator(specializedHandle, argAllocatorPos, bufferCopySize, true);
196 
197         return specializedHandle;
198     }
199 
200     public static void invoke(MethodHandle mh, long address) throws Throwable {
201         mh.invokeExact(MemoryAddress.ofLong(address));
202     }
203 
204     private static void invokeMoves(MemoryAddress buffer, MethodHandle leaf,
205                                     Binding.VMLoad[] argBindings, Binding.VMStore[] returnBindings,
206                                     ABIDescriptor abi, BufferLayout layout) throws Throwable {
207         MemorySegment bufferBase = MemoryAddressImpl.ofLongUnchecked(buffer.toRawLongValue(), layout.size);
208 
209         if (DEBUG) {
210             System.err.println("Buffer state before:");
211             layout.dump(abi.arch, bufferBase, System.err);
212         }
213 
214         MemorySegment stackArgsBase = MemoryAddressImpl.ofLongUnchecked((long)VH_LONG.get(bufferBase.asSlice(layout.stack_args)));
215         Object[] moves = new Object[argBindings.length];
216         for (int i = 0; i < moves.length; i++) {
217             Binding.VMLoad binding = argBindings[i];
218             VMStorage storage = binding.storage();
219             MemorySegment ptr = abi.arch.isStackType(storage.type())
220                 ? stackArgsBase.asSlice(storage.index() * abi.arch.typeSize(abi.arch.stackType()))
221                 : bufferBase.asSlice(layout.argOffset(storage));
222             moves[i] = SharedUtils.read(ptr, binding.type());
223         }
224 
225         // invokeInterpBindings, and then actual target
226         Object o = leaf.invoke(moves);
227 
228         if (o == null) {
229             // nop
230         } else if (o instanceof Object[] returns) {
231             for (int i = 0; i < returnBindings.length; i++) {
232                 Binding.VMStore binding = returnBindings[i];
233                 VMStorage storage = binding.storage();
234                 MemorySegment ptr = bufferBase.asSlice(layout.retOffset(storage));
235                 SharedUtils.writeOverSized(ptr, binding.type(), returns[i]);
236             }
237         } else { // single Object
238             Binding.VMStore binding = returnBindings[0];
239             VMStorage storage = binding.storage();
240             MemorySegment ptr = bufferBase.asSlice(layout.retOffset(storage));
241             SharedUtils.writeOverSized(ptr, binding.type(), o);
242         }
243 
244         if (DEBUG) {
245             System.err.println("Buffer state after:");
246             layout.dump(abi.arch, bufferBase, System.err);
247         }
248     }
249 
250     private static Object invokeInterpBindings(Object[] moves, MethodHandle leaf,
251                                                Map<VMStorage, Integer> argIndexMap,
252                                                Map<VMStorage, Integer> retIndexMap,
253                                                CallingSequence callingSequence,
254                                                long bufferCopySize) throws Throwable {
255         Binding.Context allocator = bufferCopySize != 0
256                 ? Binding.Context.ofBoundedAllocator(bufferCopySize)
257                 : Binding.Context.ofScope();
258         try (allocator) {
259             /// Invoke interpreter, got array of high-level arguments back
260             Object[] args = new Object[callingSequence.methodType().parameterCount()];
261             for (int i = 0; i < args.length; i++) {
262                 args[i] = BindingInterpreter.box(callingSequence.argumentBindings(i),
263                         (storage, type) -> moves[argIndexMap.get(storage)], allocator);
264             }
265 
266             if (DEBUG) {
267                 System.err.println("Java arguments:");
268                 System.err.println(Arrays.toString(args).indent(2));
269             }
270 
271             // invoke our target
272             Object o = leaf.invoke(args);
273 
274             if (DEBUG) {
275                 System.err.println("Java return:");
276                 System.err.println(Objects.toString(o).indent(2));
277             }
278 
279             Object[] returnMoves = new Object[retIndexMap.size()];
280             if (leaf.type().returnType() != void.class) {
281                 BindingInterpreter.unbox(o, callingSequence.returnBindings(),
282                         (storage, type, value) -> returnMoves[retIndexMap.get(storage)] = value, null);
283             }
284 
285             if (returnMoves.length == 0) {
286                 return null;
287             } else if (returnMoves.length == 1) {
288                 return returnMoves[0];
289             } else {
290                 return returnMoves;
291             }
292         } catch(Throwable t) {
293             SharedUtils.handleUncaughtException(t);
294             return null;
295         }
296     }
297 
298     // used for transporting data into native code
299     private static record CallRegs(VMStorage[] argRegs, VMStorage[] retRegs) {}
300 
301     static native long allocateOptimizedUpcallStub(MethodHandle mh, ABIDescriptor abi, CallRegs conv);
302     static native long allocateUpcallStub(MethodHandle mh, ABIDescriptor abi, BufferLayout layout);
303     static native boolean supportsOptimizedUpcalls();
304 
305     private static native void registerNatives();
306     static {
307         registerNatives();
308     }
309 }