1 /*
  2  * Copyright (c) 2024, 2025, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 package java.util.concurrent;
 26 
 27 import java.lang.invoke.MethodHandles;
 28 import java.lang.invoke.VarHandle;
 29 import java.time.Duration;
 30 import java.util.Objects;
 31 import java.util.function.UnaryOperator;
 32 import jdk.internal.misc.ThreadFlock;
 33 import jdk.internal.invoke.MhUtil;
 34 
 35 /**
 36  * StructuredTaskScope implementation.
 37  */
 38 final class StructuredTaskScopeImpl<T, R> implements StructuredTaskScope<T, R> {
 39     private static final VarHandle CANCELLED =
 40             MhUtil.findVarHandle(MethodHandles.lookup(), "cancelled", boolean.class);
 41 
 42     private final Joiner<? super T, ? extends R> joiner;
 43     private final ThreadFactory threadFactory;
 44     private final ThreadFlock flock;
 45 
 46     // state, only accessed by owner thread
 47     private static final int ST_NEW            = 0,
 48                              ST_FORKED         = 1,   // subtasks forked, need to join
 49                              ST_JOIN_STARTED   = 2,   // join started, can no longer fork
 50                              ST_JOIN_COMPLETED = 3,   // join completed
 51                              ST_CLOSED         = 4;   // closed
 52     private int state;
 53 
 54     // timer task, only accessed by owner thread
 55     private Future<?> timerTask;
 56 
 57     // set or read by any thread
 58     private volatile boolean cancelled;
 59 
 60     // set by the timer thread, read by the owner thread
 61     private volatile boolean timeoutExpired;
 62 
 63     @SuppressWarnings("this-escape")
 64     private StructuredTaskScopeImpl(Joiner<? super T, ? extends R> joiner,
 65                                     ThreadFactory threadFactory,
 66                                     String name) {
 67         this.joiner = joiner;
 68         this.threadFactory = threadFactory;
 69         this.flock = ThreadFlock.open((name != null) ? name : Objects.toIdentityString(this));
 70         this.state = ST_NEW;
 71     }
 72 
 73     /**
 74      * Returns a new {@code StructuredTaskScope} to use the given {@code Joiner} object
 75      * and with configuration that is the result of applying the given function to the
 76      * default configuration.
 77      */
 78     static <T, R> StructuredTaskScope<T, R> open(Joiner<? super T, ? extends R> joiner,
 79                                                  UnaryOperator<Configuration> configOperator) {
 80         Objects.requireNonNull(joiner);
 81 
 82         var config = (ConfigImpl) configOperator.apply(ConfigImpl.defaultConfig());
 83         var scope = new StructuredTaskScopeImpl<T, R>(joiner, config.threadFactory(), config.name());
 84 
 85         // schedule timeout
 86         Duration timeout = config.timeout();
 87         if (timeout != null) {
 88             boolean scheduled = false;
 89             try {
 90                 scope.scheduleTimeout(timeout);
 91                 scheduled = true;
 92             } finally {
 93                 if (!scheduled) {
 94                     scope.close();  // pop if scheduling timeout failed
 95                 }
 96             }
 97         }
 98 
 99         return scope;
100     }
101 
102     /**
103      * Throws WrongThreadException if the current thread is not the owner thread.
104      */
105     private void ensureOwner() {
106         if (Thread.currentThread() != flock.owner()) {
107             throw new WrongThreadException("Current thread not owner");
108         }
109     }
110 
111     /**
112      * Throws IllegalStateException if invoked by the owner thread and the owner thread
113      * has not joined.
114      */
115     private void ensureJoinedIfOwner() {
116         if (Thread.currentThread() == flock.owner() && state < ST_JOIN_STARTED) {
117             throw new IllegalStateException("join not called");
118         }
119     }
120 
121     /**
122      * Interrupts all threads in this scope, except the current thread.
123      */
124     private void interruptAll() {
125         flock.threads()
126                 .filter(t -> t != Thread.currentThread())
127                 .forEach(t -> {
128                     try {
129                         t.interrupt();
130                     } catch (Throwable ignore) { }
131                 });
132     }
133 
134     /**
135      * Cancel the scope if not already cancelled.
136      */
137     private void cancel() {
138         if (!cancelled && CANCELLED.compareAndSet(this, false, true)) {
139             // prevent new threads from starting
140             flock.shutdown();
141 
142             // interrupt all unfinished threads
143             interruptAll();
144 
145             // wakeup join
146             flock.wakeup();
147         }
148     }
149 
150     /**
151      * Schedules a task to cancel the scope on timeout.
152      */
153     private void scheduleTimeout(Duration timeout) {
154         assert Thread.currentThread() == flock.owner() && timerTask == null;
155         long nanos = TimeUnit.NANOSECONDS.convert(timeout);
156         timerTask = ForkJoinPool.commonPool().schedule(() -> {
157             if (!cancelled) {
158                 timeoutExpired = true;
159                 cancel();
160             }
161         }, nanos, TimeUnit.NANOSECONDS);
162     }
163 
164     /**
165      * Cancels the timer task if set.
166      */
167     private void cancelTimeout() {
168         assert Thread.currentThread() == flock.owner();
169         if (timerTask != null) {
170             timerTask.cancel(false);
171         }
172     }
173 
174     /**
175      * Invoked by the thread for a subtask when the subtask completes before scope is cancelled.
176      */
177     private void onComplete(SubtaskImpl<? extends T> subtask) {
178         assert subtask.state() != Subtask.State.UNAVAILABLE;
179         if (joiner.onComplete(subtask)) {
180             cancel();
181         }
182     }
183 
184     @Override
185     public <U extends T> Subtask<U> fork(Callable<? extends U> task) {
186         Objects.requireNonNull(task);
187         ensureOwner();
188         if (state > ST_FORKED) {
189             throw new IllegalStateException("join already called or scope is closed");
190         }
191 
192         var subtask = new SubtaskImpl<U>(this, task);
193 
194         // notify joiner, even if cancelled
195         if (joiner.onFork(subtask)) {
196             cancel();
197         }
198 
199         if (!cancelled) {
200             // create thread to run task
201             Thread thread = threadFactory.newThread(subtask);
202             if (thread == null) {
203                 throw new RejectedExecutionException("Rejected by thread factory");
204             }
205 
206             // attempt to start the thread
207             try {
208                 flock.start(thread);
209             } catch (IllegalStateException e) {
210                 // shutdown by another thread, or underlying flock is shutdown due
211                 // to unstructured use
212             }
213         }
214 
215         // force owner to join
216         state = ST_FORKED;
217         return subtask;
218     }
219 
220     @Override
221     public <U extends T> Subtask<U> fork(Runnable task) {
222         Objects.requireNonNull(task);
223         return fork(() -> { task.run(); return null; });
224     }
225 
226     @Override
227     public R join() throws InterruptedException {
228         ensureOwner();
229         if (state >= ST_JOIN_COMPLETED) {
230             throw new IllegalStateException("Already joined or scope is closed");
231         }
232 
233         // join started
234         state = ST_JOIN_STARTED;
235 
236         // wait for all subtasks, the scope to be cancelled, or interrupt
237         flock.awaitAll();
238 
239         // all subtasks completed or scope cancelled
240         state = ST_JOIN_COMPLETED;
241 
242         // invoke joiner onTimeout if timeout expired
243         if (timeoutExpired) {
244             cancel();  // ensure cancelled before calling onTimeout
245             joiner.onTimeout();
246         } else {
247             cancelTimeout();
248         }
249 
250         // invoke joiner to get result
251         try {
252             return joiner.result();
253         } catch (Throwable e) {
254             throw new FailedException(e);
255         }
256     }
257 
258     @Override
259     public boolean isCancelled() {
260         return cancelled;
261     }
262 
263     @Override
264     public void close() {
265         ensureOwner();
266         int s = state;
267         if (s == ST_CLOSED) {
268             return;
269         }
270 
271         // cancel the scope if join did not complete
272         if (s < ST_JOIN_COMPLETED) {
273             cancel();
274             cancelTimeout();
275         }
276 
277         // wait for stragglers
278         try {
279             flock.close();
280         } finally {
281             state = ST_CLOSED;
282         }
283 
284         // throw ISE if the owner didn't join after forking
285         if (s == ST_FORKED) {
286             throw new IllegalStateException("Owner did not join after forking");
287         }
288     }
289 
290     @Override
291     public String toString() {
292         return flock.name();
293     }
294 
295     /**
296      * Subtask implementation, runs the task specified to the fork method.
297      */
298     static final class SubtaskImpl<T> implements Subtask<T>, Runnable {
299         private static final AltResult RESULT_NULL = new AltResult(Subtask.State.SUCCESS);
300 
301         private record AltResult(Subtask.State state, Throwable exception) {
302             AltResult(Subtask.State state) {
303                 this(state, null);
304             }
305         }
306 
307         private final StructuredTaskScopeImpl<? super T, ?> scope;
308         private final Callable<? extends T> task;
309         private volatile Object result;
310 
311         SubtaskImpl(StructuredTaskScopeImpl<? super T, ?> scope, Callable<? extends T> task) {
312             this.scope = scope;
313             this.task = task;
314         }
315 
316         @Override
317         public void run() {
318             T result = null;
319             Throwable ex = null;
320             try {
321                 result = task.call();
322             } catch (Throwable e) {
323                 ex = e;
324             }
325 
326             // nothing to do if scope is cancelled
327             if (scope.isCancelled())
328                 return;
329 
330             // set result/exception and invoke onComplete
331             if (ex == null) {
332                 this.result = (result != null) ? result : RESULT_NULL;
333             } else {
334                 this.result = new AltResult(State.FAILED, ex);
335             }
336             scope.onComplete(this);
337         }
338 
339         @Override
340         public Subtask.State state() {
341             Object result = this.result;
342             if (result == null) {
343                 return State.UNAVAILABLE;
344             } else if (result instanceof AltResult alt) {
345                 // null or failed
346                 return alt.state();
347             } else {
348                 return State.SUCCESS;
349             }
350         }
351 
352         @Override
353         public T get() {
354             scope.ensureJoinedIfOwner();
355             Object result = this.result;
356             if (result instanceof AltResult) {
357                 if (result == RESULT_NULL) return null;
358             } else if (result != null) {
359                 @SuppressWarnings("unchecked")
360                 T r = (T) result;
361                 return r;
362             }
363             throw new IllegalStateException(
364                     "Result is unavailable or subtask did not complete successfully");
365         }
366 
367         @Override
368         public Throwable exception() {
369             scope.ensureJoinedIfOwner();
370             Object result = this.result;
371             if (result instanceof AltResult alt && alt.state() == State.FAILED) {
372                 return alt.exception();
373             }
374             throw new IllegalStateException(
375                     "Exception is unavailable or subtask did not complete with exception");
376         }
377 
378         @Override
379         public String toString() {
380             String stateAsString = switch (state()) {
381                 case UNAVAILABLE -> "[Unavailable]";
382                 case SUCCESS     -> "[Completed successfully]";
383                 case FAILED      -> "[Failed: " + ((AltResult) result).exception() + "]";
384             };
385             return Objects.toIdentityString(this) + stateAsString;
386         }
387     }
388 
389     /**
390      * Configuration implementation.
391      */
392     record ConfigImpl(ThreadFactory threadFactory,
393                       String name,
394                       Duration timeout) implements Configuration {
395         static Configuration defaultConfig() {
396             return new ConfigImpl(Thread.ofVirtual().factory(), null, null);
397         }
398 
399         @Override
400         public Configuration withThreadFactory(ThreadFactory threadFactory) {
401             return new ConfigImpl(Objects.requireNonNull(threadFactory), name, timeout);
402         }
403 
404         @Override
405         public Configuration withName(String name) {
406             return new ConfigImpl(threadFactory, Objects.requireNonNull(name), timeout);
407         }
408 
409         @Override
410         public Configuration withTimeout(Duration timeout) {
411             return new ConfigImpl(threadFactory, name, Objects.requireNonNull(timeout));
412         }
413     }
414 }