1 /*
  2  * Copyright (c) 2019, 2022, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.
  8  *
  9  * This code is distributed in the hope that it will be useful, but WITHOUT
 10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 12  * version 2 for more details (a copy is included in the LICENSE file that
 13  * accompanied this code).
 14  *
 15  * You should have received a copy of the GNU General Public License version
 16  * 2 along with this work; if not, write to the Free Software Foundation,
 17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 18  *
 19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 20  * or visit www.oracle.com if you need additional information or have any
 21  * questions.
 22  */
 23 
 24 /*
 25  * @test
 26  * @library /test/lib
 27  *
 28  * @requires !vm.graal.enabled
 29  * @enablePreview
 30  *
 31  * @run main/othervm/native -Xbatch -XX:CompileCommand=dontinline,*::test* -Xint                   -DTHROW=false -Xcheck:jni ClassInitBarrier
 32  * @run main/othervm/native -Xbatch -XX:CompileCommand=dontinline,*::test* -Xint                   -DTHROW=true  -Xcheck:jni ClassInitBarrier
 33  *
 34  * @run main/othervm/native -Xbatch -XX:CompileCommand=dontinline,*::test* -XX:TieredStopAtLevel=1 -DTHROW=false -Xcheck:jni ClassInitBarrier
 35  * @run main/othervm/native -Xbatch -XX:CompileCommand=dontinline,*::test* -XX:TieredStopAtLevel=1 -DTHROW=true  -Xcheck:jni ClassInitBarrier
 36  *
 37  * @run main/othervm/native -Xbatch -XX:CompileCommand=dontinline,*::test* -XX:-TieredCompilation  -DTHROW=false -Xcheck:jni ClassInitBarrier
 38  * @run main/othervm/native -Xbatch -XX:CompileCommand=dontinline,*::test* -XX:-TieredCompilation  -DTHROW=true  -Xcheck:jni ClassInitBarrier
 39  *
 40  * @run main/othervm/native -Xbatch -XX:CompileCommand=dontinline,*::test* -XX:TieredStopAtLevel=1 -DTHROW=false -XX:CompileCommand=dontinline,*::static* -Xcheck:jni ClassInitBarrier
 41  * @run main/othervm/native -Xbatch -XX:CompileCommand=dontinline,*::test* -XX:TieredStopAtLevel=1 -DTHROW=true  -XX:CompileCommand=dontinline,*::static* -Xcheck:jni ClassInitBarrier
 42  * @run main/othervm/native -Xbatch -XX:CompileCommand=dontinline,*::test* -XX:-TieredCompilation  -DTHROW=false -XX:CompileCommand=dontinline,*::static* -Xcheck:jni ClassInitBarrier
 43  * @run main/othervm/native -Xbatch -XX:CompileCommand=dontinline,*::test* -XX:-TieredCompilation  -DTHROW=true  -XX:CompileCommand=dontinline,*::static* -Xcheck:jni ClassInitBarrier
 44  *
 45  * @run main/othervm/native -Xbatch -XX:CompileCommand=dontinline,*::test* -XX:TieredStopAtLevel=1 -DTHROW=false -XX:CompileCommand=exclude,*::static* -Xcheck:jni ClassInitBarrier
 46  * @run main/othervm/native -Xbatch -XX:CompileCommand=dontinline,*::test* -XX:TieredStopAtLevel=1 -DTHROW=true  -XX:CompileCommand=exclude,*::static* -Xcheck:jni ClassInitBarrier
 47  * @run main/othervm/native -Xbatch -XX:CompileCommand=dontinline,*::test* -XX:-TieredCompilation  -DTHROW=false -XX:CompileCommand=exclude,*::static* -Xcheck:jni ClassInitBarrier
 48  * @run main/othervm/native -Xbatch -XX:CompileCommand=dontinline,*::test* -XX:-TieredCompilation  -DTHROW=true  -XX:CompileCommand=exclude,*::static* -Xcheck:jni ClassInitBarrier
 49  */
 50 
 51 import jdk.test.lib.Asserts;
 52 
 53 import java.util.*;
 54 import java.util.concurrent.atomic.AtomicBoolean;
 55 import java.util.concurrent.atomic.AtomicInteger;
 56 import java.util.function.Consumer;
 57 
 58 public class ClassInitBarrier {
 59     static {
 60         System.loadLibrary("ClassInitBarrier");
 61 
 62         if (!init()) {
 63             throw new Error("init failed");
 64         }
 65     }
 66 
 67     static native boolean init();
 68 
 69     static final boolean THROW = Boolean.getBoolean("THROW");
 70 
 71     static value class MyValue {
 72         int x = 42;
 73 
 74         void verify() {
 75             Asserts.assertEquals(x, 42);
 76         }
 77     }
 78 
 79     static class Test {
 80 
 81         static class A {
 82             static {
 83                 if (!init(B.class)) {
 84                     throw new Error("init failed");
 85                 }
 86 
 87                 changePhase(Phase.IN_PROGRESS);
 88                 runTests();      // interpreted mode
 89                 warmup();        // trigger compilation
 90                 runTests();      // compiled mode
 91 
 92                 ensureBlocked(); // ensure still blocked
 93                 maybeThrow();    // fail initialization if needed
 94 
 95                 changePhase(Phase.FINISHED);
 96             }
 97 
 98             static              void staticM(Runnable action, MyValue val) { action.run(); val.verify(); }
 99             static synchronized void staticS(Runnable action, MyValue val) { action.run(); val.verify(); }
100             static native       void staticN(Runnable action, MyValue val);
101 
102             static int staticF;
103 
104             int f;
105             void m() {}
106 
107             static native boolean init(Class<B> cls);
108         }
109 
110         static class B extends A {}
111 
112         static void testInvokeStatic(Runnable action, MyValue val)       { A.staticM(action, val); }
113         static void testInvokeStaticSync(Runnable action, MyValue val)   { A.staticS(action, val); }
114         static void testInvokeStaticNative(Runnable action, MyValue val) { A.staticN(action, val); }
115 
116         static int  testGetStatic(Runnable action, MyValue val)    { int v = A.staticF; action.run(); val.verify(); return v;   }
117         static void testPutStatic(Runnable action, MyValue val)    { A.staticF = 1;     action.run(); val.verify(); }
118         static A    testNewInstanceA(Runnable action, MyValue val) { A obj = new A();   action.run(); val.verify(); return obj; }
119         static B    testNewInstanceB(Runnable action, MyValue val) { B obj = new B();   action.run(); val.verify(); return obj; }
120 
121         static int  testGetField(A recv, Runnable action, MyValue val)      { int v = recv.f; action.run(); val.verify(); return v; }
122         static void testPutField(A recv, Runnable action, MyValue val)      { recv.f = 1;     action.run(); val.verify(); }
123         static void testInvokeVirtual(A recv, Runnable action, MyValue val) { recv.m();       action.run(); val.verify(); }
124 
125         static native void testInvokeStaticJNI(Runnable action, MyValue val);
126         static native void testInvokeStaticSyncJNI(Runnable action, MyValue val);
127         static native void testInvokeStaticNativeJNI(Runnable action, MyValue val);
128 
129         static native int  testGetStaticJNI(Runnable action, MyValue val);
130         static native void testPutStaticJNI(Runnable action, MyValue val);
131         static native A    testNewInstanceAJNI(Runnable action, MyValue val);
132         static native B    testNewInstanceBJNI(Runnable action, MyValue val);
133 
134         static native int  testGetFieldJNI(A recv, Runnable action, MyValue val);
135         static native void testPutFieldJNI(A recv, Runnable action, MyValue val);
136         static native void testInvokeVirtualJNI(A recv, Runnable action, MyValue val);
137 
138         static void runTests() {
139             checkBlockingAction(Test::testInvokeStatic);       // invokestatic
140             checkBlockingAction(Test::testInvokeStaticSync);   // invokestatic
141             checkBlockingAction(Test::testInvokeStaticNative); // invokestatic
142             checkBlockingAction(Test::testGetStatic);          // getstatic
143             checkBlockingAction(Test::testPutStatic);          // putstatic
144             checkBlockingAction(Test::testNewInstanceA);       // new
145 
146             checkNonBlockingAction(Test::testInvokeStaticJNI);       // invokestatic
147             checkNonBlockingAction(Test::testInvokeStaticSyncJNI);   // invokestatic
148             checkNonBlockingAction(Test::testInvokeStaticNativeJNI); // invokestatic
149             checkNonBlockingAction(Test::testGetStaticJNI);          // getstatic
150             checkNonBlockingAction(Test::testPutStaticJNI);          // putstatic
151             checkBlockingAction(Test::testNewInstanceAJNI);          // new
152 
153             A recv = testNewInstanceB(NON_BLOCKING.get(), new MyValue());  // trigger B initialization
154             checkNonBlockingAction(Test::testNewInstanceB); // new: NO BLOCKING: same thread: A being initialized, B fully initialized
155 
156             checkNonBlockingAction(recv, Test::testGetField);      // getfield
157             checkNonBlockingAction(recv, Test::testPutField);      // putfield
158             checkNonBlockingAction(recv, Test::testInvokeVirtual); // invokevirtual
159 
160             checkNonBlockingAction(Test::testNewInstanceBJNI);        // new: NO BLOCKING: same thread: A being initialized, B fully initialized
161             checkNonBlockingAction(recv, Test::testGetFieldJNI);      // getfield
162             checkNonBlockingAction(recv, Test::testPutFieldJNI);      // putfield
163             checkNonBlockingAction(recv, Test::testInvokeVirtualJNI); // invokevirtual
164         }
165 
166         static void warmup() {
167             MyValue val = new MyValue();
168             for (int i = 0; i < 20_000; i++) {
169                 testInvokeStatic(      NON_BLOCKING_WARMUP, val);
170                 testInvokeStaticNative(NON_BLOCKING_WARMUP, val);
171                 testInvokeStaticSync(  NON_BLOCKING_WARMUP, val);
172                 testGetStatic(         NON_BLOCKING_WARMUP, val);
173                 testPutStatic(         NON_BLOCKING_WARMUP, val);
174                 testNewInstanceA(      NON_BLOCKING_WARMUP, val);
175                 testNewInstanceB(      NON_BLOCKING_WARMUP, val);
176 
177                 testGetField(new B(),      NON_BLOCKING_WARMUP, val);
178                 testPutField(new B(),      NON_BLOCKING_WARMUP, val);
179                 testInvokeVirtual(new B(), NON_BLOCKING_WARMUP, val);
180             }
181         }
182 
183         static void run() {
184             execute(ExceptionInInitializerError.class, () -> triggerInitialization(A.class));
185             ensureFinished();
186             runTests(); // after initialization is over
187         }
188     }
189 
190     // ============================================================================================================== //
191 
192     static void execute(Class<? extends Throwable> expectedExceptionClass, Runnable action) {
193         try {
194             action.run();
195             if (THROW) throw failure("no exception thrown");
196         } catch (Throwable e) {
197             if (THROW) {
198                 if (e.getClass() == expectedExceptionClass) {
199                     // expected
200                 } else {
201                     String msg = String.format("unexpected exception thrown: expected %s, caught %s",
202                             expectedExceptionClass.getName(), e);
203                     throw failure(msg, e);
204                 }
205             } else {
206                 throw failure("no exception expected", e);
207             }
208         }
209     }
210 
211     private static AssertionError failure(String msg) {
212         return new AssertionError(phase + ": " + msg);
213     }
214 
215     private static AssertionError failure(String msg, Throwable e) {
216         return new AssertionError(phase + ": " + msg, e);
217     }
218 
219     static final List<Thread> BLOCKED_THREADS = Collections.synchronizedList(new ArrayList<>());
220     static final Consumer<Thread> ON_BLOCK = BLOCKED_THREADS::add;
221 
222     static final Map<Thread,Throwable> FAILED_THREADS = Collections.synchronizedMap(new HashMap<>());
223     static final Thread.UncaughtExceptionHandler ON_FAILURE = FAILED_THREADS::put;
224 
225     private static void ensureBlocked() {
226         for (Thread thr : BLOCKED_THREADS) {
227             try {
228                 thr.join(100);
229                 if (!thr.isAlive()) {
230                     dump(thr);
231                     throw new AssertionError("not blocked");
232                 }
233             } catch (InterruptedException e) {
234                 throw new Error(e);
235             }
236         }
237     }
238 
239 
240     private static void ensureFinished() {
241         for (Thread thr : BLOCKED_THREADS) {
242             try {
243                 thr.join(15_000);
244             } catch (InterruptedException e) {
245                 throw new Error(e);
246             }
247             if (thr.isAlive()) {
248                 dump(thr);
249                 throw new AssertionError(thr + ": still blocked");
250             }
251         }
252         for (Thread thr : BLOCKED_THREADS) {
253             if (THROW) {
254                 if (!FAILED_THREADS.containsKey(thr)) {
255                     throw new AssertionError(thr + ": exception not thrown");
256                 }
257 
258                 Throwable ex = FAILED_THREADS.get(thr);
259                 if (ex.getClass() != NoClassDefFoundError.class) {
260                     throw new AssertionError(thr + ": wrong exception thrown", ex);
261                 }
262             } else {
263                 if (FAILED_THREADS.containsKey(thr)) {
264                     Throwable ex = FAILED_THREADS.get(thr);
265                     throw new AssertionError(thr + ": exception thrown", ex);
266                 }
267             }
268         }
269         if (THROW) {
270             Asserts.assertEquals(BLOCKING_COUNTER.get(), 0);
271         } else {
272             Asserts.assertEquals(BLOCKING_COUNTER.get(), BLOCKING_ACTIONS.get());
273         }
274 
275         dumpInfo();
276     }
277 
278     interface TestCase0 {
279         void run(Runnable runnable, MyValue val);
280     }
281 
282     interface TestCase1<T> {
283         void run(T arg, Runnable runnable, MyValue val);
284     }
285 
286     enum Phase { BEFORE_INIT, IN_PROGRESS, FINISHED, INIT_FAILURE }
287 
288     static volatile Phase phase = Phase.BEFORE_INIT;
289 
290     static void changePhase(Phase newPhase) {
291         dumpInfo();
292 
293         Phase oldPhase = phase;
294         switch (oldPhase) {
295             case BEFORE_INIT:
296                 Asserts.assertEquals(NON_BLOCKING_ACTIONS.get(), 0);
297                 Asserts.assertEquals(NON_BLOCKING_COUNTER.get(), 0);
298 
299                 Asserts.assertEquals(BLOCKING_ACTIONS.get(),     0);
300                 Asserts.assertEquals(BLOCKING_COUNTER.get(),     0);
301                 break;
302             case IN_PROGRESS:
303                 Asserts.assertEquals(NON_BLOCKING_COUNTER.get(), NON_BLOCKING_ACTIONS.get());
304 
305                 Asserts.assertEquals(BLOCKING_COUNTER.get(), 0);
306                 break;
307             default: throw new Error("wrong phase transition " + oldPhase);
308         }
309         phase = newPhase;
310     }
311 
312     static void dumpInfo() {
313         System.out.println("Phase: " + phase);
314         System.out.println("Non-blocking actions: " + NON_BLOCKING_COUNTER.get() + " / " + NON_BLOCKING_ACTIONS.get());
315         System.out.println("Blocking actions:     " + BLOCKING_COUNTER.get()     + " / " + BLOCKING_ACTIONS.get());
316     }
317 
318     static final Runnable NON_BLOCKING_WARMUP = () -> {
319         if (phase != Phase.IN_PROGRESS) {
320             throw new AssertionError("NON_BLOCKING: wrong phase: " + phase);
321         }
322     };
323 
324     static Runnable disposableAction(final Phase validPhase, final AtomicInteger invocationCounter, final AtomicInteger actionCounter) {
325         actionCounter.incrementAndGet();
326 
327         final AtomicBoolean cnt = new AtomicBoolean(false);
328         return () -> {
329             if (cnt.getAndSet(true)) {
330                 throw new Error("repeated invocation");
331             }
332             invocationCounter.incrementAndGet();
333             if (phase != validPhase) {
334                 throw new AssertionError("NON_BLOCKING: wrong phase: " + phase);
335             }
336         };
337     }
338 
339     @FunctionalInterface
340     interface Factory<V> {
341         V get();
342     }
343 
344     static final AtomicInteger NON_BLOCKING_COUNTER = new AtomicInteger(0);
345     static final AtomicInteger NON_BLOCKING_ACTIONS = new AtomicInteger(0);
346     static final Factory<Runnable> NON_BLOCKING = () -> disposableAction(phase, NON_BLOCKING_COUNTER, NON_BLOCKING_ACTIONS);
347 
348     static final AtomicInteger BLOCKING_COUNTER = new AtomicInteger(0);
349     static final AtomicInteger BLOCKING_ACTIONS = new AtomicInteger(0);
350     static final Factory<Runnable> BLOCKING     = () -> disposableAction(Phase.FINISHED, BLOCKING_COUNTER, BLOCKING_ACTIONS);
351 
352     static void checkBlockingAction(TestCase0 r) {
353         MyValue val = new MyValue();
354         switch (phase) {
355             case IN_PROGRESS: {
356                 // Barrier during class initalization.
357                 r.run(NON_BLOCKING.get(), val);             // initializing thread
358                 checkBlocked(ON_BLOCK, ON_FAILURE, r); // different thread
359                 break;
360             }
361             case FINISHED: {
362                 // No barrier after class initalization is over.
363                 r.run(NON_BLOCKING.get(), val); // initializing thread
364                 checkNotBlocked(r);        // different thread
365                 break;
366             }
367             case INIT_FAILURE: {
368                 // Exception is thrown after class initialization failed.
369                 TestCase0 test = (action, valarg) -> execute(NoClassDefFoundError.class, () -> r.run(action, valarg));
370 
371                 test.run(NON_BLOCKING.get(), val); // initializing thread
372                 checkNotBlocked(test);        // different thread
373                 break;
374             }
375             default: throw new Error("wrong phase: " + phase);
376         }
377     }
378 
379     static void checkNonBlockingAction(TestCase0 r) {
380         r.run(NON_BLOCKING.get(), new MyValue()); // initializing thread
381         checkNotBlocked(r);        // different thread
382     }
383 
384     static <T> void checkNonBlockingAction(T recv, TestCase1<T> r) {
385         r.run(recv, NON_BLOCKING.get(), new MyValue());                  // initializing thread
386         checkNotBlocked((action, val) -> r.run(recv, action, val)); // different thread
387     }
388 
389     static void checkFailingAction(TestCase0 r) {
390         r.run(NON_BLOCKING.get(), new MyValue()); // initializing thread
391         checkNotBlocked(r);        // different thread
392     }
393 
394     static void triggerInitialization(Class<?> cls) {
395         try {
396             Class<?> loadedClass = Class.forName(cls.getName(), true, cls.getClassLoader());
397             if (loadedClass != cls) {
398                 throw new Error("wrong class");
399             }
400         } catch (ClassNotFoundException e) {
401             throw new Error(e);
402         }
403     }
404 
405     static void checkBlocked(Consumer<Thread> onBlockHandler, Thread.UncaughtExceptionHandler onException, TestCase0 r) {
406         Thread thr = new Thread(() -> {
407             try {
408                 r.run(BLOCKING.get(), new MyValue());
409                 System.out.println("Thread " + Thread.currentThread() + ": Finished successfully");
410             } catch(Throwable e) {
411                 System.out.println("Thread " + Thread.currentThread() + ": Exception thrown: " + e);
412                 if (!THROW) {
413                     e.printStackTrace();
414                 }
415                 throw e;
416             }
417         } );
418         thr.setUncaughtExceptionHandler(onException);
419 
420         thr.start();
421         try {
422             thr.join(100);
423 
424             dump(thr);
425             if (thr.isAlive()) {
426                 onBlockHandler.accept(thr); // blocked
427             } else {
428                 throw new AssertionError("not blocked");
429             }
430         } catch (InterruptedException e) {
431             throw new Error(e);
432         }
433     }
434 
435     static void checkNotBlocked(TestCase0 r) {
436         final Thread thr = new Thread(() -> r.run(NON_BLOCKING.get(), new MyValue()));
437         final Throwable[] ex = new Throwable[1];
438         thr.setUncaughtExceptionHandler((t, e) -> {
439             if (thr != t) {
440                 ex[0] = new Error("wrong thread: " + thr + " vs " + t);
441             } else {
442                 ex[0] = e;
443             }
444         });
445 
446         thr.start();
447         try {
448             thr.join(15_000);
449             if (thr.isAlive()) {
450                 dump(thr);
451                 throw new AssertionError("blocked");
452             }
453         } catch (InterruptedException e) {
454             throw new Error(e);
455         }
456 
457         if (ex[0] != null) {
458             throw new AssertionError("no exception expected", ex[0]);
459         }
460     }
461 
462     static void maybeThrow() {
463         if (THROW) {
464             changePhase(Phase.INIT_FAILURE);
465             throw new RuntimeException("failed class initialization");
466         }
467     }
468 
469     private static void dump(Thread thr) {
470         System.out.println("Thread: " + thr);
471         System.out.println("Thread state: " + thr.getState());
472         if (thr.isAlive()) {
473             for (StackTraceElement frame : thr.getStackTrace()) {
474                 System.out.println(frame);
475             }
476         } else {
477             if (FAILED_THREADS.containsKey(thr)) {
478                 System.out.println("Failed with an exception: ");
479                 FAILED_THREADS.get(thr).toString();
480             } else {
481                 System.out.println("Finished successfully");
482             }
483         }
484     }
485 
486     public static void main(String[] args) throws Exception {
487         Test.run();
488         System.out.println("TEST PASSED");
489     }
490 }