1 /*
  2  * Copyright (c) 2023, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.
  8  *
  9  * This code is distributed in the hope that it will be useful, but WITHOUT
 10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 12  * version 2 for more details (a copy is included in the LICENSE file that
 13  * accompanied this code).
 14  *
 15  * You should have received a copy of the GNU General Public License version
 16  * 2 along with this work; if not, write to the Free Software Foundation,
 17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 18  *
 19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 20  * or visit www.oracle.com if you need additional information or have any
 21  * questions.
 22  */
 23 
 24 import jdk.test.lib.security.SecurityUtils;
 25 
 26 import javax.net.ssl.*;
 27 import java.io.IOException;
 28 import java.net.*;
 29 import java.nio.ByteBuffer;
 30 import java.nio.file.Path;
 31 import java.util.ArrayList;
 32 import java.util.List;
 33 import java.util.concurrent.atomic.AtomicInteger;
 34 
 35 /*
 36  * @test
 37  * @bug 8301381
 38  * @library /test/lib /javax/net/ssl/templates
 39  * @summary DTLSv10 is now disabled. This test verifies that the server will
 40  *     not negotiate a connection if the client asks for it.
 41  * @run main/othervm DTLSWontNegotiateV10 DTLS
 42  * @run main/othervm DTLSWontNegotiateV10 DTLSv1.0
 43  */
 44 public class DTLSWontNegotiateV10 {
 45 
 46     private static final int MTU = 1024;
 47     private static final String DTLSV_1_0 = "DTLSv1.0";
 48     private static final String DTLS = "DTLS";
 49     private static final String DTLSV_1_2 = "DTLSv1.2";
 50 
 51     public static void main(String[] args) throws Exception {
 52         if (args[0].equals(DTLSV_1_0)) {
 53             SecurityUtils.removeFromDisabledTlsAlgs(DTLSV_1_0);
 54         }
 55 
 56         if (args.length > 1) {
 57             // running in client child process
 58             // args: protocol server-port
 59             try (DTLSClient client = new DTLSClient(args[0], Integer.parseInt(args[1]))) {
 60                 client.run();
 61             }
 62 
 63         } else {
 64             // server process
 65             // args: protocol
 66             try (DTLSServer server = new DTLSServer(args[0])) {
 67                 List<String> command = List.of(
 68                         Path.of(System.getProperty("java.home"), "bin", "java").toString(),
 69                         "DTLSWontNegotiateV10",
 70                         // if server is "DTLS" then the client should be v1.0 and vice versa
 71                         args[0].equals(DTLS) ? DTLSV_1_0 : DTLS,
 72                         Integer.toString(server.getListeningPortNumber())
 73                 );
 74 
 75                 ProcessBuilder builder = new ProcessBuilder(command);
 76                 Process p = builder.inheritIO().start();
 77                 server.run();
 78                 p.destroy();
 79                 System.out.println("Success: DTLSv1.0 connection was not established.");
 80             }
 81         }
 82     }
 83 
 84     private static class DTLSClient extends DTLSEndpoint {
 85         private final int remotePort;
 86 
 87         private final DatagramSocket socket = new DatagramSocket();
 88 
 89         public DTLSClient(String protocol, int portNumber) throws Exception {
 90             super(true, protocol);
 91             remotePort = portNumber;
 92             log("Enabled protocols: " + String.join(" ", engine.getEnabledProtocols()));
 93         }
 94 
 95         @Override
 96         public void run() throws Exception {
 97             doHandshake(socket);
 98             log("Client done handshaking. Protocol: " + engine.getSession().getProtocol());
 99         }
100 
101         @Override
102         void setRemotePortNumber(int portNumber) {
103             // don't do anything; we're using the one we already know
104         }
105 
106         @Override
107         int getRemotePortNumber() {
108             return remotePort;
109         }
110 
111         @Override
112         public void close () {
113             socket.close();
114         }
115     }
116 
117     private abstract static class DTLSEndpoint extends SSLContextTemplate implements AutoCloseable {
118         protected final SSLEngine engine;
119         protected final SSLContext context;
120         private final String protocol;
121         protected final InetAddress LOCALHOST;
122 
123         private final String tag;
124 
125         public DTLSEndpoint(boolean useClientMode, String protocol) throws Exception {
126             this.protocol = protocol;
127             if (useClientMode) {
128                 tag = "client";
129                 context = createClientSSLContext();
130             } else {
131                 tag = "server";
132                 context = createServerSSLContext();
133             }
134             engine = context.createSSLEngine();
135             engine.setUseClientMode(useClientMode);
136             SSLParameters params = engine.getSSLParameters();
137             params.setMaximumPacketSize(MTU);
138             engine.setSSLParameters(params);
139             if (protocol.equals(DTLS)) {
140                 // make sure both versions are "enabled"; 1.0 should be
141                 // disabled by policy now and won't be negotiated.
142                 engine.setEnabledProtocols(new String[]{DTLSV_1_0, DTLSV_1_2});
143             } else {
144                 engine.setEnabledProtocols(new String[]{DTLSV_1_0});
145             }
146 
147             LOCALHOST = InetAddress.getByName("localhost");
148         }
149 
150         @Override
151         protected ContextParameters getServerContextParameters() {
152             return new ContextParameters(protocol, "PKIX", "NewSunX509");
153         }
154 
155         @Override
156         protected ContextParameters getClientContextParameters() {
157             return new ContextParameters(protocol, "PKIX", "NewSunX509");
158         }
159 
160 
161         abstract void setRemotePortNumber(int portNumber);
162 
163         abstract int getRemotePortNumber();
164 
165         abstract void run() throws Exception;
166 
167         private boolean runDelegatedTasks() {
168             log("Running delegated tasks.");
169             Runnable runnable;
170             while ((runnable = engine.getDelegatedTask()) != null) {
171                 runnable.run();
172             }
173 
174             SSLEngineResult.HandshakeStatus hs = engine.getHandshakeStatus();
175             if (hs == SSLEngineResult.HandshakeStatus.NEED_TASK) {
176                 throw new RuntimeException(
177                         "Handshake shouldn't need additional tasks");
178             }
179 
180             return true;
181         }
182 
183         protected void doHandshake(DatagramSocket socket) throws Exception {
184             boolean handshaking = true;
185             engine.beginHandshake();
186             while (handshaking) {
187                 log("Handshake status = " + engine.getHandshakeStatus());
188                 handshaking = switch (engine.getHandshakeStatus()) {
189                     case NEED_UNWRAP, NEED_UNWRAP_AGAIN -> readFromServer(socket);
190                     case NEED_WRAP -> sendHandshakePackets(socket);
191                     case NEED_TASK -> runDelegatedTasks();
192                     case NOT_HANDSHAKING, FINISHED -> false;
193                 };
194             }
195         }
196 
197         private boolean readFromServer(DatagramSocket socket) throws IOException {
198             log("Reading data from remote endpoint.");
199             ByteBuffer iNet, iApp;
200             if (engine.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.NEED_UNWRAP) {
201                 byte[] buffer = new byte[MTU];
202                 DatagramPacket packet = new DatagramPacket(buffer, buffer.length);
203                 socket.receive(packet);
204                 setRemotePortNumber(packet.getPort());
205                 iNet = ByteBuffer.wrap(buffer, 0, packet.getLength());
206                 iApp = ByteBuffer.allocate(MTU);
207             } else {
208                 iNet = ByteBuffer.allocate(0);
209                 iApp = ByteBuffer.allocate(MTU);
210             }
211 
212             SSLEngineResult engineResult;
213             do {
214                 engineResult = engine.unwrap(iNet, iApp);
215             } while (iNet.hasRemaining());
216 
217             return switch (engineResult.getStatus()) {
218                 case CLOSED -> false;
219                 case OK -> true;
220                 case BUFFER_OVERFLOW -> throw new RuntimeException("Buffer overflow: "
221                         + "incorrect server maximum fragment size");
222                 case BUFFER_UNDERFLOW -> throw new RuntimeException("Buffer underflow: "
223                         + "incorrect server maximum fragment size");
224             };
225         }
226 
227         private boolean sendHandshakePackets(DatagramSocket socket) throws Exception {
228             List<DatagramPacket> packets = generateHandshakePackets();
229             log("Sending handshake packets.");
230             packets.forEach((p) -> {
231                 try {
232                     socket.send(p);
233                 } catch (IOException e) {
234                     throw new RuntimeException(e);
235                 }
236             });
237 
238             return true;
239         }
240 
241         private List<DatagramPacket> generateHandshakePackets() throws SSLException {
242             log("Generating handshake packets.");
243             List<DatagramPacket> packets = new ArrayList<>();
244             ByteBuffer oNet = ByteBuffer.allocate(engine.getSession().getPacketBufferSize());
245             ByteBuffer oApp = ByteBuffer.allocate(0);
246 
247             while (engine.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.NEED_WRAP) {
248                 SSLEngineResult result = engine.wrap(oApp, oNet);
249                 oNet.flip();
250 
251                 switch (result.getStatus()) {
252                     case BUFFER_UNDERFLOW -> {
253                         if (engine.getHandshakeStatus() != SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING) {
254                             throw new RuntimeException("Buffer underflow: "
255                                     + "incorrect server maximum fragment size");
256                         }
257                     }
258                     case BUFFER_OVERFLOW -> throw new RuntimeException("Buffer overflow: "
259                             + "incorrect server maximum fragment size");
260                     case CLOSED -> throw new RuntimeException("SSLEngine has closed");
261                 }
262 
263                 if (oNet.hasRemaining()) {
264                     byte[] packetBuffer = new byte[oNet.remaining()];
265                     oNet.get(packetBuffer);
266                     packets.add(new DatagramPacket(packetBuffer, packetBuffer.length,
267                             LOCALHOST, getRemotePortNumber()));
268                 }
269 
270                 runDelegatedTasks();
271                 oNet.clear();
272             }
273 
274             log("Generated " + packets.size() + " packets.");
275             return packets;
276         }
277 
278         protected void log(String msg) {
279             System.out.println(tag + ": " + msg);
280         }
281     }
282 
283     private static class DTLSServer extends DTLSEndpoint implements AutoCloseable {
284 
285         private final AtomicInteger portNumber = new AtomicInteger(0);
286         private final DatagramSocket socket = new DatagramSocket(0);
287 
288         public DTLSServer(String protocol) throws Exception {
289             super(false, protocol);
290             log("Enabled protocols: " + String.join(" ", engine.getEnabledProtocols()));
291         }
292 
293         @Override
294         public void run() throws Exception {
295             doHandshake(socket);
296             if (!engine.getSession().getProtocol().equals("NONE")) {
297                 throw new RuntimeException("Negotiated protocol: "
298                         + engine.getSession().getProtocol() +
299                         ". No protocol should be negotated.");
300             }
301         }
302 
303         public int getListeningPortNumber() {
304             return socket.getLocalPort();
305         }
306 
307         void setRemotePortNumber(int portNumber) {
308             this.portNumber.compareAndSet(0, portNumber);
309         }
310 
311         int getRemotePortNumber() {
312             return portNumber.get();
313         }
314 
315         @Override
316         public void close() throws Exception {
317             socket.close();
318         }
319     }
320 }