1 /* 2 * Copyright (c) 2020, 2023, Oracle and/or its affiliates. All rights reserved. 3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. 4 * 5 * This code is free software; you can redistribute it and/or modify it 6 * under the terms of the GNU General Public License version 2 only, as 7 * published by the Free Software Foundation. 8 * 9 * This code is distributed in the hope that it will be useful, but WITHOUT 10 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or 11 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License 12 * version 2 for more details (a copy is included in the LICENSE file that 13 * accompanied this code). 14 * 15 * You should have received a copy of the GNU General Public License version 16 * 2 along with this work; if not, write to the Free Software Foundation, 17 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. 18 * 19 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA 20 * or visit www.oracle.com if you need additional information or have any 21 * questions. 22 */ 23 24 import java.lang.foreign.AddressLayout; 25 import java.lang.foreign.Arena; 26 import java.lang.foreign.FunctionDescriptor; 27 import java.lang.foreign.GroupLayout; 28 import java.lang.foreign.Linker; 29 import java.lang.foreign.MemoryLayout; 30 import java.lang.foreign.MemorySegment; 31 import java.lang.foreign.PaddingLayout; 32 import java.lang.foreign.SegmentAllocator; 33 import java.lang.foreign.SequenceLayout; 34 import java.lang.foreign.StructLayout; 35 import java.lang.foreign.SymbolLookup; 36 import java.lang.foreign.UnionLayout; 37 import java.lang.foreign.ValueLayout; 38 39 import java.lang.invoke.MethodHandle; 40 import java.lang.invoke.MethodHandles; 41 import java.lang.invoke.MethodType; 42 import java.lang.invoke.VarHandle; 43 import java.util.ArrayList; 44 import java.util.List; 45 import java.util.Random; 46 import java.util.concurrent.ThreadLocalRandom; 47 import java.util.concurrent.atomic.AtomicReference; 48 import java.util.function.Consumer; 49 import java.util.function.UnaryOperator; 50 import java.util.random.RandomGenerator; 51 52 import static java.lang.foreign.MemoryLayout.PathElement.groupElement; 53 import static java.lang.foreign.MemoryLayout.PathElement.sequenceElement; 54 55 public class NativeTestHelper { 56 57 public static final boolean IS_WINDOWS = System.getProperty("os.name").startsWith("Windows"); 58 59 private static final MethodHandle MH_SAVER; 60 private static final RandomGenerator DEFAULT_RANDOM; 61 62 static { 63 int seed = Integer.getInteger("NativeTestHelper.DEFAULT_RANDOM.seed", ThreadLocalRandom.current().nextInt()); 64 System.out.println("NativeTestHelper::DEFAULT_RANDOM.seed = " + seed); 65 System.out.println("Re-run with '-DNativeTestHelper.DEFAULT_RANDOM.seed=" + seed + "' to reproduce"); 66 DEFAULT_RANDOM = new Random(seed); 67 68 try { 69 MH_SAVER = MethodHandles.lookup().findStatic(NativeTestHelper.class, "saver", 70 MethodType.methodType(Object.class, Object[].class, List.class, AtomicReference.class, SegmentAllocator.class, int.class)); 71 } catch (ReflectiveOperationException e) { 72 throw new ExceptionInInitializerError(e); 73 } 74 } 75 76 public static boolean isIntegral(MemoryLayout layout) { 77 return layout instanceof ValueLayout valueLayout && isIntegral(valueLayout.carrier()); 78 } 79 80 static boolean isIntegral(Class<?> clazz) { 81 return clazz == byte.class || clazz == char.class || clazz == short.class 82 || clazz == int.class || clazz == long.class; 83 } 84 85 public static boolean isPointer(MemoryLayout layout) { 86 return layout instanceof ValueLayout valueLayout && valueLayout.carrier() == MemorySegment.class; 87 } 88 89 // the constants below are useful aliases for C types. The type/carrier association is only valid for 64-bit platforms. 90 91 /** 92 * The layout for the {@code bool} C type 93 */ 94 public static final ValueLayout.OfBoolean C_BOOL = ValueLayout.JAVA_BOOLEAN; 95 /** 96 * The layout for the {@code char} C type 97 */ 98 public static final ValueLayout.OfByte C_CHAR = ValueLayout.JAVA_BYTE; 99 /** 100 * The layout for the {@code short} C type 101 */ 102 public static final ValueLayout.OfShort C_SHORT = ValueLayout.JAVA_SHORT; 103 /** 104 * The layout for the {@code int} C type 105 */ 106 public static final ValueLayout.OfInt C_INT = ValueLayout.JAVA_INT; 107 108 /** 109 * The layout for the {@code long long} C type. 110 */ 111 public static final ValueLayout.OfLong C_LONG_LONG = ValueLayout.JAVA_LONG; 112 /** 113 * The layout for the {@code float} C type 114 */ 115 public static final ValueLayout.OfFloat C_FLOAT = ValueLayout.JAVA_FLOAT; 116 /** 117 * The layout for the {@code double} C type 118 */ 119 public static final ValueLayout.OfDouble C_DOUBLE = ValueLayout.JAVA_DOUBLE; 120 /** 121 * The {@code T*} native type. 122 */ 123 public static final AddressLayout C_POINTER = ValueLayout.ADDRESS 124 .withTargetLayout(MemoryLayout.sequenceLayout(C_CHAR)); 125 126 public static final Linker LINKER = Linker.nativeLinker(); 127 128 private static final MethodHandle FREE = LINKER.downcallHandle( 129 LINKER.defaultLookup().find("free").get(), FunctionDescriptor.ofVoid(C_POINTER)); 130 131 private static final MethodHandle MALLOC = LINKER.downcallHandle( 132 LINKER.defaultLookup().find("malloc").get(), FunctionDescriptor.of(C_POINTER, C_LONG_LONG)); 133 134 public static void freeMemory(MemorySegment address) { 135 try { 136 FREE.invokeExact(address); 137 } catch (Throwable ex) { 138 throw new IllegalStateException(ex); 139 } 140 } 141 142 public static MemorySegment allocateMemory(long size) { 143 try { 144 return (MemorySegment) MALLOC.invokeExact(size); 145 } catch (Throwable ex) { 146 throw new IllegalStateException(ex); 147 } 148 } 149 150 public static MemorySegment findNativeOrThrow(String name) { 151 return SymbolLookup.loaderLookup().find(name).orElseThrow(); 152 } 153 154 public static MethodHandle downcallHandle(String symbol, FunctionDescriptor desc, Linker.Option... options) { 155 return LINKER.downcallHandle(findNativeOrThrow(symbol), desc, options); 156 } 157 158 public static MemorySegment upcallStub(Class<?> holder, String name, FunctionDescriptor descriptor) { 159 try { 160 MethodHandle target = MethodHandles.lookup().findStatic(holder, name, descriptor.toMethodType()); 161 return LINKER.upcallStub(target, descriptor, Arena.ofAuto()); 162 } catch (ReflectiveOperationException e) { 163 throw new RuntimeException(e); 164 } 165 } 166 167 public static TestValue[] genTestArgs(FunctionDescriptor descriptor, SegmentAllocator allocator) { 168 return genTestArgs(DEFAULT_RANDOM, descriptor, allocator); 169 } 170 171 public static TestValue[] genTestArgs(RandomGenerator random, FunctionDescriptor descriptor, SegmentAllocator allocator) { 172 TestValue[] result = new TestValue[descriptor.argumentLayouts().size()]; 173 for (int i = 0; i < result.length; i++) { 174 result[i] = genTestValue(random, descriptor.argumentLayouts().get(i), allocator); 175 } 176 return result; 177 } 178 179 public record TestValue (Object value, Consumer<Object> check) {} 180 181 public static TestValue genTestValue(MemoryLayout layout, SegmentAllocator allocator) { 182 return genTestValue(DEFAULT_RANDOM, layout, allocator); 183 } 184 185 public static TestValue genTestValue(RandomGenerator random, MemoryLayout layout, SegmentAllocator allocator) { 186 if (layout instanceof StructLayout struct) { 187 MemorySegment segment = allocator.allocate(struct); 188 List<Consumer<Object>> fieldChecks = new ArrayList<>(); 189 for (MemoryLayout fieldLayout : struct.memberLayouts()) { 190 if (fieldLayout instanceof PaddingLayout) continue; 191 MemoryLayout.PathElement fieldPath = groupElement(fieldLayout.name().orElseThrow()); 192 fieldChecks.add(initField(random, segment, struct, fieldLayout, fieldPath, allocator)); 193 } 194 return new TestValue(segment, actual -> fieldChecks.forEach(check -> check.accept(actual))); 195 } else if (layout instanceof UnionLayout union) { 196 MemorySegment segment = allocator.allocate(union); 197 List<MemoryLayout> filteredFields = union.memberLayouts().stream() 198 .filter(l -> !(l instanceof PaddingLayout)) 199 .toList(); 200 int fieldIdx = random.nextInt(filteredFields.size()); 201 MemoryLayout fieldLayout = filteredFields.get(fieldIdx); 202 MemoryLayout.PathElement fieldPath = groupElement(fieldLayout.name().orElseThrow()); 203 Consumer<Object> check = initField(random, segment, union, fieldLayout, fieldPath, allocator); 204 return new TestValue(segment, check); 205 } else if (layout instanceof SequenceLayout array) { 206 MemorySegment segment = allocator.allocate(array); 207 List<Consumer<Object>> elementChecks = new ArrayList<>(); 208 for (int i = 0; i < array.elementCount(); i++) { 209 elementChecks.add(initField(random, segment, array, array.elementLayout(), sequenceElement(i), allocator)); 210 } 211 return new TestValue(segment, actual -> elementChecks.forEach(check -> check.accept(actual))); 212 } else if (layout instanceof AddressLayout) { 213 MemorySegment value = MemorySegment.ofAddress(random.nextLong()); 214 return new TestValue(value, actual -> assertEquals(actual, value)); 215 }else if (layout instanceof ValueLayout.OfByte) { 216 byte value = (byte) random.nextInt(); 217 return new TestValue(value, actual -> assertEquals(actual, value)); 218 } else if (layout instanceof ValueLayout.OfShort) { 219 short value = (short) random.nextInt(); 220 return new TestValue(value, actual -> assertEquals(actual, value)); 221 } else if (layout instanceof ValueLayout.OfInt) { 222 int value = random.nextInt(); 223 return new TestValue(value, actual -> assertEquals(actual, value)); 224 } else if (layout instanceof ValueLayout.OfLong) { 225 long value = random.nextLong(); 226 return new TestValue(value, actual -> assertEquals(actual, value)); 227 } else if (layout instanceof ValueLayout.OfFloat) { 228 float value = random.nextFloat(); 229 return new TestValue(value, actual -> assertEquals(actual, value)); 230 } else if (layout instanceof ValueLayout.OfDouble) { 231 double value = random.nextDouble(); 232 return new TestValue(value, actual -> assertEquals(actual, value)); 233 } 234 235 throw new IllegalStateException("Unexpected layout: " + layout); 236 } 237 238 private static Consumer<Object> initField(RandomGenerator random, MemorySegment container, MemoryLayout containerLayout, 239 MemoryLayout fieldLayout, MemoryLayout.PathElement fieldPath, 240 SegmentAllocator allocator) { 241 TestValue fieldValue = genTestValue(random, fieldLayout, allocator); 242 Consumer<Object> fieldCheck = fieldValue.check(); 243 if (fieldLayout instanceof GroupLayout || fieldLayout instanceof SequenceLayout) { 244 UnaryOperator<MemorySegment> slicer = slicer(containerLayout, fieldPath); 245 MemorySegment slice = slicer.apply(container); 246 slice.copyFrom((MemorySegment) fieldValue.value()); 247 return actual -> fieldCheck.accept(slicer.apply((MemorySegment) actual)); 248 } else { 249 VarHandle accessor = containerLayout.varHandle(fieldPath); 250 //set value 251 accessor.set(container, fieldValue.value()); 252 return actual -> fieldCheck.accept(accessor.get((MemorySegment) actual)); 253 } 254 } 255 256 private static UnaryOperator<MemorySegment> slicer(MemoryLayout containerLayout, MemoryLayout.PathElement fieldPath) { 257 MethodHandle slicer = containerLayout.sliceHandle(fieldPath); 258 return container -> { 259 try { 260 return (MemorySegment) slicer.invokeExact(container); 261 } catch (Throwable e) { 262 throw new IllegalStateException(e); 263 } 264 }; 265 } 266 267 private static void assertEquals(Object actual, Object expected) { 268 if (actual.getClass() != expected.getClass()) { 269 throw new AssertionError("Type mismatch: " + actual.getClass() + " != " + expected.getClass()); 270 } 271 if (!actual.equals(expected)) { 272 throw new AssertionError("Not equal: " + actual + " != " + expected); 273 } 274 } 275 276 /** 277 * Make an upcall stub that saves its arguments into the given 'ref' array 278 * 279 * @param fd function descriptor for the upcall stub 280 * @param capturedArgs box to save arguments in 281 * @param arena allocator for making copies of by-value structs 282 * @param retIdx the index of the argument to return 283 * @return return the upcall stub 284 */ 285 public static MemorySegment makeArgSaverCB(FunctionDescriptor fd, Arena arena, 286 AtomicReference<Object[]> capturedArgs, int retIdx) { 287 MethodHandle target = MethodHandles.insertArguments(MH_SAVER, 1, fd.argumentLayouts(), capturedArgs, arena, retIdx); 288 target = target.asCollector(Object[].class, fd.argumentLayouts().size()); 289 target = target.asType(fd.toMethodType()); 290 return LINKER.upcallStub(target, fd, arena); 291 } 292 293 private static Object saver(Object[] o, List<MemoryLayout> argLayouts, AtomicReference<Object[]> ref, SegmentAllocator allocator, int retArg) { 294 for (int i = 0; i < o.length; i++) { 295 if (argLayouts.get(i) instanceof GroupLayout gl) { 296 MemorySegment ms = (MemorySegment) o[i]; 297 MemorySegment copy = allocator.allocate(gl); 298 copy.copyFrom(ms); 299 o[i] = copy; 300 } 301 } 302 ref.set(o); 303 return retArg != -1 ? o[retArg] : null; 304 } 305 }