1 /*
  2  * Copyright (c) 2023, 2024, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.
  8  *
  9  * This code is distributed in the hope that it will be useful, but WITHOUT
 10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 12  * version 2 for more details (a copy is included in the LICENSE file that
 13  * accompanied this code).
 14  *
 15  * You should have received a copy of the GNU General Public License version
 16  * 2 along with this work; if not, write to the Free Software Foundation,
 17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 18  *
 19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 20  * or visit www.oracle.com if you need additional information or have any
 21  * questions.
 22  */
 23 
 24 #include <jni.h>
 25 #include <jvmti.h>
 26 #include <jvmti_common.hpp>
 27 #include <atomic>
 28 
 29 static jvmtiEnv *jvmti = nullptr;
 30 static jint error_count = 0;
 31 
 32 extern "C" JNIEXPORT jint JNICALL
 33 Agent_OnLoad(JavaVM *vm, char *options, void *reserved) {
 34   if (vm->GetEnv((void **)&jvmti, JVMTI_VERSION) != JNI_OK) {
 35     LOG("Could not initialize JVMTI\n");
 36     return JNI_ERR;
 37   }
 38   jvmtiCapabilities caps;
 39   memset(&caps, 0, sizeof(caps));
 40   caps.can_support_virtual_threads = 1;
 41   caps.can_suspend = 1;
 42   caps.can_signal_thread = 1;
 43   jvmtiError err = jvmti->AddCapabilities(&caps);
 44   if (err != JVMTI_ERROR_NONE) {
 45     LOG("JVMTI AddCapabilities error: %d\n", err);
 46     return JNI_ERR;
 47   }
 48 
 49   return JNI_OK;
 50 }
 51 
 52 
 53 extern "C" JNIEXPORT jboolean JNICALL
 54 Java_GetThreadStateMountedTest_trySuspendInWaitingState(JNIEnv* jni, jclass clazz, jthread vthread) {
 55   const int max_retries = 10;
 56   for (int i = 0; i < max_retries; i++) {
 57     // wait a bit
 58     sleep_ms(100);
 59 
 60     // suspend the thread
 61     LOG("suspend vthread (%d)\n", i);
 62     suspend_thread(jvmti, jni, vthread);
 63 
 64     jint state = get_thread_state(jvmti, jni, vthread);
 65     if ((state & JVMTI_THREAD_STATE_WAITING) != 0) {
 66       LOG("suspended in WAITING state\n");
 67       return JNI_TRUE;
 68     }
 69     LOG("suspended vthread is not waiting: state = %x (%s)\n", state, TranslateState(state));
 70     LOG("resume vthread\n");
 71     resume_thread(jvmti, jni, vthread);
 72   }
 73   LOG("ERROR: failed to suspend in WAITING state in %d tries\n", max_retries);
 74   return JNI_FALSE;
 75 
 76 }
 77 
 78 static void verify_thread_state(const char *name, JNIEnv* jni,
 79   jthread thread, jint expected_strong, jint expected_weak)
 80 {
 81   jint state = get_thread_state(jvmti, jni, thread);
 82   LOG("%s state(%x): %s\n", name, state, TranslateState(state));
 83   bool failed = false;
 84   // check 1: all expected_strong bits are set
 85   jint actual_strong = state & expected_strong;
 86   if (actual_strong != expected_strong) {
 87     failed = true;
 88     jint missed = expected_strong - actual_strong;
 89     LOG("  ERROR: some mandatory bits are not set (%x): %s\n",
 90         missed, TranslateState(missed));
 91   }
 92   // check 2: no bits other than (expected_strong | expected_weak) are set
 93   jint actual_full = state & (expected_strong | expected_weak);
 94   if (actual_full != state) {
 95     failed = true;
 96     jint unexpected = state - actual_full;
 97     LOG("  ERROR: some unexpected bits are set (%x): %s\n",
 98         unexpected, TranslateState(unexpected));
 99   }
100   // check 3: expected_weak checks
101   if (expected_weak != 0) {
102     // check 3a: at least 1 bit from expected_weak is set
103     if ((state & expected_weak) == 0) {
104       failed = true;
105       LOG("  ERROR: no expected 'weak' bits are set\n");
106     }
107     // check 3b: not all expected_weak bits are set
108     if ((state & expected_weak) == expected_weak) {
109       failed = true;
110       LOG("  ERROR: all expected 'weak' bits are set\n");
111     }
112   }
113 
114   if (failed) {
115     LOG("  expected 'strong' state (%x): %s\n", expected_strong, TranslateState(expected_strong));
116     LOG("  expected 'weak' state (%x): %s\n", expected_weak, TranslateState(expected_weak));
117     error_count++;
118   }
119 }
120 
121 extern "C" JNIEXPORT void JNICALL
122 Java_GetThreadStateMountedTest_testThread(
123   JNIEnv* jni, jclass clazz, jthread vthread, jboolean is_vthread_suspended,
124   jboolean test_interrupt,
125   jint expected_strong, jint expected_weak)
126 {
127   jint exp_ct_state = JVMTI_THREAD_STATE_ALIVE
128                       | JVMTI_THREAD_STATE_WAITING
129                       | JVMTI_THREAD_STATE_WAITING_INDEFINITELY;
130   jint exp_vt_state = expected_strong
131                       | JVMTI_THREAD_STATE_ALIVE;
132 
133   jthread cthread = get_carrier_thread(jvmti, jni, vthread);
134 
135   verify_thread_state("cthread", jni, cthread,
136                       exp_ct_state, 0);
137   verify_thread_state("vthread", jni, vthread,
138                       exp_vt_state | (is_vthread_suspended ? JVMTI_THREAD_STATE_SUSPENDED : 0),
139                       expected_weak);
140 
141   // suspend ctread and verify
142   LOG("suspend cthread\n");
143   suspend_thread(jvmti, jni, cthread);
144   verify_thread_state("cthread", jni, cthread,
145                       exp_ct_state | JVMTI_THREAD_STATE_SUSPENDED, 0);
146   verify_thread_state("vthread", jni, vthread,
147                       exp_vt_state | (is_vthread_suspended ? JVMTI_THREAD_STATE_SUSPENDED : 0),
148                       expected_weak);
149 
150   // suspend vthread and verify
151   if (!is_vthread_suspended) {
152     LOG("suspend vthread\n");
153     suspend_thread(jvmti, jni, vthread);
154     verify_thread_state("cthread", jni, cthread,
155                         exp_ct_state | JVMTI_THREAD_STATE_SUSPENDED, 0);
156     verify_thread_state("vthread", jni, vthread,
157                         exp_vt_state | JVMTI_THREAD_STATE_SUSPENDED, expected_weak);
158   }
159 
160   // resume cthread and verify
161   LOG("resume cthread\n");
162   resume_thread(jvmti, jni, cthread);
163   verify_thread_state("cthread", jni, cthread,
164                       exp_ct_state, 0);
165   verify_thread_state("vthread", jni, vthread,
166                       exp_vt_state | JVMTI_THREAD_STATE_SUSPENDED, expected_weak);
167 
168   if (test_interrupt) {
169     // interrupt vthread (while it's suspended)
170     LOG("interrupt vthread\n");
171     check_jvmti_status(jni, jvmti->InterruptThread(vthread), "error in JVMTI InterruptThread");
172     verify_thread_state("cthread", jni, cthread,
173                         exp_ct_state, 0);
174     verify_thread_state("vthread", jni, vthread,
175                         exp_vt_state | JVMTI_THREAD_STATE_SUSPENDED | JVMTI_THREAD_STATE_INTERRUPTED,
176                         expected_weak);
177   }
178 
179   // resume vthread
180   LOG("resume vthread\n");
181   resume_thread(jvmti, jni, vthread);
182 
183   // don't verify thread state after InterruptThread and ResumeThread
184 }
185 
186 extern "C" JNIEXPORT int JNICALL
187 Java_GetThreadStateMountedTest_getErrorCount(JNIEnv* jni, jclass clazz) {
188   return error_count;
189 }
190 
191 
192 static std::atomic<bool> time_to_exit(false);
193 
194 extern "C" JNIEXPORT void JNICALL
195 Java_GetThreadStateMountedTest_runFromNative(JNIEnv* jni, jclass clazz, jobject runnable) {
196   jmethodID mid = jni->GetStaticMethodID(clazz, "runUpcall", "(Ljava/lang/Runnable;)V");
197   if (mid == nullptr) {
198     jni->FatalError("failed to get runUpcall method");
199     return;
200   }
201   jni->CallStaticVoidMethod(clazz, mid, runnable);
202 }
203 
204 extern "C" JNIEXPORT void JNICALL
205 Java_GetThreadStateMountedTest_waitInNative(JNIEnv* jni, jclass clazz) {
206   // Notify main thread that we are ready
207   jfieldID fid = jni->GetStaticFieldID(clazz, "waitInNativeReady", "Z");
208   if (fid == nullptr) {
209     jni->FatalError("cannot get waitInNativeReady field");
210     return;
211   }
212   jni->SetStaticBooleanField(clazz, fid, JNI_TRUE);
213 
214   while (!time_to_exit) {
215     sleep_ms(100);
216   }
217 }
218 
219 extern "C" JNIEXPORT void JNICALL
220 Java_GetThreadStateMountedTest_endWait(JNIEnv* jni, jclass clazz) {
221   time_to_exit = true;
222 }