1 /* 2 * Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved. 3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. 4 * 5 * This code is free software; you can redistribute it and/or modify it 6 * under the terms of the GNU General Public License version 2 only, as 7 * published by the Free Software Foundation. 8 * 9 * This code is distributed in the hope that it will be useful, but WITHOUT 10 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or 11 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License 12 * version 2 for more details (a copy is included in the LICENSE file that 13 * accompanied this code). 14 * 15 * You should have received a copy of the GNU General Public License version 16 * 2 along with this work; if not, write to the Free Software Foundation, 17 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. 18 * 19 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA 20 * or visit www.oracle.com if you need additional information or have any 21 * questions. 22 * 23 */ 24 25 import jdk.incubator.foreign.Addressable; 26 import jdk.incubator.foreign.CLinker; 27 import jdk.incubator.foreign.FunctionDescriptor; 28 import jdk.incubator.foreign.GroupLayout; 29 import jdk.incubator.foreign.MemoryAddress; 30 import jdk.incubator.foreign.MemoryLayout; 31 import jdk.incubator.foreign.MemorySegment; 32 import jdk.incubator.foreign.NativeSymbol; 33 import jdk.incubator.foreign.ResourceScope; 34 import jdk.incubator.foreign.SegmentAllocator; 35 import jdk.incubator.foreign.ValueLayout; 36 37 import java.lang.invoke.MethodHandle; 38 import java.lang.invoke.VarHandle; 39 import java.util.ArrayList; 40 import java.util.List; 41 import java.util.Stack; 42 import java.util.function.Consumer; 43 import java.util.stream.Collectors; 44 import java.util.stream.IntStream; 45 import java.util.stream.Stream; 46 47 import org.testng.annotations.*; 48 49 import static org.testng.Assert.*; 50 51 public class CallGeneratorHelper extends NativeTestHelper { 52 53 static final List<MemoryLayout> STACK_PREFIX_LAYOUTS = Stream.concat( 54 Stream.generate(() -> (MemoryLayout) C_LONG_LONG).limit(8), 55 Stream.generate(() -> (MemoryLayout) C_DOUBLE).limit(8) 56 ).toList(); 57 58 static SegmentAllocator THROWING_ALLOCATOR = (size, align) -> { 59 throw new UnsupportedOperationException(); 60 }; 61 62 static final int SAMPLE_FACTOR = Integer.parseInt((String)System.getProperties().getOrDefault("generator.sample.factor", "-1")); 63 64 static final int MAX_FIELDS = 3; 65 static final int MAX_PARAMS = 3; 66 static final int CHUNK_SIZE = 600; 67 68 public static void assertStructEquals(MemorySegment actual, MemorySegment expected, MemoryLayout layout) { 69 assertEquals(actual.byteSize(), expected.byteSize()); 70 GroupLayout g = (GroupLayout) layout; 71 for (MemoryLayout field : g.memberLayouts()) { 72 if (field instanceof ValueLayout) { 73 VarHandle vh = g.varHandle(MemoryLayout.PathElement.groupElement(field.name().orElseThrow())); 74 assertEquals(vh.get(actual), vh.get(expected)); 75 } 76 } 77 } 78 79 private static Class<?> vhCarrier(MemoryLayout layout) { 80 if (layout instanceof ValueLayout) { 81 if (isIntegral(layout)) { 82 if (layout.bitSize() == 64) { 83 return long.class; 84 } 85 return int.class; 86 } else if (layout.bitSize() == 32) { 87 return float.class; 88 } 89 return double.class; 90 } else { 91 throw new IllegalStateException("Unexpected layout: " + layout); 92 } 93 } 94 95 enum Ret { 96 VOID, 97 NON_VOID 98 } 99 100 enum StructFieldType { 101 INT("int", C_INT), 102 FLOAT("float", C_FLOAT), 103 DOUBLE("double", C_DOUBLE), 104 POINTER("void*", C_POINTER); 105 106 final String typeStr; 107 final MemoryLayout layout; 108 109 StructFieldType(String typeStr, MemoryLayout layout) { 110 this.typeStr = typeStr; 111 this.layout = layout; 112 } 113 114 MemoryLayout layout() { 115 return layout; 116 } 117 118 @SuppressWarnings("unchecked") 119 static List<List<StructFieldType>>[] perms = new List[10]; 120 121 static List<List<StructFieldType>> perms(int i) { 122 if (perms[i] == null) { 123 perms[i] = generateTest(i, values()); 124 } 125 return perms[i]; 126 } 127 } 128 129 enum ParamType { 130 INT("int", C_INT), 131 FLOAT("float", C_FLOAT), 132 DOUBLE("double", C_DOUBLE), 133 POINTER("void*", C_POINTER), 134 STRUCT("struct S", null); 135 136 private final String typeStr; 137 private final MemoryLayout layout; 138 139 ParamType(String typeStr, MemoryLayout layout) { 140 this.typeStr = typeStr; 141 this.layout = layout; 142 } 143 144 String type(List<StructFieldType> fields) { 145 return this == STRUCT ? 146 typeStr + "_" + sigCode(fields) : 147 typeStr; 148 } 149 150 MemoryLayout layout(List<StructFieldType> fields) { 151 if (this == STRUCT) { 152 long offset = 0L; 153 List<MemoryLayout> layouts = new ArrayList<>(); 154 for (StructFieldType field : fields) { 155 MemoryLayout l = field.layout(); 156 long padding = offset % l.bitSize(); 157 if (padding != 0) { 158 layouts.add(MemoryLayout.paddingLayout(padding)); 159 offset += padding; 160 } 161 layouts.add(l.withName("field" + offset)); 162 offset += l.bitSize(); 163 } 164 return MemoryLayout.structLayout(layouts.toArray(new MemoryLayout[0])); 165 } else { 166 return layout; 167 } 168 } 169 170 @SuppressWarnings("unchecked") 171 static List<List<ParamType>>[] perms = new List[10]; 172 173 static List<List<ParamType>> perms(int i) { 174 if (perms[i] == null) { 175 perms[i] = generateTest(i, values()); 176 } 177 return perms[i]; 178 } 179 } 180 181 static <Z> List<List<Z>> generateTest(int i, Z[] elems) { 182 List<List<Z>> res = new ArrayList<>(); 183 generateTest(i, new Stack<>(), elems, res); 184 return res; 185 } 186 187 static <Z> void generateTest(int i, Stack<Z> combo, Z[] elems, List<List<Z>> results) { 188 if (i == 0) { 189 results.add(new ArrayList<>(combo)); 190 } else { 191 for (Z z : elems) { 192 combo.push(z); 193 generateTest(i - 1, combo, elems, results); 194 combo.pop(); 195 } 196 } 197 } 198 199 @DataProvider(name = "functions") 200 public static Object[][] functions() { 201 int functions = 0; 202 List<Object[]> downcalls = new ArrayList<>(); 203 for (Ret r : Ret.values()) { 204 for (int i = 0; i <= MAX_PARAMS; i++) { 205 if (r != Ret.VOID && i == 0) continue; 206 for (List<ParamType> ptypes : ParamType.perms(i)) { 207 String retCode = r == Ret.VOID ? "V" : ptypes.get(0).name().charAt(0) + ""; 208 String sigCode = sigCode(ptypes); 209 if (ptypes.contains(ParamType.STRUCT)) { 210 for (int j = 1; j <= MAX_FIELDS; j++) { 211 for (List<StructFieldType> fields : StructFieldType.perms(j)) { 212 String structCode = sigCode(fields); 213 int count = functions; 214 int fCode = functions++ / CHUNK_SIZE; 215 String fName = String.format("f%d_%s_%s_%s", fCode, retCode, sigCode, structCode); 216 if (SAMPLE_FACTOR == -1 || (count % SAMPLE_FACTOR) == 0) { 217 downcalls.add(new Object[]{count, fName, r, ptypes, fields}); 218 } 219 } 220 } 221 } else { 222 String structCode = sigCode(List.<StructFieldType>of()); 223 int count = functions; 224 int fCode = functions++ / CHUNK_SIZE; 225 String fName = String.format("f%d_%s_%s_%s", fCode, retCode, sigCode, structCode); 226 if (SAMPLE_FACTOR == -1 || (count % SAMPLE_FACTOR) == 0) { 227 downcalls.add(new Object[]{count, fName, r, ptypes, List.of()}); 228 } 229 } 230 } 231 } 232 } 233 return downcalls.toArray(new Object[0][]); 234 } 235 236 static <Z extends Enum<Z>> String sigCode(List<Z> elems) { 237 return elems.stream().map(p -> p.name().charAt(0) + "").collect(Collectors.joining()); 238 } 239 240 static void generateStructDecl(List<StructFieldType> fields) { 241 String structCode = sigCode(fields); 242 List<String> fieldDecls = new ArrayList<>(); 243 for (int i = 0 ; i < fields.size() ; i++) { 244 fieldDecls.add(String.format("%s p%d;", fields.get(i).typeStr, i)); 245 } 246 String res = String.format("struct S_%s { %s };", structCode, 247 fieldDecls.stream().collect(Collectors.joining(" "))); 248 System.out.println(res); 249 } 250 251 /* this can be used to generate the test header/implementation */ 252 public static void main(String[] args) { 253 boolean header = args.length > 0 && args[0].equals("header"); 254 boolean upcall = args.length > 1 && args[1].equals("upcall"); 255 if (upcall) { 256 generateUpcalls(header); 257 } else { 258 generateDowncalls(header); 259 } 260 } 261 262 static void generateDowncalls(boolean header) { 263 if (header) { 264 System.out.println( 265 "#ifdef _WIN64\n" + 266 "#define EXPORT __declspec(dllexport)\n" + 267 "#else\n" + 268 "#define EXPORT\n" + 269 "#endif\n" 270 ); 271 272 for (int j = 1; j <= MAX_FIELDS; j++) { 273 for (List<StructFieldType> fields : StructFieldType.perms(j)) { 274 generateStructDecl(fields); 275 } 276 } 277 } else { 278 System.out.println( 279 "#include \"libh\"\n" + 280 "#ifdef __clang__\n" + 281 "#pragma clang optimize off\n" + 282 "#elif defined __GNUC__\n" + 283 "#pragma GCC optimize (\"O0\")\n" + 284 "#elif defined _MSC_BUILD\n" + 285 "#pragma optimize( \"\", off )\n" + 286 "#endif\n" 287 ); 288 } 289 290 for (Object[] downcall : functions()) { 291 String fName = (String)downcall[0]; 292 Ret r = (Ret)downcall[1]; 293 @SuppressWarnings("unchecked") 294 List<ParamType> ptypes = (List<ParamType>)downcall[2]; 295 @SuppressWarnings("unchecked") 296 List<StructFieldType> fields = (List<StructFieldType>)downcall[3]; 297 generateDowncallFunction(fName, r, ptypes, fields, header); 298 } 299 } 300 301 static void generateDowncallFunction(String fName, Ret ret, List<ParamType> params, List<StructFieldType> fields, boolean declOnly) { 302 String retType = ret == Ret.VOID ? "void" : params.get(0).type(fields); 303 List<String> paramDecls = new ArrayList<>(); 304 for (int i = 0 ; i < params.size() ; i++) { 305 paramDecls.add(String.format("%s p%d", params.get(i).type(fields), i)); 306 } 307 String sig = paramDecls.isEmpty() ? 308 "void" : 309 paramDecls.stream().collect(Collectors.joining(", ")); 310 String body = ret == Ret.VOID ? "{ }" : "{ return p0; }"; 311 String res = String.format("EXPORT %s f%s(%s) %s", retType, fName, 312 sig, declOnly ? ";" : body); 313 System.out.println(res); 314 } 315 316 static void generateUpcalls(boolean header) { 317 if (header) { 318 System.out.println( 319 "#ifdef _WIN64\n" + 320 "#define EXPORT __declspec(dllexport)\n" + 321 "#else\n" + 322 "#define EXPORT\n" + 323 "#endif\n" 324 ); 325 326 for (int j = 1; j <= MAX_FIELDS; j++) { 327 for (List<StructFieldType> fields : StructFieldType.perms(j)) { 328 generateStructDecl(fields); 329 } 330 } 331 } else { 332 System.out.println( 333 "#include \"libh\"\n" + 334 "#ifdef __clang__\n" + 335 "#pragma clang optimize off\n" + 336 "#elif defined __GNUC__\n" + 337 "#pragma GCC optimize (\"O0\")\n" + 338 "#elif defined _MSC_BUILD\n" + 339 "#pragma optimize( \"\", off )\n" + 340 "#endif\n" 341 ); 342 } 343 344 for (Object[] downcall : functions()) { 345 String fName = (String)downcall[0]; 346 Ret r = (Ret)downcall[1]; 347 @SuppressWarnings("unchecked") 348 List<ParamType> ptypes = (List<ParamType>)downcall[2]; 349 @SuppressWarnings("unchecked") 350 List<StructFieldType> fields = (List<StructFieldType>)downcall[3]; 351 generateUpcallFunction(fName, r, ptypes, fields, header); 352 } 353 } 354 355 static void generateUpcallFunction(String fName, Ret ret, List<ParamType> params, List<StructFieldType> fields, boolean declOnly) { 356 String retType = ret == Ret.VOID ? "void" : params.get(0).type(fields); 357 List<String> paramDecls = new ArrayList<>(); 358 for (int i = 0 ; i < params.size() ; i++) { 359 paramDecls.add(String.format("%s p%d", params.get(i).type(fields), i)); 360 } 361 String paramNames = IntStream.range(0, params.size()) 362 .mapToObj(i -> "p" + i) 363 .collect(Collectors.joining(",")); 364 String sig = paramDecls.isEmpty() ? 365 "" : 366 paramDecls.stream().collect(Collectors.joining(", ")) + ", "; 367 String body = String.format(ret == Ret.VOID ? "{ cb(%s); }" : "{ return cb(%s); }", paramNames); 368 List<String> paramTypes = params.stream().map(p -> p.type(fields)).collect(Collectors.toList()); 369 String cbSig = paramTypes.isEmpty() ? 370 "void" : 371 paramTypes.stream().collect(Collectors.joining(", ")); 372 String cbParam = String.format("%s (*cb)(%s)", 373 retType, cbSig); 374 375 String res = String.format("EXPORT %s %s(%s %s) %s", retType, fName, 376 sig, cbParam, declOnly ? ";" : body); 377 System.out.println(res); 378 } 379 380 //helper methods 381 382 @SuppressWarnings("unchecked") 383 static Object makeArg(MemoryLayout layout, List<Consumer<Object>> checks, boolean check) throws ReflectiveOperationException { 384 if (layout instanceof GroupLayout) { 385 MemorySegment segment = MemorySegment.allocateNative(layout, ResourceScope.newImplicitScope()); 386 initStruct(segment, (GroupLayout)layout, checks, check); 387 return segment; 388 } else if (isPointer(layout)) { 389 MemorySegment segment = MemorySegment.allocateNative(1, ResourceScope.newImplicitScope()); 390 if (check) { 391 checks.add(o -> { 392 try { 393 assertEquals(o, segment.address()); 394 } catch (Throwable ex) { 395 throw new IllegalStateException(ex); 396 } 397 }); 398 } 399 return segment.address(); 400 } else if (layout instanceof ValueLayout) { 401 if (isIntegral(layout)) { 402 if (check) { 403 checks.add(o -> assertEquals(o, 42)); 404 } 405 return 42; 406 } else if (layout.bitSize() == 32) { 407 if (check) { 408 checks.add(o -> assertEquals(o, 12f)); 409 } 410 return 12f; 411 } else { 412 if (check) { 413 checks.add(o -> assertEquals(o, 24d)); 414 } 415 return 24d; 416 } 417 } else { 418 throw new IllegalStateException("Unexpected layout: " + layout); 419 } 420 } 421 422 static void initStruct(MemorySegment str, GroupLayout g, List<Consumer<Object>> checks, boolean check) throws ReflectiveOperationException { 423 for (MemoryLayout l : g.memberLayouts()) { 424 if (l.isPadding()) continue; 425 VarHandle accessor = g.varHandle(MemoryLayout.PathElement.groupElement(l.name().get())); 426 List<Consumer<Object>> fieldsCheck = new ArrayList<>(); 427 Object value = makeArg(l, fieldsCheck, check); 428 //set value 429 accessor.set(str, value); 430 //add check 431 if (check) { 432 assertTrue(fieldsCheck.size() == 1); 433 checks.add(o -> { 434 MemorySegment actual = (MemorySegment)o; 435 try { 436 fieldsCheck.get(0).accept(accessor.get(actual)); 437 } catch (Throwable ex) { 438 throw new IllegalStateException(ex); 439 } 440 }); 441 } 442 } 443 } 444 445 static Class<?> carrier(MemoryLayout layout, boolean param) { 446 if (layout instanceof GroupLayout) { 447 return MemorySegment.class; 448 } if (isPointer(layout)) { 449 return param ? Addressable.class : MemoryAddress.class; 450 } else if (layout instanceof ValueLayout valueLayout) { 451 return valueLayout.carrier(); 452 } else { 453 throw new IllegalStateException("Unexpected layout: " + layout); 454 } 455 } 456 457 MethodHandle downcallHandle(CLinker abi, NativeSymbol symbol, SegmentAllocator allocator, FunctionDescriptor descriptor) { 458 MethodHandle mh = abi.downcallHandle(symbol, descriptor); 459 if (descriptor.returnLayout().isPresent() && descriptor.returnLayout().get() instanceof GroupLayout) { 460 mh = mh.bindTo(allocator); 461 } 462 return mh; 463 } 464 }