1 /*
  2  * Copyright (c) 2020, 2023, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.
  8  *
  9  * This code is distributed in the hope that it will be useful, but WITHOUT
 10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 12  * version 2 for more details (a copy is included in the LICENSE file that
 13  * accompanied this code).
 14  *
 15  * You should have received a copy of the GNU General Public License version
 16  * 2 along with this work; if not, write to the Free Software Foundation,
 17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 18  *
 19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 20  * or visit www.oracle.com if you need additional information or have any
 21  * questions.
 22  */
 23 
 24 /*
 25  * @test
 26  * @enablePreview
 27  * @requires jdk.foreign.linker != "UNSUPPORTED"
 28  * @run testng/othervm --enable-native-access=ALL-UNNAMED StdLibTest
 29  */
 30 
 31 import java.lang.invoke.MethodHandle;
 32 import java.lang.invoke.MethodHandles;
 33 import java.lang.invoke.MethodType;
 34 import java.time.Instant;
 35 import java.time.LocalDateTime;
 36 import java.time.ZoneOffset;
 37 import java.time.ZonedDateTime;
 38 import java.util.ArrayList;
 39 import java.util.Arrays;
 40 import java.util.Collections;
 41 import java.util.LinkedHashSet;
 42 import java.util.List;
 43 import java.util.Set;
 44 import java.util.function.Function;
 45 import java.util.stream.Collectors;
 46 import java.util.stream.Stream;
 47 
 48 import java.lang.foreign.*;
 49 
 50 import org.testng.annotations.*;
 51 
 52 import static org.testng.Assert.*;
 53 
 54 @Test
 55 public class StdLibTest extends NativeTestHelper {
 56 
 57     final static Linker abi = Linker.nativeLinker();
 58 
 59     private StdLibHelper stdLibHelper = new StdLibHelper();
 60 
 61     @Test(dataProvider = "stringPairs")
 62     void test_strcat(String s1, String s2) throws Throwable {
 63         assertEquals(stdLibHelper.strcat(s1, s2), s1 + s2);
 64     }
 65 
 66     @Test(dataProvider = "stringPairs")
 67     void test_strcmp(String s1, String s2) throws Throwable {
 68         assertEquals(Math.signum(stdLibHelper.strcmp(s1, s2)), Math.signum(s1.compareTo(s2)));
 69     }
 70 
 71     @Test(dataProvider = "strings")
 72     void test_puts(String s) throws Throwable {
 73         assertTrue(stdLibHelper.puts(s) >= 0);
 74     }
 75 
 76     @Test(dataProvider = "strings")
 77     void test_strlen(String s) throws Throwable {
 78         assertEquals(stdLibHelper.strlen(s), s.length());
 79     }
 80 
 81     @Test(dataProvider = "instants")
 82     void test_time(Instant instant) throws Throwable {
 83         StdLibHelper.Tm tm = stdLibHelper.gmtime(instant.getEpochSecond());
 84         LocalDateTime localTime = LocalDateTime.ofInstant(instant, ZoneOffset.UTC);
 85         assertEquals(tm.sec(), localTime.getSecond());
 86         assertEquals(tm.min(), localTime.getMinute());
 87         assertEquals(tm.hour(), localTime.getHour());
 88         //day pf year in Java has 1-offset
 89         assertEquals(tm.yday(), localTime.getDayOfYear() - 1);
 90         assertEquals(tm.mday(), localTime.getDayOfMonth());
 91         //days of week starts from Sunday in C, but on Monday in Java, also account for 1-offset
 92         assertEquals((tm.wday() + 6) % 7, localTime.getDayOfWeek().getValue() - 1);
 93         //month in Java has 1-offset
 94         assertEquals(tm.mon(), localTime.getMonth().getValue() - 1);
 95         assertEquals(tm.isdst(), ZoneOffset.UTC.getRules()
 96                 .isDaylightSavings(Instant.ofEpochMilli(instant.getEpochSecond() * 1000)));
 97     }
 98 
 99     @Test(dataProvider = "ints")
100     void test_qsort(List<Integer> ints) throws Throwable {
101         if (ints.size() > 0) {
102             int[] input = ints.stream().mapToInt(i -> i).toArray();
103             int[] sorted = stdLibHelper.qsort(input);
104             Arrays.sort(input);
105             assertEquals(sorted, input);
106         }
107     }
108 
109     @Test
110     void test_rand() throws Throwable {
111         int val = stdLibHelper.rand();
112         for (int i = 0 ; i < 100 ; i++) {
113             int newVal = stdLibHelper.rand();
114             if (newVal != val) {
115                 return; //ok
116             }
117             val = newVal;
118         }
119         fail("All values are the same! " + val);
120     }
121 
122     @Test(dataProvider = "printfArgs")
123     void test_printf(List<PrintfArg> args) throws Throwable {
124         String formatArgs = args.stream()
125                 .map(a -> a.format)
126                 .collect(Collectors.joining(","));
127 
128         String formatString = "hello(" + formatArgs + ")\n";
129 
130         String expected = String.format(formatString, args.stream()
131                 .map(a -> a.javaValue).toArray());
132 
133         int found = stdLibHelper.printf(formatString, args);
134         assertEquals(found, expected.length());
135     }
136 
137     @Test
138     void testSystemLibraryBadLookupName() {
139         assertTrue(LINKER.defaultLookup().find("strlen\u0000foobar").isEmpty());
140     }
141 
142     static class StdLibHelper {
143 
144         final static MethodHandle strcat = abi.downcallHandle(abi.defaultLookup().find("strcat").get(),
145                 FunctionDescriptor.of(C_POINTER, C_POINTER, C_POINTER));
146 
147         final static MethodHandle strcmp = abi.downcallHandle(abi.defaultLookup().find("strcmp").get(),
148                 FunctionDescriptor.of(C_INT, C_POINTER, C_POINTER));
149 
150         final static MethodHandle puts = abi.downcallHandle(abi.defaultLookup().find("puts").get(),
151                 FunctionDescriptor.of(C_INT, C_POINTER));
152 
153         final static MethodHandle strlen = abi.downcallHandle(abi.defaultLookup().find("strlen").get(),
154                 FunctionDescriptor.of(C_INT, C_POINTER));
155 
156         final static MethodHandle gmtime = abi.downcallHandle(abi.defaultLookup().find("gmtime").get(),
157                 FunctionDescriptor.of(C_POINTER.withTargetLayout(Tm.LAYOUT), C_POINTER));
158 
159         final static MethodHandle qsort = abi.downcallHandle(abi.defaultLookup().find("qsort").get(),
160                 FunctionDescriptor.ofVoid(C_POINTER, C_LONG_LONG, C_LONG_LONG, C_POINTER));
161 
162         final static FunctionDescriptor qsortComparFunction = FunctionDescriptor.of(C_INT,
163                 C_POINTER.withTargetLayout(C_INT), C_POINTER.withTargetLayout(C_INT));
164 
165         final static MethodHandle qsortCompar;
166 
167         final static MethodHandle rand = abi.downcallHandle(abi.defaultLookup().find("rand").get(),
168                 FunctionDescriptor.of(C_INT));
169 
170         final static MethodHandle vprintf = abi.downcallHandle(abi.defaultLookup().find("vprintf").get(),
171                 FunctionDescriptor.of(C_INT, C_POINTER, C_POINTER));
172 
173         final static MemorySegment printfAddr = abi.defaultLookup().find("printf").get();
174 
175         final static FunctionDescriptor printfBase = FunctionDescriptor.of(C_INT, C_POINTER);
176 
177         static {
178             try {
179                 //qsort upcall handle
180                 qsortCompar = MethodHandles.lookup().findStatic(StdLibTest.StdLibHelper.class, "qsortCompare",
181                         qsortComparFunction.toMethodType());
182             } catch (ReflectiveOperationException ex) {
183                 throw new IllegalStateException(ex);
184             }
185         }
186 
187         String strcat(String s1, String s2) throws Throwable {
188             try (var arena = Arena.ofConfined()) {
189                 MemorySegment buf = arena.allocate(s1.length() + s2.length() + 1);
190                 buf.setUtf8String(0, s1);
191                 MemorySegment other = arena.allocateUtf8String(s2);
192                 return ((MemorySegment)strcat.invokeExact(buf, other)).getUtf8String(0);
193             }
194         }
195 
196         int strcmp(String s1, String s2) throws Throwable {
197             try (var arena = Arena.ofConfined()) {
198                 MemorySegment ns1 = arena.allocateUtf8String(s1);
199                 MemorySegment ns2 = arena.allocateUtf8String(s2);
200                 return (int)strcmp.invokeExact(ns1, ns2);
201             }
202         }
203 
204         int puts(String msg) throws Throwable {
205             try (var arena = Arena.ofConfined()) {
206                 MemorySegment s = arena.allocateUtf8String(msg);
207                 return (int)puts.invokeExact(s);
208             }
209         }
210 
211         int strlen(String msg) throws Throwable {
212             try (var arena = Arena.ofConfined()) {
213                 MemorySegment s = arena.allocateUtf8String(msg);
214                 return (int)strlen.invokeExact(s);
215             }
216         }
217 
218         Tm gmtime(long arg) throws Throwable {
219             try (var arena = Arena.ofConfined()) {
220                 MemorySegment time = arena.allocate(8);
221                 time.set(C_LONG_LONG, 0, arg);
222                 return new Tm((MemorySegment)gmtime.invokeExact(time));
223             }
224         }
225 
226         static class Tm {
227 
228             //Tm pointer should never be freed directly, as it points to shared memory
229             private final MemorySegment base;
230 
231             static final MemoryLayout LAYOUT = MemoryLayout.structLayout(
232                     C_INT.withName("sec"),
233                     C_INT.withName("min"),
234                     C_INT.withName("hour"),
235                     C_INT.withName("mday"),
236                     C_INT.withName("mon"),
237                     C_INT.withName("year"),
238                     C_INT.withName("wday"),
239                     C_INT.withName("yday"),
240                     C_BOOL.withName("isdst"),
241                     MemoryLayout.paddingLayout(3)
242             );
243 
244             Tm(MemorySegment addr) {
245                 this.base = addr;
246             }
247 
248             int sec() {
249                 return base.get(C_INT, 0);
250             }
251             int min() {
252                 return base.get(C_INT, 4);
253             }
254             int hour() {
255                 return base.get(C_INT, 8);
256             }
257             int mday() {
258                 return base.get(C_INT, 12);
259             }
260             int mon() {
261                 return base.get(C_INT, 16);
262             }
263             int year() {
264                 return base.get(C_INT, 20);
265             }
266             int wday() {
267                 return base.get(C_INT, 24);
268             }
269             int yday() {
270                 return base.get(C_INT, 28);
271             }
272             boolean isdst() {
273                 return base.get(C_BOOL, 32);
274             }
275         }
276 
277         int[] qsort(int[] arr) throws Throwable {
278             //init native array
279             try (var arena = Arena.ofConfined()) {
280                 MemorySegment nativeArr = arena.allocateArray(C_INT, arr);
281 
282                 //call qsort
283                 MemorySegment qsortUpcallStub = abi.upcallStub(qsortCompar, qsortComparFunction, arena);
284 
285                 qsort.invokeExact(nativeArr, (long)arr.length, C_INT.byteSize(), qsortUpcallStub);
286 
287                 //convert back to Java array
288                 return nativeArr.toArray(C_INT);
289             }
290         }
291 
292         static int qsortCompare(MemorySegment addr1, MemorySegment addr2) {
293             return addr1.get(C_INT, 0) -
294                    addr2.get(C_INT, 0);
295         }
296 
297         int rand() throws Throwable {
298             return (int)rand.invokeExact();
299         }
300 
301         int printf(String format, List<PrintfArg> args) throws Throwable {
302             try (var arena = Arena.ofConfined()) {
303                 MemorySegment formatStr = arena.allocateUtf8String(format);
304                 return (int)specializedPrintf(args).invokeExact(formatStr,
305                         args.stream().map(a -> a.nativeValue(arena)).toArray());
306             }
307         }
308 
309         private MethodHandle specializedPrintf(List<PrintfArg> args) {
310             //method type
311             MethodType mt = MethodType.methodType(int.class, MemorySegment.class);
312             FunctionDescriptor fd = printfBase;
313             List<MemoryLayout> variadicLayouts = new ArrayList<>(args.size());
314             for (PrintfArg arg : args) {
315                 mt = mt.appendParameterTypes(arg.carrier);
316                 variadicLayouts.add(arg.layout);
317             }
318             Linker.Option varargIndex = Linker.Option.firstVariadicArg(fd.argumentLayouts().size());
319             MethodHandle mh = abi.downcallHandle(printfAddr,
320                     fd.appendArgumentLayouts(variadicLayouts.toArray(new MemoryLayout[args.size()])),
321                     varargIndex);
322             return mh.asSpreader(1, Object[].class, args.size());
323         }
324     }
325 
326     /*** data providers ***/
327 
328     @DataProvider
329     public static Object[][] ints() {
330         return perms(0, new Integer[] { 0, 1, 2, 3, 4 }).stream()
331                 .map(l -> new Object[] { l })
332                 .toArray(Object[][]::new);
333     }
334 
335     @DataProvider
336     public static Object[][] strings() {
337         return perms(0, new String[] { "a", "b", "c" }).stream()
338                 .map(l -> new Object[] { String.join("", l) })
339                 .toArray(Object[][]::new);
340     }
341 
342     @DataProvider
343     public static Object[][] stringPairs() {
344         Object[][] strings = strings();
345         Object[][] stringPairs = new Object[strings.length * strings.length][];
346         int pos = 0;
347         for (Object[] s1 : strings) {
348             for (Object[] s2 : strings) {
349                 stringPairs[pos++] = new Object[] { s1[0], s2[0] };
350             }
351         }
352         return stringPairs;
353     }
354 
355     @DataProvider
356     public static Object[][] instants() {
357         Instant start = ZonedDateTime.of(LocalDateTime.parse("2017-01-01T00:00:00"), ZoneOffset.UTC).toInstant();
358         Instant end = ZonedDateTime.of(LocalDateTime.parse("2017-12-31T00:00:00"), ZoneOffset.UTC).toInstant();
359         Object[][] instants = new Object[100][];
360         for (int i = 0 ; i < instants.length ; i++) {
361             Instant instant = start.plusSeconds((long)(Math.random() * (end.getEpochSecond() - start.getEpochSecond())));
362             instants[i] = new Object[] { instant };
363         }
364         return instants;
365     }
366 
367     @DataProvider
368     public static Object[][] printfArgs() {
369         ArrayList<List<PrintfArg>> res = new ArrayList<>();
370         List<List<PrintfArg>> perms = new ArrayList<>(perms(0, PrintfArg.values()));
371         for (int i = 0 ; i < 100 ; i++) {
372             Collections.shuffle(perms);
373             res.addAll(perms);
374         }
375         return res.stream()
376                 .map(l -> new Object[] { l })
377                 .toArray(Object[][]::new);
378     }
379 
380     enum PrintfArg {
381         INT(int.class, C_INT, "%d", arena -> 42, 42),
382         LONG(long.class, C_LONG_LONG, "%d", arena -> 84L, 84L),
383         DOUBLE(double.class, C_DOUBLE, "%.4f", arena -> 1.2345d, 1.2345d),
384         STRING(MemorySegment.class, C_POINTER, "%s", arena -> arena.allocateUtf8String("str"), "str");
385 
386         final Class<?> carrier;
387         final ValueLayout layout;
388         final String format;
389         final Function<Arena, ?> nativeValueFactory;
390         final Object javaValue;
391 
392         <Z, L extends ValueLayout> PrintfArg(Class<?> carrier, L layout, String format, Function<Arena, Z> nativeValueFactory, Object javaValue) {
393             this.carrier = carrier;
394             this.layout = layout;
395             this.format = format;
396             this.nativeValueFactory = nativeValueFactory;
397             this.javaValue = javaValue;
398         }
399 
400         public Object nativeValue(Arena arena) {
401             return nativeValueFactory.apply(arena);
402         }
403     }
404 
405     static <Z> Set<List<Z>> perms(int count, Z[] arr) {
406         if (count == arr.length) {
407             return Set.of(List.of());
408         } else {
409             return Arrays.stream(arr)
410                     .flatMap(num -> {
411                         Set<List<Z>> perms = perms(count + 1, arr);
412                         return Stream.concat(
413                                 //take n
414                                 perms.stream().map(l -> {
415                                     List<Z> li = new ArrayList<>(l);
416                                     li.add(num);
417                                     return li;
418                                 }),
419                                 //drop n
420                                 perms.stream());
421                     }).collect(Collectors.toCollection(LinkedHashSet::new));
422         }
423     }
424 }