1 /*
  2  * Copyright (c) 2021, 2023, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.
  8  *
  9  * This code is distributed in the hope that it will be useful, but WITHOUT
 10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 12  * version 2 for more details (a copy is included in the LICENSE file that
 13  * accompanied this code).
 14  *
 15  * You should have received a copy of the GNU General Public License version
 16  * 2 along with this work; if not, write to the Free Software Foundation,
 17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 18  *
 19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 20  * or visit www.oracle.com if you need additional information or have any
 21  * questions.
 22  */
 23 
 24 /*
 25  * @test
 26  * @library /test/lib
 27  * @modules java.base/sun.nio.ch
 28  * @key randomness
 29  * @run testng/othervm TestSocketChannels
 30  */
 31 
 32 import java.lang.foreign.Arena;
 33 import java.net.InetAddress;
 34 import java.net.InetSocketAddress;
 35 import java.nio.ByteBuffer;
 36 import java.nio.channels.ServerSocketChannel;
 37 import java.nio.channels.SocketChannel;
 38 import java.util.Arrays;
 39 import java.util.List;
 40 import java.util.concurrent.atomic.AtomicReference;
 41 import java.util.function.Supplier;
 42 import java.util.stream.Stream;
 43 
 44 import java.lang.foreign.MemorySegment;
 45 
 46 import org.testng.annotations.*;
 47 
 48 import static java.lang.foreign.ValueLayout.JAVA_BYTE;
 49 import static org.testng.Assert.*;
 50 
 51 /**
 52  * Tests consisting of buffer views with synchronous NIO network channels.
 53  */
 54 public class TestSocketChannels extends AbstractChannelsTest {
 55 
 56     static final Class<IllegalStateException> ISE = IllegalStateException.class;
 57     static final Class<WrongThreadException> WTE = WrongThreadException.class;
 58 
 59     @Test(dataProvider = "closeableArenas")
 60     public void testBasicIOWithClosedSegment(Supplier<Arena> arenaSupplier)
 61         throws Exception
 62     {
 63         try (var channel = SocketChannel.open();
 64              var server = ServerSocketChannel.open();
 65              var connectedChannel = connectChannels(server, channel)) {
 66             Arena drop = arenaSupplier.get();
 67             ByteBuffer bb = segmentBufferOfSize(drop, 16);
 68             drop.close();
 69             assertMessage(expectThrows(ISE, () -> channel.read(bb)),                           "Already closed");
 70             assertMessage(expectThrows(ISE, () -> channel.read(new ByteBuffer[] {bb})),        "Already closed");
 71             assertMessage(expectThrows(ISE, () -> channel.read(new ByteBuffer[] {bb}, 0, 1)),  "Already closed");
 72             assertMessage(expectThrows(ISE, () -> channel.write(bb)),                          "Already closed");
 73             assertMessage(expectThrows(ISE, () -> channel.write(new ByteBuffer[] {bb})),       "Already closed");
 74             assertMessage(expectThrows(ISE, () -> channel.write(new ByteBuffer[] {bb}, 0 ,1)), "Already closed");
 75         }
 76     }
 77 
 78     @Test(dataProvider = "closeableArenas")
 79     public void testScatterGatherWithClosedSegment(Supplier<Arena> arenaSupplier)
 80         throws Exception
 81     {
 82         try (var channel = SocketChannel.open();
 83              var server = ServerSocketChannel.open();
 84              var connectedChannel = connectChannels(server, channel)) {
 85             Arena drop = arenaSupplier.get();
 86             ByteBuffer[] buffers = segmentBuffersOfSize(8, drop, 16);
 87             drop.close();
 88             assertMessage(expectThrows(ISE, () -> channel.write(buffers)),       "Already closed");
 89             assertMessage(expectThrows(ISE, () -> channel.read(buffers)),        "Already closed");
 90             assertMessage(expectThrows(ISE, () -> channel.write(buffers, 0 ,8)), "Already closed");
 91             assertMessage(expectThrows(ISE, () -> channel.read(buffers, 0, 8)),  "Already closed");
 92         }
 93     }
 94 
 95     @Test(dataProvider = "closeableArenas")
 96     public void testBasicIO(Supplier<Arena> arenaSupplier)
 97         throws Exception
 98     {
 99         Arena drop;
100         try (var sc1 = SocketChannel.open();
101              var ssc = ServerSocketChannel.open();
102              var sc2 = connectChannels(ssc, sc1);
103              var scp = drop = arenaSupplier.get()) {
104             Arena scope1 = drop;
105             MemorySegment segment1 = scope1.allocate(10, 1);
106             Arena scope = drop;
107             MemorySegment segment2 = scope.allocate(10, 1);
108             for (int i = 0; i < 10; i++) {
109                 segment1.set(JAVA_BYTE, i, (byte) i);
110             }
111             ByteBuffer bb1 = segment1.asByteBuffer();
112             ByteBuffer bb2 = segment2.asByteBuffer();
113             assertEquals(sc1.write(bb1), 10);
114             assertEquals(sc2.read(bb2), 10);
115             assertEquals(bb2.flip(), ByteBuffer.wrap(new byte[]{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}));
116         }
117     }
118 
119     @Test
120     public void testBasicHeapIOWithGlobalSession() throws Exception {
121         try (var sc1 = SocketChannel.open();
122              var ssc = ServerSocketChannel.open();
123              var sc2 = connectChannels(ssc, sc1)) {
124             var segment1 = MemorySegment.ofArray(new byte[10]);
125             var segment2 = MemorySegment.ofArray(new byte[10]);
126             for (int i = 0; i < 10; i++) {
127                 segment1.set(JAVA_BYTE, i, (byte) i);
128             }
129             ByteBuffer bb1 = segment1.asByteBuffer();
130             ByteBuffer bb2 = segment2.asByteBuffer();
131             assertEquals(sc1.write(bb1), 10);
132             assertEquals(sc2.read(bb2), 10);
133             assertEquals(bb2.flip(), ByteBuffer.wrap(new byte[]{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}));
134         }
135     }
136 
137     @Test(dataProvider = "confinedArenas")
138     public void testIOOnConfinedFromAnotherThread(Supplier<Arena> arenaSupplier)
139         throws Exception
140     {
141         try (var channel = SocketChannel.open();
142              var server = ServerSocketChannel.open();
143              var connected = connectChannels(server, channel);
144              var drop = arenaSupplier.get()) {
145             Arena scope = drop;
146             var segment = scope.allocate(10, 1);
147             ByteBuffer bb = segment.asByteBuffer();
148             List<ThrowingRunnable> ioOps = List.of(
149                     () -> channel.write(bb),
150                     () -> channel.read(bb),
151                     () -> channel.write(new ByteBuffer[] {bb}),
152                     () -> channel.read(new ByteBuffer[] {bb}),
153                     () -> channel.write(new ByteBuffer[] {bb}, 0, 1),
154                     () -> channel.read(new ByteBuffer[] {bb}, 0, 1)
155             );
156             for (var ioOp : ioOps) {
157                 AtomicReference<Exception> exception = new AtomicReference<>();
158                 Runnable task = () -> exception.set(expectThrows(WTE, ioOp));
159                 var t = new Thread(task);
160                 t.start();
161                 t.join();
162                 assertMessage(exception.get(), "Attempted access outside owning thread");
163             }
164         }
165     }
166 
167     @Test(dataProvider = "closeableArenas")
168     public void testScatterGatherIO(Supplier<Arena> arenaSupplier)
169         throws Exception
170     {
171         Arena drop;
172         try (var sc1 = SocketChannel.open();
173              var ssc = ServerSocketChannel.open();
174              var sc2 = connectChannels(ssc, sc1);
175              var scp = drop = arenaSupplier.get()) {
176             var writeBuffers = mixedBuffersOfSize(32, drop, 64);
177             var readBuffers = mixedBuffersOfSize(32, drop, 64);
178             long expectedCount = remaining(writeBuffers);
179             assertEquals(writeNBytes(sc1, writeBuffers, 0, 32, expectedCount), expectedCount);
180             assertEquals(readNBytes(sc2, readBuffers, 0, 32, expectedCount), expectedCount);
181             assertEquals(flip(readBuffers), clear(writeBuffers));
182         }
183     }
184 
185     @Test(dataProvider = "closeableArenas")
186     public void testBasicIOWithDifferentSessions(Supplier<Arena> arenaSupplier)
187          throws Exception
188     {
189         try (var sc1 = SocketChannel.open();
190              var ssc = ServerSocketChannel.open();
191              var sc2 = connectChannels(ssc, sc1);
192              var drop1 = arenaSupplier.get();
193              var drop2 = arenaSupplier.get()) {
194             var writeBuffers = Stream.of(mixedBuffersOfSize(16, drop1, 64), mixedBuffersOfSize(16, drop2, 64))
195                                      .flatMap(Arrays::stream)
196                                      .toArray(ByteBuffer[]::new);
197             var readBuffers = Stream.of(mixedBuffersOfSize(16, drop1, 64), mixedBuffersOfSize(16, drop2, 64))
198                                     .flatMap(Arrays::stream)
199                                     .toArray(ByteBuffer[]::new);
200 
201             long expectedCount = remaining(writeBuffers);
202             assertEquals(writeNBytes(sc1, writeBuffers, 0, 32, expectedCount), expectedCount);
203             assertEquals(readNBytes(sc2, readBuffers, 0, 32, expectedCount), expectedCount);
204             assertEquals(flip(readBuffers), clear(writeBuffers));
205         }
206     }
207 
208     static SocketChannel connectChannels(ServerSocketChannel ssc, SocketChannel sc)
209         throws Exception
210     {
211         ssc.bind(new InetSocketAddress(InetAddress.getLoopbackAddress(), 0));
212         sc.connect(ssc.getLocalAddress());
213         return ssc.accept();
214     }
215 
216     static long writeNBytes(SocketChannel channel,
217                             ByteBuffer[] buffers, int offset, int len,
218                             long bytes)
219         throws Exception
220     {
221         long total = 0L;
222         do {
223             long n = channel.write(buffers, offset, len);
224             assertTrue(n > 0, "got:" + n);
225             total += n;
226         } while (total < bytes);
227         return total;
228     }
229 
230     static long readNBytes(SocketChannel channel,
231                            ByteBuffer[] buffers, int offset, int len,
232                            long bytes)
233         throws Exception
234     {
235         long total = 0L;
236         do {
237             long n = channel.read(buffers, offset, len);
238             assertTrue(n > 0, "got:" + n);
239             total += n;
240         } while (total < bytes);
241         return total;
242     }
243 }
244