1 /* 2 * Copyright (c) 2020, 2023, Oracle and/or its affiliates. All rights reserved. 3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. 4 * 5 * This code is free software; you can redistribute it and/or modify it 6 * under the terms of the GNU General Public License version 2 only, as 7 * published by the Free Software Foundation. 8 * 9 * This code is distributed in the hope that it will be useful, but WITHOUT 10 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or 11 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License 12 * version 2 for more details (a copy is included in the LICENSE file that 13 * accompanied this code). 14 * 15 * You should have received a copy of the GNU General Public License version 16 * 2 along with this work; if not, write to the Free Software Foundation, 17 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. 18 * 19 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA 20 * or visit www.oracle.com if you need additional information or have any 21 * questions. 22 */ 23 24 import java.lang.foreign.AddressLayout; 25 import java.lang.foreign.Arena; 26 import java.lang.foreign.FunctionDescriptor; 27 import java.lang.foreign.GroupLayout; 28 import java.lang.foreign.Linker; 29 import java.lang.foreign.MemoryLayout; 30 import java.lang.foreign.MemorySegment; 31 import java.lang.foreign.PaddingLayout; 32 import java.lang.foreign.SegmentAllocator; 33 import java.lang.foreign.SequenceLayout; 34 import java.lang.foreign.StructLayout; 35 import java.lang.foreign.SymbolLookup; 36 import java.lang.foreign.UnionLayout; 37 import java.lang.foreign.ValueLayout; 38 39 import java.lang.invoke.MethodHandle; 40 import java.lang.invoke.MethodHandles; 41 import java.lang.invoke.MethodType; 42 import java.lang.invoke.VarHandle; 43 import java.util.ArrayList; 44 import java.util.List; 45 import java.util.Random; 46 import java.util.concurrent.ThreadLocalRandom; 47 import java.util.concurrent.atomic.AtomicReference; 48 import java.util.function.Consumer; 49 import java.util.function.UnaryOperator; 50 import java.util.random.RandomGenerator; 51 52 import static java.lang.foreign.MemoryLayout.PathElement.groupElement; 53 import static java.lang.foreign.MemoryLayout.PathElement.sequenceElement; 54 55 public class NativeTestHelper { 56 57 public static final boolean IS_WINDOWS = System.getProperty("os.name").startsWith("Windows"); 58 59 private static final MethodHandle MH_SAVER; 60 private static final RandomGenerator DEFAULT_RANDOM; 61 62 static { 63 int seed = Integer.getInteger("NativeTestHelper.DEFAULT_RANDOM.seed", ThreadLocalRandom.current().nextInt()); 64 System.out.println("NativeTestHelper::DEFAULT_RANDOM.seed = " + seed); 65 System.out.println("Re-run with '-DNativeTestHelper.DEFAULT_RANDOM.seed=" + seed + "' to reproduce"); 66 DEFAULT_RANDOM = new Random(seed); 67 68 try { 69 MH_SAVER = MethodHandles.lookup().findStatic(NativeTestHelper.class, "saver", 70 MethodType.methodType(Object.class, Object[].class, List.class, AtomicReference.class, SegmentAllocator.class, int.class)); 71 } catch (ReflectiveOperationException e) { 72 throw new ExceptionInInitializerError(e); 73 } 74 } 75 76 public static boolean isIntegral(MemoryLayout layout) { 77 return layout instanceof ValueLayout valueLayout && isIntegral(valueLayout.carrier()); 78 } 79 80 static boolean isIntegral(Class<?> clazz) { 81 return clazz == byte.class || clazz == char.class || clazz == short.class 82 || clazz == int.class || clazz == long.class; 83 } 84 85 public static boolean isPointer(MemoryLayout layout) { 86 return layout instanceof ValueLayout valueLayout && valueLayout.carrier() == MemorySegment.class; 87 } 88 89 public static final Linker LINKER = Linker.nativeLinker(); 90 91 // the constants below are useful aliases for C types. The type/carrier association is only valid for 64-bit platforms. 92 93 /** 94 * The layout for the {@code bool} C type 95 */ 96 public static final ValueLayout.OfBoolean C_BOOL = (ValueLayout.OfBoolean) LINKER.canonicalLayouts().get("bool"); 97 /** 98 * The layout for the {@code char} C type 99 */ 100 public static final ValueLayout.OfByte C_CHAR = (ValueLayout.OfByte) LINKER.canonicalLayouts().get("char"); 101 /** 102 * The layout for the {@code short} C type 103 */ 104 public static final ValueLayout.OfShort C_SHORT = (ValueLayout.OfShort) LINKER.canonicalLayouts().get("short"); 105 /** 106 * The layout for the {@code int} C type 107 */ 108 public static final ValueLayout.OfInt C_INT = (ValueLayout.OfInt) LINKER.canonicalLayouts().get("int"); 109 110 /** 111 * The layout for the {@code long long} C type. 112 */ 113 public static final ValueLayout.OfLong C_LONG_LONG = (ValueLayout.OfLong) LINKER.canonicalLayouts().get("long long"); 114 /** 115 * The layout for the {@code float} C type 116 */ 117 public static final ValueLayout.OfFloat C_FLOAT = (ValueLayout.OfFloat) LINKER.canonicalLayouts().get("float"); 118 /** 119 * The layout for the {@code double} C type 120 */ 121 public static final ValueLayout.OfDouble C_DOUBLE = (ValueLayout.OfDouble) LINKER.canonicalLayouts().get("double"); 122 /** 123 * The {@code T*} native type. 124 */ 125 public static final AddressLayout C_POINTER = ((AddressLayout) LINKER.canonicalLayouts().get("void*")) 126 .withTargetLayout(MemoryLayout.sequenceLayout(Long.MAX_VALUE, C_CHAR)); 127 /** 128 * The layout for the {@code size_t} C type 129 */ 130 public static final ValueLayout C_SIZE_T = (ValueLayout) LINKER.canonicalLayouts().get("size_t"); 131 132 // Common layout shared by some tests 133 // struct S_PDI { void* p0; double p1; int p2; }; 134 public static final MemoryLayout S_PDI_LAYOUT = switch ((int) ValueLayout.ADDRESS.byteSize()) { 135 case 8 -> MemoryLayout.structLayout( 136 C_POINTER.withName("p0"), 137 C_DOUBLE.withName("p1"), 138 C_INT.withName("p2"), 139 MemoryLayout.paddingLayout(4)); 140 case 4 -> MemoryLayout.structLayout( 141 C_POINTER.withName("p0"), 142 C_DOUBLE.withName("p1"), 143 C_INT.withName("p2")); 144 default -> throw new UnsupportedOperationException("Unsupported address size"); 145 }; 146 147 private static final MethodHandle FREE = LINKER.downcallHandle( 148 LINKER.defaultLookup().find("free").get(), FunctionDescriptor.ofVoid(C_POINTER)); 149 150 private static final MethodHandle MALLOC = LINKER.downcallHandle( 151 LINKER.defaultLookup().find("malloc").get(), FunctionDescriptor.of(C_POINTER, C_LONG_LONG)); 152 153 public static void freeMemory(MemorySegment address) { 154 try { 155 FREE.invokeExact(address); 156 } catch (Throwable ex) { 157 throw new IllegalStateException(ex); 158 } 159 } 160 161 public static MemorySegment allocateMemory(long size) { 162 try { 163 return (MemorySegment) MALLOC.invokeExact(size); 164 } catch (Throwable ex) { 165 throw new IllegalStateException(ex); 166 } 167 } 168 169 public static MemorySegment findNativeOrThrow(String name) { 170 return SymbolLookup.loaderLookup().find(name).orElseThrow(); 171 } 172 173 public static MethodHandle downcallHandle(String symbol, FunctionDescriptor desc, Linker.Option... options) { 174 return LINKER.downcallHandle(findNativeOrThrow(symbol), desc, options); 175 } 176 177 public static MemorySegment upcallStub(Class<?> holder, String name, FunctionDescriptor descriptor) { 178 try { 179 MethodHandle target = MethodHandles.lookup().findStatic(holder, name, descriptor.toMethodType()); 180 return LINKER.upcallStub(target, descriptor, Arena.ofAuto()); 181 } catch (ReflectiveOperationException e) { 182 throw new RuntimeException(e); 183 } 184 } 185 186 public static TestValue[] genTestArgs(FunctionDescriptor descriptor, SegmentAllocator allocator) { 187 return genTestArgs(DEFAULT_RANDOM, descriptor, allocator); 188 } 189 190 public static TestValue[] genTestArgs(RandomGenerator random, FunctionDescriptor descriptor, SegmentAllocator allocator) { 191 TestValue[] result = new TestValue[descriptor.argumentLayouts().size()]; 192 for (int i = 0; i < result.length; i++) { 193 result[i] = genTestValue(random, descriptor.argumentLayouts().get(i), allocator); 194 } 195 return result; 196 } 197 198 public record TestValue (Object value, Consumer<Object> check) {} 199 200 public static TestValue genTestValue(MemoryLayout layout, SegmentAllocator allocator) { 201 return genTestValue(DEFAULT_RANDOM, layout, allocator); 202 } 203 204 public static TestValue genTestValue(RandomGenerator random, MemoryLayout layout, SegmentAllocator allocator) { 205 if (layout instanceof StructLayout struct) { 206 MemorySegment segment = allocator.allocate(struct); 207 List<Consumer<Object>> fieldChecks = new ArrayList<>(); 208 for (MemoryLayout fieldLayout : struct.memberLayouts()) { 209 if (fieldLayout instanceof PaddingLayout) continue; 210 MemoryLayout.PathElement fieldPath = groupElement(fieldLayout.name().orElseThrow()); 211 fieldChecks.add(initField(random, segment, struct, fieldLayout, fieldPath, allocator)); 212 } 213 return new TestValue(segment, actual -> fieldChecks.forEach(check -> check.accept(actual))); 214 } else if (layout instanceof UnionLayout union) { 215 MemorySegment segment = allocator.allocate(union); 216 List<MemoryLayout> filteredFields = union.memberLayouts().stream() 217 .filter(l -> !(l instanceof PaddingLayout)) 218 .toList(); 219 int fieldIdx = random.nextInt(filteredFields.size()); 220 MemoryLayout fieldLayout = filteredFields.get(fieldIdx); 221 MemoryLayout.PathElement fieldPath = groupElement(fieldLayout.name().orElseThrow()); 222 Consumer<Object> check = initField(random, segment, union, fieldLayout, fieldPath, allocator); 223 return new TestValue(segment, check); 224 } else if (layout instanceof SequenceLayout array) { 225 MemorySegment segment = allocator.allocate(array); 226 List<Consumer<Object>> elementChecks = new ArrayList<>(); 227 for (int i = 0; i < array.elementCount(); i++) { 228 elementChecks.add(initField(random, segment, array, array.elementLayout(), sequenceElement(i), allocator)); 229 } 230 return new TestValue(segment, actual -> elementChecks.forEach(check -> check.accept(actual))); 231 } else if (layout instanceof AddressLayout) { 232 MemorySegment value = MemorySegment.ofAddress(random.nextLong()); 233 return new TestValue(value, actual -> assertEquals(actual, value)); 234 }else if (layout instanceof ValueLayout.OfByte) { 235 byte value = (byte) random.nextInt(); 236 return new TestValue(value, actual -> assertEquals(actual, value)); 237 } else if (layout instanceof ValueLayout.OfShort) { 238 short value = (short) random.nextInt(); 239 return new TestValue(value, actual -> assertEquals(actual, value)); 240 } else if (layout instanceof ValueLayout.OfInt) { 241 int value = random.nextInt(); 242 return new TestValue(value, actual -> assertEquals(actual, value)); 243 } else if (layout instanceof ValueLayout.OfLong) { 244 long value = random.nextLong(); 245 return new TestValue(value, actual -> assertEquals(actual, value)); 246 } else if (layout instanceof ValueLayout.OfFloat) { 247 float value = random.nextFloat(); 248 return new TestValue(value, actual -> assertEquals(actual, value)); 249 } else if (layout instanceof ValueLayout.OfDouble) { 250 double value = random.nextDouble(); 251 return new TestValue(value, actual -> assertEquals(actual, value)); 252 } 253 254 throw new IllegalStateException("Unexpected layout: " + layout); 255 } 256 257 private static Consumer<Object> initField(RandomGenerator random, MemorySegment container, MemoryLayout containerLayout, 258 MemoryLayout fieldLayout, MemoryLayout.PathElement fieldPath, 259 SegmentAllocator allocator) { 260 TestValue fieldValue = genTestValue(random, fieldLayout, allocator); 261 Consumer<Object> fieldCheck = fieldValue.check(); 262 if (fieldLayout instanceof GroupLayout || fieldLayout instanceof SequenceLayout) { 263 UnaryOperator<MemorySegment> slicer = slicer(containerLayout, fieldPath); 264 MemorySegment slice = slicer.apply(container); 265 slice.copyFrom((MemorySegment) fieldValue.value()); 266 return actual -> fieldCheck.accept(slicer.apply((MemorySegment) actual)); 267 } else { 268 VarHandle accessor = containerLayout.varHandle(fieldPath); 269 //set value 270 accessor.set(container, 0L, fieldValue.value()); 271 return actual -> fieldCheck.accept(accessor.get((MemorySegment) actual, 0L)); 272 } 273 } 274 275 private static UnaryOperator<MemorySegment> slicer(MemoryLayout containerLayout, MemoryLayout.PathElement fieldPath) { 276 MethodHandle slicer = containerLayout.sliceHandle(fieldPath); 277 return container -> { 278 try { 279 return (MemorySegment) slicer.invokeExact(container, 0L); 280 } catch (Throwable e) { 281 throw new IllegalStateException(e); 282 } 283 }; 284 } 285 286 private static void assertEquals(Object actual, Object expected) { 287 if (actual.getClass() != expected.getClass()) { 288 throw new AssertionError("Type mismatch: " + actual.getClass() + " != " + expected.getClass()); 289 } 290 if (!actual.equals(expected)) { 291 throw new AssertionError("Not equal: " + actual + " != " + expected); 292 } 293 } 294 295 /** 296 * Make an upcall stub that saves its arguments into the given 'ref' array 297 * 298 * @param fd function descriptor for the upcall stub 299 * @param capturedArgs box to save arguments in 300 * @param arena allocator for making copies of by-value structs 301 * @param retIdx the index of the argument to return 302 * @return return the upcall stub 303 */ 304 public static MemorySegment makeArgSaverCB(FunctionDescriptor fd, Arena arena, 305 AtomicReference<Object[]> capturedArgs, int retIdx) { 306 MethodHandle target = MethodHandles.insertArguments(MH_SAVER, 1, fd.argumentLayouts(), capturedArgs, arena, retIdx); 307 target = target.asCollector(Object[].class, fd.argumentLayouts().size()); 308 target = target.asType(fd.toMethodType()); 309 return LINKER.upcallStub(target, fd, arena); 310 } 311 312 private static Object saver(Object[] o, List<MemoryLayout> argLayouts, AtomicReference<Object[]> ref, SegmentAllocator allocator, int retArg) { 313 for (int i = 0; i < o.length; i++) { 314 if (argLayouts.get(i) instanceof GroupLayout gl) { 315 MemorySegment ms = (MemorySegment) o[i]; 316 MemorySegment copy = allocator.allocate(gl); 317 copy.copyFrom(ms); 318 o[i] = copy; 319 } 320 } 321 ref.set(o); 322 return retArg != -1 ? o[retArg] : null; 323 } 324 }