1 /*
  2  * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.
  8  *
  9  * This code is distributed in the hope that it will be useful, but WITHOUT
 10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 12  * version 2 for more details (a copy is included in the LICENSE file that
 13  * accompanied this code).
 14  *
 15  * You should have received a copy of the GNU General Public License version
 16  * 2 along with this work; if not, write to the Free Software Foundation,
 17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 18  *
 19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 20  * or visit www.oracle.com if you need additional information or have any
 21  * questions.
 22  */
 23 
 24 #include <jvmti.h>
 25 #include "jvmti_common.hpp"
 26 
 27 static jvmtiEnv *jvmti = nullptr;
 28 
 29 extern "C" JNIEXPORT void JNICALL
 30 Java_ValueHeapwalkingTest_setTag(JNIEnv* jni_env, jclass clazz, jobject object, jlong tag) {
 31   jvmtiError err = jvmti->SetTag(object, tag);
 32   check_jvmti_error(err, "could not set tag");
 33 }
 34 
 35 extern "C" JNIEXPORT jlong JNICALL
 36 Java_ValueHeapwalkingTest_getTag(JNIEnv* jni_env, jclass clazz, jobject object) {
 37   jlong tag;
 38   check_jvmti_error(jvmti->GetTag(object, &tag), "could not get tag");
 39   return tag;
 40 }
 41 
 42 const int TAG_VALUE_CLASS = 1;
 43 const int TAG_VALUE2_CLASS = 2;
 44 const int TAG_HOLDER_CLASS = 3;
 45 const int TAG_VALUE_ARRAY = 4;
 46 const int TAG_VALUE3_ARRAY = 5;
 47 const int MAX_TAG = 5;
 48 const int START_TAG = 10; // start value for tagging objects
 49 
 50 static const char* tag_str(jlong tag) {
 51   switch (tag) {
 52   case 0: return "None";
 53   case TAG_VALUE_CLASS: return "Value class";
 54   case TAG_VALUE2_CLASS: return "Value2 class";
 55   case TAG_HOLDER_CLASS: return "ValueHolder class";
 56   case TAG_VALUE_ARRAY: return "Value[] object";
 57   case TAG_VALUE3_ARRAY: return "Value2[] object";
 58   }
 59   return "Unknown";
 60 }
 61 
 62 struct Callback_Data  {
 63   // Updated by heap_iteration_callback.
 64   jint counters[MAX_TAG + 1];
 65   // Updated by heap_reference_callback.
 66   jint ref_counters[MAX_TAG + 1][MAX_TAG + 1];
 67   // Updated by primitive_field_callback.
 68   jint primitive_counters[MAX_TAG + 1];
 69   jlong tag_counter;
 70 };
 71 
 72 static Callback_Data callbackData;
 73 
 74 extern "C" JNIEXPORT void JNICALL
 75 Java_ValueHeapwalkingTest_reset(JNIEnv* jni_env, jclass clazz) {
 76   memset(&callbackData, 0, sizeof(callbackData));
 77   callbackData.tag_counter = START_TAG;
 78 }
 79 
 80 extern "C" JNIEXPORT jint JNICALL
 81 Java_ValueHeapwalkingTest_count(JNIEnv* jni_env, jclass clazz, jint tag) {
 82   return callbackData.counters[tag];
 83 }
 84 
 85 extern "C" JNIEXPORT jint JNICALL
 86 Java_ValueHeapwalkingTest_refCount(JNIEnv* jni_env, jclass clazz, jint fromTag, jint toTag) {
 87   return callbackData.ref_counters[fromTag][toTag];
 88 }
 89 
 90 extern "C" JNIEXPORT jint JNICALL
 91 Java_ValueHeapwalkingTest_primitiveFieldCount(JNIEnv* jni_env, jclass clazz, jint tag) {
 92   return callbackData.primitive_counters[tag];
 93 }
 94 
 95 extern "C" JNIEXPORT jlong JNICALL
 96 Java_ValueHeapwalkingTest_getMaxTag(JNIEnv* jni_env, jclass clazz) {
 97   return callbackData.tag_counter;
 98 }
 99 
100 static jlong safe_deref(jlong* ref) {
101     return ref == nullptr ? 0 : *ref;
102 }
103 
104 static jint JNICALL
105 heap_iteration_callback(jlong class_tag,
106                         jlong size,
107                         jlong* tag_ptr,
108                         jint length,
109                         void* user_data) {
110   Callback_Data* data = (Callback_Data*)user_data;
111 
112   if (class_tag != 0 && class_tag <= MAX_TAG) {
113     data->counters[class_tag]++;
114     printf("heap_iteration_callback: class_tag = %d (%s), tag = %d (%s), length = %d\n",
115            (int)class_tag, tag_str(class_tag), (int)*tag_ptr, tag_str(*tag_ptr), length);
116     fflush(nullptr);
117   }
118   return 0;
119 }
120 
121 static jint JNICALL
122 heap_reference_callback(jvmtiHeapReferenceKind reference_kind,
123                         const jvmtiHeapReferenceInfo* reference_info,
124                         jlong class_tag,
125                         jlong referrer_class_tag,
126                         jlong size,
127                         jlong* tag_ptr,
128                         jlong* referrer_tag_ptr,
129                         jint length,
130                         void* user_data) {
131   Callback_Data* data = (Callback_Data*)user_data;
132 
133   jlong tag = class_tag;
134   if (tag == 0 && *tag_ptr != 0 && *tag_ptr <= MAX_TAG) {
135     tag = *tag_ptr;
136   }
137   jlong referrer_tag = referrer_class_tag;
138   if (referrer_tag == 0 && safe_deref(referrer_tag_ptr) != 0 && safe_deref(referrer_tag_ptr) <= MAX_TAG) {
139     referrer_tag = *referrer_tag_ptr;
140   }
141 
142   if (tag != 0 && referrer_tag != 0) {
143     // For testing we count only JVMTI_HEAP_REFERENCE_FIELD and JVMTI_HEAP_REFERENCE_ARRAY_ELEMENT references.
144     if (reference_kind == JVMTI_HEAP_REFERENCE_FIELD || reference_kind == JVMTI_HEAP_REFERENCE_ARRAY_ELEMENT) {
145       data->ref_counters[referrer_tag][tag]++;
146     }
147 
148     jlong cur_tag = *tag_ptr;
149     char new_tag_str[64] = {};
150     if (*tag_ptr == 0) { // i.e. class_tag != 0, but the object is untagged
151       *tag_ptr = ++data->tag_counter;
152       snprintf(new_tag_str, sizeof(new_tag_str), ", set tag to %d", (int)*tag_ptr);
153     }
154     printf("heap_reference_callback: kind = %d, class_tag = %d (%s), tag = %d (%s), referrer_tag = %d (%s) %s\n",
155            (int)reference_kind, (int)class_tag, tag_str(class_tag), (int)cur_tag, tag_str(*tag_ptr),
156            (int)referrer_tag, tag_str(referrer_tag), new_tag_str);
157     fflush(nullptr);
158   }
159 
160   return JVMTI_VISIT_OBJECTS;
161 }
162 
163 static jint JNICALL
164 primitive_field_callback(jvmtiHeapReferenceKind kind,
165                          const jvmtiHeapReferenceInfo* info,
166                          jlong object_class_tag,
167                          jlong* object_tag_ptr,
168                          jvalue value,
169                          jvmtiPrimitiveType value_type,
170                          void* user_data) {
171   Callback_Data* data = (Callback_Data*)user_data;
172   if (object_class_tag != 0) {
173     char value_str[64] = {};
174     switch (value_type) {
175     case JVMTI_PRIMITIVE_TYPE_BOOLEAN: snprintf(value_str, sizeof(value_str), "(boolean) %s", value.z ? "true" : "false"); break;
176     case JVMTI_PRIMITIVE_TYPE_BYTE:    snprintf(value_str, sizeof(value_str), "(byte) %d", value.b); break;
177     case JVMTI_PRIMITIVE_TYPE_CHAR:    snprintf(value_str, sizeof(value_str), "(char) %c", value.c); break;
178     case JVMTI_PRIMITIVE_TYPE_SHORT:   snprintf(value_str, sizeof(value_str), "(short): %d", value.s); break;
179     case JVMTI_PRIMITIVE_TYPE_INT:     snprintf(value_str, sizeof(value_str), "(int): %d", value.i); break;
180     case JVMTI_PRIMITIVE_TYPE_LONG:    snprintf(value_str, sizeof(value_str), "(long): %lld", (long long)value.j); break;
181     case JVMTI_PRIMITIVE_TYPE_FLOAT:   snprintf(value_str, sizeof(value_str), "(float): %f", value.f); break;
182     case JVMTI_PRIMITIVE_TYPE_DOUBLE:  snprintf(value_str, sizeof(value_str), "(double): %f", value.d);  break;
183     default: snprintf(value_str, sizeof(value_str), "invalid_type %d (%c)", (int)value_type, (char)value_type);
184     }
185 
186     if (object_class_tag != 0 && object_class_tag <= MAX_TAG) {
187       data->primitive_counters[object_class_tag]++;
188       if (*object_tag_ptr != 0) {
189         *object_tag_ptr = *object_tag_ptr;
190       }
191     }
192 
193     printf("primitive_field_callback: kind = %d, class_tag = %d (%s), tag = %d (%s), value = %s\n",
194            (int)kind, (int)object_class_tag, tag_str(object_class_tag),
195            (int)*object_tag_ptr, tag_str(*object_tag_ptr), value_str);
196     fflush(nullptr);
197   }
198   return 0;
199 }
200 
201 static jint JNICALL
202 array_primitive_value_callback(jlong class_tag,
203                                jlong size,
204                                jlong* tag_ptr,
205                                jint element_count,
206                                jvmtiPrimitiveType element_type,
207                                const void* elements,
208                                void* user_data) {
209   Callback_Data* data = (Callback_Data*)user_data;
210   if (class_tag != 0 || *tag_ptr != 0) {
211     printf("array_primitive_value_callback: class_tag = %d (%s), tag = %d (%s), element_count = %d, element_type = %c\n",
212            (int)class_tag, tag_str(class_tag), (int)*tag_ptr, tag_str(*tag_ptr), element_count, (char)element_type);
213     fflush(nullptr);
214   }
215   return 0;
216 }
217 
218 static jint JNICALL
219 string_primitive_value_callback(jlong class_tag,
220                                 jlong size,
221                                 jlong* tag_ptr,
222                                 const jchar* value,
223                                 jint value_length,
224                                 void* user_data) {
225   Callback_Data* data = (Callback_Data*)user_data;
226   if (class_tag != 0 || *tag_ptr != 0) {
227     jchar value_copy[1024] = {}; // fills with 0
228     if (value_length > 1023) {
229       value_length = 1023;
230     }
231     memcpy(value_copy, value, value_length * sizeof(jchar));
232     printf("string_primitive_value_callback: class_tag = %d (%s), tag = %d (%s), value=\"%ls\"\n",
233            (int)class_tag, tag_str(class_tag), (int)*tag_ptr, tag_str(*tag_ptr), (wchar_t*)value_copy);
234     fflush(nullptr);
235   }
236   return 0;
237 }
238 
239 extern "C" JNIEXPORT void JNICALL
240 Java_ValueHeapwalkingTest_followReferences(JNIEnv* jni_env, jclass clazz) {
241   jvmtiHeapCallbacks callbacks = {};
242   callbacks.heap_iteration_callback = heap_iteration_callback;
243   callbacks.heap_reference_callback = heap_reference_callback;
244   callbacks.primitive_field_callback = primitive_field_callback;
245   callbacks.array_primitive_value_callback = array_primitive_value_callback;
246   callbacks.string_primitive_value_callback = string_primitive_value_callback;
247 
248   jvmtiError err = jvmti->FollowReferences(0 /* filter nothing */,
249                                            nullptr /* no class filter */,
250                                            nullptr /* no initial object, follow roots */,
251                                            &callbacks,
252                                            &callbackData);
253   check_jvmti_error(err, "FollowReferences failed");
254 }
255 
256 extern "C" JNIEXPORT void JNICALL
257 Java_ValueHeapwalkingTest_iterateThroughHeap(JNIEnv* jni_env, jclass clazz) {
258   jvmtiHeapCallbacks callbacks = {};
259   callbacks.heap_iteration_callback = heap_iteration_callback;
260   callbacks.heap_reference_callback = heap_reference_callback;
261   callbacks.primitive_field_callback = primitive_field_callback;
262   callbacks.array_primitive_value_callback = array_primitive_value_callback;
263   callbacks.string_primitive_value_callback = string_primitive_value_callback;
264 
265   jvmtiError err = jvmti->IterateThroughHeap(0 /* filter nothing */,
266                                              nullptr /* no class filter */,
267                                              &callbacks,
268                                              &callbackData);
269   check_jvmti_error(err, "IterateThroughHeap failed");
270 }
271 
272 extern "C" JNIEXPORT jint JNICALL
273 Java_ValueHeapwalkingTest_getObjectWithTags(JNIEnv* jni_env, jclass clazz, jlong minTag, jlong maxTag, jobjectArray objects, jlongArray tags) {
274   jsize len = jni_env->GetArrayLength(objects);
275 
276   jint tag_count = (jint)(maxTag - minTag + 1);
277   jlong* scan_tags = nullptr;
278   check_jvmti_error(jvmti->Allocate(tag_count * sizeof(jlong), (unsigned char**)&scan_tags),
279                     "Allocate failed");
280 
281   for (jlong i = 0; i < tag_count; i++) {
282       scan_tags[i] = i + minTag;
283   }
284 
285   jint count = 0;
286   jobject* object_result = nullptr;
287   jlong* tag_result = nullptr;
288 
289   check_jvmti_error(jvmti->GetObjectsWithTags(tag_count, scan_tags, &count, &object_result, &tag_result),
290                     "GetObjectsWithTags failed");
291 
292   if (count > len) {
293     printf("GetObjectsWithTags returned too many entries: %d (object length is %d)\n", count, (int)len);
294     fflush(nullptr);
295     abort();
296   }
297 
298   for (jint i = 0; i < count; i++) {
299     jni_env->SetObjectArrayElement(objects, i, object_result[i]);
300   }
301   jni_env->SetLongArrayRegion(tags, 0, count, tag_result);
302 
303   jvmti->Deallocate((unsigned char*)scan_tags);
304   jvmti->Deallocate((unsigned char*)object_result);
305   jvmti->Deallocate((unsigned char*)tag_result);
306 
307   return count;
308 }
309 
310 extern "C" JNIEXPORT jint JNICALL Agent_OnLoad(JavaVM *vm, char *options, void *reserved) {
311   if (vm->GetEnv(reinterpret_cast<void **>(&jvmti), JVMTI_VERSION) != JNI_OK || !jvmti) {
312     LOG("Could not initialize JVMTI\n");
313     abort();
314   }
315   jvmtiCapabilities capabilities;
316   memset(&capabilities, 0, sizeof(capabilities));
317   capabilities.can_tag_objects = 1;
318   check_jvmti_error(jvmti->AddCapabilities(&capabilities), "adding capabilities");
319   return JVMTI_ERROR_NONE;
320 }
321