1 /* 2 * Copyright (c) 2024, 2025, Oracle and/or its affiliates. All rights reserved. 3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. 4 * 5 * This code is free software; you can redistribute it and/or modify it 6 * under the terms of the GNU General Public License version 2 only, as 7 * published by the Free Software Foundation. Oracle designates this 8 * particular file as subject to the "Classpath" exception as provided 9 * by Oracle in the LICENSE file that accompanied this code. 10 * 11 * This code is distributed in the hope that it will be useful, but WITHOUT 12 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or 13 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License 14 * version 2 for more details (a copy is included in the LICENSE file that 15 * accompanied this code). 16 * 17 * You should have received a copy of the GNU General Public License version 18 * 2 along with this work; if not, write to the Free Software Foundation, 19 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. 20 * 21 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA 22 * or visit www.oracle.com if you need additional information or have any 23 * questions. 24 */ 25 package java.util.concurrent; 26 27 import java.lang.invoke.MethodHandles; 28 import java.lang.invoke.VarHandle; 29 import java.time.Duration; 30 import java.util.Objects; 31 import java.util.function.UnaryOperator; 32 import jdk.internal.misc.ThreadFlock; 33 import jdk.internal.invoke.MhUtil; 34 import jdk.internal.vm.annotation.Stable; 35 36 /** 37 * StructuredTaskScope implementation. 38 */ 39 final class StructuredTaskScopeImpl<T, R> implements StructuredTaskScope<T, R> { 40 private static final VarHandle CANCELLED = 41 MhUtil.findVarHandle(MethodHandles.lookup(), "cancelled", boolean.class); 42 43 private final Joiner<? super T, ? extends R> joiner; 44 private final ThreadFactory threadFactory; 45 private final ThreadFlock flock; 46 47 // scope state, set by owner thread, read by any thread 48 private static final int ST_FORKED = 1, // subtasks forked, need to join 49 ST_JOIN_STARTED = 2, // join started, can no longer fork 50 ST_JOIN_COMPLETED = 3, // join completed 51 ST_CLOSED = 4; // closed 52 private volatile int state; 53 54 // set or read by any thread 55 private volatile boolean cancelled; 56 57 // timer task, only accessed by owner thread 58 private Future<?> timerTask; 59 60 // set by the timer thread, read by the owner thread 61 private volatile boolean timeoutExpired; 62 63 @SuppressWarnings("this-escape") 64 private StructuredTaskScopeImpl(Joiner<? super T, ? extends R> joiner, 65 ThreadFactory threadFactory, 66 String name) { 67 this.joiner = joiner; 68 this.threadFactory = threadFactory; 69 this.flock = ThreadFlock.open((name != null) ? name : Objects.toIdentityString(this)); 70 } 71 72 /** 73 * Returns a new {@code StructuredTaskScope} to use the given {@code Joiner} object 74 * and with configuration that is the result of applying the given function to the 75 * default configuration. 76 */ 77 static <T, R> StructuredTaskScope<T, R> open(Joiner<? super T, ? extends R> joiner, 78 UnaryOperator<Configuration> configOperator) { 79 Objects.requireNonNull(joiner); 80 81 var config = (ConfigImpl) configOperator.apply(ConfigImpl.defaultConfig()); 82 var scope = new StructuredTaskScopeImpl<T, R>(joiner, config.threadFactory(), config.name()); 83 84 // schedule timeout 85 Duration timeout = config.timeout(); 86 if (timeout != null) { 87 boolean scheduled = false; 88 try { 89 scope.scheduleTimeout(timeout); 90 scheduled = true; 91 } finally { 92 if (!scheduled) { 93 scope.close(); // pop if scheduling timeout failed 94 } 95 } 96 } 97 98 return scope; 99 } 100 101 /** 102 * Throws WrongThreadException if the current thread is not the owner thread. 103 */ 104 private void ensureOwner() { 105 if (Thread.currentThread() != flock.owner()) { 106 throw new WrongThreadException("Current thread not owner"); 107 } 108 } 109 110 /** 111 * Returns true if join has been invoked and there is an outcome. 112 */ 113 private boolean isJoinCompleted() { 114 return state >= ST_JOIN_COMPLETED; 115 } 116 117 /** 118 * Interrupts all threads in this scope, except the current thread. 119 */ 120 private void interruptAll() { 121 flock.threads() 122 .filter(t -> t != Thread.currentThread()) 123 .forEach(t -> { 124 try { 125 t.interrupt(); 126 } catch (Throwable ignore) { } 127 }); 128 } 129 130 /** 131 * Cancel the scope if not already cancelled. 132 */ 133 private void cancel() { 134 if (!cancelled && CANCELLED.compareAndSet(this, false, true)) { 135 // prevent new threads from starting 136 flock.shutdown(); 137 138 // interrupt all unfinished threads 139 interruptAll(); 140 141 // wakeup join 142 flock.wakeup(); 143 } 144 } 145 146 /** 147 * Schedules a task to cancel the scope on timeout. 148 */ 149 private void scheduleTimeout(Duration timeout) { 150 assert Thread.currentThread() == flock.owner() && timerTask == null; 151 long nanos = TimeUnit.NANOSECONDS.convert(timeout); 152 timerTask = ForkJoinPool.commonPool().schedule(() -> { 153 if (!cancelled) { 154 timeoutExpired = true; 155 cancel(); 156 } 157 }, nanos, TimeUnit.NANOSECONDS); 158 } 159 160 /** 161 * Cancels the timer task if set. 162 */ 163 private void cancelTimeout() { 164 assert Thread.currentThread() == flock.owner(); 165 if (timerTask != null) { 166 timerTask.cancel(false); 167 } 168 } 169 170 /** 171 * Invoked by the thread for a subtask when the subtask completes before scope is cancelled. 172 */ 173 private void onComplete(SubtaskImpl<? extends T> subtask) { 174 assert subtask.state() != Subtask.State.UNAVAILABLE; 175 if (joiner.onComplete(subtask)) { 176 cancel(); 177 } 178 } 179 180 @Override 181 public <U extends T> Subtask<U> fork(Callable<? extends U> task) { 182 Objects.requireNonNull(task); 183 ensureOwner(); 184 int s = state; 185 if (s > ST_FORKED) { 186 throw new IllegalStateException("join already called or scope is closed"); 187 } 188 189 var subtask = new SubtaskImpl<U>(this, task); 190 191 // notify joiner, even if cancelled 192 if (joiner.onFork(subtask)) { 193 cancel(); 194 } 195 196 if (!cancelled) { 197 // create thread to run task 198 Thread thread = threadFactory.newThread(subtask); 199 if (thread == null) { 200 throw new RejectedExecutionException("Rejected by thread factory"); 201 } 202 203 // attempt to start the thread 204 subtask.setThread(thread); 205 try { 206 flock.start(thread); 207 } catch (IllegalStateException e) { 208 // shutdown by another thread, or underlying flock is shutdown due 209 // to unstructured use 210 } 211 } 212 213 // force owner to join 214 if (s < ST_FORKED) { 215 state = ST_FORKED; 216 } 217 return subtask; 218 } 219 220 @Override 221 public <U extends T> Subtask<U> fork(Runnable task) { 222 Objects.requireNonNull(task); 223 return fork(() -> { task.run(); return null; }); 224 } 225 226 @Override 227 public R join() throws InterruptedException { 228 ensureOwner(); 229 if (state >= ST_JOIN_COMPLETED) { 230 throw new IllegalStateException("Already joined or scope is closed"); 231 } 232 233 // wait for all subtasks, the scope to be cancelled, or interrupt 234 try { 235 flock.awaitAll(); 236 } catch (InterruptedException e) { 237 state = ST_JOIN_STARTED; // joining not completed, prevent new forks 238 throw e; 239 } 240 241 // all subtasks completed or scope cancelled 242 state = ST_JOIN_COMPLETED; 243 244 // invoke joiner onTimeout if timeout expired 245 if (timeoutExpired) { 246 cancel(); // ensure cancelled before calling onTimeout 247 joiner.onTimeout(); 248 } else { 249 cancelTimeout(); 250 } 251 252 // invoke joiner to get result 253 try { 254 return joiner.result(); 255 } catch (Throwable e) { 256 throw new FailedException(e); 257 } 258 } 259 260 @Override 261 public boolean isCancelled() { 262 return cancelled; 263 } 264 265 @Override 266 public void close() { 267 ensureOwner(); 268 int s = state; 269 if (s == ST_CLOSED) { 270 return; 271 } 272 273 // cancel the scope if join did not complete 274 if (s < ST_JOIN_COMPLETED) { 275 cancel(); 276 cancelTimeout(); 277 } 278 279 // wait for stragglers 280 try { 281 flock.close(); 282 } finally { 283 state = ST_CLOSED; 284 } 285 286 // throw ISE if the owner didn't join after forking 287 if (s == ST_FORKED) { 288 throw new IllegalStateException("Owner did not join after forking"); 289 } 290 } 291 292 @Override 293 public String toString() { 294 return flock.name(); 295 } 296 297 /** 298 * Subtask implementation, runs the task specified to the fork method. 299 */ 300 static final class SubtaskImpl<T> implements Subtask<T>, Runnable { 301 private static final AltResult RESULT_NULL = new AltResult(Subtask.State.SUCCESS); 302 303 private record AltResult(Subtask.State state, Throwable exception) { 304 AltResult(Subtask.State state) { 305 this(state, null); 306 } 307 } 308 309 private final StructuredTaskScopeImpl<? super T, ?> scope; 310 private final Callable<? extends T> task; 311 private volatile Object result; 312 @Stable private Thread thread; 313 314 SubtaskImpl(StructuredTaskScopeImpl<? super T, ?> scope, Callable<? extends T> task) { 315 this.scope = scope; 316 this.task = task; 317 } 318 319 /** 320 * Sets the thread for this subtask. 321 */ 322 void setThread(Thread thread) { 323 assert thread.getState() == Thread.State.NEW; 324 this.thread = thread; 325 } 326 327 /** 328 * Throws IllegalStateException if the caller thread is not the subtask and 329 * the scope owner has not joined. 330 */ 331 private void ensureJoinedIfNotSubtask() { 332 if (Thread.currentThread() != thread && !scope.isJoinCompleted()) { 333 throw new IllegalStateException(); 334 } 335 } 336 337 @Override 338 public void run() { 339 if (Thread.currentThread() != thread) { 340 throw new WrongThreadException(); 341 } 342 343 T result = null; 344 Throwable ex = null; 345 try { 346 result = task.call(); 347 } catch (Throwable e) { 348 ex = e; 349 } 350 351 // nothing to do if scope is cancelled 352 if (scope.isCancelled()) 353 return; 354 355 // set result/exception and invoke onComplete 356 if (ex == null) { 357 this.result = (result != null) ? result : RESULT_NULL; 358 } else { 359 this.result = new AltResult(State.FAILED, ex); 360 } 361 scope.onComplete(this); 362 } 363 364 @Override 365 public Subtask.State state() { 366 Object result = this.result; 367 if (result == null) { 368 return State.UNAVAILABLE; 369 } else if (result instanceof AltResult alt) { 370 // null or failed 371 return alt.state(); 372 } else { 373 return State.SUCCESS; 374 } 375 } 376 377 @Override 378 public T get() { 379 ensureJoinedIfNotSubtask(); 380 Object result = this.result; 381 if (result instanceof AltResult) { 382 if (result == RESULT_NULL) return null; 383 } else if (result != null) { 384 @SuppressWarnings("unchecked") 385 T r = (T) result; 386 return r; 387 } 388 throw new IllegalStateException( 389 "Result is unavailable or subtask did not complete successfully"); 390 } 391 392 @Override 393 public Throwable exception() { 394 ensureJoinedIfNotSubtask(); 395 Object result = this.result; 396 if (result instanceof AltResult alt && alt.state() == State.FAILED) { 397 return alt.exception(); 398 } 399 throw new IllegalStateException( 400 "Exception is unavailable or subtask did not complete with exception"); 401 } 402 403 @Override 404 public String toString() { 405 String stateAsString = switch (state()) { 406 case UNAVAILABLE -> "[Unavailable]"; 407 case SUCCESS -> "[Completed successfully]"; 408 case FAILED -> "[Failed: " + ((AltResult) result).exception() + "]"; 409 }; 410 return Objects.toIdentityString(this) + stateAsString; 411 } 412 } 413 414 /** 415 * Configuration implementation. 416 */ 417 record ConfigImpl(ThreadFactory threadFactory, 418 String name, 419 Duration timeout) implements Configuration { 420 static Configuration defaultConfig() { 421 return new ConfigImpl(Thread.ofVirtual().factory(), null, null); 422 } 423 424 @Override 425 public Configuration withThreadFactory(ThreadFactory threadFactory) { 426 return new ConfigImpl(Objects.requireNonNull(threadFactory), name, timeout); 427 } 428 429 @Override 430 public Configuration withName(String name) { 431 return new ConfigImpl(threadFactory, Objects.requireNonNull(name), timeout); 432 } 433 434 @Override 435 public Configuration withTimeout(Duration timeout) { 436 return new ConfigImpl(threadFactory, name, Objects.requireNonNull(timeout)); 437 } 438 } 439 }