From efef09545f7960f27cd0df5f3b190afae73e51f1 Mon Sep 17 00:00:00 2001 From: Santiago Pericasgeertsen Date: Fri, 17 Nov 2023 11:45:00 -0500 Subject: [PATCH] Fixes in this commit: 1. Support for writing WS frames of more than 125 bytes, i.e. with lengths that fit in two bytes and eight bytes. 2. Configuration of max-frame-length to limit the size of a frame read by the server. The JDK WS client splits a long payload into 16K frames. 3. Some new tests for (1) and (2). --- .../junit5/websocket/DirectWsConnection.java | 3 +- .../tests/websocket/WebSocketTest.java | 89 ++++++++++++++++--- .../src/test/resources/application.yaml | 20 +++++ .../websocket/WsConfigBlueprint.java | 9 ++ .../webserver/websocket/WsConnection.java | 30 +++++-- .../websocket/WsUpgradeProvider.java | 2 +- .../webserver/websocket/WsUpgrader.java | 10 +-- .../io/helidon/websocket/AbstractWsFrame.java | 2 +- 8 files changed, 137 insertions(+), 28 deletions(-) create mode 100644 webserver/tests/websocket/src/test/resources/application.yaml diff --git a/webserver/testing/junit5/websocket/src/main/java/io/helidon/webserver/testing/junit5/websocket/DirectWsConnection.java b/webserver/testing/junit5/websocket/src/main/java/io/helidon/webserver/testing/junit5/websocket/DirectWsConnection.java index bc20d39d81f..198eb3f2916 100644 --- a/webserver/testing/junit5/websocket/src/main/java/io/helidon/webserver/testing/junit5/websocket/DirectWsConnection.java +++ b/webserver/testing/junit5/websocket/src/main/java/io/helidon/webserver/testing/junit5/websocket/DirectWsConnection.java @@ -106,7 +106,8 @@ private static DataReader reader(ArrayBlockingQueue queue) { void start() { if (serverStarted.compareAndSet(false, true)) { - WsConnection serverConnection = WsConnection.create(ctx, prologue, WritableHeaders.create(), "", serverRoute); + WsConnection serverConnection = WsConnection.create(ctx, prologue, WritableHeaders.create(), "", + serverRoute, null); ClientWsConnection clientConnection = ClientWsConnection.create(new DirectConnect(clientReader, clientWriter), clientListener); diff --git a/webserver/tests/websocket/src/test/java/io/helidon/webserver/tests/websocket/WebSocketTest.java b/webserver/tests/websocket/src/test/java/io/helidon/webserver/tests/websocket/WebSocketTest.java index bd6d964ce6a..8836058ffec 100644 --- a/webserver/tests/websocket/src/test/java/io/helidon/webserver/tests/websocket/WebSocketTest.java +++ b/webserver/tests/websocket/src/test/java/io/helidon/webserver/tests/websocket/WebSocketTest.java @@ -21,6 +21,7 @@ import java.time.Duration; import java.util.LinkedList; import java.util.List; +import java.util.Random; import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionStage; import java.util.concurrent.ExecutionException; @@ -50,6 +51,8 @@ class WebSocketTest { private final int port; private final HttpClient client; + private volatile boolean isNormalClose = true; + WebSocketTest(WebServer server) { port = server.port(); client = HttpClient.newBuilder() @@ -60,7 +63,8 @@ class WebSocketTest { @SetUpRoute static void router(Router.RouterBuilder router) { service = new EchoService(); - router.addRouting(WsRouting.builder().endpoint("/echo", service)); + router.addRouting(WsRouting.builder() + .endpoint("/echo", service)); } @BeforeEach @@ -69,11 +73,13 @@ void resetClosed() { } @AfterEach - void checkClosed() { - EchoService.CloseInfo closeInfo = service.closeInfo(); - assertThat(closeInfo, notNullValue()); - assertThat(closeInfo.status(), is(WsCloseCodes.NORMAL_CLOSE)); - assertThat(closeInfo.reason(), is("normal")); + void checkNormalClose() { + if (isNormalClose) { + EchoService.CloseInfo closeInfo = service.closeInfo(); + assertThat(closeInfo, notNullValue()); + assertThat(closeInfo.status(), is(WsCloseCodes.NORMAL_CLOSE)); + assertThat(closeInfo.reason(), is("normal")); + } } @Test @@ -91,7 +97,7 @@ void testOnce() throws Exception { ws.sendText("Hello", true).get(5, TimeUnit.SECONDS); ws.sendClose(WsCloseCodes.NORMAL_CLOSE, "normal").get(5, TimeUnit.SECONDS); - List results = listener.getResults(); + List results = listener.results().received; assertThat(results, contains("Hello")); } @@ -107,7 +113,7 @@ void testMulti() throws Exception { ws.sendText("First", true).get(5, TimeUnit.SECONDS); ws.sendText("Second", true).get(5, TimeUnit.SECONDS); ws.sendClose(WsCloseCodes.NORMAL_CLOSE, "normal").get(5, TimeUnit.SECONDS); - assertThat(listener.getResults(), contains("First", "Second")); + assertThat(listener.results().received, contains("First", "Second")); } @Test @@ -124,13 +130,63 @@ void testFragmentedAndMulti() throws Exception { ws.sendText("Third", true).get(5, TimeUnit.SECONDS); ws.sendClose(WsCloseCodes.NORMAL_CLOSE, "normal").get(5, TimeUnit.SECONDS); - assertThat(listener.getResults(), contains("FirstSecond", "Third")); + assertThat(listener.results().received, contains("FirstSecond", "Third")); + } + + /** + * Tests sending long text messages. Note that any message longer than 16K + * will be chunked into 16K pieces by the JDK client. + * + * @throws Exception if an error occurs + */ + @Test + void testLongTextMessages() throws Exception { + TestListener listener = new TestListener(); + + java.net.http.WebSocket ws = client.newWebSocketBuilder() + .buildAsync(URI.create("ws://localhost:" + port + "/echo"), listener) + .get(5, TimeUnit.SECONDS); + ws.request(10); + + String s100 = randomString(100); // less than one byte + ws.sendText(s100, true).get(5, TimeUnit.SECONDS); + String s10000 = randomString(10000); // less than two bytes + ws.sendText(s10000, true).get(5, TimeUnit.SECONDS); + ws.sendClose(WsCloseCodes.NORMAL_CLOSE, "normal").get(5, TimeUnit.SECONDS); + + assertThat(listener.results().received, contains(s100, s10000)); + } + + /** + * Test sending a single text message that will fit into a single JDK client frame + * of 16K but exceeds max-frame-length set in application.yaml for the server. + * + * @throws Exception if an error occurs + */ + @Test + void testTooLongTextMessage() throws Exception { + TestListener listener = new TestListener(); + + java.net.http.WebSocket ws = client.newWebSocketBuilder() + .buildAsync(URI.create("ws://localhost:" + port + "/echo"), listener) + .get(5, TimeUnit.SECONDS); + ws.request(10); + + String s10001 = randomString(10001); // over the limit of 10000 + ws.sendText(s10001, true).get(5, TimeUnit.SECONDS); + assertThat(listener.results().statusCode, is(1009)); + assertThat(listener.results().reason, is("Payload too large")); + isNormalClose = false; } private static class TestListener implements java.net.http.WebSocket.Listener { + + record Results(int statusCode, String reason, List received) { + } + final List received = new LinkedList<>(); final List buffered = new LinkedList<>(); - private final CompletableFuture> response = new CompletableFuture<>(); + private final CompletableFuture response = new CompletableFuture<>(); @Override public void onOpen(java.net.http.WebSocket webSocket) { @@ -151,12 +207,21 @@ public CompletionStage onText(java.net.http.WebSocket webSocket, CharSequence @Override public CompletionStage onClose(java.net.http.WebSocket webSocket, int statusCode, String reason) { - response.complete(received); + response.complete(new Results(statusCode, reason, received)); return null; } - List getResults() throws ExecutionException, InterruptedException, TimeoutException { + Results results() throws ExecutionException, InterruptedException, TimeoutException { return response.get(10, TimeUnit.SECONDS); } } + + private static String randomString(int length) { + int leftLimit = 97; // letter 'a' + int rightLimit = 122; // letter 'z' + return new Random().ints(leftLimit, rightLimit + 1) + .limit(length) + .collect(StringBuilder::new, StringBuilder::appendCodePoint, StringBuilder::append) + .toString(); + } } diff --git a/webserver/tests/websocket/src/test/resources/application.yaml b/webserver/tests/websocket/src/test/resources/application.yaml new file mode 100644 index 00000000000..72eb09e48f8 --- /dev/null +++ b/webserver/tests/websocket/src/test/resources/application.yaml @@ -0,0 +1,20 @@ +# +# Copyright (c) 2023 Oracle and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +server: + protocols: + websocket: + max-frame-length: 10000 \ No newline at end of file diff --git a/webserver/websocket/src/main/java/io/helidon/webserver/websocket/WsConfigBlueprint.java b/webserver/websocket/src/main/java/io/helidon/webserver/websocket/WsConfigBlueprint.java index 02632a4da12..c0e3a38090e 100644 --- a/webserver/websocket/src/main/java/io/helidon/webserver/websocket/WsConfigBlueprint.java +++ b/webserver/websocket/src/main/java/io/helidon/webserver/websocket/WsConfigBlueprint.java @@ -56,4 +56,13 @@ default String type() { @ConfiguredOption(WsUpgradeProvider.CONFIG_NAME) @Override String name(); + + /** + * Max WebSocket frame size supported by the server on a read operation. + * Default is 1 MB. + * + * @return max frame size to read + */ + @ConfiguredOption("1048576") + int maxFrameLength(); } diff --git a/webserver/websocket/src/main/java/io/helidon/webserver/websocket/WsConnection.java b/webserver/websocket/src/main/java/io/helidon/webserver/websocket/WsConnection.java index e2b69fa8c77..1005d00dc21 100644 --- a/webserver/websocket/src/main/java/io/helidon/webserver/websocket/WsConnection.java +++ b/webserver/websocket/src/main/java/io/helidon/webserver/websocket/WsConnection.java @@ -50,6 +50,7 @@ public class WsConnection implements ServerConnection, WsSession { private final Headers upgradeHeaders; private final String wsKey; private final WsListener listener; + private final WsConfig wsConfig; private final BufferData sendBuffer = BufferData.growing(1024); private final DataReader dataReader; @@ -67,7 +68,8 @@ private WsConnection(ConnectionContext ctx, HttpPrologue prologue, Headers upgradeHeaders, String wsKey, - WsRoute wsRoute) { + WsRoute wsRoute, + WsConfig wsConfig) { this.ctx = ctx; this.prologue = prologue; this.upgradeHeaders = upgradeHeaders; @@ -75,6 +77,7 @@ private WsConnection(ConnectionContext ctx, this.listener = wsRoute.listener(); this.dataReader = ctx.dataReader(); this.lastRequestTimestamp = DateTime.timestamp(); + this.wsConfig = wsConfig; } /** @@ -85,14 +88,16 @@ private WsConnection(ConnectionContext ctx, * @param upgradeHeaders headers for * @param wsKey ws key * @param wsRoute route to use + * @param wsConfig websocket config * @return a new connection */ public static WsConnection create(ConnectionContext ctx, HttpPrologue prologue, Headers upgradeHeaders, String wsKey, - WsRoute wsRoute) { - return new WsConnection(ctx, prologue, upgradeHeaders, wsKey, wsRoute); + WsRoute wsRoute, + WsConfig wsConfig) { + return new WsConnection(ctx, prologue, upgradeHeaders, wsKey, wsRoute, wsConfig); } @Override @@ -243,8 +248,8 @@ private boolean processFrame(ClientWsFrame frame) { private ClientWsFrame readFrame() { try { - // TODO check may payload size, danger of oom - return ClientWsFrame.read(ctx, dataReader, Integer.MAX_VALUE); + int maxFrameLength = wsConfig != null ? wsConfig.maxFrameLength() : Integer.MAX_VALUE; + return ClientWsFrame.read(ctx, dataReader, maxFrameLength); } catch (DataReader.InsufficientDataAvailableException e) { throw new CloseConnectionException("Socket closed by the other side", e); } catch (WsCloseException e) { @@ -276,9 +281,18 @@ private WsSession send(ServerWsFrame frame) { opCodeFull |= usedCode.code(); sendBuffer.write(opCodeFull); - if (frame.payloadLength() < 126) { - sendBuffer.write((int) frame.payloadLength()); - // TODO finish other options (payload longer than 126 bytes) + long length = frame.payloadLength(); + if (length < 126) { + sendBuffer.write((int) length); + } else if (length < 1 << 16) { + sendBuffer.write(126); + sendBuffer.write((int) (length >>> 8)); + sendBuffer.write((int) (length & 0xFF)); + } else { + sendBuffer.write(127); + for (int i = 56; i >= 0; i -= 8){ + sendBuffer.write((int) (length >>> i) & 0xFF); + } } sendBuffer.write(frame.payloadData()); ctx.dataWriter().writeNow(sendBuffer); diff --git a/webserver/websocket/src/main/java/io/helidon/webserver/websocket/WsUpgradeProvider.java b/webserver/websocket/src/main/java/io/helidon/webserver/websocket/WsUpgradeProvider.java index 4f9573ca07c..ffc7aea7698 100644 --- a/webserver/websocket/src/main/java/io/helidon/webserver/websocket/WsUpgradeProvider.java +++ b/webserver/websocket/src/main/java/io/helidon/webserver/websocket/WsUpgradeProvider.java @@ -25,7 +25,7 @@ */ public class WsUpgradeProvider implements Http1UpgradeProvider { /** - * HTTP/2 server connection provider configuration node name. + * WebSocket server connection provider configuration node name. */ protected static final String CONFIG_NAME = "websocket"; diff --git a/webserver/websocket/src/main/java/io/helidon/webserver/websocket/WsUpgrader.java b/webserver/websocket/src/main/java/io/helidon/webserver/websocket/WsUpgrader.java index 476e351ecfc..59b3009d4e2 100644 --- a/webserver/websocket/src/main/java/io/helidon/webserver/websocket/WsUpgrader.java +++ b/webserver/websocket/src/main/java/io/helidon/webserver/websocket/WsUpgrader.java @@ -96,12 +96,12 @@ public class WsUpgrader implements Http1Upgrader { private static final Base64.Decoder B64_DECODER = Base64.getDecoder(); private static final Base64.Encoder B64_ENCODER = Base64.getEncoder(); private static final byte[] HEADERS_SEPARATOR = "\r\n".getBytes(US_ASCII); - private final Set origins; private final boolean anyOrigin; + private final WsConfig wsConfig; protected WsUpgrader(WsConfig wsConfig) { - this.origins = wsConfig.origins(); - this.anyOrigin = this.origins.isEmpty(); + this.wsConfig = wsConfig; + this.anyOrigin = wsConfig.origins().isEmpty(); } /** @@ -191,7 +191,7 @@ public ServerConnection upgrade(ConnectionContext ctx, HttpPrologue prologue, Wr LOGGER.log(Level.TRACE, "Upgraded to websocket version " + version); } - return WsConnection.create(ctx, prologue, upgradeHeaders.orElse(EMPTY_HEADERS), wsKey, route); + return WsConnection.create(ctx, prologue, upgradeHeaders.orElse(EMPTY_HEADERS), wsKey, route, wsConfig); } protected boolean anyOrigin() { @@ -199,7 +199,7 @@ protected boolean anyOrigin() { } protected Set origins() { - return origins; + return wsConfig.origins(); } protected String hash(ConnectionContext ctx, String wsKey) { diff --git a/websocket/src/main/java/io/helidon/websocket/AbstractWsFrame.java b/websocket/src/main/java/io/helidon/websocket/AbstractWsFrame.java index 6bd28fb53df..35671337f65 100644 --- a/websocket/src/main/java/io/helidon/websocket/AbstractWsFrame.java +++ b/websocket/src/main/java/io/helidon/websocket/AbstractWsFrame.java @@ -109,7 +109,7 @@ protected static FrameHeader readFrameHeader(DataReader reader, int maxFrameLeng throw new WsCloseException("Payload too large", WsCloseCodes.TOO_BIG); } - return new FrameHeader(opCode, fin, masked, length); + return new FrameHeader(opCode, fin, masked, (int) frameLength); } protected static BufferData readPayload(DataReader reader, FrameHeader header) {