1 /*
  2  * Copyright (c) 2024, 2025, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 package java.util.concurrent;
 26 
 27 import java.lang.invoke.MethodHandles;
 28 import java.lang.invoke.VarHandle;
 29 import java.time.Duration;
 30 import java.util.Objects;
 31 import java.util.function.Function;
 32 import jdk.internal.misc.ThreadFlock;
 33 import jdk.internal.invoke.MhUtil;
 34 
 35 /**
 36  * StructuredTaskScope implementation.
 37  */
 38 final class StructuredTaskScopeImpl<T, R> implements StructuredTaskScope<T, R> {
 39     private static final VarHandle CANCELLED =
 40             MhUtil.findVarHandle(MethodHandles.lookup(), "cancelled", boolean.class);
 41 
 42     private final Joiner<? super T, ? extends R> joiner;
 43     private final ThreadFactory threadFactory;
 44     private final ThreadFlock flock;
 45 
 46     // state, only accessed by owner thread
 47     private static final int ST_NEW            = 0;
 48     private static final int ST_FORKED         = 1;   // subtasks forked, need to join
 49     private static final int ST_JOIN_STARTED   = 2;   // join started, can no longer fork
 50     private static final int ST_JOIN_COMPLETED = 3;   // join completed
 51     private static final int ST_CLOSED         = 4;   // closed
 52     private int state;
 53 
 54     // timer task, only accessed by owner thread
 55     private Future<?> timerTask;
 56 
 57     // set or read by any thread
 58     private volatile boolean cancelled;
 59 
 60     // set by the timer thread, read by the owner thread
 61     private volatile boolean timeoutExpired;
 62 
 63     @SuppressWarnings("this-escape")
 64     private StructuredTaskScopeImpl(Joiner<? super T, ? extends R> joiner,
 65                                     ThreadFactory threadFactory,
 66                                     String name) {
 67         this.joiner = joiner;
 68         this.threadFactory = threadFactory;
 69         this.flock = ThreadFlock.open((name != null) ? name : Objects.toIdentityString(this));
 70     }
 71 
 72     /**
 73      * Returns a new {@code StructuredTaskScope} to use the given {@code Joiner} object
 74      * and with configuration that is the result of applying the given function to the
 75      * default configuration.
 76      */
 77     static <T, R> StructuredTaskScope<T, R> open(Joiner<? super T, ? extends R> joiner,
 78                                                  Function<Configuration, Configuration> configFunction) {
 79         Objects.requireNonNull(joiner);
 80 
 81         var config = (ConfigImpl) configFunction.apply(ConfigImpl.defaultConfig());
 82         var scope = new StructuredTaskScopeImpl<T, R>(joiner, config.threadFactory(), config.name());
 83 
 84         // schedule timeout
 85         Duration timeout = config.timeout();
 86         if (timeout != null) {
 87             boolean scheduled = false;
 88             try {
 89                 scope.scheduleTimeout(timeout);
 90                 scheduled = true;
 91             } finally {
 92                 if (!scheduled) {
 93                     scope.close();  // pop if scheduling timeout failed
 94                 }
 95             }
 96         }
 97 
 98         return scope;
 99     }
100 
101     /**
102      * Throws WrongThreadException if the current thread is not the owner thread.
103      */
104     private void ensureOwner() {
105         if (Thread.currentThread() != flock.owner()) {
106             throw new WrongThreadException("Current thread not owner");
107         }
108     }
109 
110     /**
111      * Throws IllegalStateException if already joined or scope is closed.
112      */
113     private void ensureNotJoined() {
114         assert Thread.currentThread() == flock.owner();
115         if (state > ST_FORKED) {
116             throw new IllegalStateException("Already joined or scope is closed");
117         }
118     }
119 
120     /**
121      * Throws IllegalStateException if invoked by the owner thread and the owner thread
122      * has not joined.
123      */
124     private void ensureJoinedIfOwner() {
125         if (Thread.currentThread() == flock.owner() && state <= ST_JOIN_STARTED) {
126             throw new IllegalStateException("join not called");
127         }
128     }
129 
130     /**
131      * Interrupts all threads in this scope, except the current thread.
132      */
133     private void interruptAll() {
134         flock.threads()
135                 .filter(t -> t != Thread.currentThread())
136                 .forEach(t -> {
137                     try {
138                         t.interrupt();
139                     } catch (Throwable ignore) { }
140                 });
141     }
142 
143     /**
144      * Cancel the scope if not already cancelled.
145      */
146     private void cancel() {
147         if (!cancelled && CANCELLED.compareAndSet(this, false, true)) {
148             // prevent new threads from starting
149             flock.shutdown();
150 
151             // interrupt all unfinished threads
152             interruptAll();
153 
154             // wakeup join
155             flock.wakeup();
156         }
157     }
158 
159     /**
160      * Schedules a task to cancel the scope on timeout.
161      */
162     private void scheduleTimeout(Duration timeout) {
163         assert Thread.currentThread() == flock.owner() && timerTask == null;
164         long nanos = TimeUnit.NANOSECONDS.convert(timeout);
165         timerTask = ForkJoinPool.commonPool().schedule(() -> {
166             if (!cancelled) {
167                 timeoutExpired = true;
168                 cancel();
169             }
170         }, nanos, TimeUnit.NANOSECONDS);
171     }
172 
173     /**
174      * Cancels the timer task if set.
175      */
176     private void cancelTimeout() {
177         assert Thread.currentThread() == flock.owner();
178         if (timerTask != null) {
179             timerTask.cancel(false);
180         }
181     }
182 
183     /**
184      * Invoked by the thread for a subtask when the subtask completes before scope is cancelled.
185      */
186     private void onComplete(SubtaskImpl<? extends T> subtask) {
187         assert subtask.state() != Subtask.State.UNAVAILABLE;
188         if (joiner.onComplete(subtask)) {
189             cancel();
190         }
191     }
192 
193     @Override
194     public <U extends T> Subtask<U> fork(Callable<? extends U> task) {
195         Objects.requireNonNull(task);
196         ensureOwner();
197         ensureNotJoined();
198 
199         var subtask = new SubtaskImpl<U>(this, task);
200 
201         // notify joiner, even if cancelled
202         if (joiner.onFork(subtask)) {
203             cancel();
204         }
205 
206         if (!cancelled) {
207             // create thread to run task
208             Thread thread = threadFactory.newThread(subtask);
209             if (thread == null) {
210                 throw new RejectedExecutionException("Rejected by thread factory");
211             }
212 
213             // attempt to start the thread
214             try {
215                 flock.start(thread);
216             } catch (IllegalStateException e) {
217                 // shutdown by another thread, or underlying flock is shutdown due
218                 // to unstructured use
219             }
220         }
221 
222         // force owner to join
223         state = ST_FORKED;
224         return subtask;
225     }
226 
227     @Override
228     public <U extends T> Subtask<U> fork(Runnable task) {
229         Objects.requireNonNull(task);
230         return fork(() -> { task.run(); return null; });
231     }
232 
233     @Override
234     public R join() throws InterruptedException {
235         ensureOwner();
236         ensureNotJoined();
237 
238         // join started
239         state = ST_JOIN_STARTED;
240 
241         // wait for all subtasks, the scope to be cancelled, or interrupt
242         flock.awaitAll();
243 
244         // throw if timeout expired
245         if (timeoutExpired) {
246             throw new TimeoutException();
247         }
248         cancelTimeout();
249 
250         // all subtasks completed or cancelled
251         state = ST_JOIN_COMPLETED;
252 
253         // invoke joiner to get result
254         try {
255             return joiner.result();
256         } catch (Throwable e) {
257             throw new FailedException(e);
258         }
259     }
260 
261     @Override
262     public boolean isCancelled() {
263         return cancelled;
264     }
265 
266     @Override
267     public void close() {
268         ensureOwner();
269         int s = state;
270         if (s == ST_CLOSED) {
271             return;
272         }
273 
274         // cancel the scope if join did not complete
275         if (s < ST_JOIN_COMPLETED) {
276             cancel();
277             cancelTimeout();
278         }
279 
280         // wait for stragglers
281         try {
282             flock.close();
283         } finally {
284             state = ST_CLOSED;
285         }
286 
287         // throw ISE if the owner didn't join after forking
288         if (s == ST_FORKED) {
289             throw new IllegalStateException("Owner did not join after forking");
290         }
291     }
292 
293     @Override
294     public String toString() {
295         return flock.name();
296     }
297 
298     /**
299      * Subtask implementation, runs the task specified to the fork method.
300      */
301     static final class SubtaskImpl<T> implements Subtask<T>, Runnable {
302         private static final AltResult RESULT_NULL = new AltResult(Subtask.State.SUCCESS);
303 
304         private record AltResult(Subtask.State state, Throwable exception) {
305             AltResult(Subtask.State state) {
306                 this(state, null);
307             }
308         }
309 
310         private final StructuredTaskScopeImpl<? super T, ?> scope;
311         private final Callable<? extends T> task;
312         private volatile Object result;
313 
314         SubtaskImpl(StructuredTaskScopeImpl<? super T, ?> scope, Callable<? extends T> task) {
315             this.scope = scope;
316             this.task = task;
317         }
318 
319         @Override
320         public void run() {
321             T result = null;
322             Throwable ex = null;
323             try {
324                 result = task.call();
325             } catch (Throwable e) {
326                 ex = e;
327             }
328 
329             // nothing to do if scope is cancelled
330             if (scope.isCancelled())
331                 return;
332 
333             // set result/exception and invoke onComplete
334             if (ex == null) {
335                 this.result = (result != null) ? result : RESULT_NULL;
336             } else {
337                 this.result = new AltResult(State.FAILED, ex);
338             }
339             scope.onComplete(this);
340         }
341 
342         @Override
343         public Subtask.State state() {
344             Object result = this.result;
345             if (result == null) {
346                 return State.UNAVAILABLE;
347             } else if (result instanceof AltResult alt) {
348                 // null or failed
349                 return alt.state();
350             } else {
351                 return State.SUCCESS;
352             }
353         }
354 
355         @Override
356         public T get() {
357             scope.ensureJoinedIfOwner();
358             Object result = this.result;
359             if (result instanceof AltResult) {
360                 if (result == RESULT_NULL) return null;
361             } else if (result != null) {
362                 @SuppressWarnings("unchecked")
363                 T r = (T) result;
364                 return r;
365             }
366             throw new IllegalStateException(
367                     "Result is unavailable or subtask did not complete successfully");
368         }
369 
370         @Override
371         public Throwable exception() {
372             scope.ensureJoinedIfOwner();
373             Object result = this.result;
374             if (result instanceof AltResult alt && alt.state() == State.FAILED) {
375                 return alt.exception();
376             }
377             throw new IllegalStateException(
378                     "Exception is unavailable or subtask did not complete with exception");
379         }
380 
381         @Override
382         public String toString() {
383             String stateAsString = switch (state()) {
384                 case UNAVAILABLE -> "[Unavailable]";
385                 case SUCCESS     -> "[Completed successfully]";
386                 case FAILED      -> {
387                     Throwable ex = ((AltResult) result).exception();
388                     yield "[Failed: " + ex + "]";
389                 }
390             };
391             return Objects.toIdentityString(this) + stateAsString;
392         }
393     }
394 
395     /**
396      * Configuration implementation.
397      */
398     record ConfigImpl(ThreadFactory threadFactory,
399                       String name,
400                       Duration timeout) implements Configuration {
401         static Configuration defaultConfig() {
402             return new ConfigImpl(Thread.ofVirtual().factory(), null, null);
403         }
404 
405         @Override
406         public Configuration withThreadFactory(ThreadFactory threadFactory) {
407             return new ConfigImpl(Objects.requireNonNull(threadFactory), name, timeout);
408         }
409 
410         @Override
411         public Configuration withName(String name) {
412             return new ConfigImpl(threadFactory, Objects.requireNonNull(name), timeout);
413         }
414 
415         @Override
416         public Configuration withTimeout(Duration timeout) {
417             return new ConfigImpl(threadFactory, name, Objects.requireNonNull(timeout));
418         }
419     }
420 }