1 /*
  2  * Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 
 26 package jdk.internal.foreign.abi;
 27 
 28 import jdk.incubator.foreign.MemoryAddress;
 29 import jdk.incubator.foreign.MemoryLayouts;
 30 import jdk.incubator.foreign.MemorySegment;
 31 import jdk.incubator.foreign.ResourceScope;
 32 import jdk.incubator.foreign.SegmentAllocator;
 33 import jdk.internal.access.JavaLangInvokeAccess;
 34 import jdk.internal.access.SharedSecrets;
 35 import jdk.internal.foreign.MemoryAddressImpl;
 36 import sun.security.action.GetPropertyAction;
 37 
 38 import java.lang.invoke.MethodHandle;
 39 import java.lang.invoke.MethodHandles;
 40 import java.lang.invoke.MethodType;
 41 import java.lang.invoke.VarHandle;
 42 import java.util.Arrays;
 43 import java.util.List;
 44 import java.util.Map;
 45 import java.util.Objects;
 46 import java.util.stream.Stream;
 47 
 48 import static java.lang.invoke.MethodHandles.dropArguments;
 49 import static java.lang.invoke.MethodHandles.filterReturnValue;
 50 import static java.lang.invoke.MethodHandles.identity;
 51 import static java.lang.invoke.MethodHandles.insertArguments;
 52 import static java.lang.invoke.MethodHandles.lookup;
 53 import static java.lang.invoke.MethodType.methodType;
 54 import static jdk.internal.foreign.abi.SharedUtils.mergeArguments;
 55 import static sun.security.action.GetBooleanAction.privilegedGetProperty;
 56 
 57 /**
 58  * This class implements upcall invocation from native code through a so called 'universal adapter'. A universal upcall adapter
 59  * takes an array of storage pointers, which describes the state of the CPU at the time of the upcall. This can be used
 60  * by the Java code to fetch the upcall arguments and to store the results to the desired location, as per system ABI.
 61  */
 62 public class ProgrammableUpcallHandler {
 63     private static final boolean DEBUG =
 64         privilegedGetProperty("jdk.internal.foreign.ProgrammableUpcallHandler.DEBUG");
 65     private static final boolean USE_SPEC = Boolean.parseBoolean(
 66         GetPropertyAction.privilegedGetProperty("jdk.internal.foreign.ProgrammableUpcallHandler.USE_SPEC", "true"));
 67     private static final boolean USE_INTRINSICS = Boolean.parseBoolean(
 68         GetPropertyAction.privilegedGetProperty("jdk.internal.foreign.ProgrammableUpcallHandler.USE_INTRINSICS", "true"));
 69 
 70     private static final JavaLangInvokeAccess JLI = SharedSecrets.getJavaLangInvokeAccess();
 71 
 72     private static final VarHandle VH_LONG = MemoryLayouts.JAVA_LONG.varHandle(long.class);
 73 
 74     private static final MethodHandle MH_invokeMoves;
 75     private static final MethodHandle MH_invokeInterpBindings;
 76 
 77     static {
 78         try {
 79             MethodHandles.Lookup lookup = lookup();
 80             MH_invokeMoves = lookup.findStatic(ProgrammableUpcallHandler.class, "invokeMoves",
 81                     methodType(void.class, MemoryAddress.class, MethodHandle.class,
 82                                Binding.VMLoad[].class, Binding.VMStore[].class, ABIDescriptor.class, BufferLayout.class));
 83             MH_invokeInterpBindings = lookup.findStatic(ProgrammableUpcallHandler.class, "invokeInterpBindings",
 84                     methodType(Object.class, Object[].class, MethodHandle.class, Map.class, Map.class,
 85                             CallingSequence.class, long.class));
 86         } catch (ReflectiveOperationException e) {
 87             throw new InternalError(e);
 88         }
 89     }
 90 
 91     public static UpcallHandler make(ABIDescriptor abi, MethodHandle target, CallingSequence callingSequence) {
 92         Binding.VMLoad[] argMoves = argMoveBindings(callingSequence);
 93         Binding.VMStore[] retMoves = retMoveBindings(callingSequence);
 94 
 95         boolean isSimple = !(retMoves.length > 1);
 96 
 97         Class<?> llReturn = !isSimple
 98             ? Object[].class
 99             : retMoves.length == 1
100                 ? retMoves[0].type()
101                 : void.class;
102         Class<?>[] llParams = Arrays.stream(argMoves).map(Binding.Move::type).toArray(Class<?>[]::new);
103         MethodType llType = MethodType.methodType(llReturn, llParams);
104 
105         MethodHandle doBindings;
106         long bufferCopySize = SharedUtils.bufferCopySize(callingSequence);
107         if (USE_SPEC && isSimple) {
108             doBindings = specializedBindingHandle(target, callingSequence, llReturn, bufferCopySize);
109             assert doBindings.type() == llType;
110         } else {
111             Map<VMStorage, Integer> argIndices = SharedUtils.indexMap(argMoves);
112             Map<VMStorage, Integer> retIndices = SharedUtils.indexMap(retMoves);
113             target = target.asSpreader(Object[].class, callingSequence.methodType().parameterCount());
114             doBindings = insertArguments(MH_invokeInterpBindings, 1, target, argIndices, retIndices, callingSequence,
115                     bufferCopySize);
116             doBindings = doBindings.asCollector(Object[].class, llType.parameterCount());
117             doBindings = doBindings.asType(llType);
118         }
119 
120         long entryPoint;
121         boolean usesStackArgs = argMoveBindingsStream(callingSequence)
122                 .map(Binding.VMLoad::storage)
123                 .anyMatch(s -> abi.arch.isStackType(s.type()));
124         if (USE_INTRINSICS && isSimple && !usesStackArgs && supportsOptimizedUpcalls()) {
125             checkPrimitive(doBindings.type());
126             JLI.ensureCustomized(doBindings);
127             VMStorage[] args = Arrays.stream(argMoves).map(Binding.Move::storage).toArray(VMStorage[]::new);
128             VMStorage[] rets = Arrays.stream(retMoves).map(Binding.Move::storage).toArray(VMStorage[]::new);
129             CallRegs conv = new CallRegs(args, rets);
130             entryPoint = allocateOptimizedUpcallStub(doBindings, abi, conv);
131         } else {
132             BufferLayout layout = BufferLayout.of(abi);
133             MethodHandle doBindingsErased = doBindings.asSpreader(Object[].class, doBindings.type().parameterCount());
134             MethodHandle invokeMoves = insertArguments(MH_invokeMoves, 1, doBindingsErased, argMoves, retMoves, abi, layout);
135             entryPoint = allocateUpcallStub(invokeMoves, abi, layout);
136         }
137         return () -> entryPoint;
138     }
139 
140     private static void checkPrimitive(MethodType type) {
141         if (!type.returnType().isPrimitive()
142                 || type.parameterList().stream().anyMatch(p -> !p.isPrimitive()))
143             throw new IllegalArgumentException("MethodHandle type must be primitive: " + type);
144     }
145 
146     private static Stream<Binding.VMLoad> argMoveBindingsStream(CallingSequence callingSequence) {
147         return callingSequence.argumentBindings()
148                 .filter(Binding.VMLoad.class::isInstance)
149                 .map(Binding.VMLoad.class::cast);
150     }
151 
152     private static Binding.VMLoad[] argMoveBindings(CallingSequence callingSequence) {
153         return argMoveBindingsStream(callingSequence)
154                 .toArray(Binding.VMLoad[]::new);
155     }
156 
157     private static Binding.VMStore[] retMoveBindings(CallingSequence callingSequence) {
158         return callingSequence.returnBindings().stream()
159                 .filter(Binding.VMStore.class::isInstance)
160                 .map(Binding.VMStore.class::cast)
161                 .toArray(Binding.VMStore[]::new);
162     }
163 
164     private static MethodHandle specializedBindingHandle(MethodHandle target, CallingSequence callingSequence,
165                                                          Class<?> llReturn, long bufferCopySize) {
166         MethodType highLevelType = callingSequence.methodType();
167 
168         MethodHandle specializedHandle = target; // initial
169 
170         int argAllocatorPos = 0;
171         int argInsertPos = 1;
172         specializedHandle = dropArguments(specializedHandle, argAllocatorPos, Binding.Context.class);
173         for (int i = 0; i < highLevelType.parameterCount(); i++) {
174             MethodHandle filter = identity(highLevelType.parameterType(i));
175             int filterAllocatorPos = 0;
176             int filterInsertPos = 1; // +1 for allocator
177             filter = dropArguments(filter, filterAllocatorPos, Binding.Context.class);
178 
179             List<Binding> bindings = callingSequence.argumentBindings(i);
180             for (int j = bindings.size() - 1; j >= 0; j--) {
181                 Binding binding = bindings.get(j);
182                 filter = binding.specialize(filter, filterInsertPos, filterAllocatorPos);
183             }
184             specializedHandle = MethodHandles.collectArguments(specializedHandle, argInsertPos, filter);
185             specializedHandle = mergeArguments(specializedHandle, argAllocatorPos, argInsertPos + filterAllocatorPos);
186             argInsertPos += filter.type().parameterCount() - 1; // -1 for allocator
187         }
188 
189         if (llReturn != void.class) {
190             int retAllocatorPos = -1; // assumed not needed
191             int retInsertPos = 0;
192             MethodHandle filter = identity(llReturn);
193             List<Binding> bindings = callingSequence.returnBindings();
194             for (int j = bindings.size() - 1; j >= 0; j--) {
195                 Binding binding = bindings.get(j);
196                 filter = binding.specialize(filter, retInsertPos, retAllocatorPos);
197             }
198             specializedHandle = filterReturnValue(specializedHandle, filter);
199         }
200 
201         specializedHandle = SharedUtils.wrapWithAllocator(specializedHandle, argAllocatorPos, bufferCopySize, true);
202 
203         return specializedHandle;
204     }
205 
206     public static void invoke(MethodHandle mh, long address) throws Throwable {
207         mh.invokeExact(MemoryAddress.ofLong(address));
208     }
209 
210     private static void invokeMoves(MemoryAddress buffer, MethodHandle leaf,
211                                     Binding.VMLoad[] argBindings, Binding.VMStore[] returnBindings,
212                                     ABIDescriptor abi, BufferLayout layout) throws Throwable {
213         MemorySegment bufferBase = MemoryAddressImpl.ofLongUnchecked(buffer.toRawLongValue(), layout.size);
214 
215         if (DEBUG) {
216             System.err.println("Buffer state before:");
217             layout.dump(abi.arch, bufferBase, System.err);
218         }
219 
220         MemorySegment stackArgsBase = MemoryAddressImpl.ofLongUnchecked((long)VH_LONG.get(bufferBase.asSlice(layout.stack_args)));
221         Object[] moves = new Object[argBindings.length];
222         for (int i = 0; i < moves.length; i++) {
223             Binding.VMLoad binding = argBindings[i];
224             VMStorage storage = binding.storage();
225             MemorySegment ptr = abi.arch.isStackType(storage.type())
226                 ? stackArgsBase.asSlice(storage.index() * abi.arch.typeSize(abi.arch.stackType()))
227                 : bufferBase.asSlice(layout.argOffset(storage));
228             moves[i] = SharedUtils.read(ptr, binding.type());
229         }
230 
231         // invokeInterpBindings, and then actual target
232         Object o = leaf.invoke(moves);
233 
234         if (o == null) {
235             // nop
236         } else if (o instanceof Object[] returns) {
237             for (int i = 0; i < returnBindings.length; i++) {
238                 Binding.VMStore binding = returnBindings[i];
239                 VMStorage storage = binding.storage();
240                 MemorySegment ptr = bufferBase.asSlice(layout.retOffset(storage));
241                 SharedUtils.writeOverSized(ptr, binding.type(), returns[i]);
242             }
243         } else { // single Object
244             Binding.VMStore binding = returnBindings[0];
245             VMStorage storage = binding.storage();
246             MemorySegment ptr = bufferBase.asSlice(layout.retOffset(storage));
247             SharedUtils.writeOverSized(ptr, binding.type(), o);
248         }
249 
250         if (DEBUG) {
251             System.err.println("Buffer state after:");
252             layout.dump(abi.arch, bufferBase, System.err);
253         }
254     }
255 
256     private static Object invokeInterpBindings(Object[] moves, MethodHandle leaf,
257                                                Map<VMStorage, Integer> argIndexMap,
258                                                Map<VMStorage, Integer> retIndexMap,
259                                                CallingSequence callingSequence,
260                                                long bufferCopySize) throws Throwable {
261         Binding.Context allocator = bufferCopySize != 0
262                 ? Binding.Context.ofBoundedAllocator(bufferCopySize)
263                 : Binding.Context.ofScope();
264         try (allocator) {
265             /// Invoke interpreter, got array of high-level arguments back
266             Object[] args = new Object[callingSequence.methodType().parameterCount()];
267             for (int i = 0; i < args.length; i++) {
268                 args[i] = BindingInterpreter.box(callingSequence.argumentBindings(i),
269                         (storage, type) -> moves[argIndexMap.get(storage)], allocator);
270             }
271 
272             if (DEBUG) {
273                 System.err.println("Java arguments:");
274                 System.err.println(Arrays.toString(args).indent(2));
275             }
276 
277             // invoke our target
278             Object o = leaf.invoke(args);
279 
280             if (DEBUG) {
281                 System.err.println("Java return:");
282                 System.err.println(Objects.toString(o).indent(2));
283             }
284 
285             Object[] returnMoves = new Object[retIndexMap.size()];
286             if (leaf.type().returnType() != void.class) {
287                 BindingInterpreter.unbox(o, callingSequence.returnBindings(),
288                         (storage, type, value) -> returnMoves[retIndexMap.get(storage)] = value, null);
289             }
290 
291             if (returnMoves.length == 0) {
292                 return null;
293             } else if (returnMoves.length == 1) {
294                 return returnMoves[0];
295             } else {
296                 return returnMoves;
297             }
298         } catch(Throwable t) {
299             SharedUtils.handleUncaughtException(t);
300             return null;
301         }
302     }
303 
304     // used for transporting data into native code
305     private static record CallRegs(VMStorage[] argRegs, VMStorage[] retRegs) {}
306 
307     static native long allocateOptimizedUpcallStub(MethodHandle mh, ABIDescriptor abi, CallRegs conv);
308     static native long allocateUpcallStub(MethodHandle mh, ABIDescriptor abi, BufferLayout layout);
309     static native boolean supportsOptimizedUpcalls();
310 
311     private static native void registerNatives();
312     static {
313         registerNatives();
314     }
315 }