1 /*
  2  * Copyright (c) 2021, 2023, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 package jdk.internal.vm;
 26 
 27 import java.lang.ref.ReferenceQueue;
 28 import java.lang.ref.WeakReference;
 29 import java.util.Optional;
 30 import java.util.Set;
 31 import java.util.concurrent.ConcurrentHashMap;
 32 import java.util.concurrent.atomic.LongAdder;
 33 import java.util.stream.Stream;
 34 import jdk.internal.access.JavaLangAccess;
 35 import jdk.internal.access.SharedSecrets;
 36 import sun.security.action.GetPropertyAction;
 37 
 38 /**
 39  * This class consists exclusively of static methods to support groupings of threads.
 40  */
 41 public class ThreadContainers {
 42     private static final JavaLangAccess JLA = SharedSecrets.getJavaLangAccess();
 43 
 44     // true if all threads are tracked
 45     private static final boolean TRACK_ALL_THREADS;
 46 
 47     // the root container
 48     private static final RootContainer ROOT_CONTAINER;
 49 
 50     // the set of thread containers registered with this class
 51     private static final Set<WeakReference<ThreadContainer>> CONTAINER_REGISTRY = ConcurrentHashMap.newKeySet();
 52     private static final ReferenceQueue<Object> QUEUE = new ReferenceQueue<>();
 53 
 54     static {
 55         String s = GetPropertyAction.privilegedGetProperty("jdk.trackAllThreads");
 56         if (s != null && (s.isEmpty() || Boolean.parseBoolean(s))) {
 57             TRACK_ALL_THREADS = true;
 58             ROOT_CONTAINER = new RootContainer.TrackingRootContainer();
 59         } else {
 60             TRACK_ALL_THREADS = false;
 61             ROOT_CONTAINER = new RootContainer.CountingRootContainer();
 62         }
 63     }
 64 
 65     private ThreadContainers() { }
 66 
 67     /**
 68      * Expunge stale entries from the container registry.
 69      */
 70     private static void expungeStaleEntries() {
 71         Object key;
 72         while ((key = QUEUE.poll()) != null) {
 73             CONTAINER_REGISTRY.remove(key);
 74         }
 75     }
 76 
 77     /**
 78      * Returns true if all threads are tracked.
 79      */
 80     public static boolean trackAllThreads() {
 81         return TRACK_ALL_THREADS;
 82     }
 83 
 84     /**
 85      * Registers a thread container to be tracked this class, returning a key
 86      * that is used to remove it from the registry.
 87      */
 88     public static Object registerContainer(ThreadContainer container) {
 89         expungeStaleEntries();
 90         var ref = new WeakReference<>(container, QUEUE);
 91         CONTAINER_REGISTRY.add(ref);
 92         return ref;
 93     }
 94 
 95     /**
 96      * Removes a thread container from being tracked by specifying the key
 97      * returned when the thread container was registered.
 98      */
 99     public static void deregisterContainer(Object key) {
100         assert key instanceof WeakReference;
101         CONTAINER_REGISTRY.remove(key);
102     }
103 
104     /**
105      * Returns the root thread container.
106      */
107     public static ThreadContainer root() {
108         return ROOT_CONTAINER;
109     }
110 
111     /**
112      * Returns the parent of the given thread container.
113      *
114      * If the container has an owner then its parent is the enclosing container when
115      * nested, or the container that the owner is in, when not nested.
116      *
117      * If the container does not have an owner then the root container is returned,
118      * or null if called with the root container.
119      */
120     static ThreadContainer parent(ThreadContainer container) {
121         Thread owner = container.owner();
122         if (owner != null) {
123             ThreadContainer parent = container.enclosingScope(ThreadContainer.class);
124             if (parent != null)
125                 return parent;
126             if ((parent = container(owner)) != null)
127                 return parent;
128         }
129         ThreadContainer root = root();
130         return (container != root) ? root : null;
131     }
132 
133     /**
134      * Returns given thread container's "children".
135      */
136     static Stream<ThreadContainer> children(ThreadContainer container) {
137         // children of registered containers
138         Stream<ThreadContainer> s1 = CONTAINER_REGISTRY.stream()
139                 .map(WeakReference::get)
140                 .filter(c -> c != null && c.parent() == container);
141 
142         // container may enclose another container
143         Stream<ThreadContainer> s2 = Stream.empty();
144         if (container.owner() != null) {
145             ThreadContainer next = next(container);
146             if (next != null)
147                 s2 = Stream.of(next);
148         }
149 
150         // the top-most container owned by the threads in the container
151         Stream<ThreadContainer> s3 = container.threads()
152                 .map(t -> Optional.ofNullable(top(t)))
153                 .flatMap(Optional::stream);
154 
155         return Stream.concat(s1, Stream.concat(s2, s3));
156     }
157 
158     /**
159      * Returns the thread container that the given Thread is in or the root
160      * container if not started in a container.
161      * @throws IllegalStateException if the thread has not been started
162      */
163     public static ThreadContainer container(Thread thread) {
164         // thread container is set when the thread is started
165         if (thread.isAlive() || thread.getState() == Thread.State.TERMINATED) {
166             ThreadContainer container = JLA.threadContainer(thread);
167             return (container != null) ? container : root();
168         } else {
169             throw new IllegalStateException("Thread not started");
170         }
171     }
172 
173     /**
174      * Returns the top-most thread container owned by the given thread.
175      */
176     private static ThreadContainer top(Thread thread) {
177         StackableScope current = JLA.headStackableScope(thread);
178         ThreadContainer top = null;
179         while (current != null) {
180             if (current instanceof ThreadContainer tc) {
181                 top = tc;
182             }
183             current = current.previous();
184         }
185         return top;
186     }
187 
188     /**
189      * Returns the thread container that the given thread container encloses.
190      */
191     private static ThreadContainer next(ThreadContainer container) {
192         StackableScope current = JLA.headStackableScope(container.owner());
193         if (current != null) {
194             ThreadContainer next = null;
195             while (current != null) {
196                 if (current == container) {
197                     return next;
198                 } else if (current instanceof ThreadContainer tc) {
199                     next = tc;
200                 }
201                 current = current.previous();
202             }
203         }
204         return null;
205     }
206 
207     /**
208      * Root container that "contains" all platform threads not started in a
209      * container plus some (or all) virtual threads that are started directly
210      * with the Thread API.
211      */
212     private static abstract class RootContainer extends ThreadContainer {
213         protected RootContainer() {
214             super(true);
215         }
216         @Override
217         public ThreadContainer parent() {
218             return null;
219         }
220         @Override
221         public String name() {
222             return "<root>";
223         }
224         @Override
225         public StackableScope previous() {
226             return null;
227         }
228         @Override
229         public String toString() {
230             return name();
231         }
232 
233         /**
234          * Returns the platform threads that are not in the container as these
235          * threads are considered to be in the root container.
236          */
237         protected Stream<Thread> platformThreads() {
238             return Stream.of(JLA.getAllThreads())
239                     .filter(t -> JLA.threadContainer(t) == null);
240         }
241 
242         /**
243          * Root container that tracks all threads.
244          */
245         private static class TrackingRootContainer extends RootContainer {
246             private static final Set<Thread> VTHREADS = ConcurrentHashMap.newKeySet();
247             @Override
248             public void onStart(Thread thread) {
249                 assert thread.isVirtual();
250                 VTHREADS.add(thread);
251             }
252             @Override
253             public void onExit(Thread thread) {
254                 assert thread.isVirtual();
255                 VTHREADS.remove(thread);
256             }
257             @Override
258             public long threadCount() {
259                 return platformThreads().count() + VTHREADS.size();
260             }
261             @Override
262             public Stream<Thread> threads() {
263                 return Stream.concat(platformThreads(),
264                                      VTHREADS.stream().filter(Thread::isAlive));
265             }
266         }
267 
268         /**
269          * Root container that tracks all platform threads and just keeps a
270          * count of the virtual threads.
271          */
272         private static class CountingRootContainer extends RootContainer {
273             private static final LongAdder VTHREAD_COUNT = new LongAdder();
274             @Override
275             public void onStart(Thread thread) {
276                 assert thread.isVirtual();
277                 VTHREAD_COUNT.add(1L);
278             }
279             @Override
280             public void onExit(Thread thread) {
281                 assert thread.isVirtual();
282                 VTHREAD_COUNT.add(-1L);
283             }
284             @Override
285             public long threadCount() {
286                 return platformThreads().count() + VTHREAD_COUNT.sum();
287             }
288             @Override
289             public Stream<Thread> threads() {
290                 return platformThreads();
291             }
292         }
293     }
294 }