1 /*
  2  *  Copyright (c) 2020, 2021, Oracle and/or its affiliates. All rights reserved.
  3  *  DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  *  This code is free software; you can redistribute it and/or modify it
  6  *  under the terms of the GNU General Public License version 2 only, as
  7  *  published by the Free Software Foundation.  Oracle designates this
  8  *  particular file as subject to the "Classpath" exception as provided
  9  *  by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  *  This code is distributed in the hope that it will be useful, but WITHOUT
 12  *  ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  *  FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  *  version 2 for more details (a copy is included in the LICENSE file that
 15  *  accompanied this code).
 16  *
 17  *  You should have received a copy of the GNU General Public License version
 18  *  2 along with this work; if not, write to the Free Software Foundation,
 19  *  Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  *   Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  *  or visit www.oracle.com if you need additional information or have any
 23  *  questions.
 24  *
 25  */
 26 package jdk.internal.foreign.abi.x64.sysv;
 27 
 28 import jdk.incubator.foreign.FunctionDescriptor;
 29 import jdk.incubator.foreign.GroupLayout;
 30 import jdk.incubator.foreign.MemoryAddress;
 31 import jdk.incubator.foreign.MemoryLayout;
 32 import jdk.incubator.foreign.MemorySegment;
 33 import jdk.internal.foreign.abi.CallingSequenceBuilder;
 34 import jdk.internal.foreign.abi.UpcallHandler;
 35 import jdk.internal.foreign.abi.ABIDescriptor;
 36 import jdk.internal.foreign.abi.Binding;
 37 import jdk.internal.foreign.abi.CallingSequence;
 38 import jdk.internal.foreign.abi.ProgrammableInvoker;
 39 import jdk.internal.foreign.abi.ProgrammableUpcallHandler;
 40 import jdk.internal.foreign.abi.VMStorage;
 41 import jdk.internal.foreign.abi.x64.X86_64Architecture;
 42 import jdk.internal.foreign.abi.SharedUtils;
 43 
 44 import java.lang.invoke.MethodHandle;
 45 import java.lang.invoke.MethodHandles;
 46 import java.lang.invoke.MethodType;
 47 import java.util.List;
 48 import java.util.Optional;
 49 
 50 import static jdk.internal.foreign.PlatformLayouts.*;
 51 import static jdk.internal.foreign.abi.Binding.*;
 52 import static jdk.internal.foreign.abi.x64.X86_64Architecture.*;
 53 import static jdk.internal.foreign.abi.x64.sysv.SysVx64Linker.MAX_INTEGER_ARGUMENT_REGISTERS;
 54 import static jdk.internal.foreign.abi.x64.sysv.SysVx64Linker.MAX_VECTOR_ARGUMENT_REGISTERS;
 55 
 56 /**
 57  * For the SysV x64 C ABI specifically, this class uses the ProgrammableInvoker API, namely CallingSequenceBuilder2
 58  * to translate a C FunctionDescriptor into a CallingSequence, which can then be turned into a MethodHandle.
 59  *
 60  * This includes taking care of synthetic arguments like pointers to return buffers for 'in-memory' returns.
 61  */
 62 public class CallArranger {
 63     private static final ABIDescriptor CSysV = X86_64Architecture.abiFor(
 64         new VMStorage[] { rdi, rsi, rdx, rcx, r8, r9, rax },
 65         new VMStorage[] { xmm0, xmm1, xmm2, xmm3, xmm4, xmm5, xmm6, xmm7 },
 66         new VMStorage[] { rax, rdx },
 67         new VMStorage[] { xmm0, xmm1 },
 68         2,
 69         new VMStorage[] { r10, r11 },
 70         new VMStorage[] { xmm8, xmm9, xmm10, xmm11, xmm12, xmm13, xmm14, xmm15 },
 71         16,
 72         0 //no shadow space
 73     );
 74 
 75     // record
 76     public static class Bindings {
 77         public final CallingSequence callingSequence;
 78         public final boolean isInMemoryReturn;
 79         public final int nVectorArgs;
 80 
 81         Bindings(CallingSequence callingSequence, boolean isInMemoryReturn, int nVectorArgs) {
 82             this.callingSequence = callingSequence;
 83             this.isInMemoryReturn = isInMemoryReturn;
 84             this.nVectorArgs = nVectorArgs;
 85         }
 86     }
 87 
 88     public static Bindings getBindings(MethodType mt, FunctionDescriptor cDesc, boolean forUpcall) {
 89         SharedUtils.checkFunctionTypes(mt, cDesc, SysVx64Linker.ADDRESS_SIZE);
 90 
 91         CallingSequenceBuilder csb = new CallingSequenceBuilder(forUpcall);
 92 
 93         BindingCalculator argCalc = forUpcall ? new BoxBindingCalculator(true) : new UnboxBindingCalculator(true);
 94         BindingCalculator retCalc = forUpcall ? new UnboxBindingCalculator(false) : new BoxBindingCalculator(false);
 95 
 96         boolean returnInMemory = isInMemoryReturn(cDesc.returnLayout());
 97         if (returnInMemory) {
 98             Class<?> carrier = MemoryAddress.class;
 99             MemoryLayout layout = SysV.C_POINTER;
100             csb.addArgumentBindings(carrier, layout, argCalc.getBindings(carrier, layout));
101         } else if (cDesc.returnLayout().isPresent()) {
102             Class<?> carrier = mt.returnType();
103             MemoryLayout layout = cDesc.returnLayout().get();
104             csb.setReturnBindings(carrier, layout, retCalc.getBindings(carrier, layout));
105         }
106 
107         for (int i = 0; i < mt.parameterCount(); i++) {
108             Class<?> carrier = mt.parameterType(i);
109             MemoryLayout layout = cDesc.argumentLayouts().get(i);
110             csb.addArgumentBindings(carrier, layout, argCalc.getBindings(carrier, layout));
111         }
112 
113         if (!forUpcall) {
114             //add extra binding for number of used vector registers (used for variadic calls)
115             csb.addArgumentBindings(long.class, SysV.C_LONG,
116                     List.of(vmStore(rax, long.class)));
117         }
118 
119         csb.setTrivial(SharedUtils.isTrivial(cDesc));
120 
121         return new Bindings(csb.build(), returnInMemory, argCalc.storageCalculator.nVectorReg);
122     }
123 
124     public static MethodHandle arrangeDowncall(MethodType mt, FunctionDescriptor cDesc) {
125         Bindings bindings = getBindings(mt, cDesc, false);
126 
127         MethodHandle handle = new ProgrammableInvoker(CSysV, bindings.callingSequence).getBoundMethodHandle();
128         handle = MethodHandles.insertArguments(handle, handle.type().parameterCount() - 1, bindings.nVectorArgs);
129 
130         if (bindings.isInMemoryReturn) {
131             handle = SharedUtils.adaptDowncallForIMR(handle, cDesc);
132         }
133 
134         return handle;
135     }
136 
137     public static UpcallHandler arrangeUpcall(MethodHandle target, MethodType mt, FunctionDescriptor cDesc) {
138         Bindings bindings = getBindings(mt, cDesc, true);
139 
140         if (bindings.isInMemoryReturn) {
141             target = SharedUtils.adaptUpcallForIMR(target, true /* drop return, since we don't have bindings for it */);
142         }
143 
144         return ProgrammableUpcallHandler.make(CSysV, target, bindings.callingSequence);
145     }
146 
147     private static boolean isInMemoryReturn(Optional<MemoryLayout> returnLayout) {
148         return returnLayout
149                 .filter(GroupLayout.class::isInstance)
150                 .filter(g -> TypeClass.classifyLayout(g).inMemory())
151                 .isPresent();
152     }
153 
154     static class StorageCalculator {
155         private final boolean forArguments;
156 
157         private int nVectorReg = 0;
158         private int nIntegerReg = 0;
159         private long stackOffset = 0;
160 
161         public StorageCalculator(boolean forArguments) {
162             this.forArguments = forArguments;
163         }
164 
165         private int maxRegisterArguments(int type) {
166             return type == StorageClasses.INTEGER ?
167                     MAX_INTEGER_ARGUMENT_REGISTERS :
168                     SysVx64Linker.MAX_VECTOR_ARGUMENT_REGISTERS;
169         }
170 
171         VMStorage stackAlloc() {
172             assert forArguments : "no stack returns";
173             VMStorage storage = X86_64Architecture.stackStorage((int)stackOffset);
174             stackOffset++;
175             return storage;
176         }
177 
178         VMStorage nextStorage(int type) {
179             int registerCount = registerCount(type);
180             if (registerCount < maxRegisterArguments(type)) {
181                 VMStorage[] source =
182                     (forArguments ? CSysV.inputStorage : CSysV.outputStorage)[type];
183                 incrementRegisterCount(type);
184                 return source[registerCount];
185             } else {
186                 return stackAlloc();
187             }
188         }
189 
190         VMStorage[] structStorages(TypeClass typeClass) {
191             if (typeClass.inMemory()) {
192                 return typeClass.classes.stream().map(c -> stackAlloc()).toArray(VMStorage[]::new);
193             }
194             long nIntegerReg = typeClass.nIntegerRegs();
195 
196             if (this.nIntegerReg + nIntegerReg > MAX_INTEGER_ARGUMENT_REGISTERS) {
197                 //not enough registers - pass on stack
198                 return typeClass.classes.stream().map(c -> stackAlloc()).toArray(VMStorage[]::new);
199             }
200 
201             long nVectorReg = typeClass.nVectorRegs();
202 
203             if (this.nVectorReg + nVectorReg > MAX_VECTOR_ARGUMENT_REGISTERS) {
204                 //not enough registers - pass on stack
205                 return typeClass.classes.stream().map(c -> stackAlloc()).toArray(VMStorage[]::new);
206             }
207 
208             //ok, let's pass pass on registers
209             VMStorage[] storage = new VMStorage[(int)(nIntegerReg + nVectorReg)];
210             for (int i = 0 ; i < typeClass.classes.size() ; i++) {
211                 boolean sse = typeClass.classes.get(i) == ArgumentClassImpl.SSE;
212                 storage[i] = nextStorage(sse ? StorageClasses.VECTOR : StorageClasses.INTEGER);
213             }
214             return storage;
215         }
216 
217         int registerCount(int type) {
218             switch (type) {
219                 case StorageClasses.INTEGER:
220                     return nIntegerReg;
221                 case StorageClasses.VECTOR:
222                     return nVectorReg;
223                 default:
224                     throw new IllegalStateException();
225             }
226         }
227 
228         void incrementRegisterCount(int type) {
229             switch (type) {
230                 case StorageClasses.INTEGER:
231                     nIntegerReg++;
232                     break;
233                 case StorageClasses.VECTOR:
234                     nVectorReg++;
235                     break;
236                 default:
237                     throw new IllegalStateException();
238             }
239         }
240     }
241 
242     static abstract class BindingCalculator {
243         protected final StorageCalculator storageCalculator;
244 
245         protected BindingCalculator(boolean forArguments) {
246             this.storageCalculator = new StorageCalculator(forArguments);
247         }
248 
249         abstract List<Binding> getBindings(Class<?> carrier, MemoryLayout layout);
250     }
251 
252     static class UnboxBindingCalculator extends BindingCalculator {
253 
254         UnboxBindingCalculator(boolean forArguments) {
255             super(forArguments);
256         }
257 
258         @Override
259         List<Binding> getBindings(Class<?> carrier, MemoryLayout layout) {
260             TypeClass argumentClass = TypeClass.classifyLayout(layout);
261             Binding.Builder bindings = Binding.builder();
262             switch (argumentClass.kind()) {
263                 case STRUCT: {
264                     assert carrier == MemorySegment.class;
265                     VMStorage[] regs = storageCalculator.structStorages(argumentClass);
266                     int regIndex = 0;
267                     long offset = 0;
268                     while (offset < layout.byteSize()) {
269                         final long copy = Math.min(layout.byteSize() - offset, 8);
270                         VMStorage storage = regs[regIndex++];
271                         if (offset + copy < layout.byteSize()) {
272                             bindings.dup();
273                         }
274                         boolean useFloat = storage.type() == StorageClasses.VECTOR;
275                         Class<?> type = SharedUtils.primitiveCarrierForSize(copy, useFloat);
276                         bindings.bufferLoad(offset, type)
277                                 .vmStore(storage, type);
278                         offset += copy;
279                     }
280                     break;
281                 }
282                 case POINTER: {
283                     bindings.unboxAddress();
284                     VMStorage storage = storageCalculator.nextStorage(StorageClasses.INTEGER);
285                     bindings.vmStore(storage, long.class);
286                     break;
287                 }
288                 case INTEGER: {
289                     VMStorage storage = storageCalculator.nextStorage(StorageClasses.INTEGER);
290                     bindings.vmStore(storage, carrier);
291                     break;
292                 }
293                 case FLOAT: {
294                     VMStorage storage = storageCalculator.nextStorage(StorageClasses.VECTOR);
295                     bindings.vmStore(storage, carrier);
296                     break;
297                 }
298                 default:
299                     throw new UnsupportedOperationException("Unhandled class " + argumentClass);
300             }
301             return bindings.build();
302         }
303     }
304 
305     static class BoxBindingCalculator extends BindingCalculator {
306 
307         BoxBindingCalculator(boolean forArguments) {
308             super(forArguments);
309         }
310 
311         @Override
312         List<Binding> getBindings(Class<?> carrier, MemoryLayout layout) {
313             TypeClass argumentClass = TypeClass.classifyLayout(layout);
314             Binding.Builder bindings = Binding.builder();
315             switch (argumentClass.kind()) {
316                 case STRUCT: {
317                     assert carrier == MemorySegment.class;
318                     bindings.allocate(layout);
319                     VMStorage[] regs = storageCalculator.structStorages(argumentClass);
320                     int regIndex = 0;
321                     long offset = 0;
322                     while (offset < layout.byteSize()) {
323                         final long copy = Math.min(layout.byteSize() - offset, 8);
324                         VMStorage storage = regs[regIndex++];
325                         bindings.dup();
326                         boolean useFloat = storage.type() == StorageClasses.VECTOR;
327                         Class<?> type = SharedUtils.primitiveCarrierForSize(copy, useFloat);
328                         bindings.vmLoad(storage, type)
329                                 .bufferStore(offset, type);
330                         offset += copy;
331                     }
332                     break;
333                 }
334                 case POINTER: {
335                     VMStorage storage = storageCalculator.nextStorage(StorageClasses.INTEGER);
336                     bindings.vmLoad(storage, long.class)
337                             .boxAddress();
338                     break;
339                 }
340                 case INTEGER: {
341                     VMStorage storage = storageCalculator.nextStorage(StorageClasses.INTEGER);
342                     bindings.vmLoad(storage, carrier);
343                     break;
344                 }
345                 case FLOAT: {
346                     VMStorage storage = storageCalculator.nextStorage(StorageClasses.VECTOR);
347                     bindings.vmLoad(storage, carrier);
348                     break;
349                 }
350                 default:
351                     throw new UnsupportedOperationException("Unhandled class " + argumentClass);
352             }
353             return bindings.build();
354         }
355     }
356 
357 }