1 /*
  2  * Copyright (c) 2018, 2022, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.
  8  *
  9  * This code is distributed in the hope that it will be useful, but WITHOUT
 10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 12  * version 2 for more details (a copy is included in the LICENSE file that
 13  * accompanied this code).
 14  *
 15  * You should have received a copy of the GNU General Public License version
 16  * 2 along with this work; if not, write to the Free Software Foundation,
 17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 18  *
 19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 20  * or visit www.oracle.com if you need additional information or have any
 21  * questions.
 22  */
 23 
 24 import jdk.internal.misc.TerminatingThreadLocal;
 25 
 26 import java.lang.reflect.Constructor;
 27 import java.lang.reflect.InvocationTargetException;
 28 import java.util.Arrays;
 29 import java.util.List;
 30 import java.util.concurrent.CopyOnWriteArrayList;
 31 import java.util.concurrent.Executor;
 32 import java.util.concurrent.Executors;
 33 import java.util.concurrent.ThreadFactory;
 34 import java.util.function.Consumer;
 35 import java.util.function.Function;
 36 import java.util.stream.Stream;
 37 
 38 import org.testng.annotations.DataProvider;
 39 import org.testng.annotations.Test;
 40 import static org.testng.Assert.*;
 41 
 42 /*
 43  * @test
 44  * @bug 8202788 8291897
 45  * @summary TerminatingThreadLocal unit test
 46  * @modules java.base/java.lang:+open java.base/jdk.internal.misc
 47  * @requires vm.continuations
 48  * @enablePreview
 49  * @run testng/othervm TestTerminatingThreadLocal
 50  */
 51 public class TestTerminatingThreadLocal {
 52 
 53     @SafeVarargs
 54     static <T> Object[] testCase(T initialValue,
 55                                  Consumer<? super TerminatingThreadLocal<T>> ttlOps,
 56                                  T... expectedTerminatedValues) {
 57         return new Object[] {initialValue, ttlOps, Arrays.asList(expectedTerminatedValues)};
 58     }
 59 
 60     static <T> Stream<Object[]> testCases(T v0, T v1) {
 61         return Stream.of(
 62             testCase(v0, ttl -> {                                         }    ),
 63             testCase(v0, ttl -> { ttl.get();                              }, v0),
 64             testCase(v0, ttl -> { ttl.get();   ttl.remove();              }    ),
 65             testCase(v0, ttl -> { ttl.get();   ttl.set(v1);               }, v1),
 66             testCase(v0, ttl -> { ttl.set(v1);                            }, v1),
 67             testCase(v0, ttl -> { ttl.set(v1); ttl.remove();              }    ),
 68             testCase(v0, ttl -> { ttl.set(v1); ttl.remove(); ttl.get();   }, v0),
 69             testCase(v0, ttl -> { ttl.get();   ttl.remove(); ttl.set(v1); }, v1)
 70         );
 71     }
 72 
 73     @DataProvider
 74     public Object[][] testCases() {
 75         return Stream.of(
 76             testCases(42, 112),
 77             testCases(null, new Object()),
 78             testCases("abc", null)
 79         ).flatMap(Function.identity()).toArray(Object[][]::new);
 80     }
 81 
 82     /**
 83      * Test TerminatingThreadLocal with a platform thread.
 84      */
 85     @Test(dataProvider = "testCases")
 86     public <T> void ttlTestPlatform(T initialValue,
 87                                     Consumer<? super TerminatingThreadLocal<T>> ttlOps,
 88                                     List<T> expectedTerminatedValues) throws Exception {
 89         List<T> terminatedValues = new CopyOnWriteArrayList<>();
 90 
 91         TerminatingThreadLocal<T> ttl = new TerminatingThreadLocal<>() {
 92             @Override
 93             protected void threadTerminated(T value) {
 94                 terminatedValues.add(value);
 95             }
 96 
 97             @Override
 98             protected T initialValue() {
 99                 return initialValue;
100             }
101         };
102 
103         Thread thread = new Thread(() -> ttlOps.accept(ttl), "ttl-test-platform");
104         thread.start();
105         thread.join();
106 
107         assertEquals(terminatedValues, expectedTerminatedValues);
108     }
109 
110     /**
111      * Test TerminatingThreadLocal with a virtual thread. The thread local should be
112      * carrier thread local but accessible to the virtual thread. The threadTerminated
113      * method should be invoked when the carrier thread terminates.
114      */
115     @Test(dataProvider = "testCases")
116     public <T> void ttlTestVirtual(T initialValue,
117                                    Consumer<? super TerminatingThreadLocal<T>> ttlOps,
118                                    List<T> expectedTerminatedValues) throws Exception {
119         List<T> terminatedValues = new CopyOnWriteArrayList<>();
120 
121         TerminatingThreadLocal<T> ttl = new TerminatingThreadLocal<>() {
122             @Override
123             protected void threadTerminated(T value) {
124                 terminatedValues.add(value);
125             }
126 
127             @Override
128             protected T initialValue() {
129                 return initialValue;
130             }
131         };
132 
133         Thread carrier;
134 
135         // use a single worker thread pool for the cheduler
136         try (var pool = Executors.newSingleThreadExecutor()) {
137 
138             // capture carrier Thread
139             carrier = pool.submit(Thread::currentThread).get();
140 
141             ThreadFactory factory = virtualThreadBuilder(pool)
142                     .name("ttl-test-virtual-", 0)
143                     .factory();
144             try (var executor = Executors.newThreadPerTaskExecutor(factory)) {
145                 executor.submit(() -> ttlOps.accept(ttl)).get();
146             }
147 
148             assertTrue(terminatedValues.isEmpty(),
149                        "Unexpected terminated values after virtual thread terminated");
150         }
151 
152         // wait for carrier to terminate
153         carrier.join();
154 
155         assertEquals(terminatedValues, expectedTerminatedValues);
156     }
157 
158     /**
159      * Returns a builder to create virtual threads that use the given scheduler.
160      */
161     static Thread.Builder.OfVirtual virtualThreadBuilder(Executor scheduler) {
162         try {
163             Class<?> clazz = Class.forName("java.lang.ThreadBuilders$VirtualThreadBuilder");
164             Constructor<?> ctor = clazz.getDeclaredConstructor(Executor.class);
165             ctor.setAccessible(true);
166             return (Thread.Builder.OfVirtual) ctor.newInstance(scheduler);
167         } catch (InvocationTargetException e) {
168             Throwable cause = e.getCause();
169             if (cause instanceof RuntimeException re) {
170                 throw re;
171             }
172             throw new RuntimeException(e);
173         } catch (Exception e) {
174             throw new RuntimeException(e);
175         }
176     }
177 }