1 /*
  2  *  Copyright (c) 2020, 2021, Oracle and/or its affiliates. All rights reserved.
  3  *  DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  *  This code is free software; you can redistribute it and/or modify it
  6  *  under the terms of the GNU General Public License version 2 only, as
  7  *  published by the Free Software Foundation.  Oracle designates this
  8  *  particular file as subject to the "Classpath" exception as provided
  9  *  by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  *  This code is distributed in the hope that it will be useful, but WITHOUT
 12  *  ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  *  FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  *  version 2 for more details (a copy is included in the LICENSE file that
 15  *  accompanied this code).
 16  *
 17  *  You should have received a copy of the GNU General Public License version
 18  *  2 along with this work; if not, write to the Free Software Foundation,
 19  *  Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  *   Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  *  or visit www.oracle.com if you need additional information or have any
 23  *  questions.
 24  *
 25  */
 26 package jdk.internal.foreign.abi.x64.sysv;
 27 
 28 import jdk.incubator.foreign.FunctionDescriptor;
 29 import jdk.incubator.foreign.GroupLayout;
 30 import jdk.incubator.foreign.MemoryAddress;
 31 import jdk.incubator.foreign.MemoryLayout;
 32 import jdk.incubator.foreign.MemorySegment;
 33 import jdk.incubator.foreign.NativeSymbol;
 34 import jdk.incubator.foreign.ResourceScope;
 35 import jdk.internal.foreign.abi.CallingSequenceBuilder;
 36 import jdk.internal.foreign.abi.ABIDescriptor;
 37 import jdk.internal.foreign.abi.Binding;
 38 import jdk.internal.foreign.abi.CallingSequence;
 39 import jdk.internal.foreign.abi.ProgrammableInvoker;
 40 import jdk.internal.foreign.abi.ProgrammableUpcallHandler;
 41 import jdk.internal.foreign.abi.VMStorage;
 42 import jdk.internal.foreign.abi.x64.X86_64Architecture;
 43 import jdk.internal.foreign.abi.SharedUtils;
 44 
 45 import java.lang.invoke.MethodHandle;
 46 import java.lang.invoke.MethodHandles;
 47 import java.lang.invoke.MethodType;
 48 import java.util.List;
 49 import java.util.Optional;
 50 
 51 import static jdk.internal.foreign.PlatformLayouts.*;
 52 import static jdk.internal.foreign.abi.Binding.*;
 53 import static jdk.internal.foreign.abi.x64.X86_64Architecture.*;
 54 import static jdk.internal.foreign.abi.x64.sysv.SysVx64Linker.MAX_INTEGER_ARGUMENT_REGISTERS;
 55 import static jdk.internal.foreign.abi.x64.sysv.SysVx64Linker.MAX_VECTOR_ARGUMENT_REGISTERS;
 56 
 57 /**
 58  * For the SysV x64 C ABI specifically, this class uses the ProgrammableInvoker API, namely CallingSequenceBuilder2
 59  * to translate a C FunctionDescriptor into a CallingSequence, which can then be turned into a MethodHandle.
 60  *
 61  * This includes taking care of synthetic arguments like pointers to return buffers for 'in-memory' returns.
 62  */
 63 public class CallArranger {
 64     private static final ABIDescriptor CSysV = X86_64Architecture.abiFor(
 65         new VMStorage[] { rdi, rsi, rdx, rcx, r8, r9, rax },
 66         new VMStorage[] { xmm0, xmm1, xmm2, xmm3, xmm4, xmm5, xmm6, xmm7 },
 67         new VMStorage[] { rax, rdx },
 68         new VMStorage[] { xmm0, xmm1 },
 69         2,
 70         new VMStorage[] { r10, r11 },
 71         new VMStorage[] { xmm8, xmm9, xmm10, xmm11, xmm12, xmm13, xmm14, xmm15 },
 72         16,
 73         0 //no shadow space
 74     );
 75 
 76     // record
 77     public static class Bindings {
 78         public final CallingSequence callingSequence;
 79         public final boolean isInMemoryReturn;
 80         public final int nVectorArgs;
 81 
 82         Bindings(CallingSequence callingSequence, boolean isInMemoryReturn, int nVectorArgs) {
 83             this.callingSequence = callingSequence;
 84             this.isInMemoryReturn = isInMemoryReturn;
 85             this.nVectorArgs = nVectorArgs;
 86         }
 87     }
 88 
 89     public static Bindings getBindings(MethodType mt, FunctionDescriptor cDesc, boolean forUpcall) {
 90         CallingSequenceBuilder csb = new CallingSequenceBuilder(forUpcall);
 91 
 92         BindingCalculator argCalc = forUpcall ? new BoxBindingCalculator(true) : new UnboxBindingCalculator(true);
 93         BindingCalculator retCalc = forUpcall ? new UnboxBindingCalculator(false) : new BoxBindingCalculator(false);
 94 
 95         boolean returnInMemory = isInMemoryReturn(cDesc.returnLayout());
 96         if (returnInMemory) {
 97             Class<?> carrier = MemoryAddress.class;
 98             MemoryLayout layout = SysV.C_POINTER;
 99             csb.addArgumentBindings(carrier, layout, argCalc.getBindings(carrier, layout));
100         } else if (cDesc.returnLayout().isPresent()) {
101             Class<?> carrier = mt.returnType();
102             MemoryLayout layout = cDesc.returnLayout().get();
103             csb.setReturnBindings(carrier, layout, retCalc.getBindings(carrier, layout));
104         }
105 
106         for (int i = 0; i < mt.parameterCount(); i++) {
107             Class<?> carrier = mt.parameterType(i);
108             MemoryLayout layout = cDesc.argumentLayouts().get(i);
109             csb.addArgumentBindings(carrier, layout, argCalc.getBindings(carrier, layout));
110         }
111 
112         if (!forUpcall) {
113             //add extra binding for number of used vector registers (used for variadic calls)
114             csb.addArgumentBindings(long.class, SysV.C_LONG,
115                     List.of(vmStore(rax, long.class)));
116         }
117 
118         csb.setTrivial(SharedUtils.isTrivial(cDesc));
119 
120         return new Bindings(csb.build(), returnInMemory, argCalc.storageCalculator.nVectorReg);
121     }
122 
123     public static MethodHandle arrangeDowncall(MethodType mt, FunctionDescriptor cDesc) {
124         Bindings bindings = getBindings(mt, cDesc, false);
125 
126         MethodHandle handle = new ProgrammableInvoker(CSysV, bindings.callingSequence).getBoundMethodHandle();
127         handle = MethodHandles.insertArguments(handle, handle.type().parameterCount() - 1, bindings.nVectorArgs);
128 
129         if (bindings.isInMemoryReturn) {
130             handle = SharedUtils.adaptDowncallForIMR(handle, cDesc);
131         }
132 
133         return handle;
134     }
135 
136     public static NativeSymbol arrangeUpcall(MethodHandle target, MethodType mt, FunctionDescriptor cDesc, ResourceScope scope) {
137         Bindings bindings = getBindings(mt, cDesc, true);
138 
139         if (bindings.isInMemoryReturn) {
140             target = SharedUtils.adaptUpcallForIMR(target, true /* drop return, since we don't have bindings for it */);
141         }
142 
143         return ProgrammableUpcallHandler.make(CSysV, target, bindings.callingSequence, scope);
144     }
145 
146     private static boolean isInMemoryReturn(Optional<MemoryLayout> returnLayout) {
147         return returnLayout
148                 .filter(GroupLayout.class::isInstance)
149                 .filter(g -> TypeClass.classifyLayout(g).inMemory())
150                 .isPresent();
151     }
152 
153     static class StorageCalculator {
154         private final boolean forArguments;
155 
156         private int nVectorReg = 0;
157         private int nIntegerReg = 0;
158         private long stackOffset = 0;
159 
160         public StorageCalculator(boolean forArguments) {
161             this.forArguments = forArguments;
162         }
163 
164         private int maxRegisterArguments(int type) {
165             return type == StorageClasses.INTEGER ?
166                     MAX_INTEGER_ARGUMENT_REGISTERS :
167                     SysVx64Linker.MAX_VECTOR_ARGUMENT_REGISTERS;
168         }
169 
170         VMStorage stackAlloc() {
171             assert forArguments : "no stack returns";
172             VMStorage storage = X86_64Architecture.stackStorage((int)stackOffset);
173             stackOffset++;
174             return storage;
175         }
176 
177         VMStorage nextStorage(int type) {
178             int registerCount = registerCount(type);
179             if (registerCount < maxRegisterArguments(type)) {
180                 VMStorage[] source =
181                     (forArguments ? CSysV.inputStorage : CSysV.outputStorage)[type];
182                 incrementRegisterCount(type);
183                 return source[registerCount];
184             } else {
185                 return stackAlloc();
186             }
187         }
188 
189         VMStorage[] structStorages(TypeClass typeClass) {
190             if (typeClass.inMemory()) {
191                 return typeClass.classes.stream().map(c -> stackAlloc()).toArray(VMStorage[]::new);
192             }
193             long nIntegerReg = typeClass.nIntegerRegs();
194 
195             if (this.nIntegerReg + nIntegerReg > MAX_INTEGER_ARGUMENT_REGISTERS) {
196                 //not enough registers - pass on stack
197                 return typeClass.classes.stream().map(c -> stackAlloc()).toArray(VMStorage[]::new);
198             }
199 
200             long nVectorReg = typeClass.nVectorRegs();
201 
202             if (this.nVectorReg + nVectorReg > MAX_VECTOR_ARGUMENT_REGISTERS) {
203                 //not enough registers - pass on stack
204                 return typeClass.classes.stream().map(c -> stackAlloc()).toArray(VMStorage[]::new);
205             }
206 
207             //ok, let's pass pass on registers
208             VMStorage[] storage = new VMStorage[(int)(nIntegerReg + nVectorReg)];
209             for (int i = 0 ; i < typeClass.classes.size() ; i++) {
210                 boolean sse = typeClass.classes.get(i) == ArgumentClassImpl.SSE;
211                 storage[i] = nextStorage(sse ? StorageClasses.VECTOR : StorageClasses.INTEGER);
212             }
213             return storage;
214         }
215 
216         int registerCount(int type) {
217             switch (type) {
218                 case StorageClasses.INTEGER:
219                     return nIntegerReg;
220                 case StorageClasses.VECTOR:
221                     return nVectorReg;
222                 default:
223                     throw new IllegalStateException();
224             }
225         }
226 
227         void incrementRegisterCount(int type) {
228             switch (type) {
229                 case StorageClasses.INTEGER:
230                     nIntegerReg++;
231                     break;
232                 case StorageClasses.VECTOR:
233                     nVectorReg++;
234                     break;
235                 default:
236                     throw new IllegalStateException();
237             }
238         }
239     }
240 
241     static abstract class BindingCalculator {
242         protected final StorageCalculator storageCalculator;
243 
244         protected BindingCalculator(boolean forArguments) {
245             this.storageCalculator = new StorageCalculator(forArguments);
246         }
247 
248         abstract List<Binding> getBindings(Class<?> carrier, MemoryLayout layout);
249     }
250 
251     static class UnboxBindingCalculator extends BindingCalculator {
252 
253         UnboxBindingCalculator(boolean forArguments) {
254             super(forArguments);
255         }
256 
257         @Override
258         List<Binding> getBindings(Class<?> carrier, MemoryLayout layout) {
259             TypeClass argumentClass = TypeClass.classifyLayout(layout);
260             Binding.Builder bindings = Binding.builder();
261             switch (argumentClass.kind()) {
262                 case STRUCT: {
263                     assert carrier == MemorySegment.class;
264                     VMStorage[] regs = storageCalculator.structStorages(argumentClass);
265                     int regIndex = 0;
266                     long offset = 0;
267                     while (offset < layout.byteSize()) {
268                         final long copy = Math.min(layout.byteSize() - offset, 8);
269                         VMStorage storage = regs[regIndex++];
270                         if (offset + copy < layout.byteSize()) {
271                             bindings.dup();
272                         }
273                         boolean useFloat = storage.type() == StorageClasses.VECTOR;
274                         Class<?> type = SharedUtils.primitiveCarrierForSize(copy, useFloat);
275                         bindings.bufferLoad(offset, type)
276                                 .vmStore(storage, type);
277                         offset += copy;
278                     }
279                     break;
280                 }
281                 case POINTER: {
282                     bindings.unboxAddress(carrier);
283                     VMStorage storage = storageCalculator.nextStorage(StorageClasses.INTEGER);
284                     bindings.vmStore(storage, long.class);
285                     break;
286                 }
287                 case INTEGER: {
288                     VMStorage storage = storageCalculator.nextStorage(StorageClasses.INTEGER);
289                     bindings.vmStore(storage, carrier);
290                     break;
291                 }
292                 case FLOAT: {
293                     VMStorage storage = storageCalculator.nextStorage(StorageClasses.VECTOR);
294                     bindings.vmStore(storage, carrier);
295                     break;
296                 }
297                 default:
298                     throw new UnsupportedOperationException("Unhandled class " + argumentClass);
299             }
300             return bindings.build();
301         }
302     }
303 
304     static class BoxBindingCalculator extends BindingCalculator {
305 
306         BoxBindingCalculator(boolean forArguments) {
307             super(forArguments);
308         }
309 
310         @Override
311         List<Binding> getBindings(Class<?> carrier, MemoryLayout layout) {
312             TypeClass argumentClass = TypeClass.classifyLayout(layout);
313             Binding.Builder bindings = Binding.builder();
314             switch (argumentClass.kind()) {
315                 case STRUCT: {
316                     assert carrier == MemorySegment.class;
317                     bindings.allocate(layout);
318                     VMStorage[] regs = storageCalculator.structStorages(argumentClass);
319                     int regIndex = 0;
320                     long offset = 0;
321                     while (offset < layout.byteSize()) {
322                         final long copy = Math.min(layout.byteSize() - offset, 8);
323                         VMStorage storage = regs[regIndex++];
324                         bindings.dup();
325                         boolean useFloat = storage.type() == StorageClasses.VECTOR;
326                         Class<?> type = SharedUtils.primitiveCarrierForSize(copy, useFloat);
327                         bindings.vmLoad(storage, type)
328                                 .bufferStore(offset, type);
329                         offset += copy;
330                     }
331                     break;
332                 }
333                 case POINTER: {
334                     VMStorage storage = storageCalculator.nextStorage(StorageClasses.INTEGER);
335                     bindings.vmLoad(storage, long.class)
336                             .boxAddress();
337                     break;
338                 }
339                 case INTEGER: {
340                     VMStorage storage = storageCalculator.nextStorage(StorageClasses.INTEGER);
341                     bindings.vmLoad(storage, carrier);
342                     break;
343                 }
344                 case FLOAT: {
345                     VMStorage storage = storageCalculator.nextStorage(StorageClasses.VECTOR);
346                     bindings.vmLoad(storage, carrier);
347                     break;
348                 }
349                 default:
350                     throw new UnsupportedOperationException("Unhandled class " + argumentClass);
351             }
352             return bindings.build();
353         }
354     }
355 
356 }