1 /*
  2  * Copyright (c) 2023, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.
  8  *
  9  * This code is distributed in the hope that it will be useful, but WITHOUT
 10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 12  * version 2 for more details (a copy is included in the LICENSE file that
 13  * accompanied this code).
 14  *
 15  * You should have received a copy of the GNU General Public License version
 16  * 2 along with this work; if not, write to the Free Software Foundation,
 17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 18  *
 19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 20  * or visit www.oracle.com if you need additional information or have any
 21  * questions.
 22  */
 23 
 24 package jdk.test.lib.thread;
 25 
 26 import java.lang.foreign.Arena;
 27 import java.lang.foreign.FunctionDescriptor;
 28 import java.lang.foreign.Linker;
 29 import java.lang.foreign.MemorySegment;
 30 import java.lang.foreign.SymbolLookup;
 31 import java.lang.foreign.ValueLayout;
 32 
 33 import java.lang.invoke.*;
 34 import java.nio.file.Path;
 35 import java.time.Duration;
 36 import java.util.concurrent.atomic.AtomicReference;
 37 
 38 /**
 39  * Helper class to allow tests run an action in a virtual thread while pinning its carrier.
 40  *
 41  * It defines the {@code runPinned} method to run an action with a native frame on the stack.
 42  */
 43 public class VThreadPinner {
 44     private static final Path JAVA_LIBRARY_PATH = Path.of(System.getProperty("java.library.path"));
 45     private static final Path LIB_PATH = JAVA_LIBRARY_PATH.resolve(System.mapLibraryName("VThreadPinner"));
 46 
 47     // method handle to call the native function
 48     private static final MethodHandle INVOKER = invoker();
 49 
 50     // function pointer to call
 51     private static final MemorySegment UPCALL_STUB = upcallStub();
 52 
 53     /**
 54      * Thread local with the action to run.
 55      */
 56     private static final ThreadLocal<ActionRunner> ACTION_RUNNER = new ThreadLocal<>();
 57 
 58     /**
 59      * Runs an action, capturing any exception or error thrown.
 60      */
 61     private static class ActionRunner implements Runnable {
 62         private final ThrowingAction<?> action;
 63         private Throwable throwable;
 64 
 65         ActionRunner(ThrowingAction<?> action) {
 66             this.action = action;
 67         }
 68 
 69         @Override
 70         public void run() {
 71             try {
 72                 action.run();
 73             } catch (Throwable ex) {
 74                 throwable = ex;
 75             }
 76         }
 77 
 78         Throwable exception() {
 79             return throwable;
 80         }
 81     }
 82 
 83     /**
 84      * Called by the native function to run the action stashed in the thread local. The
 85      * action runs with the native frame on the stack.
 86      */
 87     private static void callback() {
 88         ACTION_RUNNER.get().run();
 89     }
 90 
 91     /**
 92      * A function to run from a virtual thread pinned to its carrier.
 93      */
 94     @FunctionalInterface
 95     public interface ThrowingAction<X extends Throwable> {
 96         void run() throws X;
 97     }
 98 
 99     /**
100      * Runs the given action on virtual thread pinned to its carrier.
101      */
102     public static <X extends Throwable> void runPinned(ThrowingAction<X> action) throws X {
103         if (!Thread.currentThread().isVirtual()) {
104             throw new IllegalCallerException("Not a virtual thread");
105         }
106         var runner = new ActionRunner(action);
107         ACTION_RUNNER.set(runner);
108         try {
109             INVOKER.invoke(UPCALL_STUB);
110         } catch (Throwable e) {
111             throw new RuntimeException(e);
112         } finally {
113             ACTION_RUNNER.remove();
114         }
115         Throwable ex = runner.exception();
116         if (ex != null) {
117             if (ex instanceof RuntimeException e)
118                 throw e;
119             if (ex instanceof Error e)
120                 throw e;
121             throw (X) ex;
122         }
123     }
124 
125     /**
126      * Returns a method handle to the native function void call(void *(*f)(void *)).
127      */
128     private static MethodHandle invoker() {
129         Linker abi = Linker.nativeLinker();
130         try {
131             SymbolLookup lib = SymbolLookup.libraryLookup(LIB_PATH, Arena.global());
132             MemorySegment symbol = lib.find("call").orElseThrow();
133             FunctionDescriptor desc = FunctionDescriptor.ofVoid(ValueLayout.ADDRESS);
134             return abi.downcallHandle(symbol, desc);
135         } catch (Throwable e) {
136             throw new RuntimeException(e);
137         }
138     }
139 
140     /**
141      * Returns an upcall stub to use as a function pointer to invoke the callback method.
142      */
143     private static MemorySegment upcallStub() {
144         Linker abi = Linker.nativeLinker();
145         try {
146             MethodHandle callback = MethodHandles.lookup()
147                     .findStatic(VThreadPinner.class, "callback", MethodType.methodType(void.class));
148             return abi.upcallStub(callback, FunctionDescriptor.ofVoid(), Arena.global());
149         } catch (Throwable e) {
150             throw new RuntimeException(e);
151         }
152     }
153 }