< prev index next >

src/jdk.management/share/classes/com/sun/management/internal/VirtualThreadSchedulerImpls.java

Print this page
@@ -22,10 +22,12 @@
   * or visit www.oracle.com if you need additional information or have any
   * questions.
   */
  package com.sun.management.internal;
  
+ import java.lang.reflect.InvocationTargetException;
+ import java.lang.reflect.Method;
  import java.util.concurrent.Executor;
  import java.util.concurrent.ForkJoinPool;
  import javax.management.ObjectName;
  import jdk.management.VirtualThreadSchedulerMXBean;
  import jdk.internal.access.JavaLangAccess;

@@ -88,61 +90,123 @@
      private static final class VirtualThreadSchedulerImpl extends BaseVirtualThreadSchedulerImpl {
          /**
           * Holder class for scheduler.
           */
          private static class Scheduler {
-             private static final Executor scheduler =
+             private static final Executor SCHEDULER =
                  SharedSecrets.getJavaLangAccess().virtualThreadDefaultScheduler();
              static Executor instance() {
-                 return scheduler;
+                 return SCHEDULER;
+             }
+         }
+ 
+         /**
+          * Reflective access for custom schedulers.
+          */
+         private static class SchedulerMethods {
+             private final static Method GET_PARALLELISM = findMethod("getParallelism");
+             private final static Method SET_PARALLELISM = findMethod("setParallelism", int.class);
+             private final static Method GET_POOL_SIZE = findMethod("getPoolSize");
+             private final static Method GET_MOUNTED_VTHREAD_COUNT = findMethod("getMountedVirtualThreadCount");
+             private final static Method GET_QUEUED_VTHREAD_COUNT = findMethod("getQueuedVirtualThreadCount");
+ 
+             static Method findMethod(String method, Class<?>... params) {
+                 try {
+                     return Scheduler.instance().getClass().getMethod(method, params);
+                 } catch (Exception e) {
+                     return null;
+                 }
              }
          }
  
          @Override
          public int getParallelism() {
-             if (Scheduler.instance() instanceof ForkJoinPool pool) {
+             Executor scheduler = Scheduler.instance();
+             if (scheduler instanceof ForkJoinPool pool) {
                  return pool.getParallelism();
              }
-             throw new InternalError();  // should not get here
+ 
+             // custom scheduler
+             if (SchedulerMethods.GET_PARALLELISM instanceof Method m) {
+                 return (int) invokeSchedulerMethod(m);
+             }
+ 
+             return -1;  // unknown
          }
  
          @Override
-         public void setParallelism(int size) {
-             if (Scheduler.instance() instanceof ForkJoinPool pool) {
+         public void setParallelism(int size) {Executor scheduler = Scheduler.instance();
+             if (scheduler instanceof ForkJoinPool pool) {
                  pool.setParallelism(size);
                  if (pool.getPoolSize() < size) {
                      // FJ worker thread creation is on-demand
                      Thread.startVirtualThread(() -> { });
                  }
- 
                  return;
              }
-             throw new UnsupportedOperationException();  // should not get here
+ 
+             // custom scheduler
+             if (SchedulerMethods.SET_PARALLELISM instanceof Method m) {
+                 invokeSchedulerMethod(m, size);
+             }
+ 
+             throw new UnsupportedOperationException();
          }
  
          @Override
          public int getPoolSize() {
-             if (Scheduler.instance() instanceof ForkJoinPool pool) {
+             Executor scheduler = Scheduler.instance();
+             if (scheduler instanceof ForkJoinPool pool) {
                  return pool.getPoolSize();
              }
-             return -1;  // should not get here
+ 
+             // custom scheduler
+             if (SchedulerMethods.GET_POOL_SIZE instanceof Method m) {
+                 return (int) invokeSchedulerMethod(m);
+             }
+ 
+             return -1;  // unknown
          }
  
          @Override
          public int getMountedVirtualThreadCount() {
-             if (Scheduler.instance() instanceof ForkJoinPool pool) {
+             Executor scheduler = Scheduler.instance();
+             if (scheduler instanceof ForkJoinPool pool) {
                  return pool.getActiveThreadCount();
              }
-             return -1;  // should not get here
+ 
+             // custom scheduler
+             if (SchedulerMethods.GET_MOUNTED_VTHREAD_COUNT instanceof Method m) {
+                 return (int) invokeSchedulerMethod(m);
+             }
+ 
+             return -1;  // unknown
          }
  
          @Override
          public long getQueuedVirtualThreadCount() {
-             if (Scheduler.instance() instanceof ForkJoinPool pool) {
+             Executor scheduler = Scheduler.instance();
+             if (scheduler instanceof ForkJoinPool pool) {
                  return pool.getQueuedTaskCount() + pool.getQueuedSubmissionCount();
              }
-             return -1L;  // should not get here
+ 
+             // custom scheduler
+             if (SchedulerMethods.GET_QUEUED_VTHREAD_COUNT instanceof Method m) {
+                 return (long) invokeSchedulerMethod(m);
+             }
+ 
+             return -1L;  // unknown
+         }
+ 
+         private static Object invokeSchedulerMethod(Method m, Object... args) {
+             try {
+                 return m.invoke(Scheduler.instance(), args);
+             } catch (InvocationTargetException e) {
+                 throw new RuntimeException(e.getCause());
+             } catch (IllegalAccessException e) {
+                 throw new RuntimeException(e);
+             }
          }
      }
  
      /**
       * Implementation of VirtualThreadSchedulerMXBean when virtual threads are backed
< prev index next >