1 /*
  2  * Copyright (c) 2019, 2022, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 package sun.nio.ch;
 26 
 27 import java.io.FileDescriptor;
 28 import java.io.IOException;
 29 import java.nio.channels.Pipe;
 30 import java.util.ArrayDeque;
 31 import java.util.Deque;
 32 import java.util.HashMap;
 33 import java.util.Map;
 34 import jdk.internal.misc.Unsafe;
 35 import jdk.internal.access.JavaLangAccess;
 36 import jdk.internal.access.SharedSecrets;
 37 
 38 /**
 39  * Implementation of Poller based on WSAPoll.
 40  *
 41  * KB4550945 needs to be installed, otherwise a socket registered to poll for
 42  * a connect to complete will not be polled when the connection cannot be
 43  * established.
 44  */
 45 
 46 class WSAPollPoller extends Poller {
 47     private static final JavaLangAccess JLA = SharedSecrets.getJavaLangAccess();
 48     private static final Unsafe UNSAFE = Unsafe.getUnsafe();
 49     private static final long TEMP_BUF = UNSAFE.allocateMemory(1);
 50     private static final NativeDispatcher ND = new SocketDispatcher();
 51 
 52     // initial capacity of poll array
 53     private static final int INITIAL_CAPACITY = 16;
 54 
 55     // poll array, grows as needed
 56     private long pollArrayAddress;
 57     private int pollArrayCapacity;  // allocated
 58     private int pollArraySize;      // in use
 59 
 60     // maps file descriptor to index in poll array
 61     private final Map<Integer, Integer> fdToIndex = new HashMap<>();
 62 
 63     // pipe and file descriptors used for wakeup
 64     private final Object wakeupLock = new Object();
 65     private boolean wakeupTriggered;
 66     private final Pipe pipe;
 67     private final FileDescriptor fd0, fd1;
 68 
 69     // registration updates
 70     private final Object updateLock = new Object();
 71     private final Deque<Integer> registerQueue = new ArrayDeque<>();
 72 
 73     // deregistration (stop) requests
 74     private static class DeregisterRequest {
 75         final int fdVal;
 76         DeregisterRequest(int fdVal) {
 77             this.fdVal = fdVal;
 78         }
 79         int fdVal() {
 80             return fdVal;
 81         }
 82     }
 83     private final Deque<DeregisterRequest> deregisterQueue = new ArrayDeque<>();
 84 
 85     /**
 86      * Creates a poller to support reading (POLLIN) or writing (POLLOUT)
 87      * operations.
 88      */
 89     WSAPollPoller(boolean read) throws IOException {
 90         super(read);
 91 
 92         this.pollArrayAddress = WSAPoll.allocatePollArray(INITIAL_CAPACITY);
 93         this.pollArrayCapacity = INITIAL_CAPACITY;
 94 
 95         // wakeup support
 96         this.pipe = makePipe();
 97         SourceChannelImpl source = (SourceChannelImpl) pipe.source();
 98         SinkChannelImpl sink = (SinkChannelImpl) pipe.sink();
 99         this.fd0 = source.getFD();
100         this.fd1 = sink.getFD();
101 
102         // element 0 in poll array is for wakeup.
103         putDescriptor(0, source.getFDVal());
104         putEvents(0, Net.POLLIN);
105         putRevents(0, (short) 0);
106         pollArraySize = 1;
107     }
108 
109     /**
110      * Register the file descriptor.
111      */
112     @Override
113     void implRegister(int fdVal) {
114         Integer fd = Integer.valueOf(fdVal);
115         synchronized (updateLock) {
116             registerQueue.add(fd);
117         }
118         wakeup();
119     }
120 
121     /**
122      * Deregister the file descriptor. This method waits until the poller thread
123      * has removed the file descriptor from the poll array.
124      */
125     @Override
126     void implDeregister(int fdVal) {
127         boolean interrupted = false;
128         var request = new DeregisterRequest(fdVal);
129         synchronized (request) {
130             synchronized (updateLock) {
131                 deregisterQueue.add(request);
132             }
133             wakeup();
134             try {
135                 request.wait();
136             } catch (InterruptedException e) {
137                 interrupted = true;
138             }
139         }
140         if (interrupted) {
141             Thread.currentThread().interrupt();
142         }
143     }
144 
145     @Override
146     int poll(int timeout) throws IOException {
147         // process any updates
148         synchronized (updateLock) {
149             processRegisterQueue();
150             processDeregisterQueue();
151         }
152 
153         // poll for wakeup and/or events
154         int numPolled = WSAPoll.poll(pollArrayAddress, pollArraySize, timeout);
155         boolean polledWakeup = (getRevents(0) != 0);
156         if (polledWakeup) {
157             numPolled--;
158         }
159         processEvents(numPolled);
160 
161         // clear wakeup
162         if (polledWakeup) {
163             clearWakeup();
164         }
165 
166         return numPolled;
167     }
168 
169     /**
170      * Process the queue of file descriptors to poll
171      */
172     private void processRegisterQueue() {
173         assert Thread.holdsLock(updateLock);
174         Integer fd;
175         while ((fd = registerQueue.pollFirst()) != null) {
176             short events = (reading()) ? Net.POLLIN : Net.POLLOUT;
177             int index = add(fd, events);
178             Integer previous = fdToIndex.put(fd, index);
179             assert previous == null;
180         }
181     }
182 
183     /**
184      * Process the queue of file descriptors to stop polling
185      */
186     private void processDeregisterQueue() {
187         assert Thread.holdsLock(updateLock);
188         DeregisterRequest request;
189         while ((request = deregisterQueue.pollFirst()) != null) {
190             Integer index = fdToIndex.remove(request.fdVal);
191             if (index != null) {
192                 remove(index);
193             }
194             synchronized (request) {
195                 request.notifyAll();
196             }
197         }
198     }
199 
200     /**
201      * Process the polled events, skipping the first (0) entry in the poll array
202      * as that is used by the wakeup mechanism.
203      *
204      * @param numPolled the number of polled sockets in the array (from index 1)
205      */
206     private void processEvents(int numPolled) {
207         int index = 1;
208         int remaining = numPolled;
209         while (index < pollArraySize && remaining > 0) {
210             short revents = getRevents(index);
211             if (revents != 0) {
212                 int fd = getDescriptor(index);
213                 assert fdToIndex.get(fd) == index;
214                 polled(fd);
215                 remove(index);
216                 fdToIndex.remove(fd);
217                 remaining--;
218             } else {
219                 index++;
220             }
221         }
222     }
223 
224     /**
225      * Wake up the poller thread
226      */
227     private void wakeup() {
228         synchronized (wakeupLock) {
229             if (!wakeupTriggered) {
230                 try {
231                     ND.write(fd1, TEMP_BUF, 1);
232                 } catch (IOException ioe) {
233                     throw new InternalError(ioe);
234                 }
235                 wakeupTriggered = true;
236             }
237         }
238     }
239 
240     /**
241      * Clear the wakeup event
242      */
243     private void clearWakeup() throws IOException {
244         synchronized (wakeupLock) {
245             ND.read(fd0, TEMP_BUF, 1);
246             putRevents(0, (short) 0);
247             wakeupTriggered = false;
248         }
249     }
250 
251     /**
252      * Add a pollfd entry to the poll array.
253      *
254      * @return the index of the pollfd entry in the poll array
255      */
256     private int add(int fd, short events) {
257         expandIfNeeded();
258         int index = pollArraySize;
259         assert index > 0;
260         putDescriptor(index, fd);
261         putEvents(index, events);
262         putRevents(index, (short) 0);
263         pollArraySize++;
264         return index;
265     }
266 
267     /**
268      * Removes a pollfd entry from the poll array.
269      */
270     private void remove(int index) {
271         assert index > 0 && index < pollArraySize;
272 
273         // replace pollfd at index with the last pollfd in array
274         int lastIndex = pollArraySize - 1;
275         if (lastIndex != index) {
276             int lastFd = getDescriptor(lastIndex);
277             short lastEvents = getEvents(lastIndex);
278             short lastRevents = getRevents(lastIndex);
279             putDescriptor(index, lastFd);
280             putEvents(index, lastEvents);
281             putRevents(index, lastRevents);
282 
283             assert fdToIndex.get(lastFd) == lastIndex;
284             fdToIndex.put(lastFd, index);
285         }
286         pollArraySize--;
287     }
288 
289     /**
290      * Expand poll array if at capacity.
291      */
292     private void expandIfNeeded() {
293         if (pollArraySize == pollArrayCapacity) {
294             int newCapacity = pollArrayCapacity + INITIAL_CAPACITY;
295             pollArrayAddress = WSAPoll.reallocatePollArray(pollArrayAddress, pollArrayCapacity, newCapacity);
296             pollArrayCapacity = newCapacity;
297         }
298     }
299 
300     /**
301      * Returns a PipeImpl. The creation is done on the carrier thread to avoid
302      * recursive parking when the loopback connection is created.
303      */
304     private static PipeImpl makePipe() throws IOException {
305         try {
306             return JLA.executeOnCarrierThread(() -> new PipeImpl(null, true, false));
307         } catch (IOException ioe) {
308             throw ioe;
309         } catch (Throwable e) {
310             throw new InternalError(e);
311         }
312     }
313 
314     private void putDescriptor(int i, int fd) {
315         WSAPoll.putDescriptor(pollArrayAddress, i, fd);
316     }
317 
318     private int getDescriptor(int i) {
319         return WSAPoll.getDescriptor(pollArrayAddress, i);
320     }
321 
322     private void putEvents(int i, short events) {
323         WSAPoll.putEvents(pollArrayAddress, i, events);
324     }
325 
326     private short getEvents(int i) {
327         return WSAPoll.getEvents(pollArrayAddress, i);
328     }
329 
330     private void putRevents(int i, short revents) {
331         WSAPoll.putRevents(pollArrayAddress, i, revents);
332     }
333 
334     private short getRevents(int i) {
335         return WSAPoll.getRevents(pollArrayAddress, i);
336     }
337 }