1 /*
  2  * Copyright (c) 2020, 2023, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.
  8  *
  9  * This code is distributed in the hope that it will be useful, but WITHOUT
 10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 12  * version 2 for more details (a copy is included in the LICENSE file that
 13  * accompanied this code).
 14  *
 15  * You should have received a copy of the GNU General Public License version
 16  * 2 along with this work; if not, write to the Free Software Foundation,
 17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 18  *
 19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 20  * or visit www.oracle.com if you need additional information or have any
 21  * questions.
 22  */
 23 
 24 /*
 25  * @test
 26  * @run testng/othervm --enable-native-access=ALL-UNNAMED StdLibTest
 27  */
 28 
 29 import java.lang.invoke.MethodHandle;
 30 import java.lang.invoke.MethodHandles;
 31 import java.lang.invoke.MethodType;
 32 import java.time.Instant;
 33 import java.time.LocalDateTime;
 34 import java.time.ZoneOffset;
 35 import java.time.ZonedDateTime;
 36 import java.util.ArrayList;
 37 import java.util.Arrays;
 38 import java.util.Collections;
 39 import java.util.LinkedHashSet;
 40 import java.util.List;
 41 import java.util.Set;
 42 import java.util.function.Function;
 43 import java.util.stream.Collectors;
 44 import java.util.stream.Stream;
 45 
 46 import java.lang.foreign.*;
 47 
 48 import org.testng.annotations.*;
 49 
 50 import static org.testng.Assert.*;
 51 
 52 public class StdLibTest extends NativeTestHelper {
 53 
 54     final static Linker abi = Linker.nativeLinker();
 55 
 56     private StdLibHelper stdLibHelper = new StdLibHelper();
 57 
 58     @Test(dataProvider = "stringPairs")
 59     void test_strcat(String s1, String s2) throws Throwable {
 60         assertEquals(stdLibHelper.strcat(s1, s2), s1 + s2);
 61     }
 62 
 63     @Test(dataProvider = "stringPairs")
 64     void test_strcmp(String s1, String s2) throws Throwable {
 65         assertEquals(Math.signum(stdLibHelper.strcmp(s1, s2)), Math.signum(s1.compareTo(s2)));
 66     }
 67 
 68     @Test(dataProvider = "strings")
 69     void test_puts(String s) throws Throwable {
 70         assertTrue(stdLibHelper.puts(s) >= 0);
 71     }
 72 
 73     @Test(dataProvider = "strings")
 74     void test_strlen(String s) throws Throwable {
 75         assertEquals(stdLibHelper.strlen(s), s.length());
 76     }
 77 
 78     @Test(dataProvider = "instants")
 79     void test_time(Instant instant) throws Throwable {
 80         StdLibHelper.Tm tm = stdLibHelper.gmtime(instant.getEpochSecond());
 81         LocalDateTime localTime = LocalDateTime.ofInstant(instant, ZoneOffset.UTC);
 82         assertEquals(tm.sec(), localTime.getSecond());
 83         assertEquals(tm.min(), localTime.getMinute());
 84         assertEquals(tm.hour(), localTime.getHour());
 85         //day pf year in Java has 1-offset
 86         assertEquals(tm.yday(), localTime.getDayOfYear() - 1);
 87         assertEquals(tm.mday(), localTime.getDayOfMonth());
 88         //days of week starts from Sunday in C, but on Monday in Java, also account for 1-offset
 89         assertEquals((tm.wday() + 6) % 7, localTime.getDayOfWeek().getValue() - 1);
 90         //month in Java has 1-offset
 91         assertEquals(tm.mon(), localTime.getMonth().getValue() - 1);
 92         assertEquals(tm.isdst(), ZoneOffset.UTC.getRules()
 93                 .isDaylightSavings(Instant.ofEpochMilli(instant.getEpochSecond() * 1000)));
 94     }
 95 
 96     @Test(dataProvider = "ints")
 97     void test_qsort(List<Integer> ints) throws Throwable {
 98         if (ints.size() > 0) {
 99             int[] input = ints.stream().mapToInt(i -> i).toArray();
100             int[] sorted = stdLibHelper.qsort(input);
101             Arrays.sort(input);
102             assertEquals(sorted, input);
103         }
104     }
105 
106     @Test
107     void test_rand() throws Throwable {
108         int val = stdLibHelper.rand();
109         for (int i = 0 ; i < 100 ; i++) {
110             int newVal = stdLibHelper.rand();
111             if (newVal != val) {
112                 return; //ok
113             }
114             val = newVal;
115         }
116         fail("All values are the same! " + val);
117     }
118 
119     @Test(dataProvider = "printfArgs")
120     void test_printf(List<PrintfArg> args) throws Throwable {
121         String javaFormatArgs = args.stream()
122                 .map(a -> a.javaFormat)
123                 .collect(Collectors.joining(","));
124         String nativeFormatArgs = args.stream()
125                 .map(a -> a.nativeFormat)
126                 .collect(Collectors.joining(","));
127 
128         String javaFormatString = "hello(" + javaFormatArgs + ")\n";
129         String nativeFormatString = "hello(" + nativeFormatArgs + ")\n";
130 
131         String expected = String.format(javaFormatString, args.stream()
132                 .map(a -> a.javaValue).toArray());
133 
134         int found = stdLibHelper.printf(nativeFormatString, args);
135         assertEquals(found, expected.length());
136     }
137 
138     @Test
139     void testSystemLibraryBadLookupName() {
140         assertTrue(LINKER.defaultLookup().find("strlen\u0000foobar").isEmpty());
141     }
142 
143     static class StdLibHelper {
144 
145         final static MethodHandle strcat = abi.downcallHandle(abi.defaultLookup().find("strcat").get(),
146                 FunctionDescriptor.of(C_POINTER, C_POINTER, C_POINTER));
147 
148         final static MethodHandle strcmp = abi.downcallHandle(abi.defaultLookup().find("strcmp").get(),
149                 FunctionDescriptor.of(C_INT, C_POINTER, C_POINTER));
150 
151         final static MethodHandle puts = abi.downcallHandle(abi.defaultLookup().find("puts").get(),
152                 FunctionDescriptor.of(C_INT, C_POINTER));
153 
154         final static MethodHandle strlen = abi.downcallHandle(abi.defaultLookup().find("strlen").get(),
155                 FunctionDescriptor.of(C_INT, C_POINTER));
156 
157         final static MethodHandle gmtime = abi.downcallHandle(abi.defaultLookup().find("gmtime").get(),
158                 FunctionDescriptor.of(C_POINTER.withTargetLayout(Tm.LAYOUT), C_POINTER));
159 
160         // void qsort( void *ptr, size_t count, size_t size, int (*comp)(const void *, const void *) );
161         final static MethodHandle qsort = abi.downcallHandle(abi.defaultLookup().find("qsort").get(),
162                 FunctionDescriptor.ofVoid(C_POINTER, C_SIZE_T, C_SIZE_T, C_POINTER));
163 
164         final static FunctionDescriptor qsortComparFunction = FunctionDescriptor.of(C_INT,
165                 C_POINTER.withTargetLayout(C_INT), C_POINTER.withTargetLayout(C_INT));
166 
167         final static MethodHandle qsortCompar;
168 
169         final static MethodHandle rand = abi.downcallHandle(abi.defaultLookup().find("rand").get(),
170                 FunctionDescriptor.of(C_INT));
171 
172         final static MethodHandle vprintf = abi.downcallHandle(abi.defaultLookup().find("vprintf").get(),
173                 FunctionDescriptor.of(C_INT, C_POINTER, C_POINTER));
174 
175         final static MemorySegment printfAddr = abi.defaultLookup().find("printf").get();
176 
177         final static FunctionDescriptor printfBase = FunctionDescriptor.of(C_INT, C_POINTER);
178 
179         static {
180             try {
181                 //qsort upcall handle
182                 qsortCompar = MethodHandles.lookup().findStatic(StdLibTest.StdLibHelper.class, "qsortCompare",
183                         qsortComparFunction.toMethodType());
184             } catch (ReflectiveOperationException ex) {
185                 throw new IllegalStateException(ex);
186             }
187         }
188 
189         String strcat(String s1, String s2) throws Throwable {
190             try (var arena = Arena.ofConfined()) {
191                 MemorySegment buf = arena.allocate(s1.length() + s2.length() + 1);
192                 buf.setString(0, s1);
193                 MemorySegment other = arena.allocateFrom(s2);
194                 return ((MemorySegment)strcat.invokeExact(buf, other)).getString(0);
195             }
196         }
197 
198         int strcmp(String s1, String s2) throws Throwable {
199             try (var arena = Arena.ofConfined()) {
200                 MemorySegment ns1 = arena.allocateFrom(s1);
201                 MemorySegment ns2 = arena.allocateFrom(s2);
202                 return (int)strcmp.invokeExact(ns1, ns2);
203             }
204         }
205 
206         int puts(String msg) throws Throwable {
207             try (var arena = Arena.ofConfined()) {
208                 MemorySegment s = arena.allocateFrom(msg);
209                 return (int)puts.invokeExact(s);
210             }
211         }
212 
213         int strlen(String msg) throws Throwable {
214             try (var arena = Arena.ofConfined()) {
215                 MemorySegment s = arena.allocateFrom(msg);
216                 return (int)strlen.invokeExact(s);
217             }
218         }
219 
220         Tm gmtime(long arg) throws Throwable {
221             try (var arena = Arena.ofConfined()) {
222                 MemorySegment time = arena.allocate(8);
223                 time.set(C_LONG_LONG, 0, arg);
224                 return new Tm((MemorySegment)gmtime.invokeExact(time));
225             }
226         }
227 
228         static class Tm {
229 
230             //Tm pointer should never be freed directly, as it points to shared memory
231             private final MemorySegment base;
232 
233             static final MemoryLayout LAYOUT = MemoryLayout.structLayout(
234                     C_INT.withName("sec"),
235                     C_INT.withName("min"),
236                     C_INT.withName("hour"),
237                     C_INT.withName("mday"),
238                     C_INT.withName("mon"),
239                     C_INT.withName("year"),
240                     C_INT.withName("wday"),
241                     C_INT.withName("yday"),
242                     C_BOOL.withName("isdst"),
243                     MemoryLayout.paddingLayout(3)
244             );
245 
246             Tm(MemorySegment addr) {
247                 this.base = addr;
248             }
249 
250             int sec() {
251                 return base.get(C_INT, 0);
252             }
253             int min() {
254                 return base.get(C_INT, 4);
255             }
256             int hour() {
257                 return base.get(C_INT, 8);
258             }
259             int mday() {
260                 return base.get(C_INT, 12);
261             }
262             int mon() {
263                 return base.get(C_INT, 16);
264             }
265             int year() {
266                 return base.get(C_INT, 20);
267             }
268             int wday() {
269                 return base.get(C_INT, 24);
270             }
271             int yday() {
272                 return base.get(C_INT, 28);
273             }
274             boolean isdst() {
275                 return base.get(C_BOOL, 32);
276             }
277         }
278 
279         int[] qsort(int[] arr) throws Throwable {
280             //init native array
281             try (var arena = Arena.ofConfined()) {
282                 MemorySegment nativeArr = arena.allocateFrom(C_INT, arr);
283 
284                 //call qsort
285                 MemorySegment qsortUpcallStub = abi.upcallStub(qsortCompar, qsortComparFunction, arena);
286 
287                 // both of these fit in an int
288                 // automatically widen them to long on x64
289                 int count = arr.length;
290                 int size = (int) C_INT.byteSize();
291                 qsort.invoke(nativeArr, count, size, qsortUpcallStub);
292 
293                 //convert back to Java array
294                 return nativeArr.toArray(C_INT);
295             }
296         }
297 
298         static int qsortCompare(MemorySegment addr1, MemorySegment addr2) {
299             return addr1.get(C_INT, 0) -
300                    addr2.get(C_INT, 0);
301         }
302 
303         int rand() throws Throwable {
304             return (int)rand.invokeExact();
305         }
306 
307         int printf(String format, List<PrintfArg> args) throws Throwable {
308             try (var arena = Arena.ofConfined()) {
309                 MemorySegment formatStr = arena.allocateFrom(format);
310                 return (int)specializedPrintf(args).invokeExact(formatStr,
311                         args.stream().map(a -> a.nativeValue(arena)).toArray());
312             }
313         }
314 
315         private MethodHandle specializedPrintf(List<PrintfArg> args) {
316             //method type
317             MethodType mt = MethodType.methodType(int.class, MemorySegment.class);
318             FunctionDescriptor fd = printfBase;
319             List<MemoryLayout> variadicLayouts = new ArrayList<>(args.size());
320             for (PrintfArg arg : args) {
321                 mt = mt.appendParameterTypes(arg.carrier);
322                 variadicLayouts.add(arg.layout);
323             }
324             Linker.Option varargIndex = Linker.Option.firstVariadicArg(fd.argumentLayouts().size());
325             MethodHandle mh = abi.downcallHandle(printfAddr,
326                     fd.appendArgumentLayouts(variadicLayouts.toArray(new MemoryLayout[args.size()])),
327                     varargIndex);
328             return mh.asSpreader(1, Object[].class, args.size());
329         }
330     }
331 
332     /*** data providers ***/
333 
334     @DataProvider
335     public static Object[][] ints() {
336         return perms(0, new Integer[] { 0, 1, 2, 3, 4 }).stream()
337                 .map(l -> new Object[] { l })
338                 .toArray(Object[][]::new);
339     }
340 
341     @DataProvider
342     public static Object[][] strings() {
343         return perms(0, new String[] { "a", "b", "c" }).stream()
344                 .map(l -> new Object[] { String.join("", l) })
345                 .toArray(Object[][]::new);
346     }
347 
348     @DataProvider
349     public static Object[][] stringPairs() {
350         Object[][] strings = strings();
351         Object[][] stringPairs = new Object[strings.length * strings.length][];
352         int pos = 0;
353         for (Object[] s1 : strings) {
354             for (Object[] s2 : strings) {
355                 stringPairs[pos++] = new Object[] { s1[0], s2[0] };
356             }
357         }
358         return stringPairs;
359     }
360 
361     @DataProvider
362     public static Object[][] instants() {
363         Instant start = ZonedDateTime.of(LocalDateTime.parse("2017-01-01T00:00:00"), ZoneOffset.UTC).toInstant();
364         Instant end = ZonedDateTime.of(LocalDateTime.parse("2017-12-31T00:00:00"), ZoneOffset.UTC).toInstant();
365         Object[][] instants = new Object[100][];
366         for (int i = 0 ; i < instants.length ; i++) {
367             Instant instant = start.plusSeconds((long)(Math.random() * (end.getEpochSecond() - start.getEpochSecond())));
368             instants[i] = new Object[] { instant };
369         }
370         return instants;
371     }
372 
373     @DataProvider
374     public static Object[][] printfArgs() {
375         ArrayList<List<PrintfArg>> res = new ArrayList<>();
376         List<List<PrintfArg>> perms = new ArrayList<>(perms(0, PrintfArg.values()));
377         for (int i = 0 ; i < 100 ; i++) {
378             Collections.shuffle(perms);
379             res.addAll(perms);
380         }
381         return res.stream()
382                 .map(l -> new Object[] { l })
383                 .toArray(Object[][]::new);
384     }
385 
386     enum PrintfArg {
387         INT(int.class, C_INT, "%d", "%d", arena -> 42, 42),
388         LONG(long.class, C_LONG_LONG, "%lld", "%d", arena -> 84L, 84L),
389         DOUBLE(double.class, C_DOUBLE, "%.4f", "%.4f", arena -> 1.2345d, 1.2345d),
390         STRING(MemorySegment.class, C_POINTER, "%s", "%s", arena -> arena.allocateFrom("str"), "str");
391 
392         final Class<?> carrier;
393         final ValueLayout layout;
394         final String nativeFormat;
395         final String javaFormat;
396         final Function<Arena, ?> nativeValueFactory;
397         final Object javaValue;
398 
399         <Z, L extends ValueLayout> PrintfArg(Class<?> carrier, L layout, String nativeFormat, String javaFormat,
400                                              Function<Arena, Z> nativeValueFactory, Object javaValue) {
401             this.carrier = carrier;
402             this.layout = layout;
403             this.nativeFormat = nativeFormat;
404             this.javaFormat = javaFormat;
405             this.nativeValueFactory = nativeValueFactory;
406             this.javaValue = javaValue;
407         }
408 
409         public Object nativeValue(Arena arena) {
410             return nativeValueFactory.apply(arena);
411         }
412     }
413 
414     static <Z> Set<List<Z>> perms(int count, Z[] arr) {
415         if (count == arr.length) {
416             return Set.of(List.of());
417         } else {
418             return Arrays.stream(arr)
419                     .flatMap(num -> {
420                         Set<List<Z>> perms = perms(count + 1, arr);
421                         return Stream.concat(
422                                 //take n
423                                 perms.stream().map(l -> {
424                                     List<Z> li = new ArrayList<>(l);
425                                     li.add(num);
426                                     return li;
427                                 }),
428                                 //drop n
429                                 perms.stream());
430                     }).collect(Collectors.toCollection(LinkedHashSet::new));
431         }
432     }
433 }