1 /*
  2  * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.
  8  *
  9  * This code is distributed in the hope that it will be useful, but WITHOUT
 10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 12  * version 2 for more details (a copy is included in the LICENSE file that
 13  * accompanied this code).
 14  *
 15  * You should have received a copy of the GNU General Public License version
 16  * 2 along with this work; if not, write to the Free Software Foundation,
 17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 18  *
 19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 20  * or visit www.oracle.com if you need additional information or have any
 21  * questions.
 22  */
 23 
 24 /*
 25  * @test id=default
 26  * @summary Test using a custom scheduler as the default virtual thread scheduler
 27  * @requires vm.continuations
 28  * @library /test/lib
 29  * @run junit/othervm -Djdk.virtualThreadScheduler.implClass=CustomDefaultScheduler$CustomScheduler1
 30  *     --enable-native-access=ALL-UNNAMED CustomDefaultScheduler
 31  * @run junit/othervm -Djdk.virtualThreadScheduler.implClass=CustomDefaultScheduler$CustomScheduler2
 32  *     --enable-native-access=ALL-UNNAMED CustomDefaultScheduler
 33  */
 34 
 35 /*
 36  * @test id=poller-modes
 37  * @requires vm.continuations
 38  * @requires (os.family == "linux") | (os.family == "mac")
 39  * @library /test/lib
 40  * @run junit/othervm -Djdk.pollerMode=3
 41  *     -Djdk.virtualThreadScheduler.implClass=CustomDefaultScheduler$CustomScheduler1
 42  *     --enable-native-access=ALL-UNNAMED CustomDefaultScheduler
 43  * @run junit/othervm -Djdk.pollerMode=3
 44  *     -Djdk.virtualThreadScheduler.implClass=CustomDefaultScheduler$CustomScheduler2
 45  *     --enable-native-access=ALL-UNNAMED CustomDefaultScheduler
 46  */
 47 
 48 import java.io.Closeable;
 49 import java.io.IOException;
 50 import java.lang.Thread.VirtualThreadScheduler;
 51 import java.lang.Thread.VirtualThreadTask;
 52 import java.net.InetAddress;
 53 import java.net.InetSocketAddress;
 54 import java.net.ServerSocket;
 55 import java.net.Socket;
 56 import java.util.Set;
 57 import java.util.concurrent.ConcurrentHashMap;
 58 import java.util.concurrent.Executors;
 59 import java.util.concurrent.ExecutorService;
 60 import java.util.concurrent.ThreadFactory;
 61 import java.util.concurrent.CountDownLatch;
 62 import java.util.concurrent.atomic.AtomicBoolean;
 63 import java.util.concurrent.atomic.AtomicReference;
 64 import java.util.concurrent.locks.LockSupport;
 65 
 66 import jdk.test.lib.thread.VThreadRunner;
 67 
 68 import org.junit.jupiter.api.Test;
 69 import org.junit.jupiter.api.BeforeAll;
 70 import static org.junit.jupiter.api.Assertions.*;
 71 import static org.junit.jupiter.api.Assumptions.*;
 72 
 73 class CustomDefaultScheduler {
 74     private static String schedulerClassName;
 75 
 76     @BeforeAll
 77     static void setup() {
 78         schedulerClassName = System.getProperty("jdk.virtualThreadScheduler.implClass");
 79     }
 80 
 81     /**
 82      * Custom scheduler that uses a thread pool.
 83      */
 84     public static class CustomScheduler1 implements VirtualThreadScheduler {
 85         private final ExecutorService pool;
 86 
 87         public CustomScheduler1() {
 88             ThreadFactory factory = Thread.ofPlatform().daemon().factory();
 89             pool = Executors.newFixedThreadPool(1, factory);
 90         }
 91 
 92         @Override
 93         public void onStart(VirtualThreadTask task) {
 94             pool.execute(task);
 95         }
 96 
 97         @Override
 98         public void onContinue(VirtualThreadTask task) {
 99             pool.execute(task);
100         }
101     }
102 
103     /**
104      * Custom scheduler that delegates to the built-in default scheduler.
105      */
106     public static class CustomScheduler2 implements VirtualThreadScheduler {
107         private final VirtualThreadScheduler builtinScheduler;
108 
109         // the set of threads that executed with this scheduler
110         private final Set<Thread> executed = ConcurrentHashMap.newKeySet();
111 
112         public CustomScheduler2(VirtualThreadScheduler builtinScheduler) {
113             this.builtinScheduler = builtinScheduler;
114         }
115 
116         VirtualThreadScheduler builtinScheduler() {
117             return builtinScheduler;
118         }
119 
120         @Override
121         public void onStart(VirtualThreadTask task) {
122             executed.add(task.thread());
123             builtinScheduler.onStart(task);
124         }
125 
126         @Override
127         public void onContinue(VirtualThreadTask task) {
128             executed.add(task.thread());
129             builtinScheduler.onContinue(task);
130         }
131 
132         Set<Thread> threadsExecuted() {
133             return executed;
134         }
135     }
136 
137     /**
138      * Test that a virtual thread uses the custom default scheduler.
139      */
140     @Test
141     void testUseCustomScheduler() throws Exception {
142         var ref = new AtomicReference<VirtualThreadScheduler>();
143         Thread.startVirtualThread(() -> {
144             ref.set(VirtualThreadScheduler.current());
145         }).join();
146         VirtualThreadScheduler scheduler = ref.get();
147         assertEquals(schedulerClassName, scheduler.getClass().getName());
148     }
149 
150     /**
151      * Test virtual thread park/unpark when using custom default scheduler.
152      */
153     @Test
154     void testPark() throws Exception {
155         var done = new AtomicBoolean();
156         var thread = Thread.startVirtualThread(() -> {
157             while (!done.get()) {
158                 LockSupport.park();
159             }
160         });
161         try {
162             await(thread, Thread.State.WAITING);
163         } finally {
164             done.set(true);
165             LockSupport.unpark(thread);
166             thread.join();
167         }
168     }
169 
170     /**
171      * Test virtual thread blocking on a monitor when using custom default scheduler.
172      */
173     @Test
174     void testBlockMonitor() throws Exception {
175         var ready = new CountDownLatch(1);
176         var lock = new Object();
177         var thread = Thread.ofVirtual().unstarted(() -> {
178             ready.countDown();
179             synchronized (lock) {
180             }
181         });
182         synchronized (lock) {
183             thread.start();
184             ready.await();
185             await(thread, Thread.State.BLOCKED);
186         }
187         thread.join();
188     }
189 
190     /**
191      * Test virtual thread blocking on a socket I/O when using custom default scheduler.
192      */
193     @Test
194     void testBlockSocket() throws Exception {
195         VThreadRunner.run(() -> {
196             try (var connection = new Connection()) {
197                 Socket s1 = connection.socket1();
198                 Socket s2 = connection.socket2();
199 
200                 // write bytes after current virtual thread has parked
201                 byte[] ba1 = "XXX".getBytes("UTF-8");
202                 runAfterParkedAsync(() -> s1.getOutputStream().write(ba1));
203 
204                 byte[] ba2 = new byte[10];
205                 int n = s2.getInputStream().read(ba2);
206                 assertTrue(n > 0);
207                 assertTrue(ba2[0] == 'X');
208             }
209         });
210     }
211 
212     /**
213      * Test one virtual thread starting a second virtual thread when both are scheduled
214      * by a custom default scheduler delegating to builtin default scheduler.
215      */
216     @Test
217     void testDelegatingToBuiltin() throws Exception {
218         assumeTrue(schedulerClassName.equals("CustomDefaultScheduler$CustomScheduler2"));
219 
220         var schedulerRef = new AtomicReference<VirtualThreadScheduler>();
221         var vthreadRef = new AtomicReference<Thread>();
222 
223         var vthread1 = Thread.ofVirtual().start(() -> {
224             schedulerRef.set(VirtualThreadScheduler.current());
225             Thread vthread2 = Thread.ofVirtual().start(() -> {
226                 assertTrue(VirtualThreadScheduler.current() == schedulerRef.get());
227                 vthreadRef.set(Thread.currentThread());
228             });
229             try {
230                 vthread2.join();
231             } catch (InterruptedException e) {
232                 // fail();
233             }
234         });
235 
236         vthread1.join();
237         Thread vthread2 = vthreadRef.get();
238 
239         var customScheduler = (CustomScheduler2) schedulerRef.get();
240         assertTrue(customScheduler.threadsExecuted().contains(vthread1));
241         assertTrue(customScheduler.threadsExecuted().contains(vthread2));
242     }
243 
244     /**
245      * Waits for the given thread to reach a given state.
246      */
247     private void await(Thread thread, Thread.State expectedState) throws InterruptedException {
248         Thread.State state = thread.getState();
249         while (state != expectedState) {
250             assertTrue(state != Thread.State.TERMINATED, "Thread has terminated");
251             Thread.sleep(10);
252             state = thread.getState();
253         }
254     }
255 
256     @FunctionalInterface
257     private interface ThrowingRunnable {
258         void run() throws Exception;
259     }
260 
261     /**
262      * Runs the given task asynchronously after the current virtual thread has parked.
263      * @return the thread started to run the task
264      */
265     private static Thread runAfterParkedAsync(ThrowingRunnable task) {
266         Thread target = Thread.currentThread();
267         if (!target.isVirtual())
268             throw new WrongThreadException();
269         return Thread.ofPlatform().daemon().start(() -> {
270             try {
271                 Thread.State state = target.getState();
272                 while (state != Thread.State.WAITING
273                         && state != Thread.State.TIMED_WAITING) {
274                     Thread.sleep(20);
275                     state = target.getState();
276                 }
277                 Thread.sleep(20);  // give a bit more time to release carrier
278                 task.run();
279             } catch (Exception e) {
280                 e.printStackTrace();
281             }
282         });
283     }
284 
285     /**
286      * Creates a loopback connection
287      */
288     private static class Connection implements Closeable {
289         private final Socket s1;
290         private final Socket s2;
291         Connection() throws IOException {
292             var lh = InetAddress.getLoopbackAddress();
293             try (var listener = new ServerSocket()) {
294                 listener.bind(new InetSocketAddress(lh, 0));
295                 Socket s1 = new Socket();
296                 Socket s2;
297                 try {
298                     s1.connect(listener.getLocalSocketAddress());
299                     s2 = listener.accept();
300                 } catch (IOException ioe) {
301                     s1.close();
302                     throw ioe;
303                 }
304                 this.s1 = s1;
305                 this.s2 = s2;
306             }
307 
308         }
309         Socket socket1() {
310             return s1;
311         }
312         Socket socket2() {
313             return s2;
314         }
315         @Override
316         public void close() throws IOException {
317             s1.close();
318             s2.close();
319         }
320     }
321 }