1 /*
  2  * Copyright (c) 2019, 2021, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 package java.util.concurrent;
 26 
 27 import java.lang.invoke.MethodHandles;
 28 import java.lang.invoke.VarHandle;
 29 import java.security.Permission;
 30 import java.util.ArrayList;
 31 import java.util.Collection;
 32 import java.util.Iterator;
 33 import java.util.List;
 34 import java.util.Objects;
 35 import java.util.Set;
 36 import java.util.concurrent.locks.LockSupport;
 37 import java.util.stream.Stream;
 38 import static java.util.concurrent.TimeUnit.NANOSECONDS;
 39 import jdk.internal.vm.SharedThreadContainer;
 40 
 41 /**
 42  * An ExecutorService that starts a new thread for each task. The number of
 43  * threads is unbounded.
 44  */
 45 class ThreadPerTaskExecutor implements ExecutorService {
 46     private static final Permission MODIFY_THREAD = new RuntimePermission("modifyThread");
 47     private static final VarHandle STATE;
 48     static {
 49         try {
 50             MethodHandles.Lookup l = MethodHandles.lookup();
 51             STATE = l.findVarHandle(ThreadPerTaskExecutor.class, "state", int.class);
 52         } catch (Exception e) {
 53             throw new InternalError(e);
 54         }
 55     }
 56 
 57     private final Set<Thread> threads = ConcurrentHashMap.newKeySet();
 58     private final CountDownLatch terminationSignal = new CountDownLatch(1);
 59 
 60     private final ThreadFactory factory;
 61     private final SharedThreadContainer container;
 62 
 63     // states: RUNNING -> SHUTDOWN -> TERMINATED
 64     private static final int RUNNING    = 0;
 65     private static final int SHUTDOWN   = 1;
 66     private static final int TERMINATED = 2;
 67     private volatile int state;
 68 
 69     /**
 70      * Constructs a thread-per-task executor that creates threads using the given
 71      * factory
 72      */
 73     ThreadPerTaskExecutor(ThreadFactory factory) {
 74         this.factory = Objects.requireNonNull(factory);
 75         String name = getClass().getName() + "@" + System.identityHashCode(this);
 76         this.container = SharedThreadContainer.create(name, this::threads);
 77     }
 78 
 79     /**
 80      * Throws SecurityException if there is a security manager set and it denies
 81      * RuntimePermission("modifyThread").
 82      */
 83     @SuppressWarnings("removal")
 84     private void checkPermission() {
 85         SecurityManager sm = System.getSecurityManager();
 86         if (sm != null) {
 87             sm.checkPermission(MODIFY_THREAD);
 88         }
 89     }
 90 
 91     /**
 92      * Throws RejectedExecutionException if the executor has been shutdown.
 93      */
 94     private void ensureNotShutdown() {
 95         if (state >= SHUTDOWN) {
 96             // shutdown or terminated
 97             throw new RejectedExecutionException();
 98         }
 99     }
100 
101     /**
102      * Attempts to terminate if already shutdown. If this method terminates the
103      * executor then it signals any threads that are waiting for termination.
104      */
105     private void tryTerminate() {
106         assert state >= SHUTDOWN;
107         if (threads.isEmpty()
108             && STATE.compareAndSet(this, SHUTDOWN, TERMINATED)) {
109 
110             // signal waiters
111             terminationSignal.countDown();
112 
113             // remove from registry
114             container.close();
115         }
116     }
117 
118     /**
119      * Attempts to shutdown and terminate the executor.
120      * If interruptThreads is true then all running threads are interrupted.
121      */
122     private void tryShutdownAndTerminate(boolean interruptThreads) {
123         if (STATE.compareAndSet(this, RUNNING, SHUTDOWN))
124             tryTerminate();
125         if (interruptThreads) {
126             threads.forEach(Thread::interrupt);
127         }
128     }
129 
130     private Stream<Thread> threads() {
131         return threads.stream();
132     }
133 
134     @Override
135     public void shutdown() {
136         checkPermission();
137         if (!isShutdown())
138             tryShutdownAndTerminate(false);
139     }
140 
141     @Override
142     public List<Runnable> shutdownNow() {
143         checkPermission();
144         if (!isTerminated())
145             tryShutdownAndTerminate(true);
146         return List.of();
147     }
148 
149     @Override
150     public boolean isShutdown() {
151         return state >= SHUTDOWN;
152     }
153 
154     @Override
155     public boolean isTerminated() {
156         return state >= TERMINATED;
157     }
158 
159     @Override
160     public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedException {
161         Objects.requireNonNull(unit);
162         if (isTerminated()) {
163             return true;
164         } else {
165             return terminationSignal.await(timeout, unit);
166         }
167     }
168 
169     /**
170      * Waits for executor to terminate.
171      */
172     private void awaitTermination() {
173         boolean terminated = isTerminated();
174         if (!terminated) {
175             tryShutdownAndTerminate(false);
176             boolean interrupted = false;
177             while (!terminated) {
178                 try {
179                     terminated = awaitTermination(1L, TimeUnit.DAYS);
180                 } catch (InterruptedException e) {
181                     if (!interrupted) {
182                         tryShutdownAndTerminate(true);
183                         interrupted = true;
184                     }
185                 }
186             }
187             if (interrupted) {
188                 Thread.currentThread().interrupt();
189             }
190         }
191     }
192 
193     @Override
194     public void close() {
195         checkPermission();
196         awaitTermination();
197     }
198 
199     /**
200      * Creates a thread to run the given task.
201      */
202     private Thread newThread(Runnable task) {
203         Thread thread = factory.newThread(task);
204         if (thread == null)
205             throw new RejectedExecutionException();
206         return thread;
207     }
208 
209     /**
210      * Notify the executor that the task executed by the given thread is complete.
211      * If the executor has been shutdown then this method will attempt to terminate
212      * the executor.
213      */
214     private void taskComplete(Thread thread) {
215         boolean removed = threads.remove(thread);
216         assert removed;
217         if (state == SHUTDOWN) {
218             tryTerminate();
219         }
220     }
221 
222     /**
223      * Adds a thread to the set of threads and starts it.
224      * @throws RejectedExecutionException
225      */
226     private void start(Thread thread) {
227         assert thread.getState() == Thread.State.NEW;
228         threads.add(thread);
229 
230         boolean started = false;
231         try {
232             if (state == RUNNING) {
233                 container.start(thread);
234                 started = true;
235             }
236         } finally {
237             if (!started) {
238                 taskComplete(thread);
239             }
240         }
241 
242         // throw REE if thread not started and no exception thrown
243         if (!started) {
244             throw new RejectedExecutionException();
245         }
246     }
247 
248     /**
249      * Starts a thread to execute the given task.
250      * @throws RejectedExecutionException
251      */
252     private Thread start(Runnable task) {
253         Objects.requireNonNull(task);
254         ensureNotShutdown();
255         Thread thread = newThread(new TaskRunner(this, task));
256         start(thread);
257         return thread;
258     }
259 
260     @Override
261     public void execute(Runnable task) {
262         start(task);
263     }
264     
265     @Override
266     public <T> Future<T> submit(Callable<T> task) {
267         Objects.requireNonNull(task);
268         ensureNotShutdown();
269         var future = new ThreadBoundFuture<>(this, task);
270         Thread thread = future.thread();
271         start(thread);
272         return future;
273     }
274 
275     @Override
276     public Future<?> submit(Runnable task) {
277         return submit(Executors.callable(task));
278     }
279 
280     @Override
281     public <T> Future<T> submit(Runnable task, T result) {
282         return submit(Executors.callable(task, result));
283     }
284 
285     /**
286      * Runs a task and notifies the executor when it completes.
287      */
288     private static class TaskRunner implements Runnable {
289         final ThreadPerTaskExecutor executor;
290         final Runnable task;
291         TaskRunner(ThreadPerTaskExecutor executor, Runnable task) {
292             this.executor = executor;
293             this.task = task;
294         }
295         @Override
296         public void run() {
297             try {
298                 task.run();
299             } finally {
300                 executor.taskComplete(Thread.currentThread());
301             }
302         }
303     }
304 
305     /**
306      * A Future for a task that runs in its own thread. The thread is
307      * created (but not started) when the Future is created. The thread
308      * is interrupted when the future is cancelled. The executor is
309      * notified when the task completes.
310      */
311     private static class ThreadBoundFuture<T>
312             extends CompletableFuture<T> implements Runnable {
313 
314         final ThreadPerTaskExecutor executor;
315         final Callable<T> task;
316         final Thread thread;
317 
318         ThreadBoundFuture(ThreadPerTaskExecutor executor, Callable<T> task) {
319             this.executor = executor;
320             this.task = task;
321             this.thread = executor.newThread(this);
322         }
323 
324         Thread thread() {
325             return thread;
326         }
327 
328         @Override
329         public void run() {
330             if (Thread.currentThread() != thread) {
331                 // should not happen except where something casts this object
332                 // to a Runnable and invokes the run method.
333                 throw new IllegalCallerException();
334             }
335             try {
336                 T result = task.call();
337                 complete(result);
338             } catch (Throwable e) {
339                 completeExceptionally(e);
340             } finally {
341                 executor.taskComplete(thread);
342             }
343         }
344 
345         @Override
346         public boolean cancel(boolean mayInterruptIfRunning) {
347             boolean cancelled = super.cancel(mayInterruptIfRunning);
348             if (cancelled && mayInterruptIfRunning)
349                 thread.interrupt();
350             return cancelled;
351         }
352     }
353 
354     @Override
355     public <T> List<Future<T>> invokeAll(Collection<? extends Callable<T>> tasks)
356             throws InterruptedException {
357 
358         Objects.requireNonNull(tasks);
359         List<Future<T>> futures = new ArrayList<>();
360         int j = 0;
361         try {
362             for (Callable<T> t : tasks) {
363                 Future<T> f = submit(t);
364                 futures.add(f);
365             }
366             for (int size = futures.size(); j < size; j++) {
367                 Future<T> f = futures.get(j);
368                 if (!f.isDone()) {
369                     try {
370                         f.get();
371                     } catch (ExecutionException | CancellationException ignore) { }
372                 }
373             }
374             return futures;
375         } finally {
376             cancelAll(futures, j);
377         }
378     }
379 
380     @Override
381     public <T> List<Future<T>> invokeAll(Collection<? extends Callable<T>> tasks,
382                                          long timeout, TimeUnit unit)
383             throws InterruptedException {
384 
385         Objects.requireNonNull(tasks);
386         long deadline = System.nanoTime() + unit.toNanos(timeout);
387         List<Future<T>> futures = new ArrayList<>();
388         int j = 0;
389         try {
390             for (Callable<T> t : tasks) {
391                 Future<T> f = submit(t);
392                 futures.add(f);
393             }
394             for (int size = futures.size(); j < size; j++) {
395                 Future<T> f = futures.get(j);
396                 if (!f.isDone()) {
397                     try {
398                         f.get(deadline - System.nanoTime(), NANOSECONDS);
399                     } catch (TimeoutException e) {
400                         break;
401                     } catch (ExecutionException | CancellationException ignore) { }
402                 }
403             }
404             return futures;
405         } finally {
406             cancelAll(futures, j);
407         }
408     }
409 
410     private <T> void cancelAll(List<Future<T>> futures, int j) {
411         for (int size = futures.size(); j < size; j++)
412             futures.get(j).cancel(true);
413     }
414 
415     @Override
416     public <T> T invokeAny(Collection<? extends Callable<T>> tasks)
417             throws InterruptedException, ExecutionException {
418         try {
419             return invokeAny(tasks, false, 0, null);
420         } catch (TimeoutException e) {
421             // should not happen
422             throw new InternalError(e);
423         }
424     }
425 
426     @Override
427     public <T> T invokeAny(Collection<? extends Callable<T>> tasks, long timeout, TimeUnit unit)
428             throws InterruptedException, ExecutionException, TimeoutException {
429         Objects.requireNonNull(unit);
430         return invokeAny(tasks, true, timeout, unit);
431     }
432 
433     private <T> T invokeAny(Collection<? extends Callable<T>> tasks,
434                             boolean timed,
435                             long timeout,
436                             TimeUnit unit)
437             throws InterruptedException, ExecutionException, TimeoutException {
438 
439         int size = tasks.size();
440         if (size == 0) {
441             throw new IllegalArgumentException("'tasks' is empty");
442         }
443 
444         var holder = new AnyResultHolder<T>(Thread.currentThread());
445         var threadList = new ArrayList<Thread>(size);
446         long nanos = (timed) ? unit.toNanos(timeout) : 0;
447         long startNanos = (timed) ? System.nanoTime() : 0;
448 
449         try {
450             int count = 0;
451             Iterator<? extends Callable<T>> iterator = tasks.iterator();
452             while (count < size && iterator.hasNext()) {
453                 Callable<T> task = iterator.next();
454                 Objects.requireNonNull(task);
455                 Thread thread = start(() -> {
456                     try {
457                         T r = task.call();
458                         holder.complete(r);
459                     } catch (Throwable e) {
460                         holder.completeExceptionally(e);
461                     }
462                 });
463                 threadList.add(thread);
464                 count++;
465             }
466             if (count == 0) {
467                 throw new IllegalArgumentException("'tasks' is empty");
468             }
469 
470             if (Thread.interrupted())
471                 throw new InterruptedException();
472             T result = holder.result();
473             while (result == null && holder.exceptionCount() < count) {
474                 if (timed) {
475                     long remainingNanos = nanos - (System.nanoTime() - startNanos);
476                     if (remainingNanos <= 0)
477                         throw new TimeoutException();
478                     LockSupport.parkNanos(remainingNanos);
479                 } else {
480                     LockSupport.park();
481                 }
482                 if (Thread.interrupted())
483                     throw new InterruptedException();
484                 result = holder.result();
485             }
486 
487             if (result != null) {
488                 return (result != AnyResultHolder.NULL) ? result : null;
489             } else {
490                 throw new ExecutionException(holder.firstException());
491             }
492 
493         } finally {
494             // interrupt any threads that are still running
495             for (Thread t : threadList) {
496                 if (t.isAlive()) {
497                     t.interrupt();
498                 }
499             }
500         }
501     }
502 
503     /**
504      * An object for use by invokeAny to hold the result of the first task
505      * to complete normally and/or the first exception thrown. The object
506      * also maintains a count of the number of tasks that attempted to
507      * complete up to when the first tasks completes normally.
508      */
509     private static class AnyResultHolder<T> {
510         private static final VarHandle RESULT;
511         private static final VarHandle EXCEPTION;
512         private static final VarHandle EXCEPTION_COUNT;
513         static {
514             try {
515                 MethodHandles.Lookup l = MethodHandles.lookup();
516                 RESULT = l.findVarHandle(AnyResultHolder.class, "result", Object.class);
517                 EXCEPTION = l.findVarHandle(AnyResultHolder.class, "exception", Throwable.class);
518                 EXCEPTION_COUNT = l.findVarHandle(AnyResultHolder.class, "exceptionCount", int.class);
519             } catch (Exception e) {
520                 throw new InternalError(e);
521             }
522         }
523         private static final Object NULL = new Object();
524 
525         private final Thread owner;
526         private volatile T result;
527         private volatile Throwable exception;
528         private volatile int exceptionCount;
529 
530         AnyResultHolder(Thread owner) {
531             this.owner = owner;
532         }
533 
534         /**
535          * Complete with the given result if not already completed. The winner
536          * unparks the owner thread.
537          */
538         void complete(T value) {
539             @SuppressWarnings("unchecked")
540             T v = (value != null) ? value : (T) NULL;
541             if (result == null && RESULT.compareAndSet(this, null, v)) {
542                 LockSupport.unpark(owner);
543             }
544         }
545 
546         /**
547          * Complete with the given exception. If the result is not already
548          * set then it unparks the owner thread.
549          */
550         void completeExceptionally(Throwable exc) {
551             if (result == null) {
552                 if (exception == null)
553                     EXCEPTION.compareAndSet(this, null, exc);
554                 EXCEPTION_COUNT.getAndAdd(this, 1);
555                 LockSupport.unpark(owner);
556             }
557         }
558 
559         /**
560          * Returns non-null if a task completed successfully. The result is
561          * NULL if completed with null.
562          */
563         T result() {
564             return result;
565         }
566 
567         /**
568          * Returns the first exception thrown if recorded by this object.
569          *
570          * @apiNote The result() method should be used to test if there is
571          * a result before invoking the exception method.
572          */
573         Throwable firstException() {
574             return exception;
575         }
576 
577         /**
578          * Returns the number of tasks that terminated with an exception before
579          * a task completed normally.
580          */
581         int exceptionCount() {
582             return exceptionCount;
583         }
584     }
585 }