1 /*
  2  * Copyright (c) 2020, 2023, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.
  8  *
  9  * This code is distributed in the hope that it will be useful, but WITHOUT
 10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 12  * version 2 for more details (a copy is included in the LICENSE file that
 13  * accompanied this code).
 14  *
 15  * You should have received a copy of the GNU General Public License version
 16  * 2 along with this work; if not, write to the Free Software Foundation,
 17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 18  *
 19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 20  * or visit www.oracle.com if you need additional information or have any
 21  * questions.
 22  */
 23 
 24 import java.lang.foreign.AddressLayout;
 25 import java.lang.foreign.Arena;
 26 import java.lang.foreign.FunctionDescriptor;
 27 import java.lang.foreign.GroupLayout;
 28 import java.lang.foreign.Linker;
 29 import java.lang.foreign.MemoryLayout;
 30 import java.lang.foreign.MemorySegment;
 31 import java.lang.foreign.PaddingLayout;
 32 import java.lang.foreign.SegmentAllocator;
 33 import java.lang.foreign.SequenceLayout;
 34 import java.lang.foreign.StructLayout;
 35 import java.lang.foreign.SymbolLookup;
 36 import java.lang.foreign.UnionLayout;
 37 import java.lang.foreign.ValueLayout;
 38 
 39 import java.lang.invoke.MethodHandle;
 40 import java.lang.invoke.MethodHandles;
 41 import java.lang.invoke.MethodType;
 42 import java.lang.invoke.VarHandle;
 43 import java.util.ArrayList;
 44 import java.util.List;
 45 import java.util.Random;
 46 import java.util.concurrent.ThreadLocalRandom;
 47 import java.util.concurrent.atomic.AtomicReference;
 48 import java.util.function.Consumer;
 49 import java.util.function.UnaryOperator;
 50 import java.util.random.RandomGenerator;
 51 
 52 import static java.lang.foreign.MemoryLayout.PathElement.groupElement;
 53 import static java.lang.foreign.MemoryLayout.PathElement.sequenceElement;
 54 
 55 public class NativeTestHelper {
 56 
 57     public static final boolean IS_WINDOWS = System.getProperty("os.name").startsWith("Windows");
 58 
 59     private static final MethodHandle MH_SAVER;
 60     private static final RandomGenerator DEFAULT_RANDOM;
 61 
 62     static {
 63         int seed = Integer.getInteger("NativeTestHelper.DEFAULT_RANDOM.seed", ThreadLocalRandom.current().nextInt());
 64         System.out.println("NativeTestHelper::DEFAULT_RANDOM.seed = " + seed);
 65         System.out.println("Re-run with '-DNativeTestHelper.DEFAULT_RANDOM.seed=" + seed + "' to reproduce");
 66         DEFAULT_RANDOM = new Random(seed);
 67 
 68         try {
 69             MH_SAVER = MethodHandles.lookup().findStatic(NativeTestHelper.class, "saver",
 70                     MethodType.methodType(Object.class, Object[].class, List.class, AtomicReference.class, SegmentAllocator.class, int.class));
 71         } catch (ReflectiveOperationException e) {
 72             throw new ExceptionInInitializerError(e);
 73         }
 74     }
 75 
 76     public static boolean isIntegral(MemoryLayout layout) {
 77         return layout instanceof ValueLayout valueLayout && isIntegral(valueLayout.carrier());
 78     }
 79 
 80     static boolean isIntegral(Class<?> clazz) {
 81         return clazz == byte.class || clazz == char.class || clazz == short.class
 82                 || clazz == int.class || clazz == long.class;
 83     }
 84 
 85     public static boolean isPointer(MemoryLayout layout) {
 86         return layout instanceof ValueLayout valueLayout && valueLayout.carrier() == MemorySegment.class;
 87     }
 88 
 89     public static final Linker LINKER = Linker.nativeLinker();
 90 
 91     // the constants below are useful aliases for C types. The type/carrier association is only valid for 64-bit platforms.
 92 
 93     /**
 94      * The layout for the {@code bool} C type
 95      */
 96     public static final ValueLayout.OfBoolean C_BOOL = (ValueLayout.OfBoolean) LINKER.canonicalLayouts().get("bool");
 97     /**
 98      * The layout for the {@code char} C type
 99      */
100     public static final ValueLayout.OfByte C_CHAR = (ValueLayout.OfByte) LINKER.canonicalLayouts().get("char");
101     /**
102      * The layout for the {@code short} C type
103      */
104     public static final ValueLayout.OfShort C_SHORT = (ValueLayout.OfShort) LINKER.canonicalLayouts().get("short");
105     /**
106      * The layout for the {@code int} C type
107      */
108     public static final ValueLayout.OfInt C_INT = (ValueLayout.OfInt) LINKER.canonicalLayouts().get("int");
109 
110     /**
111      * The layout for the {@code long long} C type.
112      */
113     public static final ValueLayout.OfLong C_LONG_LONG = (ValueLayout.OfLong) LINKER.canonicalLayouts().get("long long");
114     /**
115      * The layout for the {@code float} C type
116      */
117     public static final ValueLayout.OfFloat C_FLOAT = (ValueLayout.OfFloat) LINKER.canonicalLayouts().get("float");
118     /**
119      * The layout for the {@code double} C type
120      */
121     public static final ValueLayout.OfDouble C_DOUBLE = (ValueLayout.OfDouble) LINKER.canonicalLayouts().get("double");
122     /**
123      * The {@code T*} native type.
124      */
125     public static final AddressLayout C_POINTER = ((AddressLayout) LINKER.canonicalLayouts().get("void*"))
126             .withTargetLayout(MemoryLayout.sequenceLayout(Long.MAX_VALUE, C_CHAR));
127     /**
128      * The layout for the {@code size_t} C type
129      */
130     public static final ValueLayout C_SIZE_T = (ValueLayout) LINKER.canonicalLayouts().get("size_t");
131 
132     // Common layout shared by some tests
133     // struct S_PDI { void* p0; double p1; int p2; };
134     public static final MemoryLayout S_PDI_LAYOUT = switch ((int) ValueLayout.ADDRESS.byteSize()) {
135         case 8 -> MemoryLayout.structLayout(
136             C_POINTER.withName("p0"),
137             C_DOUBLE.withName("p1"),
138             C_INT.withName("p2"),
139             MemoryLayout.paddingLayout(4));
140         case 4 -> MemoryLayout.structLayout(
141             C_POINTER.withName("p0"),
142             C_DOUBLE.withName("p1"),
143             C_INT.withName("p2"));
144         default -> throw new UnsupportedOperationException("Unsupported address size");
145     };
146 
147     private static final MethodHandle FREE = LINKER.downcallHandle(
148             LINKER.defaultLookup().find("free").get(), FunctionDescriptor.ofVoid(C_POINTER));
149 
150     private static final MethodHandle MALLOC = LINKER.downcallHandle(
151             LINKER.defaultLookup().find("malloc").get(), FunctionDescriptor.of(C_POINTER, C_LONG_LONG));
152 
153     public static void freeMemory(MemorySegment address) {
154         try {
155             FREE.invokeExact(address);
156         } catch (Throwable ex) {
157             throw new IllegalStateException(ex);
158         }
159     }
160 
161     public static MemorySegment allocateMemory(long size) {
162         try {
163             return (MemorySegment) MALLOC.invokeExact(size);
164         } catch (Throwable ex) {
165             throw new IllegalStateException(ex);
166         }
167     }
168 
169     public static MemorySegment findNativeOrThrow(String name) {
170         return SymbolLookup.loaderLookup().find(name).orElseThrow();
171     }
172 
173     public static MethodHandle downcallHandle(String symbol, FunctionDescriptor desc, Linker.Option... options) {
174         return LINKER.downcallHandle(findNativeOrThrow(symbol), desc, options);
175     }
176 
177     public static MemorySegment upcallStub(Class<?> holder, String name, FunctionDescriptor descriptor) {
178         try {
179             MethodHandle target = MethodHandles.lookup().findStatic(holder, name, descriptor.toMethodType());
180             return LINKER.upcallStub(target, descriptor, Arena.ofAuto());
181         } catch (ReflectiveOperationException e) {
182             throw new RuntimeException(e);
183         }
184     }
185 
186     public static TestValue[] genTestArgs(FunctionDescriptor descriptor, SegmentAllocator allocator) {
187         return genTestArgs(DEFAULT_RANDOM, descriptor, allocator);
188     }
189 
190     public static TestValue[] genTestArgs(RandomGenerator random, FunctionDescriptor descriptor, SegmentAllocator allocator) {
191         TestValue[] result = new TestValue[descriptor.argumentLayouts().size()];
192         for (int i = 0; i < result.length; i++) {
193             result[i] = genTestValue(random, descriptor.argumentLayouts().get(i), allocator);
194         }
195         return result;
196     }
197 
198     public record TestValue (Object value, Consumer<Object> check) {}
199 
200     public static TestValue genTestValue(MemoryLayout layout, SegmentAllocator allocator) {
201         return genTestValue(DEFAULT_RANDOM, layout, allocator);
202     }
203 
204     public static TestValue genTestValue(RandomGenerator random, MemoryLayout layout, SegmentAllocator allocator) {
205         if (layout instanceof StructLayout struct) {
206             MemorySegment segment = allocator.allocate(struct);
207             List<Consumer<Object>> fieldChecks = new ArrayList<>();
208             for (MemoryLayout fieldLayout : struct.memberLayouts()) {
209                 if (fieldLayout instanceof PaddingLayout) continue;
210                 MemoryLayout.PathElement fieldPath = groupElement(fieldLayout.name().orElseThrow());
211                 fieldChecks.add(initField(random, segment, struct, fieldLayout, fieldPath, allocator));
212             }
213             return new TestValue(segment, actual -> fieldChecks.forEach(check -> check.accept(actual)));
214         } else if (layout instanceof UnionLayout union) {
215             MemorySegment segment = allocator.allocate(union);
216             List<MemoryLayout> filteredFields = union.memberLayouts().stream()
217                                                                      .filter(l -> !(l instanceof PaddingLayout))
218                                                                      .toList();
219             int fieldIdx = random.nextInt(filteredFields.size());
220             MemoryLayout fieldLayout = filteredFields.get(fieldIdx);
221             MemoryLayout.PathElement fieldPath = groupElement(fieldLayout.name().orElseThrow());
222             Consumer<Object> check = initField(random, segment, union, fieldLayout, fieldPath, allocator);
223             return new TestValue(segment, check);
224         } else if (layout instanceof SequenceLayout array) {
225             MemorySegment segment = allocator.allocate(array);
226             List<Consumer<Object>> elementChecks = new ArrayList<>();
227             for (int i = 0; i < array.elementCount(); i++) {
228                 elementChecks.add(initField(random, segment, array, array.elementLayout(), sequenceElement(i), allocator));
229             }
230             return new TestValue(segment, actual -> elementChecks.forEach(check -> check.accept(actual)));
231         } else if (layout instanceof AddressLayout) {
232             MemorySegment value = MemorySegment.ofAddress(random.nextLong());
233             return new TestValue(value, actual -> assertEquals(actual, value));
234         }else if (layout instanceof ValueLayout.OfByte) {
235             byte value = (byte) random.nextInt();
236             return new TestValue(value, actual -> assertEquals(actual, value));
237         } else if (layout instanceof ValueLayout.OfShort) {
238             short value = (short) random.nextInt();
239             return new TestValue(value, actual -> assertEquals(actual, value));
240         } else if (layout instanceof ValueLayout.OfInt) {
241             int value = random.nextInt();
242             return new TestValue(value, actual -> assertEquals(actual, value));
243         } else if (layout instanceof ValueLayout.OfLong) {
244             long value = random.nextLong();
245             return new TestValue(value, actual -> assertEquals(actual, value));
246         } else if (layout instanceof ValueLayout.OfFloat) {
247             float value = random.nextFloat();
248             return new TestValue(value, actual -> assertEquals(actual, value));
249         } else if (layout instanceof ValueLayout.OfDouble) {
250             double value = random.nextDouble();
251             return new TestValue(value, actual -> assertEquals(actual, value));
252         }
253 
254         throw new IllegalStateException("Unexpected layout: " + layout);
255     }
256 
257     private static Consumer<Object> initField(RandomGenerator random, MemorySegment container, MemoryLayout containerLayout,
258                                               MemoryLayout fieldLayout, MemoryLayout.PathElement fieldPath,
259                                               SegmentAllocator allocator) {
260         TestValue fieldValue = genTestValue(random, fieldLayout, allocator);
261         Consumer<Object> fieldCheck = fieldValue.check();
262         if (fieldLayout instanceof GroupLayout || fieldLayout instanceof SequenceLayout) {
263             UnaryOperator<MemorySegment> slicer = slicer(containerLayout, fieldPath);
264             MemorySegment slice = slicer.apply(container);
265             slice.copyFrom((MemorySegment) fieldValue.value());
266             return actual -> fieldCheck.accept(slicer.apply((MemorySegment) actual));
267         } else {
268             VarHandle accessor = containerLayout.varHandle(fieldPath);
269             //set value
270             accessor.set(container, 0L, fieldValue.value());
271             return actual -> fieldCheck.accept(accessor.get((MemorySegment) actual, 0L));
272         }
273     }
274 
275     private static UnaryOperator<MemorySegment> slicer(MemoryLayout containerLayout, MemoryLayout.PathElement fieldPath) {
276         MethodHandle slicer = containerLayout.sliceHandle(fieldPath);
277         return container -> {
278               try {
279                 return (MemorySegment) slicer.invokeExact(container, 0L);
280             } catch (Throwable e) {
281                 throw new IllegalStateException(e);
282             }
283         };
284     }
285 
286     private static void assertEquals(Object actual, Object expected) {
287         if (actual.getClass() != expected.getClass()) {
288             throw new AssertionError("Type mismatch: " + actual.getClass() + " != " + expected.getClass());
289         }
290         if (!actual.equals(expected)) {
291             throw new AssertionError("Not equal: " + actual + " != " + expected);
292         }
293     }
294 
295     /**
296      * Make an upcall stub that saves its arguments into the given 'ref' array
297      *
298      * @param fd function descriptor for the upcall stub
299      * @param capturedArgs box to save arguments in
300      * @param arena allocator for making copies of by-value structs
301      * @param retIdx the index of the argument to return
302      * @return return the upcall stub
303      */
304     public static MemorySegment makeArgSaverCB(FunctionDescriptor fd, Arena arena,
305                                                AtomicReference<Object[]> capturedArgs, int retIdx) {
306         MethodHandle target = MethodHandles.insertArguments(MH_SAVER, 1, fd.argumentLayouts(), capturedArgs, arena, retIdx);
307         target = target.asCollector(Object[].class, fd.argumentLayouts().size());
308         target = target.asType(fd.toMethodType());
309         return LINKER.upcallStub(target, fd, arena);
310     }
311 
312     private static Object saver(Object[] o, List<MemoryLayout> argLayouts, AtomicReference<Object[]> ref, SegmentAllocator allocator, int retArg) {
313         for (int i = 0; i < o.length; i++) {
314             if (argLayouts.get(i) instanceof GroupLayout gl) {
315                 MemorySegment ms = (MemorySegment) o[i];
316                 MemorySegment copy = allocator.allocate(gl);
317                 copy.copyFrom(ms);
318                 o[i] = copy;
319             }
320         }
321         ref.set(o);
322         return retArg != -1 ? o[retArg] : null;
323     }
324 }