Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for longer WebSocket frames #8025

Merged
merged 2 commits into from
Nov 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import java.time.Duration;
import java.util.LinkedList;
import java.util.List;
import java.util.Random;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage;
import java.util.concurrent.ExecutionException;
Expand Down Expand Up @@ -50,6 +51,8 @@ class WebSocketTest {
private final int port;
private final HttpClient client;

private volatile boolean isNormalClose = true;

WebSocketTest(WebServer server) {
port = server.port();
client = HttpClient.newBuilder()
Expand All @@ -69,11 +72,13 @@ void resetClosed() {
}

@AfterEach
void checkClosed() {
EchoService.CloseInfo closeInfo = service.closeInfo();
assertThat(closeInfo, notNullValue());
assertThat(closeInfo.status(), is(WsCloseCodes.NORMAL_CLOSE));
assertThat(closeInfo.reason(), is("normal"));
void checkNormalClose() {
if (isNormalClose) {
EchoService.CloseInfo closeInfo = service.closeInfo();
assertThat(closeInfo, notNullValue());
assertThat(closeInfo.status(), is(WsCloseCodes.NORMAL_CLOSE));
assertThat(closeInfo.reason(), is("normal"));
}
}

@Test
Expand All @@ -91,7 +96,7 @@ void testOnce() throws Exception {
ws.sendText("Hello", true).get(5, TimeUnit.SECONDS);
ws.sendClose(WsCloseCodes.NORMAL_CLOSE, "normal").get(5, TimeUnit.SECONDS);

List<String> results = listener.getResults();
List<String> results = listener.results().received;
assertThat(results, contains("Hello"));
}

Expand All @@ -107,7 +112,7 @@ void testMulti() throws Exception {
ws.sendText("First", true).get(5, TimeUnit.SECONDS);
ws.sendText("Second", true).get(5, TimeUnit.SECONDS);
ws.sendClose(WsCloseCodes.NORMAL_CLOSE, "normal").get(5, TimeUnit.SECONDS);
assertThat(listener.getResults(), contains("First", "Second"));
assertThat(listener.results().received, contains("First", "Second"));
}

@Test
Expand All @@ -124,13 +129,63 @@ void testFragmentedAndMulti() throws Exception {
ws.sendText("Third", true).get(5, TimeUnit.SECONDS);
ws.sendClose(WsCloseCodes.NORMAL_CLOSE, "normal").get(5, TimeUnit.SECONDS);

assertThat(listener.getResults(), contains("FirstSecond", "Third"));
assertThat(listener.results().received, contains("FirstSecond", "Third"));
}

/**
* Tests sending long text messages. Note that any message longer than 16K
* will be chunked into 16K pieces by the JDK client.
*
* @throws Exception if an error occurs
*/
@Test
void testLongTextMessages() throws Exception {
TestListener listener = new TestListener();

java.net.http.WebSocket ws = client.newWebSocketBuilder()
.buildAsync(URI.create("ws://localhost:" + port + "/echo"), listener)
.get(5, TimeUnit.SECONDS);
ws.request(10);

String s100 = randomString(100); // less than one byte
ws.sendText(s100, true).get(5, TimeUnit.SECONDS);
String s10000 = randomString(10000); // less than two bytes
ws.sendText(s10000, true).get(5, TimeUnit.SECONDS);
ws.sendClose(WsCloseCodes.NORMAL_CLOSE, "normal").get(5, TimeUnit.SECONDS);

assertThat(listener.results().received, contains(s100, s10000));
}

/**
* Test sending a single text message that will fit into a single JDK client frame
* of 16K but exceeds max-frame-length set in application.yaml for the server.
*
* @throws Exception if an error occurs
*/
@Test
void testTooLongTextMessage() throws Exception {
TestListener listener = new TestListener();

java.net.http.WebSocket ws = client.newWebSocketBuilder()
.buildAsync(URI.create("ws://localhost:" + port + "/echo"), listener)
.get(5, TimeUnit.SECONDS);
ws.request(10);

String s10001 = randomString(10001); // over the limit of 10000
ws.sendText(s10001, true).get(5, TimeUnit.SECONDS);
assertThat(listener.results().statusCode, is(1009));
assertThat(listener.results().reason, is("Payload too large"));
isNormalClose = false;
}

private static class TestListener implements java.net.http.WebSocket.Listener {

record Results(int statusCode, String reason, List<String> received) {
}

final List<String> received = new LinkedList<>();
final List<String> buffered = new LinkedList<>();
private final CompletableFuture<List<String>> response = new CompletableFuture<>();
private final CompletableFuture<Results> response = new CompletableFuture<>();

@Override
public void onOpen(java.net.http.WebSocket webSocket) {
Expand All @@ -151,12 +206,21 @@ public CompletionStage<?> onText(java.net.http.WebSocket webSocket, CharSequence

@Override
public CompletionStage<?> onClose(java.net.http.WebSocket webSocket, int statusCode, String reason) {
response.complete(received);
response.complete(new Results(statusCode, reason, received));
return null;
}

List<String> getResults() throws ExecutionException, InterruptedException, TimeoutException {
Results results() throws ExecutionException, InterruptedException, TimeoutException {
return response.get(10, TimeUnit.SECONDS);
}
}

private static String randomString(int length) {
int leftLimit = 97; // letter 'a'
int rightLimit = 122; // letter 'z'
return new Random().ints(leftLimit, rightLimit + 1)
.limit(length)
.collect(StringBuilder::new, StringBuilder::appendCodePoint, StringBuilder::append)
.toString();
}
}
20 changes: 20 additions & 0 deletions webserver/tests/websocket/src/test/resources/application.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#
# Copyright (c) 2023 Oracle and/or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

server:
protocols:
websocket:
max-frame-length: 10000
Original file line number Diff line number Diff line change
Expand Up @@ -56,4 +56,13 @@ default String type() {
@ConfiguredOption(WsUpgradeProvider.CONFIG_NAME)
@Override
String name();

/**
* Max WebSocket frame size supported by the server on a read operation.
* Default is 1 MB.
*
* @return max frame size to read
*/
@ConfiguredOption(WsConnection.MAX_FRAME_LENGTH)
int maxFrameLength();
}
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,14 @@
public class WsConnection implements ServerConnection, WsSession {
private static final System.Logger LOGGER = System.getLogger(WsConnection.class.getName());

static final String MAX_FRAME_LENGTH = "1048576";

private final ConnectionContext ctx;
private final HttpPrologue prologue;
private final Headers upgradeHeaders;
private final String wsKey;
private final WsListener listener;
private final WsConfig wsConfig;

private final BufferData sendBuffer = BufferData.growing(1024);
private final DataReader dataReader;
Expand All @@ -75,6 +78,13 @@ private WsConnection(ConnectionContext ctx,
this.listener = wsRoute.listener();
this.dataReader = ctx.dataReader();
this.lastRequestTimestamp = DateTime.timestamp();
this.wsConfig = (WsConfig) ctx.listenerContext()
.config()
.protocols()
.stream()
.filter(p -> p instanceof WsConfig)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could work around the ugly cast by doing (after filter)
.map(WsConfig.class::cast)
which will change the stream to correctly typed one.

.findFirst()
.orElseThrow(() -> new InternalError("Unable to find WebSocket config"));
}

/**
Expand Down Expand Up @@ -243,8 +253,7 @@ private boolean processFrame(ClientWsFrame frame) {

private ClientWsFrame readFrame() {
try {
// TODO check may payload size, danger of oom
return ClientWsFrame.read(ctx, dataReader, Integer.MAX_VALUE);
return ClientWsFrame.read(ctx, dataReader, wsConfig.maxFrameLength());
} catch (DataReader.InsufficientDataAvailableException e) {
throw new CloseConnectionException("Socket closed by the other side", e);
} catch (WsCloseException e) {
Expand Down Expand Up @@ -276,9 +285,18 @@ private WsSession send(ServerWsFrame frame) {
opCodeFull |= usedCode.code();
sendBuffer.write(opCodeFull);

if (frame.payloadLength() < 126) {
sendBuffer.write((int) frame.payloadLength());
// TODO finish other options (payload longer than 126 bytes)
long length = frame.payloadLength();
if (length < 126) {
sendBuffer.write((int) length);
} else if (length < 1 << 16) {
sendBuffer.write(126);
sendBuffer.write((int) (length >>> 8));
sendBuffer.write((int) (length & 0xFF));
} else {
sendBuffer.write(127);
for (int i = 56; i >= 0; i -= 8){
sendBuffer.write((int) (length >>> i) & 0xFF);
}
}
sendBuffer.write(frame.payloadData());
ctx.dataWriter().writeNow(sendBuffer);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
*/
public class WsUpgradeProvider implements Http1UpgradeProvider<WsConfig> {
/**
* HTTP/2 server connection provider configuration node name.
* WebSocket server connection provider configuration node name.
*/
protected static final String CONFIG_NAME = "websocket";

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ protected static FrameHeader readFrameHeader(DataReader reader, int maxFrameLeng
throw new WsCloseException("Payload too large", WsCloseCodes.TOO_BIG);
}

return new FrameHeader(opCode, fin, masked, length);
return new FrameHeader(opCode, fin, masked, (int) frameLength);
}

protected static BufferData readPayload(DataReader reader, FrameHeader header) {
Expand Down