1 /*
  2  * Copyright (c) 2021, 2023, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.
  8  *
  9  * This code is distributed in the hope that it will be useful, but WITHOUT
 10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 12  * version 2 for more details (a copy is included in the LICENSE file that
 13  * accompanied this code).
 14  *
 15  * You should have received a copy of the GNU General Public License version
 16  * 2 along with this work; if not, write to the Free Software Foundation,
 17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 18  *
 19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 20  * or visit www.oracle.com if you need additional information or have any
 21  * questions.
 22  */
 23 
 24 /*
 25  * @test
 26  * @enablePreview
 27  * @library /test/lib
 28  * @modules java.base/sun.nio.ch
 29  * @key randomness
 30  * @run testng/othervm TestSocketChannels
 31  */
 32 
 33 import java.lang.foreign.Arena;
 34 import java.net.InetAddress;
 35 import java.net.InetSocketAddress;
 36 import java.nio.ByteBuffer;
 37 import java.nio.channels.ServerSocketChannel;
 38 import java.nio.channels.SocketChannel;
 39 import java.util.Arrays;
 40 import java.util.List;
 41 import java.util.concurrent.atomic.AtomicReference;
 42 import java.util.function.Supplier;
 43 import java.util.stream.Stream;
 44 
 45 import java.lang.foreign.MemorySegment;
 46 
 47 import org.testng.annotations.*;
 48 
 49 import static java.lang.foreign.ValueLayout.JAVA_BYTE;
 50 import static org.testng.Assert.*;
 51 
 52 /**
 53  * Tests consisting of buffer views with synchronous NIO network channels.
 54  */
 55 public class TestSocketChannels extends AbstractChannelsTest {
 56 
 57     static final Class<IllegalStateException> ISE = IllegalStateException.class;
 58     static final Class<WrongThreadException> WTE = WrongThreadException.class;
 59 
 60     @Test(dataProvider = "closeableArenas")
 61     public void testBasicIOWithClosedSegment(Supplier<Arena> arenaSupplier)
 62         throws Exception
 63     {
 64         try (var channel = SocketChannel.open();
 65              var server = ServerSocketChannel.open();
 66              var connectedChannel = connectChannels(server, channel)) {
 67             Arena drop = arenaSupplier.get();
 68             ByteBuffer bb = segmentBufferOfSize(drop, 16);
 69             drop.close();
 70             assertMessage(expectThrows(ISE, () -> channel.read(bb)),                           "Already closed");
 71             assertMessage(expectThrows(ISE, () -> channel.read(new ByteBuffer[] {bb})),        "Already closed");
 72             assertMessage(expectThrows(ISE, () -> channel.read(new ByteBuffer[] {bb}, 0, 1)),  "Already closed");
 73             assertMessage(expectThrows(ISE, () -> channel.write(bb)),                          "Already closed");
 74             assertMessage(expectThrows(ISE, () -> channel.write(new ByteBuffer[] {bb})),       "Already closed");
 75             assertMessage(expectThrows(ISE, () -> channel.write(new ByteBuffer[] {bb}, 0 ,1)), "Already closed");
 76         }
 77     }
 78 
 79     @Test(dataProvider = "closeableArenas")
 80     public void testScatterGatherWithClosedSegment(Supplier<Arena> arenaSupplier)
 81         throws Exception
 82     {
 83         try (var channel = SocketChannel.open();
 84              var server = ServerSocketChannel.open();
 85              var connectedChannel = connectChannels(server, channel)) {
 86             Arena drop = arenaSupplier.get();
 87             ByteBuffer[] buffers = segmentBuffersOfSize(8, drop, 16);
 88             drop.close();
 89             assertMessage(expectThrows(ISE, () -> channel.write(buffers)),       "Already closed");
 90             assertMessage(expectThrows(ISE, () -> channel.read(buffers)),        "Already closed");
 91             assertMessage(expectThrows(ISE, () -> channel.write(buffers, 0 ,8)), "Already closed");
 92             assertMessage(expectThrows(ISE, () -> channel.read(buffers, 0, 8)),  "Already closed");
 93         }
 94     }
 95 
 96     @Test(dataProvider = "closeableArenas")
 97     public void testBasicIO(Supplier<Arena> arenaSupplier)
 98         throws Exception
 99     {
100         Arena drop;
101         try (var sc1 = SocketChannel.open();
102              var ssc = ServerSocketChannel.open();
103              var sc2 = connectChannels(ssc, sc1);
104              var scp = drop = arenaSupplier.get()) {
105             Arena scope1 = drop;
106             MemorySegment segment1 = scope1.allocate(10, 1);
107             Arena scope = drop;
108             MemorySegment segment2 = scope.allocate(10, 1);
109             for (int i = 0; i < 10; i++) {
110                 segment1.set(JAVA_BYTE, i, (byte) i);
111             }
112             ByteBuffer bb1 = segment1.asByteBuffer();
113             ByteBuffer bb2 = segment2.asByteBuffer();
114             assertEquals(sc1.write(bb1), 10);
115             assertEquals(sc2.read(bb2), 10);
116             assertEquals(bb2.flip(), ByteBuffer.wrap(new byte[]{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}));
117         }
118     }
119 
120     @Test
121     public void testBasicHeapIOWithGlobalSession() throws Exception {
122         try (var sc1 = SocketChannel.open();
123              var ssc = ServerSocketChannel.open();
124              var sc2 = connectChannels(ssc, sc1)) {
125             var segment1 = MemorySegment.ofArray(new byte[10]);
126             var segment2 = MemorySegment.ofArray(new byte[10]);
127             for (int i = 0; i < 10; i++) {
128                 segment1.set(JAVA_BYTE, i, (byte) i);
129             }
130             ByteBuffer bb1 = segment1.asByteBuffer();
131             ByteBuffer bb2 = segment2.asByteBuffer();
132             assertEquals(sc1.write(bb1), 10);
133             assertEquals(sc2.read(bb2), 10);
134             assertEquals(bb2.flip(), ByteBuffer.wrap(new byte[]{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}));
135         }
136     }
137 
138     @Test(dataProvider = "confinedArenas")
139     public void testIOOnConfinedFromAnotherThread(Supplier<Arena> arenaSupplier)
140         throws Exception
141     {
142         try (var channel = SocketChannel.open();
143              var server = ServerSocketChannel.open();
144              var connected = connectChannels(server, channel);
145              var drop = arenaSupplier.get()) {
146             Arena scope = drop;
147             var segment = scope.allocate(10, 1);
148             ByteBuffer bb = segment.asByteBuffer();
149             List<ThrowingRunnable> ioOps = List.of(
150                     () -> channel.write(bb),
151                     () -> channel.read(bb),
152                     () -> channel.write(new ByteBuffer[] {bb}),
153                     () -> channel.read(new ByteBuffer[] {bb}),
154                     () -> channel.write(new ByteBuffer[] {bb}, 0, 1),
155                     () -> channel.read(new ByteBuffer[] {bb}, 0, 1)
156             );
157             for (var ioOp : ioOps) {
158                 AtomicReference<Exception> exception = new AtomicReference<>();
159                 Runnable task = () -> exception.set(expectThrows(WTE, ioOp));
160                 var t = new Thread(task);
161                 t.start();
162                 t.join();
163                 assertMessage(exception.get(), "Attempted access outside owning thread");
164             }
165         }
166     }
167 
168     @Test(dataProvider = "closeableArenas")
169     public void testScatterGatherIO(Supplier<Arena> arenaSupplier)
170         throws Exception
171     {
172         Arena drop;
173         try (var sc1 = SocketChannel.open();
174              var ssc = ServerSocketChannel.open();
175              var sc2 = connectChannels(ssc, sc1);
176              var scp = drop = arenaSupplier.get()) {
177             var writeBuffers = mixedBuffersOfSize(32, drop, 64);
178             var readBuffers = mixedBuffersOfSize(32, drop, 64);
179             long expectedCount = remaining(writeBuffers);
180             assertEquals(writeNBytes(sc1, writeBuffers, 0, 32, expectedCount), expectedCount);
181             assertEquals(readNBytes(sc2, readBuffers, 0, 32, expectedCount), expectedCount);
182             assertEquals(flip(readBuffers), clear(writeBuffers));
183         }
184     }
185 
186     @Test(dataProvider = "closeableArenas")
187     public void testBasicIOWithDifferentSessions(Supplier<Arena> arenaSupplier)
188          throws Exception
189     {
190         try (var sc1 = SocketChannel.open();
191              var ssc = ServerSocketChannel.open();
192              var sc2 = connectChannels(ssc, sc1);
193              var drop1 = arenaSupplier.get();
194              var drop2 = arenaSupplier.get()) {
195             var writeBuffers = Stream.of(mixedBuffersOfSize(16, drop1, 64), mixedBuffersOfSize(16, drop2, 64))
196                                      .flatMap(Arrays::stream)
197                                      .toArray(ByteBuffer[]::new);
198             var readBuffers = Stream.of(mixedBuffersOfSize(16, drop1, 64), mixedBuffersOfSize(16, drop2, 64))
199                                     .flatMap(Arrays::stream)
200                                     .toArray(ByteBuffer[]::new);
201 
202             long expectedCount = remaining(writeBuffers);
203             assertEquals(writeNBytes(sc1, writeBuffers, 0, 32, expectedCount), expectedCount);
204             assertEquals(readNBytes(sc2, readBuffers, 0, 32, expectedCount), expectedCount);
205             assertEquals(flip(readBuffers), clear(writeBuffers));
206         }
207     }
208 
209     static SocketChannel connectChannels(ServerSocketChannel ssc, SocketChannel sc)
210         throws Exception
211     {
212         ssc.bind(new InetSocketAddress(InetAddress.getLoopbackAddress(), 0));
213         sc.connect(ssc.getLocalAddress());
214         return ssc.accept();
215     }
216 
217     static long writeNBytes(SocketChannel channel,
218                             ByteBuffer[] buffers, int offset, int len,
219                             long bytes)
220         throws Exception
221     {
222         long total = 0L;
223         do {
224             long n = channel.write(buffers, offset, len);
225             assertTrue(n > 0, "got:" + n);
226             total += n;
227         } while (total < bytes);
228         return total;
229     }
230 
231     static long readNBytes(SocketChannel channel,
232                            ByteBuffer[] buffers, int offset, int len,
233                            long bytes)
234         throws Exception
235     {
236         long total = 0L;
237         do {
238             long n = channel.read(buffers, offset, len);
239             assertTrue(n > 0, "got:" + n);
240             total += n;
241         } while (total < bytes);
242         return total;
243     }
244 }
245