1 /*
  2  * Copyright (c) 2024, 2025, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 package java.util.concurrent;
 26 
 27 import java.lang.invoke.MethodHandles;
 28 import java.lang.invoke.VarHandle;
 29 import java.time.Duration;
 30 import java.util.Objects;
 31 import java.util.function.UnaryOperator;
 32 import jdk.internal.misc.ThreadFlock;
 33 import jdk.internal.invoke.MhUtil;
 34 import jdk.internal.vm.annotation.Stable;
 35 
 36 /**
 37  * StructuredTaskScope implementation.
 38  */
 39 final class StructuredTaskScopeImpl<T, R> implements StructuredTaskScope<T, R> {
 40     private static final VarHandle CANCELLED =
 41             MhUtil.findVarHandle(MethodHandles.lookup(), "cancelled", boolean.class);
 42 
 43     private final Joiner<? super T, ? extends R> joiner;
 44     private final ThreadFactory threadFactory;
 45     private final ThreadFlock flock;
 46 
 47     // scope state, set by owner thread, read by any thread
 48     private static final int ST_FORKED         = 1,   // subtasks forked, need to join
 49                              ST_JOIN_STARTED   = 2,   // join started, can no longer fork
 50                              ST_JOIN_COMPLETED = 3,   // join completed
 51                              ST_CLOSED         = 4;   // closed
 52     private volatile int state;
 53 
 54     // set or read by any thread
 55     private volatile boolean cancelled;
 56 
 57     // timer task, only accessed by owner thread
 58     private Future<?> timerTask;
 59 
 60     // set by the timer thread, read by the owner thread
 61     private volatile boolean timeoutExpired;
 62 
 63     @SuppressWarnings("this-escape")
 64     private StructuredTaskScopeImpl(Joiner<? super T, ? extends R> joiner,
 65                                     ThreadFactory threadFactory,
 66                                     String name) {
 67         this.joiner = joiner;
 68         this.threadFactory = threadFactory;
 69         this.flock = ThreadFlock.open((name != null) ? name : Objects.toIdentityString(this));
 70     }
 71 
 72     /**
 73      * Returns a new {@code StructuredTaskScope} to use the given {@code Joiner} object
 74      * and with configuration that is the result of applying the given function to the
 75      * default configuration.
 76      */
 77     static <T, R> StructuredTaskScope<T, R> open(Joiner<? super T, ? extends R> joiner,
 78                                                  UnaryOperator<Configuration> configOperator) {
 79         Objects.requireNonNull(joiner);
 80 
 81         var config = (ConfigImpl) configOperator.apply(ConfigImpl.defaultConfig());
 82         var scope = new StructuredTaskScopeImpl<T, R>(joiner, config.threadFactory(), config.name());
 83 
 84         // schedule timeout
 85         Duration timeout = config.timeout();
 86         if (timeout != null) {
 87             boolean scheduled = false;
 88             try {
 89                 scope.scheduleTimeout(timeout);
 90                 scheduled = true;
 91             } finally {
 92                 if (!scheduled) {
 93                     scope.close();  // pop if scheduling timeout failed
 94                 }
 95             }
 96         }
 97 
 98         return scope;
 99     }
100 
101     /**
102      * Throws WrongThreadException if the current thread is not the owner thread.
103      */
104     private void ensureOwner() {
105         if (Thread.currentThread() != flock.owner()) {
106             throw new WrongThreadException("Current thread not owner");
107         }
108     }
109 
110     /**
111      * Returns true if join has been invoked and there is an outcome.
112      */
113     private boolean isJoinCompleted() {
114         return state >= ST_JOIN_COMPLETED;
115     }
116 
117     /**
118      * Interrupts all threads in this scope, except the current thread.
119      */
120     private void interruptAll() {
121         flock.threads()
122                 .filter(t -> t != Thread.currentThread())
123                 .forEach(t -> {
124                     try {
125                         t.interrupt();
126                     } catch (Throwable ignore) { }
127                 });
128     }
129 
130     /**
131      * Cancel the scope if not already cancelled.
132      */
133     private void cancel() {
134         if (!cancelled && CANCELLED.compareAndSet(this, false, true)) {
135             // prevent new threads from starting
136             flock.shutdown();
137 
138             // interrupt all unfinished threads
139             interruptAll();
140 
141             // wakeup join
142             flock.wakeup();
143         }
144     }
145 
146     /**
147      * Schedules a task to cancel the scope on timeout.
148      */
149     private void scheduleTimeout(Duration timeout) {
150         assert Thread.currentThread() == flock.owner() && timerTask == null;
151         long nanos = TimeUnit.NANOSECONDS.convert(timeout);
152         timerTask = ForkJoinPool.commonPool().schedule(() -> {
153             if (!cancelled) {
154                 timeoutExpired = true;
155                 cancel();
156             }
157         }, nanos, TimeUnit.NANOSECONDS);
158     }
159 
160     /**
161      * Cancels the timer task if set.
162      */
163     private void cancelTimeout() {
164         assert Thread.currentThread() == flock.owner();
165         if (timerTask != null) {
166             timerTask.cancel(false);
167         }
168     }
169 
170     /**
171      * Invoked by the thread for a subtask when the subtask completes before scope is cancelled.
172      */
173     private void onComplete(SubtaskImpl<? extends T> subtask) {
174         assert subtask.state() != Subtask.State.UNAVAILABLE;
175         if (joiner.onComplete(subtask)) {
176             cancel();
177         }
178     }
179 
180     @Override
181     public <U extends T> Subtask<U> fork(Callable<? extends U> task) {
182         Objects.requireNonNull(task);
183         ensureOwner();
184         int s = state;
185         if (s > ST_FORKED) {
186             throw new IllegalStateException("join already called or scope is closed");
187         }
188 
189         var subtask = new SubtaskImpl<U>(this, task);
190 
191         // notify joiner, even if cancelled
192         if (joiner.onFork(subtask)) {
193             cancel();
194         }
195 
196         if (!cancelled) {
197             // create thread to run task
198             Thread thread = threadFactory.newThread(subtask);
199             if (thread == null) {
200                 throw new RejectedExecutionException("Rejected by thread factory");
201             }
202 
203             // attempt to start the thread
204             subtask.setThread(thread);
205             try {
206                 flock.start(thread);
207             } catch (IllegalStateException e) {
208                 // shutdown by another thread, or underlying flock is shutdown due
209                 // to unstructured use
210             }
211         }
212 
213         // force owner to join
214         if (s < ST_FORKED) {
215             state = ST_FORKED;
216         }
217         return subtask;
218     }
219 
220     @Override
221     public <U extends T> Subtask<U> fork(Runnable task) {
222         Objects.requireNonNull(task);
223         return fork(() -> { task.run(); return null; });
224     }
225 
226     @Override
227     public R join() throws InterruptedException {
228         ensureOwner();
229         if (state >= ST_JOIN_COMPLETED) {
230             throw new IllegalStateException("Already joined or scope is closed");
231         }
232 
233         // wait for all subtasks, the scope to be cancelled, or interrupt
234         try {
235             flock.awaitAll();
236         } catch (InterruptedException e) {
237             state = ST_JOIN_STARTED;  // joining not completed, prevent new forks
238             throw e;
239         }
240 
241         // all subtasks completed or scope cancelled
242         state = ST_JOIN_COMPLETED;
243 
244         // invoke joiner onTimeout if timeout expired
245         if (timeoutExpired) {
246             cancel();  // ensure cancelled before calling onTimeout
247             joiner.onTimeout();
248         } else {
249             cancelTimeout();
250         }
251 
252         // invoke joiner to get result
253         try {
254             return joiner.result();
255         } catch (Throwable e) {
256             throw new FailedException(e);
257         }
258     }
259 
260     @Override
261     public boolean isCancelled() {
262         return cancelled;
263     }
264 
265     @Override
266     public void close() {
267         ensureOwner();
268         int s = state;
269         if (s == ST_CLOSED) {
270             return;
271         }
272 
273         // cancel the scope if join did not complete
274         if (s < ST_JOIN_COMPLETED) {
275             cancel();
276             cancelTimeout();
277         }
278 
279         // wait for stragglers
280         try {
281             flock.close();
282         } finally {
283             state = ST_CLOSED;
284         }
285 
286         // throw ISE if the owner didn't join after forking
287         if (s == ST_FORKED) {
288             throw new IllegalStateException("Owner did not join after forking");
289         }
290     }
291 
292     @Override
293     public String toString() {
294         return flock.name();
295     }
296 
297     /**
298      * Subtask implementation, runs the task specified to the fork method.
299      */
300     static final class SubtaskImpl<T> implements Subtask<T>, Runnable {
301         private static final AltResult RESULT_NULL = new AltResult(Subtask.State.SUCCESS);
302 
303         private record AltResult(Subtask.State state, Throwable exception) {
304             AltResult(Subtask.State state) {
305                 this(state, null);
306             }
307         }
308 
309         private final StructuredTaskScopeImpl<? super T, ?> scope;
310         private final Callable<? extends T> task;
311         private volatile Object result;
312         @Stable private Thread thread;
313 
314         SubtaskImpl(StructuredTaskScopeImpl<? super T, ?> scope, Callable<? extends T> task) {
315             this.scope = scope;
316             this.task = task;
317         }
318 
319         /**
320          * Sets the thread for this subtask.
321          */
322         void setThread(Thread thread) {
323             assert thread.getState() == Thread.State.NEW;
324             this.thread = thread;
325         }
326 
327         /**
328          * Throws IllegalStateException if the caller thread is not the subtask and
329          * the scope owner has not joined.
330          */
331         private void ensureJoinedIfNotSubtask() {
332             if (Thread.currentThread() != thread && !scope.isJoinCompleted()) {
333                 throw new IllegalStateException();
334             }
335         }
336 
337         @Override
338         public void run() {
339             if (Thread.currentThread() != thread) {
340                 throw new WrongThreadException();
341             }
342 
343             T result = null;
344             Throwable ex = null;
345             try {
346                 result = task.call();
347             } catch (Throwable e) {
348                 ex = e;
349             }
350 
351             // nothing to do if scope is cancelled
352             if (scope.isCancelled())
353                 return;
354 
355             // set result/exception and invoke onComplete
356             if (ex == null) {
357                 this.result = (result != null) ? result : RESULT_NULL;
358             } else {
359                 this.result = new AltResult(State.FAILED, ex);
360             }
361             scope.onComplete(this);
362         }
363 
364         @Override
365         public Subtask.State state() {
366             Object result = this.result;
367             if (result == null) {
368                 return State.UNAVAILABLE;
369             } else if (result instanceof AltResult alt) {
370                 // null or failed
371                 return alt.state();
372             } else {
373                 return State.SUCCESS;
374             }
375         }
376 
377         @Override
378         public T get() {
379             ensureJoinedIfNotSubtask();
380             Object result = this.result;
381             if (result instanceof AltResult) {
382                 if (result == RESULT_NULL) return null;
383             } else if (result != null) {
384                 @SuppressWarnings("unchecked")
385                 T r = (T) result;
386                 return r;
387             }
388             throw new IllegalStateException(
389                     "Result is unavailable or subtask did not complete successfully");
390         }
391 
392         @Override
393         public Throwable exception() {
394             ensureJoinedIfNotSubtask();
395             Object result = this.result;
396             if (result instanceof AltResult alt && alt.state() == State.FAILED) {
397                 return alt.exception();
398             }
399             throw new IllegalStateException(
400                     "Exception is unavailable or subtask did not complete with exception");
401         }
402 
403         @Override
404         public String toString() {
405             String stateAsString = switch (state()) {
406                 case UNAVAILABLE -> "[Unavailable]";
407                 case SUCCESS     -> "[Completed successfully]";
408                 case FAILED      -> "[Failed: " + ((AltResult) result).exception() + "]";
409             };
410             return Objects.toIdentityString(this) + stateAsString;
411         }
412     }
413 
414     /**
415      * Configuration implementation.
416      */
417     record ConfigImpl(ThreadFactory threadFactory,
418                       String name,
419                       Duration timeout) implements Configuration {
420         static Configuration defaultConfig() {
421             return new ConfigImpl(Thread.ofVirtual().factory(), null, null);
422         }
423 
424         @Override
425         public Configuration withThreadFactory(ThreadFactory threadFactory) {
426             return new ConfigImpl(Objects.requireNonNull(threadFactory), name, timeout);
427         }
428 
429         @Override
430         public Configuration withName(String name) {
431             return new ConfigImpl(threadFactory, Objects.requireNonNull(name), timeout);
432         }
433 
434         @Override
435         public Configuration withTimeout(Duration timeout) {
436             return new ConfigImpl(threadFactory, name, Objects.requireNonNull(timeout));
437         }
438     }
439 }