1 /*
  2  * Copyright (c) 2020, 2023, Oracle and/or its affiliates. All rights reserved.
  3  *  DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  *  This code is free software; you can redistribute it and/or modify it
  6  *  under the terms of the GNU General Public License version 2 only, as
  7  *  published by the Free Software Foundation.
  8  *
  9  *  This code is distributed in the hope that it will be useful, but WITHOUT
 10  *  ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 11  *  FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 12  *  version 2 for more details (a copy is included in the LICENSE file that
 13  *  accompanied this code).
 14  *
 15  *  You should have received a copy of the GNU General Public License version
 16  *  2 along with this work; if not, write to the Free Software Foundation,
 17  *  Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 18  *
 19  *   Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 20  *  or visit www.oracle.com if you need additional information or have any
 21  *  questions.
 22  *
 23  */
 24 
 25 /*
 26  * @test
 27  * @enablePreview
 28  * @requires jdk.foreign.linker != "UNSUPPORTED"
 29  * @modules java.base/jdk.internal.foreign
 30  * @run testng/othervm --enable-native-access=ALL-UNNAMED -Dgenerator.sample.factor=17 TestVarArgs
 31  */
 32 
 33 import java.lang.foreign.Arena;
 34 import java.lang.foreign.Linker;
 35 import java.lang.foreign.FunctionDescriptor;
 36 import java.lang.foreign.MemoryLayout;
 37 import java.lang.foreign.ValueLayout;
 38 import java.lang.foreign.MemorySegment;
 39 
 40 import org.testng.annotations.DataProvider;
 41 import org.testng.annotations.Test;
 42 
 43 import java.lang.invoke.MethodHandle;
 44 import java.lang.invoke.MethodHandles;
 45 import java.lang.invoke.MethodType;
 46 import java.lang.invoke.VarHandle;
 47 import java.util.ArrayList;
 48 import java.util.List;
 49 
 50 import static java.lang.foreign.MemoryLayout.PathElement.*;
 51 
 52 public class TestVarArgs extends CallGeneratorHelper {
 53 
 54     static final VarHandle VH_IntArray = C_INT.arrayElementVarHandle();
 55     static final MethodHandle MH_CHECK;
 56 
 57     static final Linker LINKER = Linker.nativeLinker();
 58     static {
 59         System.loadLibrary("VarArgs");
 60         try {
 61             MH_CHECK = MethodHandles.lookup().findStatic(TestVarArgs.class, "check",
 62                     MethodType.methodType(void.class, int.class, MemorySegment.class, List.class));
 63         } catch (ReflectiveOperationException e) {
 64             throw new ExceptionInInitializerError(e);
 65         }
 66     }
 67 
 68     static final MemorySegment VARARGS_ADDR = findNativeOrThrow("varargs");
 69 
 70     @Test(dataProvider = "variadicFunctions")
 71     public void testVarArgs(int count, String fName, Ret ret, // ignore this stuff
 72                             List<ParamType> paramTypes, List<StructFieldType> fields) throws Throwable {
 73         try (Arena arena = Arena.ofConfined()) {
 74             List<Arg> args = makeArgs(arena, paramTypes, fields);
 75             MethodHandle checker = MethodHandles.insertArguments(MH_CHECK, 2, args);
 76             MemorySegment writeBack = LINKER.upcallStub(checker, FunctionDescriptor.ofVoid(C_INT, C_POINTER), arena);
 77             MemorySegment callInfo = arena.allocate(CallInfo.LAYOUT);
 78             MemoryLayout layout = MemoryLayout.sequenceLayout(args.size(), C_INT);
 79             MemorySegment argIDs = arena.allocate(layout);
 80 
 81             CallInfo.writeback(callInfo, writeBack);
 82             CallInfo.argIDs(callInfo, argIDs);
 83 
 84             for (int i = 0; i < args.size(); i++) {
 85                 VH_IntArray.set(argIDs, (long) i, args.get(i).id.ordinal());
 86             }
 87 
 88             List<MemoryLayout> argLayouts = new ArrayList<>();
 89             argLayouts.add(C_POINTER); // call info
 90             argLayouts.add(C_INT); // size
 91 
 92             FunctionDescriptor baseDesc = FunctionDescriptor.ofVoid(argLayouts.toArray(MemoryLayout[]::new));
 93             Linker.Option varargIndex = Linker.Option.firstVariadicArg(baseDesc.argumentLayouts().size());
 94             FunctionDescriptor desc = baseDesc.appendArgumentLayouts(args.stream().map(a -> a.layout).toArray(MemoryLayout[]::new));
 95 
 96             MethodHandle downcallHandle = LINKER.downcallHandle(VARARGS_ADDR, desc, varargIndex);
 97 
 98             List<Object> argValues = new ArrayList<>();
 99             argValues.add(callInfo); // call info
100             argValues.add(args.size());  // size
101             args.forEach(a -> argValues.add(a.value()));
102 
103             downcallHandle.invokeWithArguments(argValues);
104 
105             // args checked by upcall
106         }
107     }
108 
109     private static List<ParamType> createParameterTypesForStruct(int extraIntArgs) {
110         List<ParamType> paramTypes = new ArrayList<ParamType>();
111         for (int i = 0; i < extraIntArgs; i++) {
112             paramTypes.add(ParamType.INT);
113         }
114         paramTypes.add(ParamType.STRUCT);
115         return paramTypes;
116     }
117 
118     private static List<StructFieldType> createFieldsForStruct(int fieldCount, StructFieldType fieldType) {
119         List<StructFieldType> fields = new ArrayList<StructFieldType>();
120         for (int i = 0; i < fieldCount; i++) {
121             fields.add(fieldType);
122         }
123         return fields;
124     }
125 
126     @DataProvider(name = "variadicFunctions")
127     public static Object[][] variadicFunctions() {
128         List<Object[]> downcalls = new ArrayList<>();
129 
130         var functionsDowncalls = functions();
131         for (var array : functionsDowncalls) {
132             downcalls.add(array);
133         }
134 
135         // Test struct with 4 floats
136         int extraIntArgs = 0;
137         List<StructFieldType> fields = createFieldsForStruct(4, StructFieldType.FLOAT);
138         List<ParamType> paramTypes = createParameterTypesForStruct(extraIntArgs);
139         downcalls.add(new Object[] { 0, "", Ret.VOID, paramTypes, fields });
140 
141         // Test struct with 4 floats without enough registers for all fields
142         extraIntArgs = 6;
143         fields = createFieldsForStruct(4, StructFieldType.FLOAT);
144         paramTypes = createParameterTypesForStruct(extraIntArgs);
145         downcalls.add(new Object[] { 0, "", Ret.VOID, paramTypes, fields });
146 
147         // Test struct with 2 doubles without enough registers for all fields
148         extraIntArgs = 7;
149         fields = createFieldsForStruct(2, StructFieldType.DOUBLE);
150         paramTypes = createParameterTypesForStruct(extraIntArgs);
151         downcalls.add(new Object[] { 0, "", Ret.VOID, paramTypes, fields });
152 
153         // Test struct with 2 ints without enough registers for all fields
154         fields = createFieldsForStruct(2, StructFieldType.INT);
155         paramTypes = createParameterTypesForStruct(extraIntArgs);
156         downcalls.add(new Object[] { 0, "", Ret.VOID, paramTypes, fields });
157 
158         return downcalls.toArray(new Object[0][]);
159     }
160 
161     private static List<Arg> makeArgs(Arena arena, List<ParamType> paramTypes, List<StructFieldType> fields) {
162         List<Arg> args = new ArrayList<>();
163         for (ParamType pType : paramTypes) {
164             MemoryLayout layout = pType.layout(fields);
165             if (layout instanceof ValueLayout.OfFloat) {
166                 layout = C_DOUBLE; // promote to double, per C spec
167             }
168             TestValue testValue = genTestValue(layout, arena);
169             Arg.NativeType type = Arg.NativeType.of(pType.type(fields));
170             args.add(pType == ParamType.STRUCT
171                 ? Arg.structArg(type, layout, testValue)
172                 : Arg.primitiveArg(type, layout, testValue));
173         }
174         return args;
175     }
176 
177     private static void check(int index, MemorySegment ptr, List<Arg> args) {
178         Arg varArg = args.get(index);
179         MemoryLayout layout = varArg.layout;
180         MethodHandle getter = varArg.getter;
181         try (Arena arena = Arena.ofConfined()) {
182             MemorySegment seg = ptr.asSlice(0, layout)
183                     .reinterpret(arena, null);
184             Object obj = getter.invoke(seg);
185             varArg.check(obj);
186         } catch (Throwable e) {
187             throw new RuntimeException(e);
188         }
189     }
190 
191     private static class CallInfo {
192         static final MemoryLayout LAYOUT = MemoryLayout.structLayout(
193                 C_POINTER.withName("writeback"), // writeback
194                 C_POINTER.withName("argIDs")); // arg ids
195 
196         static final VarHandle VH_writeback = LAYOUT.varHandle(groupElement("writeback"));
197         static final VarHandle VH_argIDs = LAYOUT.varHandle(groupElement("argIDs"));
198 
199         static void writeback(MemorySegment seg, MemorySegment addr) {
200             VH_writeback.set(seg, addr);
201         }
202         static void argIDs(MemorySegment seg, MemorySegment addr) {
203             VH_argIDs.set(seg, addr);
204         }
205     }
206 
207     private static final class Arg {
208         private final TestValue value;
209 
210         final NativeType id;
211         final MemoryLayout layout;
212         final MethodHandle getter;
213 
214         private Arg(NativeType id, MemoryLayout layout, TestValue value, MethodHandle getter) {
215             this.id = id;
216             this.layout = layout;
217             this.value = value;
218             this.getter = getter;
219         }
220 
221         private static Arg primitiveArg(NativeType id, MemoryLayout layout, TestValue value) {
222             return new Arg(id, layout, value, layout.varHandle().toMethodHandle(VarHandle.AccessMode.GET));
223         }
224 
225         private static Arg structArg(NativeType id, MemoryLayout layout, TestValue value) {
226             return new Arg(id, layout, value, MethodHandles.identity(MemorySegment.class));
227         }
228 
229         public void check(Object actual) {
230             value.check().accept(actual);
231         }
232 
233         public Object value() {
234             return value.value();
235         }
236 
237         enum NativeType {
238             INT,
239             DOUBLE,
240             POINTER,
241             S_I,
242             S_F,
243             S_D,
244             S_P,
245             S_II,
246             S_IF,
247             S_ID,
248             S_IP,
249             S_FI,
250             S_FF,
251             S_FD,
252             S_FP,
253             S_DI,
254             S_DF,
255             S_DD,
256             S_DP,
257             S_PI,
258             S_PF,
259             S_PD,
260             S_PP,
261             S_III,
262             S_IIF,
263             S_IID,
264             S_IIP,
265             S_IFI,
266             S_IFF,
267             S_IFD,
268             S_IFP,
269             S_IDI,
270             S_IDF,
271             S_IDD,
272             S_IDP,
273             S_IPI,
274             S_IPF,
275             S_IPD,
276             S_IPP,
277             S_FII,
278             S_FIF,
279             S_FID,
280             S_FIP,
281             S_FFI,
282             S_FFF,
283             S_FFD,
284             S_FFP,
285             S_FDI,
286             S_FDF,
287             S_FDD,
288             S_FDP,
289             S_FPI,
290             S_FPF,
291             S_FPD,
292             S_FPP,
293             S_DII,
294             S_DIF,
295             S_DID,
296             S_DIP,
297             S_DFI,
298             S_DFF,
299             S_DFD,
300             S_DFP,
301             S_DDI,
302             S_DDF,
303             S_DDD,
304             S_DDP,
305             S_DPI,
306             S_DPF,
307             S_DPD,
308             S_DPP,
309             S_PII,
310             S_PIF,
311             S_PID,
312             S_PIP,
313             S_PFI,
314             S_PFF,
315             S_PFD,
316             S_PFP,
317             S_PDI,
318             S_PDF,
319             S_PDD,
320             S_PDP,
321             S_PPI,
322             S_PPF,
323             S_PPD,
324             S_PPP,
325             S_FFFF,
326             ;
327 
328             public static NativeType of(String type) {
329                 return NativeType.valueOf(switch (type) {
330                     case "int" -> "INT";
331                     case "float" -> "DOUBLE"; // promote
332                     case "double" -> "DOUBLE";
333                     case "void*" -> "POINTER";
334                     default -> type.substring("struct ".length());
335                 });
336             }
337         }
338     }
339 
340 }