1 /*
  2  * Copyright (c) 2019, 2021, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 package sun.nio.ch;
 26 
 27 import java.io.FileDescriptor;
 28 import java.io.IOException;
 29 import java.nio.channels.Pipe;
 30 import java.util.ArrayDeque;
 31 import java.util.Deque;
 32 import java.util.HashMap;
 33 import java.util.Map;
 34 import jdk.internal.misc.Unsafe;
 35 import jdk.internal.access.JavaLangAccess;
 36 import jdk.internal.access.SharedSecrets;
 37 
 38 /**
 39  * Implementation of Poller based on WSAPoll.
 40  *
 41  * KB4550945 needs to be installed, otherwise a socket registered to poll for
 42  * a connect to complete will not be polled when the connection cannot be
 43  * established.
 44  */
 45 
 46 class WSAPollPoller extends Poller {
 47     private static final JavaLangAccess JLA = SharedSecrets.getJavaLangAccess();
 48     private static final Unsafe UNSAFE = Unsafe.getUnsafe();
 49     private static final long TEMP_BUF = UNSAFE.allocateMemory(1);
 50     private static final NativeDispatcher ND = new SocketDispatcher();
 51 
 52     // initial capacity of poll array
 53     private static final int INITIAL_CAPACITY = 16;
 54 
 55     // true if this is a poller for reading (POLLIN), false for writing (POLLOUT)
 56     private final boolean read;
 57 
 58     // poll array, grows as needed
 59     private long pollArrayAddress;
 60     private int pollArrayCapacity;  // allocated
 61     private int pollArraySize;      // in use
 62 
 63     // maps file descriptor to index in poll array
 64     private final Map<Integer, Integer> fdToIndex = new HashMap<>();
 65 
 66     // pipe and file descriptors used for wakeup
 67     private final Object wakeupLock = new Object();
 68     private boolean wakeupTriggered;
 69     private final Pipe pipe;
 70     private final FileDescriptor fd0, fd1;
 71 
 72     // registration updates
 73     private final Object updateLock = new Object();
 74     private final Deque<Integer> registerQueue = new ArrayDeque<>();
 75 
 76     // deregistration (stop) requests
 77     private static class DeregisterRequest {
 78         final int fdVal;
 79         DeregisterRequest(int fdVal) {
 80             this.fdVal = fdVal;
 81         }
 82         int fdVal() {
 83             return fdVal;
 84         }
 85     }
 86     private final Deque<DeregisterRequest> deregisterQueue = new ArrayDeque<>();
 87 
 88     /**
 89      * Creates a poller to support reading (POLLIN) or writing (POLLOUT)
 90      * operations.
 91      */
 92     WSAPollPoller(boolean read) throws IOException {
 93         this.read = read;
 94 
 95         this.pollArrayAddress = WSAPoll.allocatePollArray(INITIAL_CAPACITY);
 96         this.pollArrayCapacity = INITIAL_CAPACITY;
 97 
 98         // wakeup support
 99         this.pipe = makePipe();
100         SourceChannelImpl source = (SourceChannelImpl) pipe.source();
101         SinkChannelImpl sink = (SinkChannelImpl) pipe.sink();
102         this.fd0 = source.getFD();
103         this.fd1 = sink.getFD();
104 
105         // element 0 in poll array is for wakeup.
106         putDescriptor(0, source.getFDVal());
107         putEvents(0, Net.POLLIN);
108         putRevents(0, (short) 0);
109         pollArraySize = 1;
110     }
111 
112     /**
113      * Register the file descriptor.
114      */
115     @Override
116     void implRegister(int fdVal) {
117         Integer fd = Integer.valueOf(fdVal);
118         synchronized (updateLock) {
119             registerQueue.add(fd);
120         }
121         wakeup();
122     }
123 
124     /**
125      * Deregister the file descriptor. This method waits until the poller thread
126      * has removed the file descriptor from the poll array.
127      */
128     @Override
129     void implDeregister(int fdVal) {
130         boolean interrupted = false;
131         var request = new DeregisterRequest(fdVal);
132         synchronized (request) {
133             synchronized (updateLock) {
134                 deregisterQueue.add(request);
135             }
136             wakeup();
137             try {
138                 request.wait();
139             } catch (InterruptedException e) {
140                 interrupted = true;
141             }
142         }
143         if (interrupted) {
144             Thread.currentThread().interrupt();
145         }
146     }
147 
148     @Override
149     int poll(int timeout) throws IOException {
150         // process any updates
151         synchronized (updateLock) {
152             processRegisterQueue();
153             processDeregisterQueue();
154         }
155 
156         // poll for wakeup and/or events
157         int numPolled = WSAPoll.poll(pollArrayAddress, pollArraySize, timeout);
158         boolean polledWakeup = (getRevents(0) != 0);
159         if (polledWakeup) {
160             numPolled--;
161         }
162         processEvents(numPolled);
163 
164         // clear wakeup
165         if (polledWakeup) {
166             clearWakeup();
167         }
168 
169         return numPolled;
170     }
171 
172     /**
173      * Process the queue of file descriptors to poll
174      */
175     private void processRegisterQueue() {
176         assert Thread.holdsLock(updateLock);
177         Integer fd;
178         while ((fd = registerQueue.pollFirst()) != null) {
179             short events = (read) ? Net.POLLIN : Net.POLLOUT;
180             int index = add(fd, events);
181             fdToIndex.put(fd, index);
182         }
183     }
184 
185     /**
186      * Process the queue of file descriptors to stop polling
187      */
188     private void processDeregisterQueue() {
189         assert Thread.holdsLock(updateLock);
190         DeregisterRequest request;
191         while ((request = deregisterQueue.pollFirst()) != null) {
192             Integer index = fdToIndex.remove(request.fdVal);
193             if (index != null) {
194                 remove(index);
195             }
196             synchronized (request) {
197                 request.notifyAll();
198             }
199         }
200     }
201 
202     /**
203      * Process the polled events, skipping the first (0) entry in the poll array
204      * as that is used by the wakeup mechanism.
205      *
206      * @param numPolled the number of polled sockets in the array (from index 1)
207      */
208     private void processEvents(int numPolled) {
209         int index = 1;
210         int remaining = numPolled;
211         while (index < pollArraySize && remaining > 0) {
212             short revents = getRevents(index);
213             if (revents != 0) {
214                 int fd = getDescriptor(index);
215                 assert fdToIndex.get(fd) == index;
216                 polled(fd);
217                 remove(index);
218                 fdToIndex.remove(fd);
219                 remaining--;
220             } else {
221                 index++;
222             }
223         }
224     }
225 
226     /**
227      * Wake up the poller thread
228      */
229     private void wakeup() {
230         synchronized (wakeupLock) {
231             if (!wakeupTriggered) {
232                 try {
233                     ND.write(fd1, TEMP_BUF, 1);
234                 } catch (IOException ioe) {
235                     throw new InternalError(ioe);
236                 }
237                 wakeupTriggered = true;
238             }
239         }
240     }
241 
242     /**
243      * Clear the wakeup event
244      */
245     private void clearWakeup() throws IOException {
246         synchronized (wakeupLock) {
247             ND.read(fd0, TEMP_BUF, 1);
248             putRevents(0, (short) 0);
249             wakeupTriggered = false;
250         }
251     }
252 
253     /**
254      * Add a pollfd entry to the poll array.
255      * 
256      * @return the index of the pollfd entry in the poll array
257      */
258     private int add(int fd, short events) {
259         expandIfNeeded();
260         int index = pollArraySize;
261         assert index > 0;
262         putDescriptor(index, fd);
263         putEvents(index, events);
264         putRevents(index, (short) 0);
265         pollArraySize++;
266         return index;
267     }
268 
269     /**
270      * Removes a pollfd entry from the poll array.
271      */
272     private void remove(int index) {
273         assert index > 0 && index < pollArraySize;
274 
275         // replace pollfd at index with the last pollfd in array
276         int lastIndex = pollArraySize - 1;
277         if (lastIndex != index) {
278             int lastFd = getDescriptor(lastIndex);
279             short lastEvents = getEvents(lastIndex);
280             short lastRevents = getRevents(lastIndex);
281             putDescriptor(index, lastFd);
282             putEvents(index, lastEvents);
283             putRevents(index, lastRevents);
284 
285             assert fdToIndex.get(lastFd) == lastIndex;
286             fdToIndex.put(lastFd, index);
287         }
288         pollArraySize--;
289     }
290 
291     /**
292      * Expand poll array if at capacity.
293      */
294     private void expandIfNeeded() {
295         if (pollArraySize == pollArrayCapacity) {
296             int newCapacity = pollArrayCapacity + INITIAL_CAPACITY;
297             pollArrayAddress = WSAPoll.reallocatePollArray(pollArrayAddress, pollArrayCapacity, newCapacity);
298             pollArrayCapacity = newCapacity;
299         }
300     }
301 
302     /**
303      * Returns a PipeImpl. The creation is done on the carrier thread to avoid
304      * recursive parking when the loopback connection is created.
305      */
306     private static PipeImpl makePipe() throws IOException {
307         try {
308             return JLA.executeOnCarrierThread(() -> new PipeImpl(null, false));
309         } catch (IOException ioe) {
310             throw ioe;
311         } catch (Throwable e) {
312             throw new InternalError(e);
313         }
314     }
315 
316     private void putDescriptor(int i, int fd) {
317         WSAPoll.putDescriptor(pollArrayAddress, i, fd);
318     }
319 
320     private int getDescriptor(int i) {
321         return WSAPoll.getDescriptor(pollArrayAddress, i);
322     }
323 
324     private void putEvents(int i, short events) {
325         WSAPoll.putEvents(pollArrayAddress, i, events);
326     }
327 
328     private short getEvents(int i) {
329         return WSAPoll.getEvents(pollArrayAddress, i);
330     }
331 
332     private void putRevents(int i, short revents) {
333         WSAPoll.putRevents(pollArrayAddress, i, revents);
334     }
335 
336     private short getRevents(int i) {
337         return WSAPoll.getRevents(pollArrayAddress, i);
338     }
339 }