1 /*
  2  * Copyright (c) 2023, 2024, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.
  8  *
  9  * This code is distributed in the hope that it will be useful, but WITHOUT
 10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 12  * version 2 for more details (a copy is included in the LICENSE file that
 13  * accompanied this code).
 14  *
 15  * You should have received a copy of the GNU General Public License version
 16  * 2 along with this work; if not, write to the Free Software Foundation,
 17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 18  *
 19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 20  * or visit www.oracle.com if you need additional information or have any
 21  * questions.
 22  */
 23 
 24 import jdk.test.lib.process.ProcessTools;
 25 import jdk.test.lib.security.SecurityUtils;
 26 
 27 import javax.net.ssl.*;
 28 import java.io.IOException;
 29 import java.net.*;
 30 import java.nio.ByteBuffer;
 31 import java.util.ArrayList;
 32 import java.util.List;
 33 import java.util.concurrent.atomic.AtomicInteger;
 34 
 35 /*
 36  * @test
 37  * @bug 8301381
 38  * @library /test/lib /javax/net/ssl/templates
 39  * @summary DTLSv10 is now disabled. This test verifies that the server will
 40  *     not negotiate a connection if the client asks for it.
 41  * @run main/othervm DTLSWontNegotiateV10 DTLS
 42  * @run main/othervm DTLSWontNegotiateV10 DTLSv1.0
 43  */
 44 public class DTLSWontNegotiateV10 {
 45 
 46     private static final int MTU = 1024;
 47     private static final String DTLSV_1_0 = "DTLSv1.0";
 48     private static final String DTLS = "DTLS";
 49     private static final String DTLSV_1_2 = "DTLSv1.2";
 50 
 51     private static final int READ_TIMEOUT_SECS = Integer.getInteger("readtimeout", 30);
 52 
 53     public static void main(String[] args) throws Exception {
 54         if (args[0].equals(DTLSV_1_0)) {
 55             SecurityUtils.removeFromDisabledTlsAlgs(DTLSV_1_0);
 56         }
 57 
 58         if (args.length > 1) {
 59             // running in client child process
 60             // args: protocol server-port
 61             try (DTLSClient client = new DTLSClient(args[0], Integer.parseInt(args[1]))) {
 62                 client.run();
 63             }
 64 
 65         } else {
 66             // server process
 67             // args: protocol
 68             final int totalAttempts = 5;
 69             int tries;
 70             for (tries = 0 ; tries < totalAttempts ; ++tries) {
 71                 try {
 72                     System.out.printf("Starting server %d/%d attempts%n", tries+1, totalAttempts);
 73                     runServer(args[0]);
 74                     break;
 75                 } catch (SocketTimeoutException exc) {
 76                     System.out.println("The server timed-out waiting for packets from the client.");
 77                 }
 78             }
 79             if (tries == totalAttempts) {
 80                 throw new RuntimeException("The server/client communications timed-out after " + totalAttempts + " tries.");
 81             }
 82         }
 83     }
 84 
 85     private static void runServer(String protocol) throws Exception {
 86         // args: protocol
 87         Process clientProcess = null;
 88         try (DTLSServer server = new DTLSServer(protocol)) {
 89             List<String> command = List.of(
 90                     "DTLSWontNegotiateV10",
 91                     // if server is "DTLS" then the client should be v1.0 and vice versa
 92                     protocol.equals(DTLS) ? DTLSV_1_0 : DTLS,
 93                     Integer.toString(server.getListeningPortNumber())
 94             );
 95 
 96             ProcessBuilder builder = ProcessTools.createTestJavaProcessBuilder(command);
 97             clientProcess = builder.inheritIO().start();
 98             server.run();
 99             System.out.println("Success: DTLSv1.0 connection was not established.");
100 
101         } finally {
102             if (clientProcess != null) {
103                 clientProcess.destroy();
104             }
105         }
106     }
107 
108     private static class DTLSClient extends DTLSEndpoint {
109         private final int remotePort;
110 
111         private final DatagramSocket socket = new DatagramSocket();
112 
113         public DTLSClient(String protocol, int portNumber) throws Exception {
114             super(true, protocol);
115             remotePort = portNumber;
116             socket.setSoTimeout(READ_TIMEOUT_SECS * 1000);
117             log("Client listening on port " + socket.getLocalPort()
118                     + ". Sending data to server port " + remotePort);
119             log("Enabled protocols: " + String.join(" ", engine.getEnabledProtocols()));
120         }
121 
122         @Override
123         public void run() throws Exception {
124             doHandshake(socket);
125             log("Client done handshaking. Protocol: " + engine.getSession().getProtocol());
126         }
127 
128         @Override
129         void setRemotePortNumber(int portNumber) {
130             // don't do anything; we're using the one we already know
131         }
132 
133         @Override
134         int getRemotePortNumber() {
135             return remotePort;
136         }
137 
138         @Override
139         public void close () {
140             socket.close();
141         }
142     }
143 
144     private abstract static class DTLSEndpoint extends SSLContextTemplate implements AutoCloseable {
145         protected final SSLEngine engine;
146         protected final SSLContext context;
147         private final String protocol;
148         protected final InetAddress LOCALHOST;
149 
150         private final String tag;
151 
152         public DTLSEndpoint(boolean useClientMode, String protocol) throws Exception {
153             this.protocol = protocol;
154             if (useClientMode) {
155                 tag = "client";
156                 context = createClientSSLContext();
157             } else {
158                 tag = "server";
159                 context = createServerSSLContext();
160             }
161             engine = context.createSSLEngine();
162             engine.setUseClientMode(useClientMode);
163             SSLParameters params = engine.getSSLParameters();
164             params.setMaximumPacketSize(MTU);
165             engine.setSSLParameters(params);
166             if (protocol.equals(DTLS)) {
167                 // make sure both versions are "enabled"; 1.0 should be
168                 // disabled by policy now and won't be negotiated.
169                 engine.setEnabledProtocols(new String[]{DTLSV_1_0, DTLSV_1_2});
170             } else {
171                 engine.setEnabledProtocols(new String[]{DTLSV_1_0});
172             }
173 
174             LOCALHOST = InetAddress.getByName("localhost");
175         }
176 
177         @Override
178         protected ContextParameters getServerContextParameters() {
179             return new ContextParameters(protocol, "PKIX", "NewSunX509");
180         }
181 
182         @Override
183         protected ContextParameters getClientContextParameters() {
184             return new ContextParameters(protocol, "PKIX", "NewSunX509");
185         }
186 
187 
188         abstract void setRemotePortNumber(int portNumber);
189 
190         abstract int getRemotePortNumber();
191 
192         abstract void run() throws Exception;
193 
194         private boolean runDelegatedTasks() {
195             log("Running delegated tasks.");
196             Runnable runnable;
197             while ((runnable = engine.getDelegatedTask()) != null) {
198                 runnable.run();
199             }
200 
201             SSLEngineResult.HandshakeStatus hs = engine.getHandshakeStatus();
202             if (hs == SSLEngineResult.HandshakeStatus.NEED_TASK) {
203                 throw new RuntimeException(
204                         "Handshake shouldn't need additional tasks");
205             }
206 
207             return true;
208         }
209 
210         protected void doHandshake(DatagramSocket socket) throws Exception {
211             boolean handshaking = true;
212             engine.beginHandshake();
213             while (handshaking) {
214                 log("Handshake status = " + engine.getHandshakeStatus());
215                 handshaking = switch (engine.getHandshakeStatus()) {
216                     case NEED_UNWRAP, NEED_UNWRAP_AGAIN -> readFromServer(socket);
217                     case NEED_WRAP -> sendHandshakePackets(socket);
218                     case NEED_TASK -> runDelegatedTasks();
219                     case NOT_HANDSHAKING, FINISHED -> false;
220                 };
221             }
222         }
223 
224         private boolean readFromServer(DatagramSocket socket) throws IOException {
225             log("Reading data from remote endpoint.");
226             ByteBuffer iNet, iApp;
227             if (engine.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.NEED_UNWRAP) {
228                 byte[] buffer = new byte[MTU];
229                 DatagramPacket packet = new DatagramPacket(buffer, buffer.length);
230                 socket.receive(packet);
231                 setRemotePortNumber(packet.getPort());
232                 iNet = ByteBuffer.wrap(buffer, 0, packet.getLength());
233                 iApp = ByteBuffer.allocate(MTU);
234             } else {
235                 iNet = ByteBuffer.allocate(0);
236                 iApp = ByteBuffer.allocate(MTU);
237             }
238 
239             SSLEngineResult engineResult;
240             do {
241                 engineResult = engine.unwrap(iNet, iApp);
242             } while (iNet.hasRemaining());
243 
244             return switch (engineResult.getStatus()) {
245                 case CLOSED -> false;
246                 case OK -> true;
247                 case BUFFER_OVERFLOW -> throw new RuntimeException("Buffer overflow: "
248                         + "incorrect server maximum fragment size");
249                 case BUFFER_UNDERFLOW -> throw new RuntimeException("Buffer underflow: "
250                         + "incorrect server maximum fragment size");
251             };
252         }
253 
254         private boolean sendHandshakePackets(DatagramSocket socket) throws Exception {
255             List<DatagramPacket> packets = generateHandshakePackets();
256             log("Sending handshake packets.");
257             packets.forEach((p) -> {
258                 try {
259                     socket.send(p);
260                 } catch (IOException e) {
261                     throw new RuntimeException(e);
262                 }
263             });
264 
265             return true;
266         }
267 
268         private List<DatagramPacket> generateHandshakePackets() throws SSLException {
269             log("Generating handshake packets.");
270             List<DatagramPacket> packets = new ArrayList<>();
271             ByteBuffer oNet = ByteBuffer.allocate(engine.getSession().getPacketBufferSize());
272             ByteBuffer oApp = ByteBuffer.allocate(0);
273 
274             while (engine.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.NEED_WRAP) {
275                 SSLEngineResult result = engine.wrap(oApp, oNet);
276                 oNet.flip();
277 
278                 switch (result.getStatus()) {
279                     case BUFFER_UNDERFLOW -> {
280                         if (engine.getHandshakeStatus() != SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING) {
281                             throw new RuntimeException("Buffer underflow: "
282                                     + "incorrect server maximum fragment size");
283                         }
284                     }
285                     case BUFFER_OVERFLOW -> throw new RuntimeException("Buffer overflow: "
286                             + "incorrect server maximum fragment size");
287                     case CLOSED -> throw new RuntimeException("SSLEngine has closed");
288                 }
289 
290                 if (oNet.hasRemaining()) {
291                     byte[] packetBuffer = new byte[oNet.remaining()];
292                     oNet.get(packetBuffer);
293                     packets.add(new DatagramPacket(packetBuffer, packetBuffer.length,
294                             LOCALHOST, getRemotePortNumber()));
295                 }
296 
297                 runDelegatedTasks();
298                 oNet.clear();
299             }
300 
301             log("Generated " + packets.size() + " packets.");
302             return packets;
303         }
304 
305         protected void log(String msg) {
306             System.out.println(tag + ": " + msg);
307         }
308     }
309 
310     private static class DTLSServer extends DTLSEndpoint implements AutoCloseable {
311 
312         private final AtomicInteger portNumber = new AtomicInteger(0);
313         private final DatagramSocket socket = new DatagramSocket(0);
314 
315         public DTLSServer(String protocol) throws Exception {
316             super(false, protocol);
317             socket.setSoTimeout(READ_TIMEOUT_SECS * 1000);
318             log("Server listening on port: " + socket.getLocalPort());
319             log("Enabled protocols: " + String.join(" ", engine.getEnabledProtocols()));
320         }
321 
322         @Override
323         public void run() throws Exception {
324             doHandshake(socket);
325             if (!engine.getSession().getProtocol().equals("NONE")) {
326                 throw new RuntimeException("Negotiated protocol: "
327                         + engine.getSession().getProtocol() +
328                         ". No protocol should be negotated.");
329             }
330         }
331 
332         public int getListeningPortNumber() {
333             return socket.getLocalPort();
334         }
335 
336         void setRemotePortNumber(int portNumber) {
337             this.portNumber.compareAndSet(0, portNumber);
338         }
339 
340         int getRemotePortNumber() {
341             return portNumber.get();
342         }
343 
344         @Override
345         public void close() throws Exception {
346             socket.close();
347         }
348     }
349 }