1 /* 2 * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved. 3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. 4 * 5 * This code is free software; you can redistribute it and/or modify it 6 * under the terms of the GNU General Public License version 2 only, as 7 * published by the Free Software Foundation. Oracle designates this 8 * particular file as subject to the "Classpath" exception as provided 9 * by Oracle in the LICENSE file that accompanied this code. 10 * 11 * This code is distributed in the hope that it will be useful, but WITHOUT 12 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or 13 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License 14 * version 2 for more details (a copy is included in the LICENSE file that 15 * accompanied this code). 16 * 17 * You should have received a copy of the GNU General Public License version 18 * 2 along with this work; if not, write to the Free Software Foundation, 19 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. 20 * 21 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA 22 * or visit www.oracle.com if you need additional information or have any 23 * questions. 24 */ 25 package java.util.concurrent; 26 27 import java.lang.invoke.MethodHandles; 28 import java.lang.invoke.VarHandle; 29 import java.security.AccessController; 30 import java.security.PrivilegedAction; 31 import java.time.Duration; 32 import java.util.Objects; 33 import java.util.function.Function; 34 import jdk.internal.misc.InnocuousThread; 35 import jdk.internal.misc.ThreadFlock; 36 import jdk.internal.invoke.MhUtil; 37 38 /** 39 * StructuredTaskScope implementation. 40 */ 41 final class StructuredTaskScopeImpl<T, R> implements StructuredTaskScope<T, R> { 42 private static final VarHandle CANCELLED = 43 MhUtil.findVarHandle(MethodHandles.lookup(), "cancelled", boolean.class); 44 45 private final Joiner<? super T, ? extends R> joiner; 46 private final ThreadFactory threadFactory; 47 private final ThreadFlock flock; 48 49 // state, only accessed by owner thread 50 private static final int ST_NEW = 0; 51 private static final int ST_FORKED = 1; // subtasks forked, need to join 52 private static final int ST_JOIN_STARTED = 2; // join started, can no longer fork 53 private static final int ST_JOIN_COMPLETED = 3; // join completed 54 private static final int ST_CLOSED = 4; // closed 55 private int state; 56 57 // timer task, only accessed by owner thread 58 private Future<?> timerTask; 59 60 // set or read by any thread 61 private volatile boolean cancelled; 62 63 // set by the timer thread, read by the owner thread 64 private volatile boolean timeoutExpired; 65 66 @SuppressWarnings("this-escape") 67 private StructuredTaskScopeImpl(Joiner<? super T, ? extends R> joiner, 68 ThreadFactory threadFactory, 69 String name) { 70 this.joiner = joiner; 71 this.threadFactory = threadFactory; 72 this.flock = ThreadFlock.open((name != null) ? name : Objects.toIdentityString(this)); 73 } 74 75 /** 76 * Returns a new {@code StructuredTaskScope} to use the given {@code Joiner} object 77 * and with configuration that is the result of applying the given function to the 78 * default configuration. 79 */ 80 static <T, R> StructuredTaskScope<T, R> open(Joiner<? super T, ? extends R> joiner, 81 Function<Config, Config> configFunction) { 82 Objects.requireNonNull(joiner); 83 84 var config = (ConfigImpl) configFunction.apply(ConfigImpl.defaultConfig()); 85 var scope = new StructuredTaskScopeImpl<T, R>(joiner, config.threadFactory(), config.name()); 86 87 // schedule timeout 88 Duration timeout = config.timeout(); 89 if (timeout != null) { 90 boolean scheduled = false; 91 try { 92 scope.scheduleTimeout(timeout); 93 scheduled = true; 94 } finally { 95 if (!scheduled) { 96 scope.close(); // pop if scheduling timeout failed 97 } 98 } 99 } 100 101 return scope; 102 } 103 104 /** 105 * Throws WrongThreadException if the current thread is not the owner thread. 106 */ 107 private void ensureOwner() { 108 if (Thread.currentThread() != flock.owner()) { 109 throw new WrongThreadException("Current thread not owner"); 110 } 111 } 112 113 /** 114 * Throws IllegalStateException if already joined or scope is closed. 115 */ 116 private void ensureNotJoined() { 117 assert Thread.currentThread() == flock.owner(); 118 if (state > ST_FORKED) { 119 throw new IllegalStateException("Already joined or scope is closed"); 120 } 121 } 122 123 /** 124 * Throws IllegalStateException if invoked by the owner thread and the owner thread 125 * has not joined. 126 */ 127 private void ensureJoinedIfOwner() { 128 if (Thread.currentThread() == flock.owner() && state <= ST_JOIN_STARTED) { 129 throw new IllegalStateException("join not called"); 130 } 131 } 132 133 /** 134 * Interrupts all threads in this scope, except the current thread. 135 */ 136 private void interruptAll() { 137 flock.threads() 138 .filter(t -> t != Thread.currentThread()) 139 .forEach(t -> { 140 try { 141 t.interrupt(); 142 } catch (Throwable ignore) { } 143 }); 144 } 145 146 /** 147 * Cancel the scope if not already cancelled. 148 */ 149 private void cancel() { 150 if (!cancelled && CANCELLED.compareAndSet(this, false, true)) { 151 // prevent new threads from starting 152 flock.shutdown(); 153 154 // interrupt all unfinished threads 155 interruptAll(); 156 157 // wakeup join 158 flock.wakeup(); 159 } 160 } 161 162 /** 163 * Schedules a task to cancel the scope on timeout. 164 */ 165 private void scheduleTimeout(Duration timeout) { 166 assert Thread.currentThread() == flock.owner() && timerTask == null; 167 timerTask = TimerSupport.schedule(timeout, () -> { 168 if (!cancelled) { 169 timeoutExpired = true; 170 cancel(); 171 } 172 }); 173 } 174 175 /** 176 * Cancels the timer task if set. 177 */ 178 private void cancelTimeout() { 179 assert Thread.currentThread() == flock.owner(); 180 if (timerTask != null) { 181 timerTask.cancel(false); 182 } 183 } 184 185 /** 186 * Invoked by the thread for a subtask when the subtask completes before scope is cancelled. 187 */ 188 private void onComplete(SubtaskImpl<? extends T> subtask) { 189 assert subtask.state() != Subtask.State.UNAVAILABLE; 190 if (joiner.onComplete(subtask)) { 191 cancel(); 192 } 193 } 194 195 @Override 196 public <U extends T> Subtask<U> fork(Callable<? extends U> task) { 197 Objects.requireNonNull(task); 198 ensureOwner(); 199 ensureNotJoined(); 200 201 var subtask = new SubtaskImpl<U>(this, task); 202 203 // notify joiner, even if cancelled 204 if (joiner.onFork(subtask)) { 205 cancel(); 206 } 207 208 if (!cancelled) { 209 // create thread to run task 210 Thread thread = threadFactory.newThread(subtask); 211 if (thread == null) { 212 throw new RejectedExecutionException("Rejected by thread factory"); 213 } 214 215 // attempt to start the thread 216 try { 217 flock.start(thread); 218 } catch (IllegalStateException e) { 219 // shutdown by another thread, or underlying flock is shutdown due 220 // to unstructured use 221 } 222 } 223 224 // force owner to join 225 state = ST_FORKED; 226 return subtask; 227 } 228 229 @Override 230 public <U extends T> Subtask<U> fork(Runnable task) { 231 Objects.requireNonNull(task); 232 return fork(() -> { task.run(); return null; }); 233 } 234 235 @Override 236 public R join() throws InterruptedException { 237 ensureOwner(); 238 ensureNotJoined(); 239 240 // join started 241 state = ST_JOIN_STARTED; 242 243 // wait for all subtasks, the scope to be cancelled, or interrupt 244 flock.awaitAll(); 245 246 // throw if timeout expired 247 if (timeoutExpired) { 248 throw new TimeoutException(); 249 } 250 cancelTimeout(); 251 252 // all subtasks completed or cancelled 253 state = ST_JOIN_COMPLETED; 254 255 // invoke joiner to get result 256 try { 257 return joiner.result(); 258 } catch (Throwable e) { 259 throw new FailedException(e); 260 } 261 } 262 263 @Override 264 public boolean isCancelled() { 265 return cancelled; 266 } 267 268 @Override 269 public void close() { 270 ensureOwner(); 271 int s = state; 272 if (s == ST_CLOSED) { 273 return; 274 } 275 276 // cancel the scope if join did not complete 277 if (s < ST_JOIN_COMPLETED) { 278 cancel(); 279 cancelTimeout(); 280 } 281 282 // wait for stragglers 283 try { 284 flock.close(); 285 } finally { 286 state = ST_CLOSED; 287 } 288 289 // throw ISE if the owner didn't join after forking 290 if (s == ST_FORKED) { 291 throw new IllegalStateException("Owner did not join after forking"); 292 } 293 } 294 295 @Override 296 public String toString() { 297 return flock.name(); 298 } 299 300 /** 301 * Subtask implementation, runs the task specified to the fork method. 302 */ 303 static final class SubtaskImpl<T> implements Subtask<T>, Runnable { 304 private static final AltResult RESULT_NULL = new AltResult(Subtask.State.SUCCESS); 305 306 private record AltResult(Subtask.State state, Throwable exception) { 307 AltResult(Subtask.State state) { 308 this(state, null); 309 } 310 } 311 312 private final StructuredTaskScopeImpl<? super T, ?> scope; 313 private final Callable<? extends T> task; 314 private volatile Object result; 315 316 SubtaskImpl(StructuredTaskScopeImpl<? super T, ?> scope, Callable<? extends T> task) { 317 this.scope = scope; 318 this.task = task; 319 } 320 321 @Override 322 public void run() { 323 T result = null; 324 Throwable ex = null; 325 try { 326 result = task.call(); 327 } catch (Throwable e) { 328 ex = e; 329 } 330 331 // nothing to do if scope is cancelled 332 if (scope.isCancelled()) 333 return; 334 335 // set result/exception and invoke onComplete 336 if (ex == null) { 337 this.result = (result != null) ? result : RESULT_NULL; 338 } else { 339 this.result = new AltResult(State.FAILED, ex); 340 } 341 scope.onComplete(this); 342 } 343 344 @Override 345 public Subtask.State state() { 346 Object result = this.result; 347 if (result == null) { 348 return State.UNAVAILABLE; 349 } else if (result instanceof AltResult alt) { 350 // null or failed 351 return alt.state(); 352 } else { 353 return State.SUCCESS; 354 } 355 } 356 357 @Override 358 public T get() { 359 scope.ensureJoinedIfOwner(); 360 Object result = this.result; 361 if (result instanceof AltResult) { 362 if (result == RESULT_NULL) return null; 363 } else if (result != null) { 364 @SuppressWarnings("unchecked") 365 T r = (T) result; 366 return r; 367 } 368 throw new IllegalStateException( 369 "Result is unavailable or subtask did not complete successfully"); 370 } 371 372 @Override 373 public Throwable exception() { 374 scope.ensureJoinedIfOwner(); 375 Object result = this.result; 376 if (result instanceof AltResult alt && alt.state() == State.FAILED) { 377 return alt.exception(); 378 } 379 throw new IllegalStateException( 380 "Exception is unavailable or subtask did not complete with exception"); 381 } 382 383 @Override 384 public String toString() { 385 String stateAsString = switch (state()) { 386 case UNAVAILABLE -> "[Unavailable]"; 387 case SUCCESS -> "[Completed successfully]"; 388 case FAILED -> { 389 Throwable ex = ((AltResult) result).exception(); 390 yield "[Failed: " + ex + "]"; 391 } 392 }; 393 return Objects.toIdentityString(this) + stateAsString; 394 } 395 } 396 397 /** 398 * Config implementation. 399 */ 400 record ConfigImpl(ThreadFactory threadFactory, 401 String name, 402 Duration timeout) implements Config { 403 static Config defaultConfig() { 404 return new ConfigImpl(Thread.ofVirtual().factory(), null, null); 405 } 406 407 @Override 408 public Config withThreadFactory(ThreadFactory threadFactory) { 409 return new ConfigImpl(Objects.requireNonNull(threadFactory), name, timeout); 410 } 411 412 @Override 413 public Config withName(String name) { 414 return new ConfigImpl(threadFactory, Objects.requireNonNull(name), timeout); 415 } 416 417 @Override 418 public Config withTimeout(Duration timeout) { 419 return new ConfigImpl(threadFactory, name, Objects.requireNonNull(timeout)); 420 } 421 } 422 423 /** 424 * Used to schedule a task to cancel the scope when a timeout expires. 425 */ 426 private static class TimerSupport { 427 private static final ScheduledExecutorService DELAYED_TASK_SCHEDULER; 428 static { 429 ScheduledThreadPoolExecutor stpe = (ScheduledThreadPoolExecutor) 430 Executors.newScheduledThreadPool(1, task -> { 431 Thread t = InnocuousThread.newThread("StructuredTaskScope-Timer", task); 432 t.setDaemon(true); 433 return t; 434 }); 435 stpe.setRemoveOnCancelPolicy(true); 436 DELAYED_TASK_SCHEDULER = stpe; 437 } 438 439 static Future<?> schedule(Duration timeout, Runnable task) { 440 long nanos = TimeUnit.NANOSECONDS.convert(timeout); 441 return DELAYED_TASK_SCHEDULER.schedule(task, nanos, TimeUnit.NANOSECONDS); 442 } 443 } 444 }