1 /*
  2  * Copyright (c) 2020, 2023, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 package jdk.internal.foreign.abi;
 26 
 27 import jdk.internal.access.JavaLangAccess;
 28 import jdk.internal.access.JavaLangInvokeAccess;
 29 import jdk.internal.access.SharedSecrets;
 30 import jdk.internal.foreign.CABI;
 31 import jdk.internal.foreign.abi.AbstractLinker.UpcallStubFactory;
 32 import jdk.internal.foreign.abi.aarch64.linux.LinuxAArch64Linker;
 33 import jdk.internal.foreign.abi.aarch64.macos.MacOsAArch64Linker;
 34 import jdk.internal.foreign.abi.aarch64.windows.WindowsAArch64Linker;
 35 import jdk.internal.foreign.abi.fallback.FallbackLinker;
 36 import jdk.internal.foreign.abi.ppc64.linux.LinuxPPC64Linker;
 37 import jdk.internal.foreign.abi.ppc64.linux.LinuxPPC64leLinker;
 38 import jdk.internal.foreign.abi.riscv64.linux.LinuxRISCV64Linker;
 39 import jdk.internal.foreign.abi.s390.linux.LinuxS390Linker;
 40 import jdk.internal.foreign.abi.x64.sysv.SysVx64Linker;
 41 import jdk.internal.foreign.abi.x64.windows.Windowsx64Linker;
 42 import jdk.internal.vm.annotation.ForceInline;
 43 
 44 import java.lang.foreign.AddressLayout;
 45 import java.lang.foreign.Arena;
 46 import java.lang.foreign.Linker;
 47 import java.lang.foreign.FunctionDescriptor;
 48 import java.lang.foreign.GroupLayout;
 49 import java.lang.foreign.MemoryLayout;
 50 import java.lang.foreign.MemorySegment;
 51 import java.lang.foreign.MemorySegment.Scope;
 52 import java.lang.foreign.SegmentAllocator;
 53 import java.lang.foreign.ValueLayout;
 54 import java.lang.invoke.MethodHandle;
 55 import java.lang.invoke.MethodHandles;
 56 import java.lang.invoke.MethodType;
 57 import java.lang.ref.Reference;
 58 import java.nio.ByteOrder;
 59 import java.util.Arrays;
 60 import java.util.Map;
 61 import java.util.Objects;
 62 import java.util.stream.Collectors;
 63 import java.util.stream.IntStream;
 64 
 65 import static java.lang.foreign.ValueLayout.*;
 66 import static java.lang.invoke.MethodHandles.*;
 67 import static java.lang.invoke.MethodType.methodType;
 68 
 69 public final class SharedUtils {
 70 
 71     private SharedUtils() {
 72     }
 73 
 74     private static final JavaLangAccess JLA = SharedSecrets.getJavaLangAccess();
 75     private static final JavaLangInvokeAccess JLIA = SharedSecrets.getJavaLangInvokeAccess();
 76 
 77     private static final MethodHandle MH_ALLOC_BUFFER;
 78     private static final MethodHandle MH_BUFFER_COPY;
 79     private static final MethodHandle MH_REACHABILITY_FENCE;
 80     public static final MethodHandle MH_CHECK_SYMBOL;
 81     private static final MethodHandle MH_CHECK_CAPTURE_SEGMENT;
 82 
 83     public static final AddressLayout C_POINTER = ADDRESS
 84             .withTargetLayout(MemoryLayout.sequenceLayout(Long.MAX_VALUE, JAVA_BYTE));
 85 
 86     public static final Arena DUMMY_ARENA = new Arena() {
 87         @Override
 88         public Scope scope() {
 89             throw new UnsupportedOperationException();
 90         }
 91 
 92         @Override
 93         public MemorySegment allocate(long byteSize, long byteAlignment) {
 94             throw new UnsupportedOperationException();
 95         }
 96 
 97         @Override
 98         public void close() {
 99             // do nothing
100         }
101     };
102 
103     static {
104         try {
105             MethodHandles.Lookup lookup = MethodHandles.lookup();
106             MH_ALLOC_BUFFER = lookup.findVirtual(SegmentAllocator.class, "allocate",
107                     methodType(MemorySegment.class, MemoryLayout.class));
108             MH_BUFFER_COPY = lookup.findStatic(SharedUtils.class, "bufferCopy",
109                     methodType(MemorySegment.class, MemorySegment.class, MemorySegment.class));
110             MH_REACHABILITY_FENCE = lookup.findStatic(Reference.class, "reachabilityFence",
111                     methodType(void.class, Object.class));
112             MH_CHECK_SYMBOL = lookup.findStatic(SharedUtils.class, "checkSymbol",
113                     methodType(void.class, MemorySegment.class));
114             MH_CHECK_CAPTURE_SEGMENT = lookup.findStatic(SharedUtils.class, "checkCaptureSegment",
115                     methodType(MemorySegment.class, MemorySegment.class));
116         } catch (ReflectiveOperationException e) {
117             throw new BootstrapMethodError(e);
118         }
119     }
120 
121     // this allocator should be used when no allocation is expected
122     public static final SegmentAllocator THROWING_ALLOCATOR = (size, align) -> {
123         throw new IllegalStateException("Cannot get here");
124     };
125 
126     public static long alignUp(long addr, long alignment) {
127         return ((addr - 1) | (alignment - 1)) + 1;
128     }
129 
130     public static long remainsToAlignment(long addr, long alignment) {
131         return alignUp(addr, alignment) - addr;
132     }
133 
134     /**
135      * Takes a MethodHandle that takes an input buffer as a first argument (a MemorySegment), and returns nothing,
136      * and adapts it to return a MemorySegment, by allocating a MemorySegment for the input
137      * buffer, calling the target MethodHandle, and then returning the allocated MemorySegment.
138      *
139      * This allows viewing a MethodHandle that makes use of in memory return (IMR) as a MethodHandle that just returns
140      * a MemorySegment without requiring a pre-allocated buffer as an explicit input.
141      *
142      * @param handle the target handle to adapt
143      * @param cDesc the function descriptor of the native function (with actual return layout)
144      * @return the adapted handle
145      */
146     public static MethodHandle adaptDowncallForIMR(MethodHandle handle, FunctionDescriptor cDesc, CallingSequence sequence) {
147         if (handle.type().returnType() != void.class)
148             throw new IllegalArgumentException("return expected to be void for in memory returns: " + handle.type());
149         int imrAddrIdx = sequence.numLeadingParams();
150         if (handle.type().parameterType(imrAddrIdx) != MemorySegment.class)
151             throw new IllegalArgumentException("MemorySegment expected as third param: " + handle.type());
152         if (cDesc.returnLayout().isEmpty())
153             throw new IllegalArgumentException("Return layout needed: " + cDesc);
154 
155         MethodHandle ret = identity(MemorySegment.class); // (MemorySegment) MemorySegment
156         handle = collectArguments(ret, 1, handle); // (MemorySegment, MemorySegment, SegmentAllocator, MemorySegment, ...) MemorySegment
157         handle = mergeArguments(handle, 0, 1 + imrAddrIdx);  // (MemorySegment, MemorySegment, SegmentAllocator, ...) MemorySegment
158         handle = collectArguments(handle, 0, insertArguments(MH_ALLOC_BUFFER, 1, cDesc.returnLayout().get())); // (SegmentAllocator, MemorySegment, SegmentAllocator, ...) MemorySegment
159         handle = mergeArguments(handle, 0, 2);  // (SegmentAllocator, MemorySegment, ...) MemorySegment
160         handle = swapArguments(handle, 0, 1); // (MemorySegment, SegmentAllocator, ...) MemorySegment
161         return handle;
162     }
163 
164     /**
165      * Takes a MethodHandle that returns a MemorySegment, and adapts it to take an input buffer as a first argument
166      * (a MemorySegment), and upon invocation, copies the contents of the returned MemorySegment into the input buffer
167      * passed as the first argument.
168      *
169      * @param target the target handle to adapt
170      * @return the adapted handle
171      */
172     private static MethodHandle adaptUpcallForIMR(MethodHandle target, boolean dropReturn) {
173         if (target.type().returnType() != MemorySegment.class)
174             throw new IllegalArgumentException("Must return MemorySegment for IMR");
175 
176         target = collectArguments(MH_BUFFER_COPY, 1, target); // (MemorySegment, ...) MemorySegment
177 
178         if (dropReturn) { // no handling for return value, need to drop it
179             target = dropReturn(target);
180         } else {
181             // adjust return type so it matches the inferred type of the effective
182             // function descriptor
183             target = target.asType(target.type().changeReturnType(MemorySegment.class));
184         }
185 
186         return target;
187     }
188 
189     public static UpcallStubFactory arrangeUpcallHelper(MethodType targetType, boolean isInMemoryReturn, boolean dropReturn,
190                                                         ABIDescriptor abi, CallingSequence callingSequence) {
191         if (isInMemoryReturn) {
192             // simulate the adaptation to get the type
193             MethodHandle fakeTarget = MethodHandles.empty(targetType);
194             targetType = adaptUpcallForIMR(fakeTarget, dropReturn).type();
195         }
196 
197         UpcallStubFactory factory = UpcallLinker.makeFactory(targetType, abi, callingSequence);
198 
199         if (isInMemoryReturn) {
200             final UpcallStubFactory finalFactory = factory;
201             factory = (target, scope) -> {
202                 target = adaptUpcallForIMR(target, dropReturn);
203                 return finalFactory.makeStub(target, scope);
204             };
205         }
206 
207         return factory;
208     }
209 
210     private static MemorySegment bufferCopy(MemorySegment dest, MemorySegment buffer) {
211         return dest.copyFrom(buffer);
212     }
213 
214     public static Class<?> primitiveCarrierForSize(long size, boolean useFloat) {
215         return primitiveLayoutForSize(size, useFloat).carrier();
216     }
217 
218     public static ValueLayout primitiveLayoutForSize(long size, boolean useFloat) {
219         if (useFloat) {
220             if (size == 4) {
221                 return JAVA_FLOAT;
222             } else if (size == 8) {
223                 return JAVA_DOUBLE;
224             }
225         } else {
226             if (size == 1) {
227                 return JAVA_BYTE;
228             } else if (size == 2) {
229                 return JAVA_SHORT;
230             } else if (size <= 4) {
231                 return JAVA_INT;
232             } else if (size <= 8) {
233                 return JAVA_LONG;
234             }
235         }
236 
237         throw new IllegalArgumentException("No layout for size: " + size + " isFloat=" + useFloat);
238     }
239 
240     public static Linker getSystemLinker() {
241         return switch (CABI.current()) {
242             case WIN_64 -> Windowsx64Linker.getInstance();
243             case SYS_V -> SysVx64Linker.getInstance();
244             case LINUX_AARCH_64 -> LinuxAArch64Linker.getInstance();
245             case MAC_OS_AARCH_64 -> MacOsAArch64Linker.getInstance();
246             case WIN_AARCH_64 -> WindowsAArch64Linker.getInstance();
247             case LINUX_PPC_64 -> LinuxPPC64Linker.getInstance();
248             case LINUX_PPC_64_LE -> LinuxPPC64leLinker.getInstance();
249             case LINUX_RISCV_64 -> LinuxRISCV64Linker.getInstance();
250             case LINUX_S390 -> LinuxS390Linker.getInstance();
251             case FALLBACK -> FallbackLinker.getInstance();
252             case UNSUPPORTED -> throw new UnsupportedOperationException("Platform does not support native linker");
253         };
254     }
255 
256     static Map<VMStorage, Integer> indexMap(Binding.Move[] moves) {
257         return IntStream.range(0, moves.length)
258                         .boxed()
259                         .collect(Collectors.toMap(i -> moves[i].storage(), i -> i));
260     }
261 
262     static MethodHandle mergeArguments(MethodHandle mh, int sourceIndex, int destIndex) {
263         MethodType oldType = mh.type();
264         Class<?> sourceType = oldType.parameterType(sourceIndex);
265         Class<?> destType = oldType.parameterType(destIndex);
266         if (sourceType != destType) {
267             // TODO meet?
268             throw new IllegalArgumentException("Parameter types differ: " + sourceType + " != " + destType);
269         }
270         MethodType newType = oldType.dropParameterTypes(destIndex, destIndex + 1);
271         int[] reorder = new int[oldType.parameterCount()];
272         if (destIndex < sourceIndex) {
273             sourceIndex--;
274         }
275         for (int i = 0, index = 0; i < reorder.length; i++) {
276             if (i != destIndex) {
277                 reorder[i] = index++;
278             } else {
279                 reorder[i] = sourceIndex;
280             }
281         }
282         return permuteArguments(mh, newType, reorder);
283     }
284 
285 
286     public static MethodHandle swapArguments(MethodHandle mh, int firstArg, int secondArg) {
287         MethodType mtype = mh.type();
288         int[] perms = new int[mtype.parameterCount()];
289         MethodType swappedType = MethodType.methodType(mtype.returnType());
290         for (int i = 0 ; i < perms.length ; i++) {
291             int dst = i;
292             if (i == firstArg) dst = secondArg;
293             if (i == secondArg) dst = firstArg;
294             perms[i] = dst;
295             swappedType = swappedType.appendParameterTypes(mtype.parameterType(dst));
296         }
297         return permuteArguments(mh, swappedType, perms);
298     }
299 
300     private static MethodHandle reachabilityFenceHandle(Class<?> type) {
301         return MH_REACHABILITY_FENCE.asType(MethodType.methodType(void.class, type));
302     }
303 
304     public static void handleUncaughtException(Throwable t) {
305         if (t != null) {
306             try {
307                 t.printStackTrace();
308                 System.err.println("Unrecoverable uncaught exception encountered. The VM will now exit");
309             } finally {
310                 JLA.exit(1);
311             }
312         }
313     }
314 
315     public static long unboxSegment(MemorySegment segment) {
316         if (!segment.isNative()) {
317             throw new IllegalArgumentException("Heap segment not allowed: " + segment);
318         }
319         return segment.address();
320     }
321 
322     public static void checkExceptions(MethodHandle target) {
323         Class<?>[] exceptions = JLIA.exceptionTypes(target);
324         if (exceptions != null && exceptions.length != 0) {
325             throw new IllegalArgumentException("Target handle may throw exceptions: " + Arrays.toString(exceptions));
326         }
327     }
328 
329     public static MethodHandle maybeInsertAllocator(FunctionDescriptor descriptor, MethodHandle handle) {
330         if (descriptor.returnLayout().isEmpty() || !(descriptor.returnLayout().get() instanceof GroupLayout)) {
331             // not returning segment, just insert a throwing allocator
332             handle = insertArguments(handle, 1, THROWING_ALLOCATOR);
333         }
334         return handle;
335     }
336 
337     public static MethodHandle maybeCheckCaptureSegment(MethodHandle handle, LinkerOptions options) {
338         if (options.hasCapturedCallState()) {
339             // (<target address>, SegmentAllocator, <capture segment>, ...) -> ...
340             handle = MethodHandles.filterArguments(handle, 2, MH_CHECK_CAPTURE_SEGMENT);
341         }
342         return handle;
343     }
344 
345     @ForceInline
346     public static MemorySegment checkCaptureSegment(MemorySegment captureSegment) {
347         Objects.requireNonNull(captureSegment);
348         if (captureSegment.equals(MemorySegment.NULL)) {
349             throw new IllegalArgumentException("Capture segment is NULL: " + captureSegment);
350         }
351         return captureSegment.asSlice(0, CapturableState.LAYOUT);
352     }
353 
354     @ForceInline
355     public static void checkSymbol(MemorySegment symbol) {
356         Objects.requireNonNull(symbol);
357         if (symbol.equals(MemorySegment.NULL))
358             throw new IllegalArgumentException("Symbol is NULL: " + symbol);
359     }
360 
361     static void checkType(Class<?> actualType, Class<?> expectedType) {
362         if (expectedType != actualType) {
363             throw new IllegalArgumentException(
364                     String.format("Invalid operand type: %s. %s expected", actualType, expectedType));
365         }
366     }
367 
368     public static boolean isPowerOfTwo(int width) {
369         return Integer.bitCount(width) == 1;
370     }
371 
372     static long pickChunkOffset(long chunkOffset, long byteWidth, int chunkWidth) {
373         return ByteOrder.nativeOrder() == ByteOrder.BIG_ENDIAN
374                 ? byteWidth - chunkWidth - chunkOffset
375                 : chunkOffset;
376     }
377 
378     public static Arena newBoundedArena(long size) {
379         return new Arena() {
380             final Arena arena = Arena.ofConfined();
381             final SegmentAllocator slicingAllocator = SegmentAllocator.slicingAllocator(arena.allocate(size));
382 
383             @Override
384             public Scope scope() {
385                 return arena.scope();
386             }
387 
388             @Override
389             public void close() {
390                 arena.close();
391             }
392 
393             @Override
394             public MemorySegment allocate(long byteSize, long byteAlignment) {
395                 return slicingAllocator.allocate(byteSize, byteAlignment);
396             }
397         };
398     }
399 
400     public static Arena newEmptyArena() {
401         return new Arena() {
402             final Arena arena = Arena.ofConfined();
403 
404             @Override
405             public Scope scope() {
406                 return arena.scope();
407             }
408 
409             @Override
410             public void close() {
411                 arena.close();
412             }
413 
414             @Override
415             public MemorySegment allocate(long byteSize, long byteAlignment) {
416                 throw new UnsupportedOperationException();
417             }
418         };
419     }
420 
421     static void writeOverSized(MemorySegment ptr, Class<?> type, Object o) {
422         // use VH_LONG for integers to zero out the whole register in the process
423         if (type == long.class) {
424             ptr.set(JAVA_LONG_UNALIGNED, 0, (long) o);
425         } else if (type == int.class) {
426             ptr.set(JAVA_LONG_UNALIGNED, 0, (int) o);
427         } else if (type == short.class) {
428             ptr.set(JAVA_LONG_UNALIGNED, 0, (short) o);
429         } else if (type == char.class) {
430             ptr.set(JAVA_LONG_UNALIGNED, 0, (char) o);
431         } else if (type == byte.class) {
432             ptr.set(JAVA_LONG_UNALIGNED, 0, (byte) o);
433         } else if (type == float.class) {
434             ptr.set(JAVA_FLOAT_UNALIGNED, 0, (float) o);
435         } else if (type == double.class) {
436             ptr.set(JAVA_DOUBLE_UNALIGNED, 0, (double) o);
437         } else if (type == boolean.class) {
438             boolean b = (boolean)o;
439             ptr.set(JAVA_LONG_UNALIGNED, 0, b ? (long)1 : (long)0);
440         } else {
441             throw new IllegalArgumentException("Unsupported carrier: " + type);
442         }
443     }
444 
445     static void write(MemorySegment ptr, long offset, Class<?> type, Object o) {
446         if (type == long.class) {
447             ptr.set(JAVA_LONG_UNALIGNED, offset, (long) o);
448         } else if (type == int.class) {
449             ptr.set(JAVA_INT_UNALIGNED, offset, (int) o);
450         } else if (type == short.class) {
451             ptr.set(JAVA_SHORT_UNALIGNED, offset, (short) o);
452         } else if (type == char.class) {
453             ptr.set(JAVA_CHAR_UNALIGNED, offset, (char) o);
454         } else if (type == byte.class) {
455             ptr.set(JAVA_BYTE, offset, (byte) o);
456         } else if (type == float.class) {
457             ptr.set(JAVA_FLOAT_UNALIGNED, offset, (float) o);
458         } else if (type == double.class) {
459             ptr.set(JAVA_DOUBLE_UNALIGNED, offset, (double) o);
460         } else if (type == boolean.class) {
461             ptr.set(JAVA_BOOLEAN, offset, (boolean) o);
462         } else {
463             throw new IllegalArgumentException("Unsupported carrier: " + type);
464         }
465     }
466 
467     static Object read(MemorySegment ptr, long offset, Class<?> type) {
468         if (type == long.class) {
469             return ptr.get(JAVA_LONG_UNALIGNED, offset);
470         } else if (type == int.class) {
471             return ptr.get(JAVA_INT_UNALIGNED, offset);
472         } else if (type == short.class) {
473             return ptr.get(JAVA_SHORT_UNALIGNED, offset);
474         } else if (type == char.class) {
475             return ptr.get(JAVA_CHAR_UNALIGNED, offset);
476         } else if (type == byte.class) {
477             return ptr.get(JAVA_BYTE, offset);
478         } else if (type == float.class) {
479             return ptr.get(JAVA_FLOAT_UNALIGNED, offset);
480         } else if (type == double.class) {
481             return ptr.get(JAVA_DOUBLE_UNALIGNED, offset);
482         } else if (type == boolean.class) {
483             return ptr.get(JAVA_BOOLEAN, offset);
484         } else {
485             throw new IllegalArgumentException("Unsupported carrier: " + type);
486         }
487     }
488 
489     public static Map<String, MemoryLayout> canonicalLayouts(ValueLayout longLayout, ValueLayout sizetLayout, ValueLayout wchartLayout) {
490         return Map.ofEntries(
491                 // specified canonical layouts
492                 Map.entry("bool", ValueLayout.JAVA_BOOLEAN),
493                 Map.entry("char", ValueLayout.JAVA_BYTE),
494                 Map.entry("short", ValueLayout.JAVA_SHORT),
495                 Map.entry("int", ValueLayout.JAVA_INT),
496                 Map.entry("float", ValueLayout.JAVA_FLOAT),
497                 Map.entry("long", longLayout),
498                 Map.entry("long long", ValueLayout.JAVA_LONG),
499                 Map.entry("double", ValueLayout.JAVA_DOUBLE),
500                 Map.entry("void*", ValueLayout.ADDRESS),
501                 Map.entry("size_t", sizetLayout),
502                 Map.entry("wchar_t", wchartLayout),
503                 // unspecified size-dependent layouts
504                 Map.entry("int8_t", ValueLayout.JAVA_BYTE),
505                 Map.entry("int16_t", ValueLayout.JAVA_SHORT),
506                 Map.entry("int32_t", ValueLayout.JAVA_INT),
507                 Map.entry("int64_t", ValueLayout.JAVA_LONG),
508                 // unspecified JNI layouts
509                 Map.entry("jboolean", ValueLayout.JAVA_BOOLEAN),
510                 Map.entry("jchar", ValueLayout.JAVA_CHAR),
511                 Map.entry("jbyte", ValueLayout.JAVA_BYTE),
512                 Map.entry("jshort", ValueLayout.JAVA_SHORT),
513                 Map.entry("jint", ValueLayout.JAVA_INT),
514                 Map.entry("jlong", ValueLayout.JAVA_LONG),
515                 Map.entry("jfloat", ValueLayout.JAVA_FLOAT),
516                 Map.entry("jdouble", ValueLayout.JAVA_DOUBLE)
517         );
518     }
519 }