1 /* 2 * Copyright (c) 2020, 2023, Oracle and/or its affiliates. All rights reserved. 3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. 4 * 5 * This code is free software; you can redistribute it and/or modify it 6 * under the terms of the GNU General Public License version 2 only, as 7 * published by the Free Software Foundation. 8 * 9 * This code is distributed in the hope that it will be useful, but WITHOUT 10 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or 11 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License 12 * version 2 for more details (a copy is included in the LICENSE file that 13 * accompanied this code). 14 * 15 * You should have received a copy of the GNU General Public License version 16 * 2 along with this work; if not, write to the Free Software Foundation, 17 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. 18 * 19 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA 20 * or visit www.oracle.com if you need additional information or have any 21 * questions. 22 */ 23 24 /* 25 * @test 26 * @run testng/othervm --enable-native-access=ALL-UNNAMED StdLibTest 27 */ 28 29 import java.lang.invoke.MethodHandle; 30 import java.lang.invoke.MethodHandles; 31 import java.lang.invoke.MethodType; 32 import java.time.Instant; 33 import java.time.LocalDateTime; 34 import java.time.ZoneOffset; 35 import java.time.ZonedDateTime; 36 import java.util.ArrayList; 37 import java.util.Arrays; 38 import java.util.Collections; 39 import java.util.LinkedHashSet; 40 import java.util.List; 41 import java.util.Set; 42 import java.util.function.Function; 43 import java.util.stream.Collectors; 44 import java.util.stream.Stream; 45 46 import java.lang.foreign.*; 47 48 import org.testng.annotations.*; 49 50 import static org.testng.Assert.*; 51 52 public class StdLibTest extends NativeTestHelper { 53 54 final static Linker abi = Linker.nativeLinker(); 55 56 private StdLibHelper stdLibHelper = new StdLibHelper(); 57 58 @Test(dataProvider = "stringPairs") 59 void test_strcat(String s1, String s2) throws Throwable { 60 assertEquals(stdLibHelper.strcat(s1, s2), s1 + s2); 61 } 62 63 @Test(dataProvider = "stringPairs") 64 void test_strcmp(String s1, String s2) throws Throwable { 65 assertEquals(Math.signum(stdLibHelper.strcmp(s1, s2)), Math.signum(s1.compareTo(s2))); 66 } 67 68 @Test(dataProvider = "strings") 69 void test_puts(String s) throws Throwable { 70 assertTrue(stdLibHelper.puts(s) >= 0); 71 } 72 73 @Test(dataProvider = "strings") 74 void test_strlen(String s) throws Throwable { 75 assertEquals(stdLibHelper.strlen(s), s.length()); 76 } 77 78 @Test(dataProvider = "instants") 79 void test_time(Instant instant) throws Throwable { 80 StdLibHelper.Tm tm = stdLibHelper.gmtime(instant.getEpochSecond()); 81 LocalDateTime localTime = LocalDateTime.ofInstant(instant, ZoneOffset.UTC); 82 assertEquals(tm.sec(), localTime.getSecond()); 83 assertEquals(tm.min(), localTime.getMinute()); 84 assertEquals(tm.hour(), localTime.getHour()); 85 //day pf year in Java has 1-offset 86 assertEquals(tm.yday(), localTime.getDayOfYear() - 1); 87 assertEquals(tm.mday(), localTime.getDayOfMonth()); 88 //days of week starts from Sunday in C, but on Monday in Java, also account for 1-offset 89 assertEquals((tm.wday() + 6) % 7, localTime.getDayOfWeek().getValue() - 1); 90 //month in Java has 1-offset 91 assertEquals(tm.mon(), localTime.getMonth().getValue() - 1); 92 assertEquals(tm.isdst(), ZoneOffset.UTC.getRules() 93 .isDaylightSavings(Instant.ofEpochMilli(instant.getEpochSecond() * 1000))); 94 } 95 96 @Test(dataProvider = "ints") 97 void test_qsort(List<Integer> ints) throws Throwable { 98 if (ints.size() > 0) { 99 int[] input = ints.stream().mapToInt(i -> i).toArray(); 100 int[] sorted = stdLibHelper.qsort(input); 101 Arrays.sort(input); 102 assertEquals(sorted, input); 103 } 104 } 105 106 @Test 107 void test_rand() throws Throwable { 108 int val = stdLibHelper.rand(); 109 for (int i = 0 ; i < 100 ; i++) { 110 int newVal = stdLibHelper.rand(); 111 if (newVal != val) { 112 return; //ok 113 } 114 val = newVal; 115 } 116 fail("All values are the same! " + val); 117 } 118 119 @Test(dataProvider = "printfArgs") 120 void test_printf(List<PrintfArg> args) throws Throwable { 121 String javaFormatArgs = args.stream() 122 .map(a -> a.javaFormat) 123 .collect(Collectors.joining(",")); 124 String nativeFormatArgs = args.stream() 125 .map(a -> a.nativeFormat) 126 .collect(Collectors.joining(",")); 127 128 String javaFormatString = "hello(" + javaFormatArgs + ")\n"; 129 String nativeFormatString = "hello(" + nativeFormatArgs + ")\n"; 130 131 String expected = String.format(javaFormatString, args.stream() 132 .map(a -> a.javaValue).toArray()); 133 134 int found = stdLibHelper.printf(nativeFormatString, args); 135 assertEquals(found, expected.length()); 136 } 137 138 @Test 139 void testSystemLibraryBadLookupName() { 140 assertTrue(LINKER.defaultLookup().find("strlen\u0000foobar").isEmpty()); 141 } 142 143 static class StdLibHelper { 144 145 final static MethodHandle strcat = abi.downcallHandle(abi.defaultLookup().find("strcat").get(), 146 FunctionDescriptor.of(C_POINTER, C_POINTER, C_POINTER)); 147 148 final static MethodHandle strcmp = abi.downcallHandle(abi.defaultLookup().find("strcmp").get(), 149 FunctionDescriptor.of(C_INT, C_POINTER, C_POINTER)); 150 151 final static MethodHandle puts = abi.downcallHandle(abi.defaultLookup().find("puts").get(), 152 FunctionDescriptor.of(C_INT, C_POINTER)); 153 154 final static MethodHandle strlen = abi.downcallHandle(abi.defaultLookup().find("strlen").get(), 155 FunctionDescriptor.of(C_INT, C_POINTER)); 156 157 final static MethodHandle gmtime = abi.downcallHandle(abi.defaultLookup().find("gmtime").get(), 158 FunctionDescriptor.of(C_POINTER.withTargetLayout(Tm.LAYOUT), C_POINTER)); 159 160 // void qsort( void *ptr, size_t count, size_t size, int (*comp)(const void *, const void *) ); 161 final static MethodHandle qsort = abi.downcallHandle(abi.defaultLookup().find("qsort").get(), 162 FunctionDescriptor.ofVoid(C_POINTER, C_SIZE_T, C_SIZE_T, C_POINTER)); 163 164 final static FunctionDescriptor qsortComparFunction = FunctionDescriptor.of(C_INT, 165 C_POINTER.withTargetLayout(C_INT), C_POINTER.withTargetLayout(C_INT)); 166 167 final static MethodHandle qsortCompar; 168 169 final static MethodHandle rand = abi.downcallHandle(abi.defaultLookup().find("rand").get(), 170 FunctionDescriptor.of(C_INT)); 171 172 final static MethodHandle vprintf = abi.downcallHandle(abi.defaultLookup().find("vprintf").get(), 173 FunctionDescriptor.of(C_INT, C_POINTER, C_POINTER)); 174 175 final static MemorySegment printfAddr = abi.defaultLookup().find("printf").get(); 176 177 final static FunctionDescriptor printfBase = FunctionDescriptor.of(C_INT, C_POINTER); 178 179 static { 180 try { 181 //qsort upcall handle 182 qsortCompar = MethodHandles.lookup().findStatic(StdLibTest.StdLibHelper.class, "qsortCompare", 183 qsortComparFunction.toMethodType()); 184 } catch (ReflectiveOperationException ex) { 185 throw new IllegalStateException(ex); 186 } 187 } 188 189 String strcat(String s1, String s2) throws Throwable { 190 try (var arena = Arena.ofConfined()) { 191 MemorySegment buf = arena.allocate(s1.length() + s2.length() + 1); 192 buf.setString(0, s1); 193 MemorySegment other = arena.allocateFrom(s2); 194 return ((MemorySegment)strcat.invokeExact(buf, other)).getString(0); 195 } 196 } 197 198 int strcmp(String s1, String s2) throws Throwable { 199 try (var arena = Arena.ofConfined()) { 200 MemorySegment ns1 = arena.allocateFrom(s1); 201 MemorySegment ns2 = arena.allocateFrom(s2); 202 return (int)strcmp.invokeExact(ns1, ns2); 203 } 204 } 205 206 int puts(String msg) throws Throwable { 207 try (var arena = Arena.ofConfined()) { 208 MemorySegment s = arena.allocateFrom(msg); 209 return (int)puts.invokeExact(s); 210 } 211 } 212 213 int strlen(String msg) throws Throwable { 214 try (var arena = Arena.ofConfined()) { 215 MemorySegment s = arena.allocateFrom(msg); 216 return (int)strlen.invokeExact(s); 217 } 218 } 219 220 Tm gmtime(long arg) throws Throwable { 221 try (var arena = Arena.ofConfined()) { 222 MemorySegment time = arena.allocate(8); 223 time.set(C_LONG_LONG, 0, arg); 224 return new Tm((MemorySegment)gmtime.invokeExact(time)); 225 } 226 } 227 228 static class Tm { 229 230 //Tm pointer should never be freed directly, as it points to shared memory 231 private final MemorySegment base; 232 233 static final MemoryLayout LAYOUT = MemoryLayout.structLayout( 234 C_INT.withName("sec"), 235 C_INT.withName("min"), 236 C_INT.withName("hour"), 237 C_INT.withName("mday"), 238 C_INT.withName("mon"), 239 C_INT.withName("year"), 240 C_INT.withName("wday"), 241 C_INT.withName("yday"), 242 C_BOOL.withName("isdst"), 243 MemoryLayout.paddingLayout(3) 244 ); 245 246 Tm(MemorySegment addr) { 247 this.base = addr; 248 } 249 250 int sec() { 251 return base.get(C_INT, 0); 252 } 253 int min() { 254 return base.get(C_INT, 4); 255 } 256 int hour() { 257 return base.get(C_INT, 8); 258 } 259 int mday() { 260 return base.get(C_INT, 12); 261 } 262 int mon() { 263 return base.get(C_INT, 16); 264 } 265 int year() { 266 return base.get(C_INT, 20); 267 } 268 int wday() { 269 return base.get(C_INT, 24); 270 } 271 int yday() { 272 return base.get(C_INT, 28); 273 } 274 boolean isdst() { 275 return base.get(C_BOOL, 32); 276 } 277 } 278 279 int[] qsort(int[] arr) throws Throwable { 280 //init native array 281 try (var arena = Arena.ofConfined()) { 282 MemorySegment nativeArr = arena.allocateFrom(C_INT, arr); 283 284 //call qsort 285 MemorySegment qsortUpcallStub = abi.upcallStub(qsortCompar, qsortComparFunction, arena); 286 287 // both of these fit in an int 288 // automatically widen them to long on x64 289 int count = arr.length; 290 int size = (int) C_INT.byteSize(); 291 qsort.invoke(nativeArr, count, size, qsortUpcallStub); 292 293 //convert back to Java array 294 return nativeArr.toArray(C_INT); 295 } 296 } 297 298 static int qsortCompare(MemorySegment addr1, MemorySegment addr2) { 299 return addr1.get(C_INT, 0) - 300 addr2.get(C_INT, 0); 301 } 302 303 int rand() throws Throwable { 304 return (int)rand.invokeExact(); 305 } 306 307 int printf(String format, List<PrintfArg> args) throws Throwable { 308 try (var arena = Arena.ofConfined()) { 309 MemorySegment formatStr = arena.allocateFrom(format); 310 return (int)specializedPrintf(args).invokeExact(formatStr, 311 args.stream().map(a -> a.nativeValue(arena)).toArray()); 312 } 313 } 314 315 private MethodHandle specializedPrintf(List<PrintfArg> args) { 316 //method type 317 MethodType mt = MethodType.methodType(int.class, MemorySegment.class); 318 FunctionDescriptor fd = printfBase; 319 List<MemoryLayout> variadicLayouts = new ArrayList<>(args.size()); 320 for (PrintfArg arg : args) { 321 mt = mt.appendParameterTypes(arg.carrier); 322 variadicLayouts.add(arg.layout); 323 } 324 Linker.Option varargIndex = Linker.Option.firstVariadicArg(fd.argumentLayouts().size()); 325 MethodHandle mh = abi.downcallHandle(printfAddr, 326 fd.appendArgumentLayouts(variadicLayouts.toArray(new MemoryLayout[args.size()])), 327 varargIndex); 328 return mh.asSpreader(1, Object[].class, args.size()); 329 } 330 } 331 332 /*** data providers ***/ 333 334 @DataProvider 335 public static Object[][] ints() { 336 return perms(0, new Integer[] { 0, 1, 2, 3, 4 }).stream() 337 .map(l -> new Object[] { l }) 338 .toArray(Object[][]::new); 339 } 340 341 @DataProvider 342 public static Object[][] strings() { 343 return perms(0, new String[] { "a", "b", "c" }).stream() 344 .map(l -> new Object[] { String.join("", l) }) 345 .toArray(Object[][]::new); 346 } 347 348 @DataProvider 349 public static Object[][] stringPairs() { 350 Object[][] strings = strings(); 351 Object[][] stringPairs = new Object[strings.length * strings.length][]; 352 int pos = 0; 353 for (Object[] s1 : strings) { 354 for (Object[] s2 : strings) { 355 stringPairs[pos++] = new Object[] { s1[0], s2[0] }; 356 } 357 } 358 return stringPairs; 359 } 360 361 @DataProvider 362 public static Object[][] instants() { 363 Instant start = ZonedDateTime.of(LocalDateTime.parse("2017-01-01T00:00:00"), ZoneOffset.UTC).toInstant(); 364 Instant end = ZonedDateTime.of(LocalDateTime.parse("2017-12-31T00:00:00"), ZoneOffset.UTC).toInstant(); 365 Object[][] instants = new Object[100][]; 366 for (int i = 0 ; i < instants.length ; i++) { 367 Instant instant = start.plusSeconds((long)(Math.random() * (end.getEpochSecond() - start.getEpochSecond()))); 368 instants[i] = new Object[] { instant }; 369 } 370 return instants; 371 } 372 373 @DataProvider 374 public static Object[][] printfArgs() { 375 ArrayList<List<PrintfArg>> res = new ArrayList<>(); 376 List<List<PrintfArg>> perms = new ArrayList<>(perms(0, PrintfArg.values())); 377 for (int i = 0 ; i < 100 ; i++) { 378 Collections.shuffle(perms); 379 res.addAll(perms); 380 } 381 return res.stream() 382 .map(l -> new Object[] { l }) 383 .toArray(Object[][]::new); 384 } 385 386 enum PrintfArg { 387 INT(int.class, C_INT, "%d", "%d", arena -> 42, 42), 388 LONG(long.class, C_LONG_LONG, "%lld", "%d", arena -> 84L, 84L), 389 DOUBLE(double.class, C_DOUBLE, "%.4f", "%.4f", arena -> 1.2345d, 1.2345d), 390 STRING(MemorySegment.class, C_POINTER, "%s", "%s", arena -> arena.allocateFrom("str"), "str"); 391 392 final Class<?> carrier; 393 final ValueLayout layout; 394 final String nativeFormat; 395 final String javaFormat; 396 final Function<Arena, ?> nativeValueFactory; 397 final Object javaValue; 398 399 <Z, L extends ValueLayout> PrintfArg(Class<?> carrier, L layout, String nativeFormat, String javaFormat, 400 Function<Arena, Z> nativeValueFactory, Object javaValue) { 401 this.carrier = carrier; 402 this.layout = layout; 403 this.nativeFormat = nativeFormat; 404 this.javaFormat = javaFormat; 405 this.nativeValueFactory = nativeValueFactory; 406 this.javaValue = javaValue; 407 } 408 409 public Object nativeValue(Arena arena) { 410 return nativeValueFactory.apply(arena); 411 } 412 } 413 414 static <Z> Set<List<Z>> perms(int count, Z[] arr) { 415 if (count == arr.length) { 416 return Set.of(List.of()); 417 } else { 418 return Arrays.stream(arr) 419 .flatMap(num -> { 420 Set<List<Z>> perms = perms(count + 1, arr); 421 return Stream.concat( 422 //take n 423 perms.stream().map(l -> { 424 List<Z> li = new ArrayList<>(l); 425 li.add(num); 426 return li; 427 }), 428 //drop n 429 perms.stream()); 430 }).collect(Collectors.toCollection(LinkedHashSet::new)); 431 } 432 } 433 }