1 /*
  2  * Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.
  8  *
  9  * This code is distributed in the hope that it will be useful, but WITHOUT
 10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 12  * version 2 for more details (a copy is included in the LICENSE file that
 13  * accompanied this code).
 14  *
 15  * You should have received a copy of the GNU General Public License version
 16  * 2 along with this work; if not, write to the Free Software Foundation,
 17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 18  *
 19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 20  * or visit www.oracle.com if you need additional information or have any
 21  * questions.
 22  */
 23 
 24 /*
 25  * @test
 26  * @bug 8245245
 27  * @summary Test for Websocket URI encoding during HandShake
 28  * @library /test/lib
 29  * @build jdk.test.lib.net.SimpleSSLContext
 30  * @modules java.net.http
 31  *          jdk.httpserver
 32  * @run testng/othervm -Djdk.internal.httpclient.debug=true HandshakeUrlEncodingTest
 33  */
 34 
 35 import com.sun.net.httpserver.HttpHandler;
 36 import com.sun.net.httpserver.HttpServer;
 37 import com.sun.net.httpserver.HttpsConfigurator;
 38 import com.sun.net.httpserver.HttpsServer;
 39 import com.sun.net.httpserver.HttpExchange;
 40 import jdk.test.lib.net.SimpleSSLContext;
 41 import jdk.test.lib.net.URIBuilder;
 42 import org.testng.annotations.AfterTest;
 43 import org.testng.annotations.BeforeTest;
 44 import org.testng.annotations.DataProvider;
 45 import org.testng.annotations.Test;
 46 
 47 import javax.net.ssl.SSLContext;
 48 import java.io.IOException;
 49 import java.io.InputStream;
 50 import java.io.OutputStream;
 51 import java.net.InetAddress;
 52 import java.net.InetSocketAddress;
 53 import java.net.URI;
 54 import java.net.http.HttpClient;
 55 import java.net.http.WebSocket;
 56 import java.net.http.WebSocketHandshakeException;
 57 import java.util.concurrent.CompletionException;
 58 import java.util.concurrent.ExecutionException;
 59 import java.util.concurrent.ExecutorService;
 60 import java.util.concurrent.Executors;
 61 
 62 import static java.net.http.HttpClient.Builder.NO_PROXY;
 63 import static org.testng.Assert.assertEquals;
 64 import static org.testng.Assert.assertNotNull;
 65 import static org.testng.Assert.fail;
 66 import static java.lang.System.out;
 67 
 68 public class HandshakeUrlEncodingTest {
 69 
 70     SSLContext sslContext;
 71     HttpServer httpTestServer;
 72     HttpsServer httpsTestServer;
 73     String httpURI;
 74     String httpsURI;
 75 
 76     static String queryPart;
 77 
 78     static final int ITERATION_COUNT = 10;
 79     // a shared executor helps reduce the amount of threads created by the test
 80     static final ExecutorService executor = Executors.newCachedThreadPool();
 81 
 82     @DataProvider(name = "variants")
 83     public Object[][] variants() {
 84         return new Object[][]{
 85             { httpURI,   false },
 86             { httpsURI,  false },
 87             { httpURI,   true  },
 88             { httpsURI,  true  }
 89         };
 90     }
 91 
 92     HttpClient newHttpClient() {
 93         return HttpClient.newBuilder()
 94                          .proxy(NO_PROXY)
 95                          .executor(executor)
 96                          .sslContext(sslContext)
 97                          .build();
 98     }
 99 
100     @Test(dataProvider = "variants")
101     public void test(String uri, boolean sameClient) {
102         HttpClient client = null;
103         out.println("The url is " + uri);
104         for (int i = 0; i < ITERATION_COUNT; i++) {
105             System.out.printf("iteration %s%n", i);
106             if (!sameClient || client == null)
107                 client = newHttpClient();
108 
109             try {
110                 client.newWebSocketBuilder()
111                     .buildAsync(URI.create(uri), new WebSocket.Listener() { })
112                     .join();
113                 fail("Expected to throw");
114             } catch (CompletionException ce) {
115                 final Throwable t = getCompletionCause(ce);
116                 if (!(t instanceof WebSocketHandshakeException)) {
117                     throw new AssertionError("Unexpected exception", t);
118                 }
119                 final WebSocketHandshakeException wse = (WebSocketHandshakeException) t;
120                 assertNotNull(wse.getResponse());
121                 assertNotNull(wse.getResponse().uri());
122                 assertNotNull(wse.getResponse().statusCode());
123                 final String rawQuery = wse.getResponse().uri().getRawQuery();
124                 final String expectedRawQuery = "&raw=abc+def/ghi=xyz&encoded=abc%2Bdef%2Fghi%3Dxyz";
125                 assertEquals(rawQuery, expectedRawQuery);
126                 final String body = (String) wse.getResponse().body();
127                 final String expectedBody = "/?" + expectedRawQuery;
128                 assertEquals(body, expectedBody);
129                 out.println("Status code is " + wse.getResponse().statusCode());
130                 out.println("Response is " + wse.getResponse());
131                 assertEquals(wse.getResponse().statusCode(), 400);
132             }
133         }
134     }
135 
136     @BeforeTest
137     public void setup() throws Exception {
138         sslContext = new SimpleSSLContext().get();
139         if (sslContext == null)
140             throw new AssertionError("Unexpected null sslContext");
141 
142 
143         InetSocketAddress sa = new InetSocketAddress(InetAddress.getLoopbackAddress(), 0);
144         queryPart = "?&raw=abc+def/ghi=xyz&encoded=abc%2Bdef%2Fghi%3Dxyz";
145         httpTestServer = HttpServer.create(sa, 10);
146         httpURI = URIBuilder.newBuilder()
147                             .scheme("ws")
148                             .host("localhost")
149                             .port(httpTestServer.getAddress().getPort())
150                             .path("/")
151                             .build()
152                             .toString() + queryPart;
153 
154         httpTestServer.createContext("/", new UrlHandler());
155 
156         httpsTestServer = HttpsServer.create(sa, 10);
157         httpsTestServer.setHttpsConfigurator(new HttpsConfigurator(sslContext));
158         httpsURI = URIBuilder.newBuilder()
159                              .scheme("wss")
160                              .host("localhost")
161                              .port(httpsTestServer.getAddress().getPort())
162                              .path("/")
163                              .build()
164                              .toString() + queryPart;
165 
166         httpsTestServer.createContext("/", new UrlHandler());
167 
168         httpTestServer.start();
169         httpsTestServer.start();
170     }
171 
172     @AfterTest
173     public void teardown() {
174         httpTestServer.stop(0);
175         httpsTestServer.stop(0);
176         executor.shutdownNow();
177     }
178 
179     private static Throwable getCompletionCause(Throwable x) {
180         if (!(x instanceof CompletionException)
181             && !(x instanceof ExecutionException)) return x;
182         final Throwable cause = x.getCause();
183         if (cause == null) {
184             throw new InternalError("Unexpected null cause", x);
185         }
186         return cause;
187     }
188 
189     static class UrlHandler implements HttpHandler {
190 
191         @Override
192         public void handle(HttpExchange e) throws IOException {
193             try(InputStream is = e.getRequestBody();
194                 OutputStream os = e.getResponseBody()) {
195                 String testUri = "/?&raw=abc+def/ghi=xyz&encoded=abc%2Bdef%2Fghi%3Dxyz";
196                 URI uri = e.getRequestURI();
197                 byte[] bytes = is.readAllBytes();
198                 if (uri.toString().equals(testUri)) {
199                     bytes = testUri.getBytes();
200                 }
201                 e.sendResponseHeaders(400, bytes.length);
202                 os.write(bytes);
203 
204             }
205         }
206     }
207 }
208