1 /*
  2  * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.
  8  *
  9  * This code is distributed in the hope that it will be useful, but WITHOUT
 10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 12  * version 2 for more details (a copy is included in the LICENSE file that
 13  * accompanied this code).
 14  *
 15  * You should have received a copy of the GNU General Public License version
 16  * 2 along with this work; if not, write to the Free Software Foundation,
 17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 18  *
 19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 20  * or visit www.oracle.com if you need additional information or have any
 21  * questions.
 22  */
 23 
 24 /*
 25  * @test
 26  * @summary Stress test virtual threads with a variation of the Skynet 1M benchmark that uses
 27  *   a channel implementation based on object monitors. This variant uses a reduced number of
 28  *   100k virtual threads at the final level.
 29  * @requires vm.debug != true & vm.continuations
 30  * @run main/othervm/timeout=300 Skynet100kWithMonitors 50
 31  */
 32 
 33 /*
 34  * @test
 35  * @requires vm.debug == true & vm.continuations
 36  * @run main/othervm/timeout=300 Skynet100kWithMonitors 10
 37  */
 38 
 39 public class Skynet100kWithMonitors {
 40 
 41     public static void main(String[] args) {
 42         int iterations = (args.length) > 0 ? Integer.parseInt(args[0]) : 10;
 43         for (int i = 0; i < iterations; i++) {
 44             skynet(100_000, 4999950000L);
 45         }
 46     }
 47 
 48     static void skynet(int num, long expected) {
 49         long start = System.currentTimeMillis();
 50         var chan = new Channel<Long>();
 51 
 52         Thread.startVirtualThread(() -> skynet(chan, 0, num, 10));
 53 
 54         long sum = chan.receive();
 55         long end = System.currentTimeMillis();
 56         System.out.format("Result: %d in %s ms%n", sum, (end-start));
 57         if (sum != expected)
 58             throw new RuntimeException("Expected " + expected);
 59     }
 60 
 61     static void skynet(Channel<Long> result, int num, int size, int div) {
 62         if (size == 1) {
 63             result.send((long)num);
 64         } else {
 65             var chan = new Channel<Long>();
 66             for (int i = 0; i < div; i++) {
 67                 int subNum = num + i * (size / div);
 68                 Thread.startVirtualThread(() -> skynet(chan, subNum, size / div, div));
 69             }
 70             long sum = 0;
 71             for (int i = 0; i < div; i++) {
 72                 sum += chan.receive();
 73             }
 74             result.send(sum);
 75         }
 76     }
 77 
 78     static class Channel<T> {
 79         private final Object lock = new Object();
 80         private T element;
 81 
 82         Channel() {
 83         }
 84 
 85         void send(T e) {
 86             boolean interrupted = false;
 87             synchronized (lock) {
 88                 while (element != null) {
 89                     try {
 90                         lock.wait();
 91                     } catch (InterruptedException x) {
 92                         interrupted = true;
 93                     }
 94                 }
 95                 element = e;
 96                 lock.notifyAll();
 97             }
 98             if (interrupted)
 99                 Thread.currentThread().interrupt();
100         }
101 
102         T receive() {
103             T e;
104             boolean interrupted = false;
105             synchronized (lock) {
106                 while ((e = element) == null) {
107                     try {
108                         lock.wait();
109                     } catch (InterruptedException x) {
110                         interrupted = true;
111                     }
112                 }
113                 element = null;
114                 lock.notifyAll();
115             }
116             if (interrupted)
117                 Thread.currentThread().interrupt();
118             return e;
119         }
120     }
121 }