1 /*
  2  * Copyright (c) 2020, 2023, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.
  8  *
  9  * This code is distributed in the hope that it will be useful, but WITHOUT
 10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 12  * version 2 for more details (a copy is included in the LICENSE file that
 13  * accompanied this code).
 14  *
 15  * You should have received a copy of the GNU General Public License version
 16  * 2 along with this work; if not, write to the Free Software Foundation,
 17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 18  *
 19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 20  * or visit www.oracle.com if you need additional information or have any
 21  * questions.
 22  */
 23 
 24 import java.lang.foreign.AddressLayout;
 25 import java.lang.foreign.Arena;
 26 import java.lang.foreign.FunctionDescriptor;
 27 import java.lang.foreign.GroupLayout;
 28 import java.lang.foreign.Linker;
 29 import java.lang.foreign.MemoryLayout;
 30 import java.lang.foreign.MemorySegment;
 31 import java.lang.foreign.PaddingLayout;
 32 import java.lang.foreign.SegmentAllocator;
 33 import java.lang.foreign.SequenceLayout;
 34 import java.lang.foreign.StructLayout;
 35 import java.lang.foreign.SymbolLookup;
 36 import java.lang.foreign.UnionLayout;
 37 import java.lang.foreign.ValueLayout;
 38 
 39 import java.lang.invoke.MethodHandle;
 40 import java.lang.invoke.MethodHandles;
 41 import java.lang.invoke.MethodType;
 42 import java.lang.invoke.VarHandle;
 43 import java.util.ArrayList;
 44 import java.util.List;
 45 import java.util.Random;
 46 import java.util.concurrent.ThreadLocalRandom;
 47 import java.util.concurrent.atomic.AtomicReference;
 48 import java.util.function.Consumer;
 49 import java.util.function.UnaryOperator;
 50 import java.util.random.RandomGenerator;
 51 
 52 import static java.lang.foreign.MemoryLayout.PathElement.groupElement;
 53 import static java.lang.foreign.MemoryLayout.PathElement.sequenceElement;
 54 
 55 public class NativeTestHelper {
 56 
 57     public static final boolean IS_WINDOWS = System.getProperty("os.name").startsWith("Windows");
 58 
 59     private static final MethodHandle MH_SAVER;
 60     private static final RandomGenerator DEFAULT_RANDOM;
 61 
 62     static {
 63         int seed = Integer.getInteger("NativeTestHelper.DEFAULT_RANDOM.seed", ThreadLocalRandom.current().nextInt());
 64         System.out.println("NativeTestHelper::DEFAULT_RANDOM.seed = " + seed);
 65         System.out.println("Re-run with '-DNativeTestHelper.DEFAULT_RANDOM.seed=" + seed + "' to reproduce");
 66         DEFAULT_RANDOM = new Random(seed);
 67 
 68         try {
 69             MH_SAVER = MethodHandles.lookup().findStatic(NativeTestHelper.class, "saver",
 70                     MethodType.methodType(Object.class, Object[].class, List.class, AtomicReference.class, SegmentAllocator.class, int.class));
 71         } catch (ReflectiveOperationException e) {
 72             throw new ExceptionInInitializerError(e);
 73         }
 74     }
 75 
 76     public static boolean isIntegral(MemoryLayout layout) {
 77         return layout instanceof ValueLayout valueLayout && isIntegral(valueLayout.carrier());
 78     }
 79 
 80     static boolean isIntegral(Class<?> clazz) {
 81         return clazz == byte.class || clazz == char.class || clazz == short.class
 82                 || clazz == int.class || clazz == long.class;
 83     }
 84 
 85     public static boolean isPointer(MemoryLayout layout) {
 86         return layout instanceof ValueLayout valueLayout && valueLayout.carrier() == MemorySegment.class;
 87     }
 88 
 89     // the constants below are useful aliases for C types. The type/carrier association is only valid for 64-bit platforms.
 90 
 91     /**
 92      * The layout for the {@code bool} C type
 93      */
 94     public static final ValueLayout.OfBoolean C_BOOL = ValueLayout.JAVA_BOOLEAN;
 95     /**
 96      * The layout for the {@code char} C type
 97      */
 98     public static final ValueLayout.OfByte C_CHAR = ValueLayout.JAVA_BYTE;
 99     /**
100      * The layout for the {@code short} C type
101      */
102     public static final ValueLayout.OfShort C_SHORT = ValueLayout.JAVA_SHORT;
103     /**
104      * The layout for the {@code int} C type
105      */
106     public static final ValueLayout.OfInt C_INT = ValueLayout.JAVA_INT;
107 
108     /**
109      * The layout for the {@code long long} C type.
110      */
111     public static final ValueLayout.OfLong C_LONG_LONG = ValueLayout.JAVA_LONG;
112     /**
113      * The layout for the {@code float} C type
114      */
115     public static final ValueLayout.OfFloat C_FLOAT = ValueLayout.JAVA_FLOAT;
116     /**
117      * The layout for the {@code double} C type
118      */
119     public static final ValueLayout.OfDouble C_DOUBLE = ValueLayout.JAVA_DOUBLE;
120     /**
121      * The {@code T*} native type.
122      */
123     public static final AddressLayout C_POINTER = ValueLayout.ADDRESS
124             .withTargetLayout(MemoryLayout.sequenceLayout(C_CHAR));
125 
126     public static final Linker LINKER = Linker.nativeLinker();
127 
128     private static final MethodHandle FREE = LINKER.downcallHandle(
129             LINKER.defaultLookup().find("free").get(), FunctionDescriptor.ofVoid(C_POINTER));
130 
131     private static final MethodHandle MALLOC = LINKER.downcallHandle(
132             LINKER.defaultLookup().find("malloc").get(), FunctionDescriptor.of(C_POINTER, C_LONG_LONG));
133 
134     public static void freeMemory(MemorySegment address) {
135         try {
136             FREE.invokeExact(address);
137         } catch (Throwable ex) {
138             throw new IllegalStateException(ex);
139         }
140     }
141 
142     public static MemorySegment allocateMemory(long size) {
143         try {
144             return (MemorySegment) MALLOC.invokeExact(size);
145         } catch (Throwable ex) {
146             throw new IllegalStateException(ex);
147         }
148     }
149 
150     public static MemorySegment findNativeOrThrow(String name) {
151         return SymbolLookup.loaderLookup().find(name).orElseThrow();
152     }
153 
154     public static MethodHandle downcallHandle(String symbol, FunctionDescriptor desc, Linker.Option... options) {
155         return LINKER.downcallHandle(findNativeOrThrow(symbol), desc, options);
156     }
157 
158     public static MemorySegment upcallStub(Class<?> holder, String name, FunctionDescriptor descriptor) {
159         try {
160             MethodHandle target = MethodHandles.lookup().findStatic(holder, name, descriptor.toMethodType());
161             return LINKER.upcallStub(target, descriptor, Arena.ofAuto());
162         } catch (ReflectiveOperationException e) {
163             throw new RuntimeException(e);
164         }
165     }
166 
167     public static TestValue[] genTestArgs(FunctionDescriptor descriptor, SegmentAllocator allocator) {
168         return genTestArgs(DEFAULT_RANDOM, descriptor, allocator);
169     }
170 
171     public static TestValue[] genTestArgs(RandomGenerator random, FunctionDescriptor descriptor, SegmentAllocator allocator) {
172         TestValue[] result = new TestValue[descriptor.argumentLayouts().size()];
173         for (int i = 0; i < result.length; i++) {
174             result[i] = genTestValue(random, descriptor.argumentLayouts().get(i), allocator);
175         }
176         return result;
177     }
178 
179     public record TestValue (Object value, Consumer<Object> check) {}
180 
181     public static TestValue genTestValue(MemoryLayout layout, SegmentAllocator allocator) {
182         return genTestValue(DEFAULT_RANDOM, layout, allocator);
183     }
184 
185     public static TestValue genTestValue(RandomGenerator random, MemoryLayout layout, SegmentAllocator allocator) {
186         if (layout instanceof StructLayout struct) {
187             MemorySegment segment = allocator.allocate(struct);
188             List<Consumer<Object>> fieldChecks = new ArrayList<>();
189             for (MemoryLayout fieldLayout : struct.memberLayouts()) {
190                 if (fieldLayout instanceof PaddingLayout) continue;
191                 MemoryLayout.PathElement fieldPath = groupElement(fieldLayout.name().orElseThrow());
192                 fieldChecks.add(initField(random, segment, struct, fieldLayout, fieldPath, allocator));
193             }
194             return new TestValue(segment, actual -> fieldChecks.forEach(check -> check.accept(actual)));
195         } else if (layout instanceof UnionLayout union) {
196             MemorySegment segment = allocator.allocate(union);
197             List<MemoryLayout> filteredFields = union.memberLayouts().stream()
198                                                                      .filter(l -> !(l instanceof PaddingLayout))
199                                                                      .toList();
200             int fieldIdx = random.nextInt(filteredFields.size());
201             MemoryLayout fieldLayout = filteredFields.get(fieldIdx);
202             MemoryLayout.PathElement fieldPath = groupElement(fieldLayout.name().orElseThrow());
203             Consumer<Object> check = initField(random, segment, union, fieldLayout, fieldPath, allocator);
204             return new TestValue(segment, check);
205         } else if (layout instanceof SequenceLayout array) {
206             MemorySegment segment = allocator.allocate(array);
207             List<Consumer<Object>> elementChecks = new ArrayList<>();
208             for (int i = 0; i < array.elementCount(); i++) {
209                 elementChecks.add(initField(random, segment, array, array.elementLayout(), sequenceElement(i), allocator));
210             }
211             return new TestValue(segment, actual -> elementChecks.forEach(check -> check.accept(actual)));
212         } else if (layout instanceof AddressLayout) {
213             MemorySegment value = MemorySegment.ofAddress(random.nextLong());
214             return new TestValue(value, actual -> assertEquals(actual, value));
215         }else if (layout instanceof ValueLayout.OfByte) {
216             byte value = (byte) random.nextInt();
217             return new TestValue(value, actual -> assertEquals(actual, value));
218         } else if (layout instanceof ValueLayout.OfShort) {
219             short value = (short) random.nextInt();
220             return new TestValue(value, actual -> assertEquals(actual, value));
221         } else if (layout instanceof ValueLayout.OfInt) {
222             int value = random.nextInt();
223             return new TestValue(value, actual -> assertEquals(actual, value));
224         } else if (layout instanceof ValueLayout.OfLong) {
225             long value = random.nextLong();
226             return new TestValue(value, actual -> assertEquals(actual, value));
227         } else if (layout instanceof ValueLayout.OfFloat) {
228             float value = random.nextFloat();
229             return new TestValue(value, actual -> assertEquals(actual, value));
230         } else if (layout instanceof ValueLayout.OfDouble) {
231             double value = random.nextDouble();
232             return new TestValue(value, actual -> assertEquals(actual, value));
233         }
234 
235         throw new IllegalStateException("Unexpected layout: " + layout);
236     }
237 
238     private static Consumer<Object> initField(RandomGenerator random, MemorySegment container, MemoryLayout containerLayout,
239                                               MemoryLayout fieldLayout, MemoryLayout.PathElement fieldPath,
240                                               SegmentAllocator allocator) {
241         TestValue fieldValue = genTestValue(random, fieldLayout, allocator);
242         Consumer<Object> fieldCheck = fieldValue.check();
243         if (fieldLayout instanceof GroupLayout || fieldLayout instanceof SequenceLayout) {
244             UnaryOperator<MemorySegment> slicer = slicer(containerLayout, fieldPath);
245             MemorySegment slice = slicer.apply(container);
246             slice.copyFrom((MemorySegment) fieldValue.value());
247             return actual -> fieldCheck.accept(slicer.apply((MemorySegment) actual));
248         } else {
249             VarHandle accessor = containerLayout.varHandle(fieldPath);
250             //set value
251             accessor.set(container, fieldValue.value());
252             return actual -> fieldCheck.accept(accessor.get((MemorySegment) actual));
253         }
254     }
255 
256     private static UnaryOperator<MemorySegment> slicer(MemoryLayout containerLayout, MemoryLayout.PathElement fieldPath) {
257         MethodHandle slicer = containerLayout.sliceHandle(fieldPath);
258         return container -> {
259               try {
260                 return (MemorySegment) slicer.invokeExact(container);
261             } catch (Throwable e) {
262                 throw new IllegalStateException(e);
263             }
264         };
265     }
266 
267     private static void assertEquals(Object actual, Object expected) {
268         if (actual.getClass() != expected.getClass()) {
269             throw new AssertionError("Type mismatch: " + actual.getClass() + " != " + expected.getClass());
270         }
271         if (!actual.equals(expected)) {
272             throw new AssertionError("Not equal: " + actual + " != " + expected);
273         }
274     }
275 
276     /**
277      * Make an upcall stub that saves its arguments into the given 'ref' array
278      *
279      * @param fd function descriptor for the upcall stub
280      * @param capturedArgs box to save arguments in
281      * @param arena allocator for making copies of by-value structs
282      * @param retIdx the index of the argument to return
283      * @return return the upcall stub
284      */
285     public static MemorySegment makeArgSaverCB(FunctionDescriptor fd, Arena arena,
286                                                AtomicReference<Object[]> capturedArgs, int retIdx) {
287         MethodHandle target = MethodHandles.insertArguments(MH_SAVER, 1, fd.argumentLayouts(), capturedArgs, arena, retIdx);
288         target = target.asCollector(Object[].class, fd.argumentLayouts().size());
289         target = target.asType(fd.toMethodType());
290         return LINKER.upcallStub(target, fd, arena);
291     }
292 
293     private static Object saver(Object[] o, List<MemoryLayout> argLayouts, AtomicReference<Object[]> ref, SegmentAllocator allocator, int retArg) {
294         for (int i = 0; i < o.length; i++) {
295             if (argLayouts.get(i) instanceof GroupLayout gl) {
296                 MemorySegment ms = (MemorySegment) o[i];
297                 MemorySegment copy = allocator.allocate(gl);
298                 copy.copyFrom(ms);
299                 o[i] = copy;
300             }
301         }
302         ref.set(o);
303         return retArg != -1 ? o[retArg] : null;
304     }
305 }