1 /*
  2  * Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved.
  3  *  DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  *  This code is free software; you can redistribute it and/or modify it
  6  *  under the terms of the GNU General Public License version 2 only, as
  7  *  published by the Free Software Foundation.
  8  *
  9  *  This code is distributed in the hope that it will be useful, but WITHOUT
 10  *  ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 11  *  FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 12  *  version 2 for more details (a copy is included in the LICENSE file that
 13  *  accompanied this code).
 14  *
 15  *  You should have received a copy of the GNU General Public License version
 16  *  2 along with this work; if not, write to the Free Software Foundation,
 17  *  Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 18  *
 19  *   Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 20  *  or visit www.oracle.com if you need additional information or have any
 21  *  questions.
 22  *
 23  */
 24 
 25 import jdk.incubator.foreign.GroupLayout;
 26 import jdk.incubator.foreign.MemoryAddress;
 27 import jdk.incubator.foreign.MemoryLayout;
 28 import jdk.incubator.foreign.MemorySegment;
 29 import jdk.incubator.foreign.ResourceScope;
 30 import jdk.incubator.foreign.SegmentAllocator;
 31 import jdk.incubator.foreign.ValueLayout;
 32 
 33 import java.lang.invoke.VarHandle;
 34 import java.util.ArrayList;
 35 import java.util.List;
 36 import java.util.Stack;
 37 import java.util.function.Consumer;
 38 import java.util.stream.Collectors;
 39 import java.util.stream.IntStream;
 40 
 41 import org.testng.annotations.*;
 42 
 43 import static jdk.incubator.foreign.CLinker.*;
 44 import static org.testng.Assert.*;
 45 
 46 public class CallGeneratorHelper extends NativeTestHelper {
 47 
 48     static SegmentAllocator IMPLICIT_ALLOCATOR = (size, align) -> MemorySegment.allocateNative(size, align, ResourceScope.newImplicitScope());
 49 
 50     static final int SAMPLE_FACTOR = Integer.parseInt((String)System.getProperties().getOrDefault("generator.sample.factor", "-1"));
 51 
 52     static final int MAX_FIELDS = 3;
 53     static final int MAX_PARAMS = 3;
 54     static final int CHUNK_SIZE = 600;
 55 
 56     public static void assertStructEquals(MemorySegment actual, MemorySegment expected, MemoryLayout layout) {
 57         assertEquals(actual.byteSize(), expected.byteSize());
 58         GroupLayout g = (GroupLayout) layout;
 59         for (MemoryLayout field : g.memberLayouts()) {
 60             if (field instanceof ValueLayout) {
 61                 VarHandle vh = g.varHandle(vhCarrier(field), MemoryLayout.PathElement.groupElement(field.name().orElseThrow()));
 62                 assertEquals(vh.get(actual), vh.get(expected));
 63             }
 64         }
 65     }
 66 
 67     private static Class<?> vhCarrier(MemoryLayout layout) {
 68         if (layout instanceof ValueLayout) {
 69             if (isIntegral(layout)) {
 70                 if (layout.bitSize() == 64) {
 71                     return long.class;
 72                 }
 73                 return int.class;
 74             } else if (layout.bitSize() == 32) {
 75                 return float.class;
 76             }
 77             return double.class;
 78         } else {
 79             throw new IllegalStateException("Unexpected layout: " + layout);
 80         }
 81     }
 82 
 83     enum Ret {
 84         VOID,
 85         NON_VOID
 86     }
 87 
 88     enum StructFieldType {
 89         INT("int", C_INT),
 90         FLOAT("float", C_FLOAT),
 91         DOUBLE("double", C_DOUBLE),
 92         POINTER("void*", C_POINTER);
 93 
 94         final String typeStr;
 95         final MemoryLayout layout;
 96 
 97         StructFieldType(String typeStr, MemoryLayout layout) {
 98             this.typeStr = typeStr;
 99             this.layout = layout;
100         }
101 
102         MemoryLayout layout() {
103             return layout;
104         }
105 
106         @SuppressWarnings("unchecked")
107         static List<List<StructFieldType>>[] perms = new List[10];
108 
109         static List<List<StructFieldType>> perms(int i) {
110             if (perms[i] == null) {
111                 perms[i] = generateTest(i, values());
112             }
113             return perms[i];
114         }
115     }
116 
117     enum ParamType {
118         INT("int", C_INT),
119         FLOAT("float", C_FLOAT),
120         DOUBLE("double", C_DOUBLE),
121         POINTER("void*", C_POINTER),
122         STRUCT("struct S", null);
123 
124         private final String typeStr;
125         private final MemoryLayout layout;
126 
127         ParamType(String typeStr, MemoryLayout layout) {
128             this.typeStr = typeStr;
129             this.layout = layout;
130         }
131 
132         String type(List<StructFieldType> fields) {
133             return this == STRUCT ?
134                     typeStr + "_" + sigCode(fields) :
135                     typeStr;
136         }
137 
138         MemoryLayout layout(List<StructFieldType> fields) {
139             if (this == STRUCT) {
140                 long offset = 0L;
141                 List<MemoryLayout> layouts = new ArrayList<>();
142                 for (StructFieldType field : fields) {
143                     MemoryLayout l = field.layout();
144                     long padding = offset % l.bitSize();
145                     if (padding != 0) {
146                         layouts.add(MemoryLayout.paddingLayout(padding));
147                         offset += padding;
148                     }
149                     layouts.add(l.withName("field" + offset));
150                     offset += l.bitSize();
151                 }
152                 return MemoryLayout.structLayout(layouts.toArray(new MemoryLayout[0]));
153             } else {
154                 return layout;
155             }
156         }
157 
158         @SuppressWarnings("unchecked")
159         static List<List<ParamType>>[] perms = new List[10];
160 
161         static List<List<ParamType>> perms(int i) {
162             if (perms[i] == null) {
163                 perms[i] = generateTest(i, values());
164             }
165             return perms[i];
166         }
167     }
168 
169     static <Z> List<List<Z>> generateTest(int i, Z[] elems) {
170         List<List<Z>> res = new ArrayList<>();
171         generateTest(i, new Stack<>(), elems, res);
172         return res;
173     }
174 
175     static <Z> void generateTest(int i, Stack<Z> combo, Z[] elems, List<List<Z>> results) {
176         if (i == 0) {
177             results.add(new ArrayList<>(combo));
178         } else {
179             for (Z z : elems) {
180                 combo.push(z);
181                 generateTest(i - 1, combo, elems, results);
182                 combo.pop();
183             }
184         }
185     }
186 
187     @DataProvider(name = "functions")
188     public static Object[][] functions() {
189         int functions = 0;
190         List<Object[]> downcalls = new ArrayList<>();
191         for (Ret r : Ret.values()) {
192             for (int i = 0; i <= MAX_PARAMS; i++) {
193                 if (r != Ret.VOID && i == 0) continue;
194                 for (List<ParamType> ptypes : ParamType.perms(i)) {
195                     String retCode = r == Ret.VOID ? "V" : ptypes.get(0).name().charAt(0) + "";
196                     String sigCode = sigCode(ptypes);
197                     if (ptypes.contains(ParamType.STRUCT)) {
198                         for (int j = 1; j <= MAX_FIELDS; j++) {
199                             for (List<StructFieldType> fields : StructFieldType.perms(j)) {
200                                 String structCode = sigCode(fields);
201                                 int count = functions;
202                                 int fCode = functions++ / CHUNK_SIZE;
203                                 String fName = String.format("f%d_%s_%s_%s", fCode, retCode, sigCode, structCode);
204                                 if (SAMPLE_FACTOR == -1 || (count % SAMPLE_FACTOR) == 0) {
205                                     downcalls.add(new Object[]{count, fName, r, ptypes, fields});
206                                 }
207                             }
208                         }
209                     } else {
210                         String structCode = sigCode(List.<StructFieldType>of());
211                         int count = functions;
212                         int fCode = functions++ / CHUNK_SIZE;
213                         String fName = String.format("f%d_%s_%s_%s", fCode, retCode, sigCode, structCode);
214                         if (SAMPLE_FACTOR == -1 || (count % SAMPLE_FACTOR) == 0) {
215                             downcalls.add(new Object[]{count, fName, r, ptypes, List.of()});
216                         }
217                     }
218                 }
219             }
220         }
221         return downcalls.toArray(new Object[0][]);
222     }
223 
224     static <Z extends Enum<Z>> String sigCode(List<Z> elems) {
225         return elems.stream().map(p -> p.name().charAt(0) + "").collect(Collectors.joining());
226     }
227 
228     static void generateStructDecl(List<StructFieldType> fields) {
229         String structCode = sigCode(fields);
230         List<String> fieldDecls = new ArrayList<>();
231         for (int i = 0 ; i < fields.size() ; i++) {
232             fieldDecls.add(String.format("%s p%d;", fields.get(i).typeStr, i));
233         }
234         String res = String.format("struct S_%s { %s };", structCode,
235                 fieldDecls.stream().collect(Collectors.joining(" ")));
236         System.out.println(res);
237     }
238 
239     /* this can be used to generate the test header/implementation */
240     public static void main(String[] args) {
241         boolean header = args.length > 0 && args[0].equals("header");
242         boolean upcall = args.length > 1 && args[1].equals("upcall");
243         if (upcall) {
244             generateUpcalls(header);
245         } else {
246             generateDowncalls(header);
247         }
248     }
249 
250     static void generateDowncalls(boolean header) {
251         if (header) {
252             System.out.println(
253                 "#ifdef _WIN64\n" +
254                 "#define EXPORT __declspec(dllexport)\n" +
255                 "#else\n" +
256                 "#define EXPORT\n" +
257                 "#endif\n"
258             );
259 
260             for (int j = 1; j <= MAX_FIELDS; j++) {
261                 for (List<StructFieldType> fields : StructFieldType.perms(j)) {
262                     generateStructDecl(fields);
263                 }
264             }
265         } else {
266             System.out.println(
267                 "#include \"libh\"\n" +
268                 "#ifdef __clang__\n" +
269                 "#pragma clang optimize off\n" +
270                 "#elif defined __GNUC__\n" +
271                 "#pragma GCC optimize (\"O0\")\n" +
272                 "#elif defined _MSC_BUILD\n" +
273                 "#pragma optimize( \"\", off )\n" +
274                 "#endif\n"
275             );
276         }
277 
278         for (Object[] downcall : functions()) {
279             String fName = (String)downcall[0];
280             Ret r = (Ret)downcall[1];
281             @SuppressWarnings("unchecked")
282             List<ParamType> ptypes = (List<ParamType>)downcall[2];
283             @SuppressWarnings("unchecked")
284             List<StructFieldType> fields = (List<StructFieldType>)downcall[3];
285             generateDowncallFunction(fName, r, ptypes, fields, header);
286         }
287     }
288 
289     static void generateDowncallFunction(String fName, Ret ret, List<ParamType> params, List<StructFieldType> fields, boolean declOnly) {
290         String retType = ret == Ret.VOID ? "void" : params.get(0).type(fields);
291         List<String> paramDecls = new ArrayList<>();
292         for (int i = 0 ; i < params.size() ; i++) {
293             paramDecls.add(String.format("%s p%d", params.get(i).type(fields), i));
294         }
295         String sig = paramDecls.isEmpty() ?
296                 "void" :
297                 paramDecls.stream().collect(Collectors.joining(", "));
298         String body = ret == Ret.VOID ? "{ }" : "{ return p0; }";
299         String res = String.format("EXPORT %s f%s(%s) %s", retType, fName,
300                 sig, declOnly ? ";" : body);
301         System.out.println(res);
302     }
303 
304     static void generateUpcalls(boolean header) {
305         if (header) {
306             System.out.println(
307                 "#ifdef _WIN64\n" +
308                 "#define EXPORT __declspec(dllexport)\n" +
309                 "#else\n" +
310                 "#define EXPORT\n" +
311                 "#endif\n"
312             );
313 
314             for (int j = 1; j <= MAX_FIELDS; j++) {
315                 for (List<StructFieldType> fields : StructFieldType.perms(j)) {
316                     generateStructDecl(fields);
317                 }
318             }
319         } else {
320             System.out.println(
321                 "#include \"libh\"\n" +
322                 "#ifdef __clang__\n" +
323                 "#pragma clang optimize off\n" +
324                 "#elif defined __GNUC__\n" +
325                 "#pragma GCC optimize (\"O0\")\n" +
326                 "#elif defined _MSC_BUILD\n" +
327                 "#pragma optimize( \"\", off )\n" +
328                 "#endif\n"
329             );
330         }
331 
332         for (Object[] downcall : functions()) {
333             String fName = (String)downcall[0];
334             Ret r = (Ret)downcall[1];
335             @SuppressWarnings("unchecked")
336             List<ParamType> ptypes = (List<ParamType>)downcall[2];
337             @SuppressWarnings("unchecked")
338             List<StructFieldType> fields = (List<StructFieldType>)downcall[3];
339             generateUpcallFunction(fName, r, ptypes, fields, header);
340         }
341     }
342 
343     static void generateUpcallFunction(String fName, Ret ret, List<ParamType> params, List<StructFieldType> fields, boolean declOnly) {
344         String retType = ret == Ret.VOID ? "void" : params.get(0).type(fields);
345         List<String> paramDecls = new ArrayList<>();
346         for (int i = 0 ; i < params.size() ; i++) {
347             paramDecls.add(String.format("%s p%d", params.get(i).type(fields), i));
348         }
349         String paramNames = IntStream.range(0, params.size())
350                 .mapToObj(i -> "p" + i)
351                 .collect(Collectors.joining(","));
352         String sig = paramDecls.isEmpty() ?
353                 "" :
354                 paramDecls.stream().collect(Collectors.joining(", ")) + ", ";
355         String body = String.format(ret == Ret.VOID ? "{ cb(%s); }" : "{ return cb(%s); }", paramNames);
356         List<String> paramTypes = params.stream().map(p -> p.type(fields)).collect(Collectors.toList());
357         String cbSig = paramTypes.isEmpty() ?
358                 "void" :
359                 paramTypes.stream().collect(Collectors.joining(", "));
360         String cbParam = String.format("%s (*cb)(%s)",
361                 retType, cbSig);
362 
363         String res = String.format("EXPORT %s %s(%s %s) %s", retType, fName,
364                 sig, cbParam, declOnly ? ";" : body);
365         System.out.println(res);
366     }
367 
368     //helper methods
369 
370     @SuppressWarnings("unchecked")
371     static Object makeArg(MemoryLayout layout, List<Consumer<Object>> checks, boolean check) throws ReflectiveOperationException {
372         if (layout instanceof GroupLayout) {
373             MemorySegment segment = MemorySegment.allocateNative(layout, ResourceScope.newImplicitScope());
374             initStruct(segment, (GroupLayout)layout, checks, check);
375             return segment;
376         } else if (isPointer(layout)) {
377             MemorySegment segment = MemorySegment.allocateNative(1, ResourceScope.newImplicitScope());
378             if (check) {
379                 checks.add(o -> {
380                     try {
381                         assertEquals(o, segment.address());
382                     } catch (Throwable ex) {
383                         throw new IllegalStateException(ex);
384                     }
385                 });
386             }
387             return segment.address();
388         } else if (layout instanceof ValueLayout) {
389             if (isIntegral(layout)) {
390                 if (check) {
391                     checks.add(o -> assertEquals(o, 42));
392                 }
393                 return 42;
394             } else if (layout.bitSize() == 32) {
395                 if (check) {
396                     checks.add(o -> assertEquals(o, 12f));
397                 }
398                 return 12f;
399             } else {
400                 if (check) {
401                     checks.add(o -> assertEquals(o, 24d));
402                 }
403                 return 24d;
404             }
405         } else {
406             throw new IllegalStateException("Unexpected layout: " + layout);
407         }
408     }
409 
410     static void initStruct(MemorySegment str, GroupLayout g, List<Consumer<Object>> checks, boolean check) throws ReflectiveOperationException {
411         for (MemoryLayout l : g.memberLayouts()) {
412             if (l.isPadding()) continue;
413             VarHandle accessor = g.varHandle(structFieldCarrier(l), MemoryLayout.PathElement.groupElement(l.name().get()));
414             List<Consumer<Object>> fieldsCheck = new ArrayList<>();
415             Object value = makeArg(l, fieldsCheck, check);
416             if (isPointer(l)) {
417                 value = ((MemoryAddress)value).toRawLongValue();
418             }
419             //set value
420             accessor.set(str, value);
421             //add check
422             if (check) {
423                 assertTrue(fieldsCheck.size() == 1);
424                 checks.add(o -> {
425                     MemorySegment actual = (MemorySegment)o;
426                     try {
427                         if (isPointer(l)) {
428                             fieldsCheck.get(0).accept(MemoryAddress.ofLong((long)accessor.get(actual)));
429                         } else {
430                             fieldsCheck.get(0).accept(accessor.get(actual));
431                         }
432                     } catch (Throwable ex) {
433                         throw new IllegalStateException(ex);
434                     }
435                 });
436             }
437         }
438     }
439 
440     static Class<?> structFieldCarrier(MemoryLayout layout) {
441         if (isPointer(layout)) {
442             return long.class;
443         } else if (layout instanceof ValueLayout) {
444             if (isIntegral(layout)) {
445                 return int.class;
446             } else if (layout.bitSize() == 32) {
447                 return float.class;
448             } else {
449                 return double.class;
450             }
451         } else {
452             throw new IllegalStateException("Unexpected layout: " + layout);
453         }
454     }
455 
456     static Class<?> paramCarrier(MemoryLayout layout) {
457         if (layout instanceof GroupLayout) {
458             return MemorySegment.class;
459         } if (isPointer(layout)) {
460             return MemoryAddress.class;
461         } else if (layout instanceof ValueLayout) {
462             if (isIntegral(layout)) {
463                 return int.class;
464             } else if (layout.bitSize() == 32) {
465                 return float.class;
466             } else {
467                 return double.class;
468             }
469         } else {
470             throw new IllegalStateException("Unexpected layout: " + layout);
471         }
472     }
473 }