1 /*
  2  * Copyright (c) 2020, 2023, Oracle and/or its affiliates. All rights reserved.
  3  *  DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  *  This code is free software; you can redistribute it and/or modify it
  6  *  under the terms of the GNU General Public License version 2 only, as
  7  *  published by the Free Software Foundation.
  8  *
  9  *  This code is distributed in the hope that it will be useful, but WITHOUT
 10  *  ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 11  *  FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 12  *  version 2 for more details (a copy is included in the LICENSE file that
 13  *  accompanied this code).
 14  *
 15  *  You should have received a copy of the GNU General Public License version
 16  *  2 along with this work; if not, write to the Free Software Foundation,
 17  *  Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 18  *
 19  *   Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 20  *  or visit www.oracle.com if you need additional information or have any
 21  *  questions.
 22  *
 23  */
 24 
 25 /*
 26  * @test
 27  * @modules java.base/jdk.internal.foreign
 28  * @run testng/othervm --enable-native-access=ALL-UNNAMED -Dgenerator.sample.factor=17 TestVarArgs
 29  */
 30 
 31 import java.lang.foreign.Arena;
 32 import java.lang.foreign.Linker;
 33 import java.lang.foreign.FunctionDescriptor;
 34 import java.lang.foreign.MemoryLayout;
 35 import java.lang.foreign.ValueLayout;
 36 import java.lang.foreign.MemorySegment;
 37 
 38 import org.testng.annotations.DataProvider;
 39 import org.testng.annotations.Test;
 40 
 41 import java.lang.foreign.ValueLayout;
 42 import java.lang.invoke.MethodHandle;
 43 import java.lang.invoke.MethodHandles;
 44 import java.lang.invoke.MethodType;
 45 import java.lang.invoke.VarHandle;
 46 import java.util.ArrayList;
 47 import java.util.List;
 48 
 49 import static java.lang.foreign.MemoryLayout.PathElement.*;
 50 
 51 public class TestVarArgs extends CallGeneratorHelper {
 52 
 53     static final MethodHandle MH_CHECK;
 54 
 55     static final Linker LINKER = Linker.nativeLinker();
 56     static {
 57         System.loadLibrary("VarArgs");
 58         try {
 59             MH_CHECK = MethodHandles.lookup().findStatic(TestVarArgs.class, "check",
 60                     MethodType.methodType(void.class, int.class, MemorySegment.class, List.class));
 61         } catch (ReflectiveOperationException e) {
 62             throw new ExceptionInInitializerError(e);
 63         }
 64     }
 65 
 66     static final MemorySegment VARARGS_ADDR = findNativeOrThrow("varargs");
 67 
 68     @Test(dataProvider = "variadicFunctions")
 69     public void testVarArgs(int count, String fName, Ret ret, // ignore this stuff
 70                             List<ParamType> paramTypes, List<StructFieldType> fields) throws Throwable {
 71         try (Arena arena = Arena.ofConfined()) {
 72             List<Arg> args = makeArgs(arena, paramTypes, fields);
 73             MethodHandle checker = MethodHandles.insertArguments(MH_CHECK, 2, args);
 74             MemorySegment writeBack = LINKER.upcallStub(checker, FunctionDescriptor.ofVoid(C_INT, C_POINTER), arena);
 75             MemorySegment callInfo = arena.allocate(CallInfo.LAYOUT);
 76             MemoryLayout layout = MemoryLayout.sequenceLayout(args.size(), C_INT);
 77             MemorySegment argIDs = arena.allocate(layout);
 78 
 79             CallInfo.writeback(callInfo, writeBack);
 80             CallInfo.argIDs(callInfo, argIDs);
 81 
 82             for (int i = 0; i < args.size(); i++) {
 83                 argIDs.setAtIndex(ValueLayout.JAVA_INT, i, args.get(i).id.ordinal());
 84             }
 85 
 86             List<MemoryLayout> argLayouts = new ArrayList<>();
 87             argLayouts.add(C_POINTER); // call info
 88             argLayouts.add(C_INT); // size
 89 
 90             FunctionDescriptor baseDesc = FunctionDescriptor.ofVoid(argLayouts.toArray(MemoryLayout[]::new));
 91             Linker.Option varargIndex = Linker.Option.firstVariadicArg(baseDesc.argumentLayouts().size());
 92             FunctionDescriptor desc = baseDesc.appendArgumentLayouts(args.stream().map(a -> a.layout).toArray(MemoryLayout[]::new));
 93 
 94             MethodHandle downcallHandle = LINKER.downcallHandle(VARARGS_ADDR, desc, varargIndex);
 95 
 96             List<Object> argValues = new ArrayList<>();
 97             argValues.add(callInfo); // call info
 98             argValues.add(args.size());  // size
 99             args.forEach(a -> argValues.add(a.value()));
100 
101             downcallHandle.invokeWithArguments(argValues);
102 
103             // args checked by upcall
104         }
105     }
106 
107     private static List<ParamType> createParameterTypesForStruct(int extraIntArgs) {
108         List<ParamType> paramTypes = new ArrayList<ParamType>();
109         for (int i = 0; i < extraIntArgs; i++) {
110             paramTypes.add(ParamType.INT);
111         }
112         paramTypes.add(ParamType.STRUCT);
113         return paramTypes;
114     }
115 
116     private static List<StructFieldType> createFieldsForStruct(int fieldCount, StructFieldType fieldType) {
117         List<StructFieldType> fields = new ArrayList<StructFieldType>();
118         for (int i = 0; i < fieldCount; i++) {
119             fields.add(fieldType);
120         }
121         return fields;
122     }
123 
124     @DataProvider(name = "variadicFunctions")
125     public static Object[][] variadicFunctions() {
126         List<Object[]> downcalls = new ArrayList<>();
127 
128         var functionsDowncalls = functions();
129         for (var array : functionsDowncalls) {
130             downcalls.add(array);
131         }
132 
133         // Test struct with 4 floats
134         int extraIntArgs = 0;
135         List<StructFieldType> fields = createFieldsForStruct(4, StructFieldType.FLOAT);
136         List<ParamType> paramTypes = createParameterTypesForStruct(extraIntArgs);
137         downcalls.add(new Object[] { 0, "", Ret.VOID, paramTypes, fields });
138 
139         // Test struct with 4 floats without enough registers for all fields
140         extraIntArgs = 6;
141         fields = createFieldsForStruct(4, StructFieldType.FLOAT);
142         paramTypes = createParameterTypesForStruct(extraIntArgs);
143         downcalls.add(new Object[] { 0, "", Ret.VOID, paramTypes, fields });
144 
145         // Test struct with 2 doubles without enough registers for all fields
146         extraIntArgs = 7;
147         fields = createFieldsForStruct(2, StructFieldType.DOUBLE);
148         paramTypes = createParameterTypesForStruct(extraIntArgs);
149         downcalls.add(new Object[] { 0, "", Ret.VOID, paramTypes, fields });
150 
151         // Test struct with 2 ints without enough registers for all fields
152         fields = createFieldsForStruct(2, StructFieldType.INT);
153         paramTypes = createParameterTypesForStruct(extraIntArgs);
154         downcalls.add(new Object[] { 0, "", Ret.VOID, paramTypes, fields });
155 
156         return downcalls.toArray(new Object[0][]);
157     }
158 
159     private static List<Arg> makeArgs(Arena arena, List<ParamType> paramTypes, List<StructFieldType> fields) {
160         List<Arg> args = new ArrayList<>();
161         for (ParamType pType : paramTypes) {
162             MemoryLayout layout = pType.layout(fields);
163             if (layout instanceof ValueLayout.OfFloat) {
164                 layout = C_DOUBLE; // promote to double, per C spec
165             }
166             TestValue testValue = genTestValue(layout, arena);
167             Arg.NativeType type = Arg.NativeType.of(pType.type(fields));
168             args.add(pType == ParamType.STRUCT
169                 ? Arg.structArg(type, layout, testValue)
170                 : Arg.primitiveArg(type, layout, testValue));
171         }
172         return args;
173     }
174 
175     private static void check(int index, MemorySegment ptr, List<Arg> args) {
176         Arg varArg = args.get(index);
177         MemoryLayout layout = varArg.layout;
178         MethodHandle getter = varArg.getter;
179         try (Arena arena = Arena.ofConfined()) {
180             MemorySegment seg = ptr.asSlice(0, layout)
181                     .reinterpret(arena, null);
182             Object obj = getter.invoke(seg);
183             varArg.check(obj);
184         } catch (Throwable e) {
185             throw new RuntimeException(e);
186         }
187     }
188 
189     private static class CallInfo {
190         static final MemoryLayout LAYOUT = MemoryLayout.structLayout(
191                 C_POINTER.withName("writeback"), // writeback
192                 C_POINTER.withName("argIDs")); // arg ids
193 
194         static final VarHandle VH_writeback = LAYOUT.varHandle(groupElement("writeback"));
195         static final VarHandle VH_argIDs = LAYOUT.varHandle(groupElement("argIDs"));
196 
197         static void writeback(MemorySegment seg, MemorySegment addr) {
198             VH_writeback.set(seg, 0L, addr);
199         }
200         static void argIDs(MemorySegment seg, MemorySegment addr) {
201             VH_argIDs.set(seg, 0L, addr);
202         }
203     }
204 
205     private static final class Arg {
206         private final TestValue value;
207 
208         final NativeType id;
209         final MemoryLayout layout;
210         final MethodHandle getter;
211 
212         private Arg(NativeType id, MemoryLayout layout, TestValue value, MethodHandle getter) {
213             this.id = id;
214             this.layout = layout;
215             this.value = value;
216             this.getter = getter;
217         }
218 
219         private static Arg primitiveArg(NativeType id, MemoryLayout layout, TestValue value) {
220             MethodHandle getterHandle = layout.varHandle().toMethodHandle(VarHandle.AccessMode.GET);
221             getterHandle = MethodHandles.insertArguments(getterHandle, 1, 0L); // align signature with getter for structs
222             return new Arg(id, layout, value, getterHandle);
223         }
224 
225         private static Arg structArg(NativeType id, MemoryLayout layout, TestValue value) {
226             return new Arg(id, layout, value, MethodHandles.identity(MemorySegment.class));
227         }
228 
229         public void check(Object actual) {
230             value.check().accept(actual);
231         }
232 
233         public Object value() {
234             return value.value();
235         }
236 
237         enum NativeType {
238             INT,
239             DOUBLE,
240             POINTER,
241             S_I,
242             S_F,
243             S_D,
244             S_P,
245             S_II,
246             S_IF,
247             S_ID,
248             S_IP,
249             S_FI,
250             S_FF,
251             S_FD,
252             S_FP,
253             S_DI,
254             S_DF,
255             S_DD,
256             S_DP,
257             S_PI,
258             S_PF,
259             S_PD,
260             S_PP,
261             S_III,
262             S_IIF,
263             S_IID,
264             S_IIP,
265             S_IFI,
266             S_IFF,
267             S_IFD,
268             S_IFP,
269             S_IDI,
270             S_IDF,
271             S_IDD,
272             S_IDP,
273             S_IPI,
274             S_IPF,
275             S_IPD,
276             S_IPP,
277             S_FII,
278             S_FIF,
279             S_FID,
280             S_FIP,
281             S_FFI,
282             S_FFF,
283             S_FFD,
284             S_FFP,
285             S_FDI,
286             S_FDF,
287             S_FDD,
288             S_FDP,
289             S_FPI,
290             S_FPF,
291             S_FPD,
292             S_FPP,
293             S_DII,
294             S_DIF,
295             S_DID,
296             S_DIP,
297             S_DFI,
298             S_DFF,
299             S_DFD,
300             S_DFP,
301             S_DDI,
302             S_DDF,
303             S_DDD,
304             S_DDP,
305             S_DPI,
306             S_DPF,
307             S_DPD,
308             S_DPP,
309             S_PII,
310             S_PIF,
311             S_PID,
312             S_PIP,
313             S_PFI,
314             S_PFF,
315             S_PFD,
316             S_PFP,
317             S_PDI,
318             S_PDF,
319             S_PDD,
320             S_PDP,
321             S_PPI,
322             S_PPF,
323             S_PPD,
324             S_PPP,
325             S_FFFF,
326             ;
327 
328             public static NativeType of(String type) {
329                 return NativeType.valueOf(switch (type) {
330                     case "int" -> "INT";
331                     case "float" -> "DOUBLE"; // promote
332                     case "double" -> "DOUBLE";
333                     case "void*" -> "POINTER";
334                     default -> type.substring("struct ".length());
335                 });
336             }
337         }
338     }
339 
340 }