1 /*
  2  * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 package java.util.concurrent;
 26 
 27 import java.lang.invoke.MethodHandles;
 28 import java.lang.invoke.VarHandle;
 29 import java.security.AccessController;
 30 import java.security.PrivilegedAction;
 31 import java.time.Duration;
 32 import java.util.Objects;
 33 import java.util.function.Function;
 34 import jdk.internal.misc.InnocuousThread;
 35 import jdk.internal.misc.ThreadFlock;
 36 import jdk.internal.invoke.MhUtil;
 37 
 38 /**
 39  * StructuredTaskScope implementation.
 40  */
 41 final class StructuredTaskScopeImpl<T, R> implements StructuredTaskScope<T, R> {
 42     private static final VarHandle CANCELLED =
 43             MhUtil.findVarHandle(MethodHandles.lookup(), "cancelled", boolean.class);
 44 
 45     private final Joiner<? super T, ? extends R> joiner;
 46     private final ThreadFactory threadFactory;
 47     private final ThreadFlock flock;
 48 
 49     // state, only accessed by owner thread
 50     private static final int ST_NEW            = 0;
 51     private static final int ST_FORKED         = 1;   // subtasks forked, need to join
 52     private static final int ST_JOIN_STARTED   = 2;   // join started, can no longer fork
 53     private static final int ST_JOIN_COMPLETED = 3;   // join completed
 54     private static final int ST_CLOSED         = 4;   // closed
 55     private int state;
 56 
 57     // timer task, only accessed by owner thread
 58     private Future<?> timerTask;
 59 
 60     // set or read by any thread
 61     private volatile boolean cancelled;
 62 
 63     // set by the timer thread, read by the owner thread
 64     private volatile boolean timeoutExpired;
 65 
 66     @SuppressWarnings("this-escape")
 67     private StructuredTaskScopeImpl(Joiner<? super T, ? extends R> joiner,
 68                                     ThreadFactory threadFactory,
 69                                     String name) {
 70         this.joiner = joiner;
 71         this.threadFactory = threadFactory;
 72         this.flock = ThreadFlock.open((name != null) ? name : Objects.toIdentityString(this));
 73     }
 74 
 75     /**
 76      * Returns a new {@code StructuredTaskScope} to use the given {@code Joiner} object
 77      * and with configuration that is the result of applying the given function to the
 78      * default configuration.
 79      */
 80     static <T, R> StructuredTaskScope<T, R> open(Joiner<? super T, ? extends R> joiner,
 81                                                  Function<Config, Config> configFunction) {
 82         Objects.requireNonNull(joiner);
 83 
 84         var config = (ConfigImpl) configFunction.apply(ConfigImpl.defaultConfig());
 85         var scope = new StructuredTaskScopeImpl<T, R>(joiner, config.threadFactory(), config.name());
 86 
 87         // schedule timeout
 88         Duration timeout = config.timeout();
 89         if (timeout != null) {
 90             boolean scheduled = false;
 91             try {
 92                 scope.scheduleTimeout(timeout);
 93                 scheduled = true;
 94             } finally {
 95                 if (!scheduled) {
 96                     scope.close();  // pop if scheduling timeout failed
 97                 }
 98             }
 99         }
100 
101         return scope;
102     }
103 
104     /**
105      * Throws WrongThreadException if the current thread is not the owner thread.
106      */
107     private void ensureOwner() {
108         if (Thread.currentThread() != flock.owner()) {
109             throw new WrongThreadException("Current thread not owner");
110         }
111     }
112 
113     /**
114      * Throws IllegalStateException if already joined or scope is closed.
115      */
116     private void ensureNotJoined() {
117         assert Thread.currentThread() == flock.owner();
118         if (state > ST_FORKED) {
119             throw new IllegalStateException("Already joined or scope is closed");
120         }
121     }
122 
123     /**
124      * Throws IllegalStateException if invoked by the owner thread and the owner thread
125      * has not joined.
126      */
127     private void ensureJoinedIfOwner() {
128         if (Thread.currentThread() == flock.owner() && state <= ST_JOIN_STARTED) {
129             throw new IllegalStateException("join not called");
130         }
131     }
132 
133     /**
134      * Interrupts all threads in this scope, except the current thread.
135      */
136     private void interruptAll() {
137         flock.threads()
138                 .filter(t -> t != Thread.currentThread())
139                 .forEach(t -> {
140                     try {
141                         t.interrupt();
142                     } catch (Throwable ignore) { }
143                 });
144     }
145 
146     /**
147      * Cancel the scope if not already cancelled.
148      */
149     private void cancel() {
150         if (!cancelled && CANCELLED.compareAndSet(this, false, true)) {
151             // prevent new threads from starting
152             flock.shutdown();
153 
154             // interrupt all unfinished threads
155             interruptAll();
156 
157             // wakeup join
158             flock.wakeup();
159         }
160     }
161 
162     /**
163      * Schedules a task to cancel the scope on timeout.
164      */
165     private void scheduleTimeout(Duration timeout) {
166         assert Thread.currentThread() == flock.owner() && timerTask == null;
167         timerTask = TimerSupport.schedule(timeout, () -> {
168             if (!cancelled) {
169                 timeoutExpired = true;
170                 cancel();
171             }
172         });
173     }
174 
175     /**
176      * Cancels the timer task if set.
177      */
178     private void cancelTimeout() {
179         assert Thread.currentThread() == flock.owner();
180         if (timerTask != null) {
181             timerTask.cancel(false);
182         }
183     }
184 
185     /**
186      * Invoked by the thread for a subtask when the subtask completes before scope is cancelled.
187      */
188     private void onComplete(SubtaskImpl<? extends T> subtask) {
189         assert subtask.state() != Subtask.State.UNAVAILABLE;
190         if (joiner.onComplete(subtask)) {
191             cancel();
192         }
193     }
194 
195     @Override
196     public <U extends T> Subtask<U> fork(Callable<? extends U> task) {
197         Objects.requireNonNull(task);
198         ensureOwner();
199         ensureNotJoined();
200 
201         var subtask = new SubtaskImpl<U>(this, task);
202 
203         // notify joiner, even if cancelled
204         if (joiner.onFork(subtask)) {
205             cancel();
206         }
207 
208         if (!cancelled) {
209             // create thread to run task
210             Thread thread = threadFactory.newThread(subtask);
211             if (thread == null) {
212                 throw new RejectedExecutionException("Rejected by thread factory");
213             }
214 
215             // attempt to start the thread
216             try {
217                 flock.start(thread);
218             } catch (IllegalStateException e) {
219                 // shutdown by another thread, or underlying flock is shutdown due
220                 // to unstructured use
221             }
222         }
223 
224         // force owner to join
225         state = ST_FORKED;
226         return subtask;
227     }
228 
229     @Override
230     public <U extends T> Subtask<U> fork(Runnable task) {
231         Objects.requireNonNull(task);
232         return fork(() -> { task.run(); return null; });
233     }
234 
235     @Override
236     public R join() throws InterruptedException {
237         ensureOwner();
238         ensureNotJoined();
239 
240         // join started
241         state = ST_JOIN_STARTED;
242 
243         // wait for all subtasks, the scope to be cancelled, or interrupt
244         flock.awaitAll();
245 
246         // throw if timeout expired
247         if (timeoutExpired) {
248             throw new TimeoutException();
249         }
250         cancelTimeout();
251 
252         // all subtasks completed or cancelled
253         state = ST_JOIN_COMPLETED;
254 
255         // invoke joiner to get result
256         try {
257             return joiner.result();
258         } catch (Throwable e) {
259             throw new FailedException(e);
260         }
261     }
262 
263     @Override
264     public boolean isCancelled() {
265         return cancelled;
266     }
267 
268     @Override
269     public void close() {
270         ensureOwner();
271         int s = state;
272         if (s == ST_CLOSED) {
273             return;
274         }
275 
276         // cancel the scope if join did not complete
277         if (s < ST_JOIN_COMPLETED) {
278             cancel();
279             cancelTimeout();
280         }
281 
282         // wait for stragglers
283         try {
284             flock.close();
285         } finally {
286             state = ST_CLOSED;
287         }
288 
289         // throw ISE if the owner didn't join after forking
290         if (s == ST_FORKED) {
291             throw new IllegalStateException("Owner did not join after forking");
292         }
293     }
294 
295     @Override
296     public String toString() {
297         return flock.name();
298     }
299 
300     /**
301      * Subtask implementation, runs the task specified to the fork method.
302      */
303     static final class SubtaskImpl<T> implements Subtask<T>, Runnable {
304         private static final AltResult RESULT_NULL = new AltResult(Subtask.State.SUCCESS);
305 
306         private record AltResult(Subtask.State state, Throwable exception) {
307             AltResult(Subtask.State state) {
308                 this(state, null);
309             }
310         }
311 
312         private final StructuredTaskScopeImpl<? super T, ?> scope;
313         private final Callable<? extends T> task;
314         private volatile Object result;
315 
316         SubtaskImpl(StructuredTaskScopeImpl<? super T, ?> scope, Callable<? extends T> task) {
317             this.scope = scope;
318             this.task = task;
319         }
320 
321         @Override
322         public void run() {
323             T result = null;
324             Throwable ex = null;
325             try {
326                 result = task.call();
327             } catch (Throwable e) {
328                 ex = e;
329             }
330 
331             // nothing to do if scope is cancelled
332             if (scope.isCancelled())
333                 return;
334 
335             // set result/exception and invoke onComplete
336             if (ex == null) {
337                 this.result = (result != null) ? result : RESULT_NULL;
338             } else {
339                 this.result = new AltResult(State.FAILED, ex);
340             }
341             scope.onComplete(this);
342         }
343 
344         @Override
345         public Subtask.State state() {
346             Object result = this.result;
347             if (result == null) {
348                 return State.UNAVAILABLE;
349             } else if (result instanceof AltResult alt) {
350                 // null or failed
351                 return alt.state();
352             } else {
353                 return State.SUCCESS;
354             }
355         }
356 
357         @Override
358         public T get() {
359             scope.ensureJoinedIfOwner();
360             Object result = this.result;
361             if (result instanceof AltResult) {
362                 if (result == RESULT_NULL) return null;
363             } else if (result != null) {
364                 @SuppressWarnings("unchecked")
365                 T r = (T) result;
366                 return r;
367             }
368             throw new IllegalStateException(
369                     "Result is unavailable or subtask did not complete successfully");
370         }
371 
372         @Override
373         public Throwable exception() {
374             scope.ensureJoinedIfOwner();
375             Object result = this.result;
376             if (result instanceof AltResult alt && alt.state() == State.FAILED) {
377                 return alt.exception();
378             }
379             throw new IllegalStateException(
380                     "Exception is unavailable or subtask did not complete with exception");
381         }
382 
383         @Override
384         public String toString() {
385             String stateAsString = switch (state()) {
386                 case UNAVAILABLE -> "[Unavailable]";
387                 case SUCCESS     -> "[Completed successfully]";
388                 case FAILED      -> {
389                     Throwable ex = ((AltResult) result).exception();
390                     yield "[Failed: " + ex + "]";
391                 }
392             };
393             return Objects.toIdentityString(this) + stateAsString;
394         }
395     }
396 
397     /**
398      * Config implementation.
399      */
400     record ConfigImpl(ThreadFactory threadFactory,
401                       String name,
402                       Duration timeout) implements Config {
403         static Config defaultConfig() {
404             return new ConfigImpl(Thread.ofVirtual().factory(), null, null);
405         }
406 
407         @Override
408         public Config withThreadFactory(ThreadFactory threadFactory) {
409             return new ConfigImpl(Objects.requireNonNull(threadFactory), name, timeout);
410         }
411 
412         @Override
413         public Config withName(String name) {
414             return new ConfigImpl(threadFactory, Objects.requireNonNull(name), timeout);
415         }
416 
417         @Override
418         public Config withTimeout(Duration timeout) {
419             return new ConfigImpl(threadFactory, name, Objects.requireNonNull(timeout));
420         }
421     }
422 
423     /**
424      * Used to schedule a task to cancel the scope when a timeout expires.
425      */
426     private static class TimerSupport {
427         private static final ScheduledExecutorService DELAYED_TASK_SCHEDULER;
428         static {
429             ScheduledThreadPoolExecutor stpe = (ScheduledThreadPoolExecutor)
430                     Executors.newScheduledThreadPool(1, task -> {
431                         Thread t = InnocuousThread.newThread("StructuredTaskScope-Timer", task);
432                         t.setDaemon(true);
433                         return t;
434                     });
435             stpe.setRemoveOnCancelPolicy(true);
436             DELAYED_TASK_SCHEDULER = stpe;
437         }
438 
439         static Future<?> schedule(Duration timeout, Runnable task) {
440             long nanos = TimeUnit.NANOSECONDS.convert(timeout);
441             return DELAYED_TASK_SCHEDULER.schedule(task, nanos, TimeUnit.NANOSECONDS);
442         }
443     }
444 }