1 /*
  2  * Copyright (c) 2022, 2023, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 package jdk.internal.foreign.abi;
 26 
 27 import jdk.internal.foreign.SystemLookup;
 28 import jdk.internal.foreign.Utils;
 29 import jdk.internal.foreign.abi.aarch64.linux.LinuxAArch64Linker;
 30 import jdk.internal.foreign.abi.aarch64.macos.MacOsAArch64Linker;
 31 import jdk.internal.foreign.abi.aarch64.windows.WindowsAArch64Linker;
 32 import jdk.internal.foreign.abi.fallback.FallbackLinker;
 33 import jdk.internal.foreign.abi.ppc64.linux.LinuxPPC64Linker;
 34 import jdk.internal.foreign.abi.ppc64.linux.LinuxPPC64leLinker;
 35 import jdk.internal.foreign.abi.riscv64.linux.LinuxRISCV64Linker;
 36 import jdk.internal.foreign.abi.s390.linux.LinuxS390Linker;
 37 import jdk.internal.foreign.abi.x64.sysv.SysVx64Linker;
 38 import jdk.internal.foreign.abi.x64.windows.Windowsx64Linker;
 39 import jdk.internal.foreign.layout.AbstractLayout;
 40 import jdk.internal.reflect.CallerSensitive;
 41 import jdk.internal.reflect.Reflection;
 42 
 43 import java.lang.foreign.AddressLayout;
 44 import java.lang.foreign.GroupLayout;
 45 import java.lang.foreign.MemoryLayout;
 46 import java.lang.foreign.Arena;
 47 import java.lang.foreign.FunctionDescriptor;
 48 import java.lang.foreign.Linker;
 49 import java.lang.foreign.MemorySegment;
 50 import java.lang.foreign.PaddingLayout;
 51 import java.lang.foreign.SequenceLayout;
 52 import java.lang.foreign.StructLayout;
 53 import java.lang.foreign.UnionLayout;
 54 import java.lang.foreign.ValueLayout;
 55 import java.lang.invoke.MethodHandle;
 56 import java.lang.invoke.MethodType;
 57 import java.util.HashSet;
 58 import java.util.List;
 59 import java.nio.ByteOrder;
 60 import java.util.Objects;
 61 import java.util.Set;
 62 
 63 public abstract sealed class AbstractLinker implements Linker permits LinuxAArch64Linker, MacOsAArch64Linker,
 64                                                                       SysVx64Linker, WindowsAArch64Linker,
 65                                                                       Windowsx64Linker,
 66                                                                       LinuxPPC64Linker, LinuxPPC64leLinker,
 67                                                                       LinuxRISCV64Linker, LinuxS390Linker,
 68                                                                       FallbackLinker {
 69 
 70     public interface UpcallStubFactory {
 71         MemorySegment makeStub(MethodHandle target, Arena arena);
 72     }
 73 
 74     private record LinkRequest(FunctionDescriptor descriptor, LinkerOptions options) {}
 75     private final SoftReferenceCache<LinkRequest, MethodHandle> DOWNCALL_CACHE = new SoftReferenceCache<>();
 76     private final SoftReferenceCache<LinkRequest, UpcallStubFactory> UPCALL_CACHE = new SoftReferenceCache<>();
 77     private final Set<MemoryLayout> CANONICAL_LAYOUTS_CACHE = new HashSet<>(canonicalLayouts().values());
 78 
 79     @Override
 80     @CallerSensitive
 81     public final MethodHandle downcallHandle(MemorySegment symbol, FunctionDescriptor function, Option... options) {
 82         Reflection.ensureNativeAccess(Reflection.getCallerClass(), Linker.class, "downcallHandle");
 83         SharedUtils.checkSymbol(symbol);
 84         return downcallHandle0(function, options).bindTo(symbol);
 85     }
 86 
 87     @Override
 88     @CallerSensitive
 89     public final MethodHandle downcallHandle(FunctionDescriptor function, Option... options) {
 90         Reflection.ensureNativeAccess(Reflection.getCallerClass(), Linker.class, "downcallHandle");
 91         return downcallHandle0(function, options);
 92     }
 93 
 94     private MethodHandle downcallHandle0(FunctionDescriptor function, Option... options) {
 95         Objects.requireNonNull(function);
 96         Objects.requireNonNull(options);
 97         checkLayouts(function);
 98         function = stripNames(function);
 99         LinkerOptions optionSet = LinkerOptions.forDowncall(function, options);
100         validateVariadicLayouts(function, optionSet);
101 
102         return DOWNCALL_CACHE.get(new LinkRequest(function, optionSet), linkRequest ->  {
103             FunctionDescriptor fd = linkRequest.descriptor();
104             MethodType type = fd.toMethodType();
105             MethodHandle handle = arrangeDowncall(type, fd, linkRequest.options());
106             handle = SharedUtils.maybeCheckCaptureSegment(handle, linkRequest.options());
107             handle = SharedUtils.maybeInsertAllocator(fd, handle);
108             return handle;
109         });
110     }
111 
112     protected abstract MethodHandle arrangeDowncall(MethodType inferredMethodType, FunctionDescriptor function, LinkerOptions options);
113 
114     @Override
115     @CallerSensitive
116     public final MemorySegment upcallStub(MethodHandle target, FunctionDescriptor function, Arena arena, Linker.Option... options) {
117         Reflection.ensureNativeAccess(Reflection.getCallerClass(), Linker.class, "upcallStub");
118         Objects.requireNonNull(arena);
119         Objects.requireNonNull(target);
120         Objects.requireNonNull(function);
121         checkLayouts(function);
122         SharedUtils.checkExceptions(target);
123         function = stripNames(function);
124         LinkerOptions optionSet = LinkerOptions.forUpcall(function, options);
125 
126         MethodType type = function.toMethodType();
127         if (!type.equals(target.type())) {
128             throw new IllegalArgumentException("Wrong method handle type: " + target.type());
129         }
130 
131         UpcallStubFactory factory = UPCALL_CACHE.get(new LinkRequest(function, optionSet), linkRequest ->
132             arrangeUpcall(type, linkRequest.descriptor(), linkRequest.options()));
133         return factory.makeStub(target, arena);
134     }
135 
136     protected abstract UpcallStubFactory arrangeUpcall(MethodType targetType, FunctionDescriptor function, LinkerOptions options);
137 
138     @Override
139     public SystemLookup defaultLookup() {
140         return SystemLookup.getInstance();
141     }
142 
143     /** {@return byte order used by this linker} */
144     protected abstract ByteOrder linkerByteOrder();
145 
146     // C spec mandates that variadic arguments smaller than int are promoted to int,
147     // and float is promoted to double
148     // See: https://en.cppreference.com/w/c/language/conversion#Default_argument_promotions
149     // We reject the corresponding layouts here, to avoid issues where unsigned values
150     // are sign extended when promoted. (as we don't have a way to unambiguously represent signed-ness atm).
151     private void validateVariadicLayouts(FunctionDescriptor function, LinkerOptions optionSet) {
152         if (optionSet.isVariadicFunction()) {
153             List<MemoryLayout> argumentLayouts = function.argumentLayouts();
154             List<MemoryLayout> variadicLayouts = argumentLayouts.subList(optionSet.firstVariadicArgIndex(), argumentLayouts.size());
155 
156             for (MemoryLayout variadicLayout : variadicLayouts) {
157                 if (variadicLayout.equals(ValueLayout.JAVA_BOOLEAN)
158                     || variadicLayout.equals(ValueLayout.JAVA_BYTE)
159                     || variadicLayout.equals(ValueLayout.JAVA_CHAR)
160                     || variadicLayout.equals(ValueLayout.JAVA_SHORT)
161                     || variadicLayout.equals(ValueLayout.JAVA_FLOAT)) {
162                     throw new IllegalArgumentException("Invalid variadic argument layout: " + variadicLayout);
163                 }
164             }
165         }
166     }
167 
168     private void checkLayouts(FunctionDescriptor descriptor) {
169         descriptor.returnLayout().ifPresent(this::checkLayout);
170         descriptor.argumentLayouts().forEach(this::checkLayout);
171     }
172 
173     private void checkLayout(MemoryLayout layout) {
174         // Note: we should not worry about padding layouts, as they cannot be present in a function descriptor
175         if (layout instanceof SequenceLayout) {
176             throw new IllegalArgumentException("Unsupported layout: " + layout);
177         } else {
178             checkLayoutRecursive(layout);
179         }
180     }
181 
182     private void checkLayoutRecursive(MemoryLayout layout) {
183         if (layout instanceof ValueLayout vl) {
184             checkSupported(vl);
185         } else if (layout instanceof StructLayout sl) {
186             checkHasNaturalAlignment(layout);
187             long offset = 0;
188             long lastUnpaddedOffset = 0;
189             for (MemoryLayout member : sl.memberLayouts()) {
190                 // check element offset before recursing so that an error points at the
191                 // outermost layout first
192                 checkMemberOffset(sl, member, lastUnpaddedOffset, offset);
193                 checkLayoutRecursive(member);
194 
195                 offset += member.byteSize();
196                 if (!(member instanceof PaddingLayout)) {
197                     lastUnpaddedOffset = offset;
198                 }
199             }
200             checkGroupSize(sl, lastUnpaddedOffset);
201         } else if (layout instanceof UnionLayout ul) {
202             checkHasNaturalAlignment(layout);
203             long maxUnpaddedLayout = 0;
204             for (MemoryLayout member : ul.memberLayouts()) {
205                 checkLayoutRecursive(member);
206                 if (!(member instanceof PaddingLayout)) {
207                     maxUnpaddedLayout = Long.max(maxUnpaddedLayout, member.byteSize());
208                 }
209             }
210             checkGroupSize(ul, maxUnpaddedLayout);
211         } else if (layout instanceof SequenceLayout sl) {
212             checkHasNaturalAlignment(layout);
213             checkLayoutRecursive(sl.elementLayout());
214         }
215     }
216 
217     // check for trailing padding
218     private void checkGroupSize(GroupLayout gl, long maxUnpaddedOffset) {
219         long expectedSize = Utils.alignUp(maxUnpaddedOffset, gl.byteAlignment());
220         if (gl.byteSize() != expectedSize) {
221             throw new IllegalArgumentException("Layout '" + gl + "' has unexpected size: "
222                     + gl.byteSize() + " != " + expectedSize);
223         }
224     }
225 
226     // checks both that there is no excess padding between 'memberLayout' and
227     // the previous layout
228     private void checkMemberOffset(StructLayout parent, MemoryLayout memberLayout,
229                                           long lastUnpaddedOffset, long offset) {
230         long expectedOffset = Utils.alignUp(lastUnpaddedOffset, memberLayout.byteAlignment());
231         if (expectedOffset != offset) {
232             throw new IllegalArgumentException("Member layout '" + memberLayout + "', of '" + parent + "'" +
233                     " found at unexpected offset: " + offset + " != " + expectedOffset);
234         }
235     }
236 
237     private void checkSupported(ValueLayout valueLayout) {
238         valueLayout = valueLayout.withoutName();
239         if (valueLayout instanceof AddressLayout addressLayout) {
240             valueLayout = addressLayout.withoutTargetLayout();
241         }
242         if (!CANONICAL_LAYOUTS_CACHE.contains(valueLayout.withoutName())) {
243             throw new IllegalArgumentException("Unsupported layout: " + valueLayout);
244         }
245     }
246 
247     private void checkHasNaturalAlignment(MemoryLayout layout) {
248         if (!((AbstractLayout<?>) layout).hasNaturalAlignment()) {
249             throw new IllegalArgumentException("Layout alignment must be natural alignment: " + layout);
250         }
251     }
252 
253     private static MemoryLayout stripNames(MemoryLayout ml) {
254         // we don't care about transferring alignment and byte order here
255         // since the linker already restricts those such that they will always be the same
256         return switch (ml) {
257             case StructLayout sl -> MemoryLayout.structLayout(stripNames(sl.memberLayouts()));
258             case UnionLayout ul -> MemoryLayout.unionLayout(stripNames(ul.memberLayouts()));
259             case SequenceLayout sl -> MemoryLayout.sequenceLayout(sl.elementCount(), stripNames(sl.elementLayout()));
260             case AddressLayout al -> al.targetLayout()
261                     .map(tl -> al.withoutName().withTargetLayout(stripNames(tl)))
262                     .orElseGet(al::withoutName);
263             default -> ml.withoutName(); // ValueLayout and PaddingLayout
264         };
265     }
266 
267     private static MemoryLayout[] stripNames(List<MemoryLayout> layouts) {
268         return layouts.stream()
269                 .map(AbstractLinker::stripNames)
270                 .toArray(MemoryLayout[]::new);
271     }
272 
273     private static FunctionDescriptor stripNames(FunctionDescriptor function) {
274         return function.returnLayout()
275                 .map(rl -> FunctionDescriptor.of(stripNames(rl), stripNames(function.argumentLayouts())))
276                 .orElseGet(() -> FunctionDescriptor.ofVoid(stripNames(function.argumentLayouts())));
277     }
278 }