1 /*
  2  * Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved.
  3  *  DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  *  This code is free software; you can redistribute it and/or modify it
  6  *  under the terms of the GNU General Public License version 2 only, as
  7  *  published by the Free Software Foundation.
  8  *
  9  *  This code is distributed in the hope that it will be useful, but WITHOUT
 10  *  ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 11  *  FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 12  *  version 2 for more details (a copy is included in the LICENSE file that
 13  *  accompanied this code).
 14  *
 15  *  You should have received a copy of the GNU General Public License version
 16  *  2 along with this work; if not, write to the Free Software Foundation,
 17  *  Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 18  *
 19  *   Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 20  *  or visit www.oracle.com if you need additional information or have any
 21  *  questions.
 22  *
 23  */
 24 
 25 import jdk.incubator.foreign.Addressable;
 26 import jdk.incubator.foreign.CLinker;
 27 import jdk.incubator.foreign.FunctionDescriptor;
 28 import jdk.incubator.foreign.GroupLayout;
 29 import jdk.incubator.foreign.MemoryAddress;
 30 import jdk.incubator.foreign.MemoryLayout;
 31 import jdk.incubator.foreign.MemorySegment;
 32 import jdk.incubator.foreign.NativeSymbol;
 33 import jdk.incubator.foreign.ResourceScope;
 34 import jdk.incubator.foreign.SegmentAllocator;
 35 import jdk.incubator.foreign.ValueLayout;
 36 
 37 import java.lang.invoke.MethodHandle;
 38 import java.lang.invoke.VarHandle;
 39 import java.util.ArrayList;
 40 import java.util.List;
 41 import java.util.Stack;
 42 import java.util.function.Consumer;
 43 import java.util.stream.Collectors;
 44 import java.util.stream.IntStream;
 45 import java.util.stream.Stream;
 46 
 47 import org.testng.annotations.*;
 48 
 49 import static org.testng.Assert.*;
 50 
 51 public class CallGeneratorHelper extends NativeTestHelper {
 52 
 53     static final List<MemoryLayout> STACK_PREFIX_LAYOUTS = Stream.concat(
 54             Stream.generate(() -> (MemoryLayout) C_LONG_LONG).limit(8),
 55             Stream.generate(() -> (MemoryLayout)  C_DOUBLE).limit(8)
 56         ).toList();
 57 
 58     static SegmentAllocator THROWING_ALLOCATOR = (size, align) -> {
 59         throw new UnsupportedOperationException();
 60     };
 61 
 62     static final int SAMPLE_FACTOR = Integer.parseInt((String)System.getProperties().getOrDefault("generator.sample.factor", "-1"));
 63 
 64     static final int MAX_FIELDS = 3;
 65     static final int MAX_PARAMS = 3;
 66     static final int CHUNK_SIZE = 600;
 67 
 68     public static void assertStructEquals(MemorySegment actual, MemorySegment expected, MemoryLayout layout) {
 69         assertEquals(actual.byteSize(), expected.byteSize());
 70         GroupLayout g = (GroupLayout) layout;
 71         for (MemoryLayout field : g.memberLayouts()) {
 72             if (field instanceof ValueLayout) {
 73                 VarHandle vh = g.varHandle(MemoryLayout.PathElement.groupElement(field.name().orElseThrow()));
 74                 assertEquals(vh.get(actual), vh.get(expected));
 75             }
 76         }
 77     }
 78 
 79     private static Class<?> vhCarrier(MemoryLayout layout) {
 80         if (layout instanceof ValueLayout) {
 81             if (isIntegral(layout)) {
 82                 if (layout.bitSize() == 64) {
 83                     return long.class;
 84                 }
 85                 return int.class;
 86             } else if (layout.bitSize() == 32) {
 87                 return float.class;
 88             }
 89             return double.class;
 90         } else {
 91             throw new IllegalStateException("Unexpected layout: " + layout);
 92         }
 93     }
 94 
 95     enum Ret {
 96         VOID,
 97         NON_VOID
 98     }
 99 
100     enum StructFieldType {
101         INT("int", C_INT),
102         FLOAT("float", C_FLOAT),
103         DOUBLE("double", C_DOUBLE),
104         POINTER("void*", C_POINTER);
105 
106         final String typeStr;
107         final MemoryLayout layout;
108 
109         StructFieldType(String typeStr, MemoryLayout layout) {
110             this.typeStr = typeStr;
111             this.layout = layout;
112         }
113 
114         MemoryLayout layout() {
115             return layout;
116         }
117 
118         @SuppressWarnings("unchecked")
119         static List<List<StructFieldType>>[] perms = new List[10];
120 
121         static List<List<StructFieldType>> perms(int i) {
122             if (perms[i] == null) {
123                 perms[i] = generateTest(i, values());
124             }
125             return perms[i];
126         }
127     }
128 
129     enum ParamType {
130         INT("int", C_INT),
131         FLOAT("float", C_FLOAT),
132         DOUBLE("double", C_DOUBLE),
133         POINTER("void*", C_POINTER),
134         STRUCT("struct S", null);
135 
136         private final String typeStr;
137         private final MemoryLayout layout;
138 
139         ParamType(String typeStr, MemoryLayout layout) {
140             this.typeStr = typeStr;
141             this.layout = layout;
142         }
143 
144         String type(List<StructFieldType> fields) {
145             return this == STRUCT ?
146                     typeStr + "_" + sigCode(fields) :
147                     typeStr;
148         }
149 
150         MemoryLayout layout(List<StructFieldType> fields) {
151             if (this == STRUCT) {
152                 long offset = 0L;
153                 List<MemoryLayout> layouts = new ArrayList<>();
154                 for (StructFieldType field : fields) {
155                     MemoryLayout l = field.layout();
156                     long padding = offset % l.bitSize();
157                     if (padding != 0) {
158                         layouts.add(MemoryLayout.paddingLayout(padding));
159                         offset += padding;
160                     }
161                     layouts.add(l.withName("field" + offset));
162                     offset += l.bitSize();
163                 }
164                 return MemoryLayout.structLayout(layouts.toArray(new MemoryLayout[0]));
165             } else {
166                 return layout;
167             }
168         }
169 
170         @SuppressWarnings("unchecked")
171         static List<List<ParamType>>[] perms = new List[10];
172 
173         static List<List<ParamType>> perms(int i) {
174             if (perms[i] == null) {
175                 perms[i] = generateTest(i, values());
176             }
177             return perms[i];
178         }
179     }
180 
181     static <Z> List<List<Z>> generateTest(int i, Z[] elems) {
182         List<List<Z>> res = new ArrayList<>();
183         generateTest(i, new Stack<>(), elems, res);
184         return res;
185     }
186 
187     static <Z> void generateTest(int i, Stack<Z> combo, Z[] elems, List<List<Z>> results) {
188         if (i == 0) {
189             results.add(new ArrayList<>(combo));
190         } else {
191             for (Z z : elems) {
192                 combo.push(z);
193                 generateTest(i - 1, combo, elems, results);
194                 combo.pop();
195             }
196         }
197     }
198 
199     @DataProvider(name = "functions")
200     public static Object[][] functions() {
201         int functions = 0;
202         List<Object[]> downcalls = new ArrayList<>();
203         for (Ret r : Ret.values()) {
204             for (int i = 0; i <= MAX_PARAMS; i++) {
205                 if (r != Ret.VOID && i == 0) continue;
206                 for (List<ParamType> ptypes : ParamType.perms(i)) {
207                     String retCode = r == Ret.VOID ? "V" : ptypes.get(0).name().charAt(0) + "";
208                     String sigCode = sigCode(ptypes);
209                     if (ptypes.contains(ParamType.STRUCT)) {
210                         for (int j = 1; j <= MAX_FIELDS; j++) {
211                             for (List<StructFieldType> fields : StructFieldType.perms(j)) {
212                                 String structCode = sigCode(fields);
213                                 int count = functions;
214                                 int fCode = functions++ / CHUNK_SIZE;
215                                 String fName = String.format("f%d_%s_%s_%s", fCode, retCode, sigCode, structCode);
216                                 if (SAMPLE_FACTOR == -1 || (count % SAMPLE_FACTOR) == 0) {
217                                     downcalls.add(new Object[]{count, fName, r, ptypes, fields});
218                                 }
219                             }
220                         }
221                     } else {
222                         String structCode = sigCode(List.<StructFieldType>of());
223                         int count = functions;
224                         int fCode = functions++ / CHUNK_SIZE;
225                         String fName = String.format("f%d_%s_%s_%s", fCode, retCode, sigCode, structCode);
226                         if (SAMPLE_FACTOR == -1 || (count % SAMPLE_FACTOR) == 0) {
227                             downcalls.add(new Object[]{count, fName, r, ptypes, List.of()});
228                         }
229                     }
230                 }
231             }
232         }
233         return downcalls.toArray(new Object[0][]);
234     }
235 
236     static <Z extends Enum<Z>> String sigCode(List<Z> elems) {
237         return elems.stream().map(p -> p.name().charAt(0) + "").collect(Collectors.joining());
238     }
239 
240     static void generateStructDecl(List<StructFieldType> fields) {
241         String structCode = sigCode(fields);
242         List<String> fieldDecls = new ArrayList<>();
243         for (int i = 0 ; i < fields.size() ; i++) {
244             fieldDecls.add(String.format("%s p%d;", fields.get(i).typeStr, i));
245         }
246         String res = String.format("struct S_%s { %s };", structCode,
247                 fieldDecls.stream().collect(Collectors.joining(" ")));
248         System.out.println(res);
249     }
250 
251     /* this can be used to generate the test header/implementation */
252     public static void main(String[] args) {
253         boolean header = args.length > 0 && args[0].equals("header");
254         boolean upcall = args.length > 1 && args[1].equals("upcall");
255         if (upcall) {
256             generateUpcalls(header);
257         } else {
258             generateDowncalls(header);
259         }
260     }
261 
262     static void generateDowncalls(boolean header) {
263         if (header) {
264             System.out.println(
265                 "#ifdef _WIN64\n" +
266                 "#define EXPORT __declspec(dllexport)\n" +
267                 "#else\n" +
268                 "#define EXPORT\n" +
269                 "#endif\n"
270             );
271 
272             for (int j = 1; j <= MAX_FIELDS; j++) {
273                 for (List<StructFieldType> fields : StructFieldType.perms(j)) {
274                     generateStructDecl(fields);
275                 }
276             }
277         } else {
278             System.out.println(
279                 "#include \"libh\"\n" +
280                 "#ifdef __clang__\n" +
281                 "#pragma clang optimize off\n" +
282                 "#elif defined __GNUC__\n" +
283                 "#pragma GCC optimize (\"O0\")\n" +
284                 "#elif defined _MSC_BUILD\n" +
285                 "#pragma optimize( \"\", off )\n" +
286                 "#endif\n"
287             );
288         }
289 
290         for (Object[] downcall : functions()) {
291             String fName = (String)downcall[0];
292             Ret r = (Ret)downcall[1];
293             @SuppressWarnings("unchecked")
294             List<ParamType> ptypes = (List<ParamType>)downcall[2];
295             @SuppressWarnings("unchecked")
296             List<StructFieldType> fields = (List<StructFieldType>)downcall[3];
297             generateDowncallFunction(fName, r, ptypes, fields, header);
298         }
299     }
300 
301     static void generateDowncallFunction(String fName, Ret ret, List<ParamType> params, List<StructFieldType> fields, boolean declOnly) {
302         String retType = ret == Ret.VOID ? "void" : params.get(0).type(fields);
303         List<String> paramDecls = new ArrayList<>();
304         for (int i = 0 ; i < params.size() ; i++) {
305             paramDecls.add(String.format("%s p%d", params.get(i).type(fields), i));
306         }
307         String sig = paramDecls.isEmpty() ?
308                 "void" :
309                 paramDecls.stream().collect(Collectors.joining(", "));
310         String body = ret == Ret.VOID ? "{ }" : "{ return p0; }";
311         String res = String.format("EXPORT %s f%s(%s) %s", retType, fName,
312                 sig, declOnly ? ";" : body);
313         System.out.println(res);
314     }
315 
316     static void generateUpcalls(boolean header) {
317         if (header) {
318             System.out.println(
319                 "#ifdef _WIN64\n" +
320                 "#define EXPORT __declspec(dllexport)\n" +
321                 "#else\n" +
322                 "#define EXPORT\n" +
323                 "#endif\n"
324             );
325 
326             for (int j = 1; j <= MAX_FIELDS; j++) {
327                 for (List<StructFieldType> fields : StructFieldType.perms(j)) {
328                     generateStructDecl(fields);
329                 }
330             }
331         } else {
332             System.out.println(
333                 "#include \"libh\"\n" +
334                 "#ifdef __clang__\n" +
335                 "#pragma clang optimize off\n" +
336                 "#elif defined __GNUC__\n" +
337                 "#pragma GCC optimize (\"O0\")\n" +
338                 "#elif defined _MSC_BUILD\n" +
339                 "#pragma optimize( \"\", off )\n" +
340                 "#endif\n"
341             );
342         }
343 
344         for (Object[] downcall : functions()) {
345             String fName = (String)downcall[0];
346             Ret r = (Ret)downcall[1];
347             @SuppressWarnings("unchecked")
348             List<ParamType> ptypes = (List<ParamType>)downcall[2];
349             @SuppressWarnings("unchecked")
350             List<StructFieldType> fields = (List<StructFieldType>)downcall[3];
351             generateUpcallFunction(fName, r, ptypes, fields, header);
352         }
353     }
354 
355     static void generateUpcallFunction(String fName, Ret ret, List<ParamType> params, List<StructFieldType> fields, boolean declOnly) {
356         String retType = ret == Ret.VOID ? "void" : params.get(0).type(fields);
357         List<String> paramDecls = new ArrayList<>();
358         for (int i = 0 ; i < params.size() ; i++) {
359             paramDecls.add(String.format("%s p%d", params.get(i).type(fields), i));
360         }
361         String paramNames = IntStream.range(0, params.size())
362                 .mapToObj(i -> "p" + i)
363                 .collect(Collectors.joining(","));
364         String sig = paramDecls.isEmpty() ?
365                 "" :
366                 paramDecls.stream().collect(Collectors.joining(", ")) + ", ";
367         String body = String.format(ret == Ret.VOID ? "{ cb(%s); }" : "{ return cb(%s); }", paramNames);
368         List<String> paramTypes = params.stream().map(p -> p.type(fields)).collect(Collectors.toList());
369         String cbSig = paramTypes.isEmpty() ?
370                 "void" :
371                 paramTypes.stream().collect(Collectors.joining(", "));
372         String cbParam = String.format("%s (*cb)(%s)",
373                 retType, cbSig);
374 
375         String res = String.format("EXPORT %s %s(%s %s) %s", retType, fName,
376                 sig, cbParam, declOnly ? ";" : body);
377         System.out.println(res);
378     }
379 
380     //helper methods
381 
382     @SuppressWarnings("unchecked")
383     static Object makeArg(MemoryLayout layout, List<Consumer<Object>> checks, boolean check) throws ReflectiveOperationException {
384         if (layout instanceof GroupLayout) {
385             MemorySegment segment = MemorySegment.allocateNative(layout, ResourceScope.newImplicitScope());
386             initStruct(segment, (GroupLayout)layout, checks, check);
387             return segment;
388         } else if (isPointer(layout)) {
389             MemorySegment segment = MemorySegment.allocateNative(1, ResourceScope.newImplicitScope());
390             if (check) {
391                 checks.add(o -> {
392                     try {
393                         assertEquals(o, segment.address());
394                     } catch (Throwable ex) {
395                         throw new IllegalStateException(ex);
396                     }
397                 });
398             }
399             return segment.address();
400         } else if (layout instanceof ValueLayout) {
401             if (isIntegral(layout)) {
402                 if (check) {
403                     checks.add(o -> assertEquals(o, 42));
404                 }
405                 return 42;
406             } else if (layout.bitSize() == 32) {
407                 if (check) {
408                     checks.add(o -> assertEquals(o, 12f));
409                 }
410                 return 12f;
411             } else {
412                 if (check) {
413                     checks.add(o -> assertEquals(o, 24d));
414                 }
415                 return 24d;
416             }
417         } else {
418             throw new IllegalStateException("Unexpected layout: " + layout);
419         }
420     }
421 
422     static void initStruct(MemorySegment str, GroupLayout g, List<Consumer<Object>> checks, boolean check) throws ReflectiveOperationException {
423         for (MemoryLayout l : g.memberLayouts()) {
424             if (l.isPadding()) continue;
425             VarHandle accessor = g.varHandle(MemoryLayout.PathElement.groupElement(l.name().get()));
426             List<Consumer<Object>> fieldsCheck = new ArrayList<>();
427             Object value = makeArg(l, fieldsCheck, check);
428             //set value
429             accessor.set(str, value);
430             //add check
431             if (check) {
432                 assertTrue(fieldsCheck.size() == 1);
433                 checks.add(o -> {
434                     MemorySegment actual = (MemorySegment)o;
435                     try {
436                         fieldsCheck.get(0).accept(accessor.get(actual));
437                     } catch (Throwable ex) {
438                         throw new IllegalStateException(ex);
439                     }
440                 });
441             }
442         }
443     }
444 
445     static Class<?> carrier(MemoryLayout layout, boolean param) {
446         if (layout instanceof GroupLayout) {
447             return MemorySegment.class;
448         } if (isPointer(layout)) {
449             return param ? Addressable.class : MemoryAddress.class;
450         } else if (layout instanceof ValueLayout valueLayout) {
451             return valueLayout.carrier();
452         } else {
453             throw new IllegalStateException("Unexpected layout: " + layout);
454         }
455     }
456 
457     MethodHandle downcallHandle(CLinker abi, NativeSymbol symbol, SegmentAllocator allocator, FunctionDescriptor descriptor) {
458         MethodHandle mh = abi.downcallHandle(symbol, descriptor);
459         if (descriptor.returnLayout().isPresent() && descriptor.returnLayout().get() instanceof GroupLayout) {
460             mh = mh.bindTo(allocator);
461         }
462         return mh;
463     }
464 }