< prev index next >

test/jdk/java/foreign/TestUpcall.java

Print this page
*** 21,51 ***
   *  questions.
   *
   */
  
  /*
!  * @test
   * @requires ((os.arch == "amd64" | os.arch == "x86_64") & sun.arch.data.model == "64") | os.arch == "aarch64"
   * @modules jdk.incubator.foreign/jdk.internal.foreign
   * @build NativeTestHelper CallGeneratorHelper TestUpcall
   *
   * @run testng/othervm -XX:+IgnoreUnrecognizedVMOptions -XX:-VerifyDependencies
   *   --enable-native-access=ALL-UNNAMED -Dgenerator.sample.factor=17
   *   TestUpcall
   */
  
  import jdk.incubator.foreign.CLinker;
  import jdk.incubator.foreign.FunctionDescriptor;
  import jdk.incubator.foreign.SymbolLookup;
  import jdk.incubator.foreign.MemoryAddress;
  import jdk.incubator.foreign.MemoryLayout;
  import jdk.incubator.foreign.MemorySegment;
  
  import jdk.incubator.foreign.ResourceScope;
  import org.testng.annotations.BeforeClass;
  import org.testng.annotations.Test;
  
  import java.lang.invoke.MethodHandle;
  import java.lang.invoke.MethodHandles;
  import java.lang.invoke.MethodType;
  import java.util.ArrayList;
  import java.util.List;
  import java.util.concurrent.atomic.AtomicReference;
  import java.util.function.Consumer;
  import java.util.stream.Collectors;
  
  import static java.lang.invoke.MethodHandles.insertArguments;
- import static jdk.incubator.foreign.CLinker.C_POINTER;
  import static org.testng.Assert.assertEquals;
  
  
  public class TestUpcall extends CallGeneratorHelper {
  
      static {
          System.loadLibrary("TestUpcall");
      }
!     static CLinker abi = CLinker.getInstance();
  
      static final SymbolLookup LOOKUP = SymbolLookup.loaderLookup();
  
      static MethodHandle DUMMY;
      static MethodHandle PASS_AND_SAVE;
--- 21,90 ---
   *  questions.
   *
   */
  
  /*
!  * @test id=scope
   * @requires ((os.arch == "amd64" | os.arch == "x86_64") & sun.arch.data.model == "64") | os.arch == "aarch64"
   * @modules jdk.incubator.foreign/jdk.internal.foreign
   * @build NativeTestHelper CallGeneratorHelper TestUpcall
   *
   * @run testng/othervm -XX:+IgnoreUnrecognizedVMOptions -XX:-VerifyDependencies
   *   --enable-native-access=ALL-UNNAMED -Dgenerator.sample.factor=17
+  *   -DUPCALL_TEST_TYPE=SCOPE
   *   TestUpcall
   */
  
+ /*
+  * @test id=no_scope
+  * @requires ((os.arch == "amd64" | os.arch == "x86_64") & sun.arch.data.model == "64") | os.arch == "aarch64"
+  * @modules jdk.incubator.foreign/jdk.internal.foreign
+  * @build NativeTestHelper CallGeneratorHelper TestUpcall
+  *
+  * @run testng/othervm -XX:+IgnoreUnrecognizedVMOptions -XX:-VerifyDependencies
+  *   --enable-native-access=ALL-UNNAMED -Dgenerator.sample.factor=17
+  *   -DUPCALL_TEST_TYPE=NO_SCOPE
+  *   TestUpcall
+  */
+ 
+ /*
+  * @test id=async
+  * @requires ((os.arch == "amd64" | os.arch == "x86_64") & sun.arch.data.model == "64") | os.arch == "aarch64"
+  * @modules jdk.incubator.foreign/jdk.internal.foreign
+  * @build NativeTestHelper CallGeneratorHelper TestUpcall
+  *
+  * @run testng/othervm -XX:+IgnoreUnrecognizedVMOptions -XX:-VerifyDependencies
+  *   --enable-native-access=ALL-UNNAMED -Dgenerator.sample.factor=17
+  *   -DUPCALL_TEST_TYPE=ASYNC
+  *   TestUpcall
+  */
+ 
+ import jdk.incubator.foreign.Addressable;
  import jdk.incubator.foreign.CLinker;
  import jdk.incubator.foreign.FunctionDescriptor;
+ import jdk.incubator.foreign.NativeSymbol;
+ import jdk.incubator.foreign.SegmentAllocator;
  import jdk.incubator.foreign.SymbolLookup;
  import jdk.incubator.foreign.MemoryAddress;
  import jdk.incubator.foreign.MemoryLayout;
  import jdk.incubator.foreign.MemorySegment;
  
  import jdk.incubator.foreign.ResourceScope;
+ import org.testng.SkipException;
  import org.testng.annotations.BeforeClass;
  import org.testng.annotations.Test;
  
  import java.lang.invoke.MethodHandle;
  import java.lang.invoke.MethodHandles;
  import java.lang.invoke.MethodType;
  import java.util.ArrayList;
+ import java.util.HashMap;
  import java.util.List;
+ import java.util.Map;
  import java.util.concurrent.atomic.AtomicReference;
  import java.util.function.Consumer;
  import java.util.stream.Collectors;
  
  import static java.lang.invoke.MethodHandles.insertArguments;
  import static org.testng.Assert.assertEquals;
  
  
  public class TestUpcall extends CallGeneratorHelper {
  
+     private enum TestType {
+         SCOPE,
+         NO_SCOPE,
+         ASYNC
+     }
+ 
+     private static final TestType UPCALL_TEST_TYPE = TestType.valueOf(System.getProperty("UPCALL_TEST_TYPE"));
+ 
      static {
          System.loadLibrary("TestUpcall");
+         System.loadLibrary("AsyncInvokers");
      }
!     static CLinker abi = CLinker.systemCLinker();
  
      static final SymbolLookup LOOKUP = SymbolLookup.loaderLookup();
  
      static MethodHandle DUMMY;
      static MethodHandle PASS_AND_SAVE;

*** 78,59 ***
          } catch (Throwable ex) {
              throw new IllegalStateException(ex);
          }
      }
  
!     static MemoryAddress dummyStub;
  
      @BeforeClass
      void setup() {
          dummyStub = abi.upcallStub(DUMMY, FunctionDescriptor.ofVoid(), ResourceScope.newImplicitScope());
      }
  
      @Test(dataProvider="functions", dataProviderClass=CallGeneratorHelper.class)
      public void testUpcalls(int count, String fName, Ret ret, List<ParamType> paramTypes, List<StructFieldType> fields) throws Throwable {
          List<Consumer<Object>> returnChecks = new ArrayList<>();
          List<Consumer<Object[]>> argChecks = new ArrayList<>();
!         MemoryAddress addr = LOOKUP.lookup(fName).get();
!         MethodType mtype = methodType(ret, paramTypes, fields);
!         try (NativeScope scope = new NativeScope()) {
!             MethodHandle mh = abi.downcallHandle(addr, scope, mtype, function(ret, paramTypes, fields));
!             Object[] args = makeArgs(scope.scope(), ret, paramTypes, fields, returnChecks, argChecks);
              Object[] callArgs = args;
              Object res = mh.invokeWithArguments(callArgs);
              argChecks.forEach(c -> c.accept(args));
              if (ret == Ret.NON_VOID) {
                  returnChecks.forEach(c -> c.accept(res));
              }
          }
      }
  
      @Test(dataProvider="functions", dataProviderClass=CallGeneratorHelper.class)
!     public void testUpcallsNoScope(int count, String fName, Ret ret, List<ParamType> paramTypes, List<StructFieldType> fields) throws Throwable {
          List<Consumer<Object>> returnChecks = new ArrayList<>();
          List<Consumer<Object[]>> argChecks = new ArrayList<>();
!         MemoryAddress addr = LOOKUP.lookup(fName).get();
!         MethodType mtype = methodType(ret, paramTypes, fields);
!         MethodHandle mh = abi.downcallHandle(addr, IMPLICIT_ALLOCATOR, mtype, function(ret, paramTypes, fields));
!         Object[] args = makeArgs(ResourceScope.newImplicitScope(), ret, paramTypes, fields, returnChecks, argChecks);
!         Object[] callArgs = args;
!         Object res = mh.invokeWithArguments(callArgs);
!         argChecks.forEach(c -> c.accept(args));
!         if (ret == Ret.NON_VOID) {
!             returnChecks.forEach(c -> c.accept(res));
          }
      }
  
!     static MethodType methodType(Ret ret, List<ParamType> params, List<StructFieldType> fields) {
!         MethodType mt = ret == Ret.VOID ?
!                 MethodType.methodType(void.class) : MethodType.methodType(paramCarrier(params.get(0).layout(fields)));
!         for (ParamType p : params) {
!             mt = mt.appendParameterTypes(paramCarrier(p.layout(fields)));
          }
!         mt = mt.appendParameterTypes(MemoryAddress.class); //the callback
!         return mt;
      }
  
      static FunctionDescriptor function(Ret ret, List<ParamType> params, List<StructFieldType> fields) {
          List<MemoryLayout> paramLayouts = params.stream().map(p -> p.layout(fields)).collect(Collectors.toList());
          paramLayouts.add(C_POINTER); // the callback
--- 117,94 ---
          } catch (Throwable ex) {
              throw new IllegalStateException(ex);
          }
      }
  
!     static NativeSymbol dummyStub;
  
      @BeforeClass
      void setup() {
          dummyStub = abi.upcallStub(DUMMY, FunctionDescriptor.ofVoid(), ResourceScope.newImplicitScope());
      }
  
+     private static void checkSelected(TestType type) {
+         if (UPCALL_TEST_TYPE != type)
+             return;//throw new SkipException("Skipping tests that were not selected");
+     }
+ 
      @Test(dataProvider="functions", dataProviderClass=CallGeneratorHelper.class)
      public void testUpcalls(int count, String fName, Ret ret, List<ParamType> paramTypes, List<StructFieldType> fields) throws Throwable {
+         checkSelected(TestType.SCOPE);
+ 
          List<Consumer<Object>> returnChecks = new ArrayList<>();
          List<Consumer<Object[]>> argChecks = new ArrayList<>();
!         NativeSymbol addr = LOOKUP.lookup(fName).get();
!         try (ResourceScope scope = ResourceScope.newConfinedScope()) {
!             SegmentAllocator allocator = SegmentAllocator.newNativeArena(scope);
!             MethodHandle mh = downcallHandle(abi, addr, allocator, function(ret, paramTypes, fields));
!             Object[] args = makeArgs(scope, ret, paramTypes, fields, returnChecks, argChecks);
              Object[] callArgs = args;
              Object res = mh.invokeWithArguments(callArgs);
              argChecks.forEach(c -> c.accept(args));
              if (ret == Ret.NON_VOID) {
                  returnChecks.forEach(c -> c.accept(res));
              }
          }
      }
  
      @Test(dataProvider="functions", dataProviderClass=CallGeneratorHelper.class)
!     public void testUpcallsAsync(int count, String fName, Ret ret, List<ParamType> paramTypes, List<StructFieldType> fields) throws Throwable {
+         checkSelected(TestType.ASYNC);
          List<Consumer<Object>> returnChecks = new ArrayList<>();
          List<Consumer<Object[]>> argChecks = new ArrayList<>();
!         NativeSymbol addr = LOOKUP.lookup(fName).get();
!         try (ResourceScope scope = ResourceScope.newSharedScope()) {
!             SegmentAllocator allocator = SegmentAllocator.newNativeArena(scope);
!             FunctionDescriptor descriptor = function(ret, paramTypes, fields);
!             MethodHandle mh = reverse(downcallHandle(abi, addr, allocator, descriptor));
!             Object[] args = makeArgs(ResourceScope.newImplicitScope(), ret, paramTypes, fields, returnChecks, argChecks);
! 
!             mh = mh.asSpreader(Object[].class, args.length);
!             mh = MethodHandles.insertArguments(mh, 0, (Object) args);
+             FunctionDescriptor callbackDesc = descriptor.returnLayout()
+                     .map(FunctionDescriptor::of)
+                     .orElse(FunctionDescriptor.ofVoid());
+             NativeSymbol callback = abi.upcallStub(mh, callbackDesc, scope);
+ 
+             MethodHandle invoker = asyncInvoker(ret, ret == Ret.VOID ? null : paramTypes.get(0), fields);
+ 
+             Object res = invoker.type().returnType() == MemorySegment.class
+                     ? invoker.invoke(allocator, callback)
+                     : invoker.invoke(callback);
+             argChecks.forEach(c -> c.accept(args));
+             if (ret == Ret.NON_VOID) {
+                 returnChecks.forEach(c -> c.accept(res));
+             }
          }
      }
  
!     private static final Map<String, MethodHandle> INVOKERS = new HashMap<>();
! 
!     private MethodHandle asyncInvoker(Ret ret, ParamType returnType, List<StructFieldType> fields) {
!         if (ret == Ret.VOID) {
!             String name = "call_async_V";
+             return INVOKERS.computeIfAbsent(name, symbol ->
+                 abi.downcallHandle(
+                     LOOKUP.lookup(symbol).orElseThrow(),
+                         FunctionDescriptor.ofVoid(C_POINTER)));
          }
! 
!         String name = "call_async_" + returnType.name().charAt(0)
+                 + (returnType == ParamType.STRUCT ? "_" + sigCode(fields) : "");
+ 
+         return INVOKERS.computeIfAbsent(name, symbol -> {
+             NativeSymbol invokerSymbol = LOOKUP.lookup(symbol).orElseThrow();
+             MemoryLayout returnLayout = returnType.layout(fields);
+             FunctionDescriptor desc = FunctionDescriptor.of(returnLayout, C_POINTER);
+ 
+             return abi.downcallHandle(invokerSymbol, desc);
+         });
      }
  
      static FunctionDescriptor function(Ret ret, List<ParamType> params, List<StructFieldType> fields) {
          List<MemoryLayout> paramLayouts = params.stream().map(p -> p.layout(fields)).collect(Collectors.toList());
          paramLayouts.add(C_POINTER); // the callback

*** 148,23 ***
          args[params.size()] = makeCallback(scope, ret, params, fields, checks, argChecks);
          return args;
      }
  
      @SuppressWarnings("unchecked")
!     static MemoryAddress makeCallback(ResourceScope scope, Ret ret, List<ParamType> params, List<StructFieldType> fields, List<Consumer<Object>> checks, List<Consumer<Object[]>> argChecks) {
          if (params.isEmpty()) {
!             return dummyStub.address();
          }
  
          AtomicReference<Object[]> box = new AtomicReference<>();
          MethodHandle mh = insertArguments(PASS_AND_SAVE, 1, box);
          mh = mh.asCollector(Object[].class, params.size());
  
          for (int i = 0; i < params.size(); i++) {
              ParamType pt = params.get(i);
              MemoryLayout layout = pt.layout(fields);
!             Class<?> carrier = paramCarrier(layout);
              mh = mh.asType(mh.type().changeParameterType(i, carrier));
  
              final int finalI = i;
              if (carrier == MemorySegment.class) {
                  argChecks.add(o -> assertStructEquals((MemorySegment) box.get()[finalI], (MemorySegment) o[finalI], layout));
--- 222,23 ---
          args[params.size()] = makeCallback(scope, ret, params, fields, checks, argChecks);
          return args;
      }
  
      @SuppressWarnings("unchecked")
!     static NativeSymbol makeCallback(ResourceScope scope, Ret ret, List<ParamType> params, List<StructFieldType> fields, List<Consumer<Object>> checks, List<Consumer<Object[]>> argChecks) {
          if (params.isEmpty()) {
!             return dummyStub;
          }
  
          AtomicReference<Object[]> box = new AtomicReference<>();
          MethodHandle mh = insertArguments(PASS_AND_SAVE, 1, box);
          mh = mh.asCollector(Object[].class, params.size());
  
          for (int i = 0; i < params.size(); i++) {
              ParamType pt = params.get(i);
              MemoryLayout layout = pt.layout(fields);
!             Class<?> carrier = carrier(layout, false);
              mh = mh.asType(mh.type().changeParameterType(i, carrier));
  
              final int finalI = i;
              if (carrier == MemorySegment.class) {
                  argChecks.add(o -> assertStructEquals((MemorySegment) box.get()[finalI], (MemorySegment) o[finalI], layout));

*** 173,11 ***
              }
          }
  
          ParamType firstParam = params.get(0);
          MemoryLayout firstlayout = firstParam.layout(fields);
!         Class<?> firstCarrier = paramCarrier(firstlayout);
  
          if (firstCarrier == MemorySegment.class) {
              checks.add(o -> assertStructEquals((MemorySegment) box.get()[0], (MemorySegment) o, firstlayout));
          } else {
              checks.add(o -> assertEquals(o, box.get()[0]));
--- 247,11 ---
              }
          }
  
          ParamType firstParam = params.get(0);
          MemoryLayout firstlayout = firstParam.layout(fields);
!         Class<?> firstCarrier = carrier(firstlayout, true);
  
          if (firstCarrier == MemorySegment.class) {
              checks.add(o -> assertStructEquals((MemorySegment) box.get()[0], (MemorySegment) o, firstlayout));
          } else {
              checks.add(o -> assertEquals(o, box.get()[0]));

*** 206,6 ***
--- 280,19 ---
      }
  
      static void dummy() {
          //do nothing
      }
+ 
+     static MethodHandle reverse(MethodHandle handle) {
+         MethodType type = handle.type();
+         if (type.returnType().equals(MemoryAddress.class)) {
+             type = type.changeReturnType(Addressable.class);
+         }
+         for (int i = 0 ; i < type.parameterCount() ; i++) {
+             if (type.parameterType(i).equals(Addressable.class)) {
+                 type.changeParameterType(i, MemoryAddress.class);
+             }
+         }
+         return handle.asType(type);
+     }
  }
< prev index next >