< prev index next >

src/java.base/share/classes/java/util/concurrent/StructuredTaskScopeImpl.java

Print this page
@@ -26,13 +26,14 @@
  
  import java.lang.invoke.MethodHandles;
  import java.lang.invoke.VarHandle;
  import java.time.Duration;
  import java.util.Objects;
- import java.util.function.Function;
+ import java.util.function.UnaryOperator;
  import jdk.internal.misc.ThreadFlock;
  import jdk.internal.invoke.MhUtil;
+ import jdk.internal.vm.annotation.Stable;
  
  /**
   * StructuredTaskScope implementation.
   */
  final class StructuredTaskScopeImpl<T, R> implements StructuredTaskScope<T, R> {

@@ -41,47 +42,45 @@
  
      private final Joiner<? super T, ? extends R> joiner;
      private final ThreadFactory threadFactory;
      private final ThreadFlock flock;
  
-     // state, only accessed by owner thread
-     private static final int ST_NEW            = 0,
-                              ST_FORKED         = 1,   // subtasks forked, need to join
+     // scope state, set by owner thread, read by any thread
+     private static final int ST_FORKED         = 1,   // subtasks forked, need to join
                               ST_JOIN_STARTED   = 2,   // join started, can no longer fork
                               ST_JOIN_COMPLETED = 3,   // join completed
                               ST_CLOSED         = 4;   // closed
-     private int state;
- 
-     // timer task, only accessed by owner thread
-     private Future<?> timerTask;
+     private volatile int state;
  
      // set or read by any thread
      private volatile boolean cancelled;
  
+     // timer task, only accessed by owner thread
+     private Future<?> timerTask;
+ 
      // set by the timer thread, read by the owner thread
      private volatile boolean timeoutExpired;
  
      @SuppressWarnings("this-escape")
      private StructuredTaskScopeImpl(Joiner<? super T, ? extends R> joiner,
                                      ThreadFactory threadFactory,
                                      String name) {
          this.joiner = joiner;
          this.threadFactory = threadFactory;
          this.flock = ThreadFlock.open((name != null) ? name : Objects.toIdentityString(this));
-         this.state = ST_NEW;
      }
  
      /**
       * Returns a new {@code StructuredTaskScope} to use the given {@code Joiner} object
       * and with configuration that is the result of applying the given function to the
       * default configuration.
       */
      static <T, R> StructuredTaskScope<T, R> open(Joiner<? super T, ? extends R> joiner,
-                                                  Function<Configuration, Configuration> configFunction) {
+                                                  UnaryOperator<Configuration> configOperator) {
          Objects.requireNonNull(joiner);
  
-         var config = (ConfigImpl) configFunction.apply(ConfigImpl.defaultConfig());
+         var config = (ConfigImpl) configOperator.apply(ConfigImpl.defaultConfig());
          var scope = new StructuredTaskScopeImpl<T, R>(joiner, config.threadFactory(), config.name());
  
          // schedule timeout
          Duration timeout = config.timeout();
          if (timeout != null) {

@@ -107,27 +106,14 @@
              throw new WrongThreadException("Current thread not owner");
          }
      }
  
      /**
-      * Throws IllegalStateException if already joined or scope is closed.
+      * Returns true if join has been invoked and there is an outcome.
       */
-     private void ensureNotJoined() {
-         assert Thread.currentThread() == flock.owner();
-         if (state > ST_FORKED) {
-             throw new IllegalStateException("Already joined or scope is closed");
-         }
-     }
- 
-     /**
-      * Throws IllegalStateException if invoked by the owner thread and the owner thread
-      * has not joined.
-      */
-     private void ensureJoinedIfOwner() {
-         if (Thread.currentThread() == flock.owner() && state <= ST_JOIN_STARTED) {
-             throw new IllegalStateException("join not called");
-         }
+     private boolean isJoinCompleted() {
+         return state >= ST_JOIN_COMPLETED;
      }
  
      /**
       * Interrupts all threads in this scope, except the current thread.
       */

@@ -193,11 +179,14 @@
  
      @Override
      public <U extends T> Subtask<U> fork(Callable<? extends U> task) {
          Objects.requireNonNull(task);
          ensureOwner();
-         ensureNotJoined();
+         int s = state;
+         if (s > ST_FORKED) {
+             throw new IllegalStateException("join already called or scope is closed");
+         }
  
          var subtask = new SubtaskImpl<U>(this, task);
  
          // notify joiner, even if cancelled
          if (joiner.onFork(subtask)) {

@@ -210,20 +199,23 @@
              if (thread == null) {
                  throw new RejectedExecutionException("Rejected by thread factory");
              }
  
              // attempt to start the thread
+             subtask.setThread(thread);
              try {
                  flock.start(thread);
              } catch (IllegalStateException e) {
                  // shutdown by another thread, or underlying flock is shutdown due
                  // to unstructured use
              }
          }
  
          // force owner to join
-         state = ST_FORKED;
+         if (s < ST_FORKED) {
+             state = ST_FORKED;
+         }
          return subtask;
      }
  
      @Override
      public <U extends T> Subtask<U> fork(Runnable task) {

@@ -232,27 +224,33 @@
      }
  
      @Override
      public R join() throws InterruptedException {
          ensureOwner();
-         ensureNotJoined();
- 
-         // join started
-         state = ST_JOIN_STARTED;
+         if (state >= ST_JOIN_COMPLETED) {
+             throw new IllegalStateException("Already joined or scope is closed");
+         }
  
          // wait for all subtasks, the scope to be cancelled, or interrupt
-         flock.awaitAll();
- 
-         // throw if timeout expired
-         if (timeoutExpired) {
-             throw new TimeoutException();
+         try {
+             flock.awaitAll();
+         } catch (InterruptedException e) {
+             state = ST_JOIN_STARTED;  // joining not completed, prevent new forks
+             throw e;
          }
-         cancelTimeout();
  
-         // all subtasks completed or cancelled
+         // all subtasks completed or scope cancelled
          state = ST_JOIN_COMPLETED;
  
+         // invoke joiner onTimeout if timeout expired
+         if (timeoutExpired) {
+             cancel();  // ensure cancelled before calling onTimeout
+             joiner.onTimeout();
+         } else {
+             cancelTimeout();
+         }
+ 
          // invoke joiner to get result
          try {
              return joiner.result();
          } catch (Throwable e) {
              throw new FailedException(e);

@@ -309,18 +307,41 @@
          }
  
          private final StructuredTaskScopeImpl<? super T, ?> scope;
          private final Callable<? extends T> task;
          private volatile Object result;
+         @Stable private Thread thread;
  
          SubtaskImpl(StructuredTaskScopeImpl<? super T, ?> scope, Callable<? extends T> task) {
              this.scope = scope;
              this.task = task;
          }
  
+         /**
+          * Sets the thread for this subtask.
+          */
+         void setThread(Thread thread) {
+             assert thread.getState() == Thread.State.NEW;
+             this.thread = thread;
+         }
+ 
+         /**
+          * Throws IllegalStateException if the caller thread is not the subtask and
+          * the scope owner has not joined.
+          */
+         private void ensureJoinedIfNotSubtask() {
+             if (Thread.currentThread() != thread && !scope.isJoinCompleted()) {
+                 throw new IllegalStateException();
+             }
+         }
+ 
          @Override
          public void run() {
+             if (Thread.currentThread() != thread) {
+                 throw new WrongThreadException();
+             }
+ 
              T result = null;
              Throwable ex = null;
              try {
                  result = task.call();
              } catch (Throwable e) {

@@ -353,11 +374,11 @@
              }
          }
  
          @Override
          public T get() {
-             scope.ensureJoinedIfOwner();
+             ensureJoinedIfNotSubtask();
              Object result = this.result;
              if (result instanceof AltResult) {
                  if (result == RESULT_NULL) return null;
              } else if (result != null) {
                  @SuppressWarnings("unchecked")

@@ -368,11 +389,11 @@
                      "Result is unavailable or subtask did not complete successfully");
          }
  
          @Override
          public Throwable exception() {
-             scope.ensureJoinedIfOwner();
+             ensureJoinedIfNotSubtask();
              Object result = this.result;
              if (result instanceof AltResult alt && alt.state() == State.FAILED) {
                  return alt.exception();
              }
              throw new IllegalStateException(
< prev index next >