Skip to content

Commit

Permalink
Fixes in this commit:
Browse files Browse the repository at this point in the history
1. Support for writing WS frames of more than 125 bytes, i.e. with lengths that fit in two bytes and eight bytes.
2. Configuration of max-frame-length to limit the size of a frame read by the server. The JDK WS client splits a long payload into 16K frames.
3. Some new tests for (1) and (2).
  • Loading branch information
spericas committed Nov 17, 2023
1 parent ea20930 commit a6b1a3d
Show file tree
Hide file tree
Showing 7 changed files with 135 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import java.time.Duration;
import java.util.LinkedList;
import java.util.List;
import java.util.Random;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage;
import java.util.concurrent.ExecutionException;
Expand Down Expand Up @@ -50,6 +51,8 @@ class WebSocketTest {
private final int port;
private final HttpClient client;

private volatile boolean isNormalClose = true;

WebSocketTest(WebServer server) {
port = server.port();
client = HttpClient.newBuilder()
Expand All @@ -69,11 +72,13 @@ void resetClosed() {
}

@AfterEach
void checkClosed() {
EchoService.CloseInfo closeInfo = service.closeInfo();
assertThat(closeInfo, notNullValue());
assertThat(closeInfo.status(), is(WsCloseCodes.NORMAL_CLOSE));
assertThat(closeInfo.reason(), is("normal"));
void checkNormalClose() {
if (isNormalClose) {
EchoService.CloseInfo closeInfo = service.closeInfo();
assertThat(closeInfo, notNullValue());
assertThat(closeInfo.status(), is(WsCloseCodes.NORMAL_CLOSE));
assertThat(closeInfo.reason(), is("normal"));
}
}

@Test
Expand All @@ -91,7 +96,7 @@ void testOnce() throws Exception {
ws.sendText("Hello", true).get(5, TimeUnit.SECONDS);
ws.sendClose(WsCloseCodes.NORMAL_CLOSE, "normal").get(5, TimeUnit.SECONDS);

List<String> results = listener.getResults();
List<String> results = listener.results().received;
assertThat(results, contains("Hello"));
}

Expand All @@ -107,7 +112,7 @@ void testMulti() throws Exception {
ws.sendText("First", true).get(5, TimeUnit.SECONDS);
ws.sendText("Second", true).get(5, TimeUnit.SECONDS);
ws.sendClose(WsCloseCodes.NORMAL_CLOSE, "normal").get(5, TimeUnit.SECONDS);
assertThat(listener.getResults(), contains("First", "Second"));
assertThat(listener.results().received, contains("First", "Second"));
}

@Test
Expand All @@ -124,13 +129,63 @@ void testFragmentedAndMulti() throws Exception {
ws.sendText("Third", true).get(5, TimeUnit.SECONDS);
ws.sendClose(WsCloseCodes.NORMAL_CLOSE, "normal").get(5, TimeUnit.SECONDS);

assertThat(listener.getResults(), contains("FirstSecond", "Third"));
assertThat(listener.results().received, contains("FirstSecond", "Third"));
}

/**
* Tests sending long text messages. Note that any message longer than 16K
* will be chunked into 16K pieces by the JDK client.
*
* @throws Exception if an error occurs
*/
@Test
void testLongTextMessages() throws Exception {
TestListener listener = new TestListener();

java.net.http.WebSocket ws = client.newWebSocketBuilder()
.buildAsync(URI.create("ws://localhost:" + port + "/echo"), listener)
.get(5, TimeUnit.SECONDS);
ws.request(10);

String s100 = randomString(100); // less than one byte
ws.sendText(s100, true).get(5, TimeUnit.SECONDS);
String s10000 = randomString(10000); // less than two bytes
ws.sendText(s10000, true).get(5, TimeUnit.SECONDS);
ws.sendClose(WsCloseCodes.NORMAL_CLOSE, "normal").get(5, TimeUnit.SECONDS);

assertThat(listener.results().received, contains(s100, s10000));
}

/**
* Test sending a single text message that will fit into a single JDK client frame
* of 16K but exceeds max-frame-length set in application.yaml for the server.
*
* @throws Exception if an error occurs
*/
@Test
void testTooLongTextMessage() throws Exception {
TestListener listener = new TestListener();

java.net.http.WebSocket ws = client.newWebSocketBuilder()
.buildAsync(URI.create("ws://localhost:" + port + "/echo"), listener)
.get(5, TimeUnit.SECONDS);
ws.request(10);

String s10001 = randomString(10001); // over the limit of 10000
ws.sendText(s10001, true).get(5, TimeUnit.SECONDS);
assertThat(listener.results().statusCode, is(1009));
assertThat(listener.results().reason, is("Payload too large"));
isNormalClose = false;
}

private static class TestListener implements java.net.http.WebSocket.Listener {

record Results(int statusCode, String reason, List<String> received) {
}

final List<String> received = new LinkedList<>();
final List<String> buffered = new LinkedList<>();
private final CompletableFuture<List<String>> response = new CompletableFuture<>();
private final CompletableFuture<Results> response = new CompletableFuture<>();

@Override
public void onOpen(java.net.http.WebSocket webSocket) {
Expand All @@ -151,12 +206,21 @@ public CompletionStage<?> onText(java.net.http.WebSocket webSocket, CharSequence

@Override
public CompletionStage<?> onClose(java.net.http.WebSocket webSocket, int statusCode, String reason) {
response.complete(received);
response.complete(new Results(statusCode, reason, received));
return null;
}

List<String> getResults() throws ExecutionException, InterruptedException, TimeoutException {
Results results() throws ExecutionException, InterruptedException, TimeoutException {
return response.get(10, TimeUnit.SECONDS);
}
}

private static String randomString(int length) {
int leftLimit = 97; // letter 'a'
int rightLimit = 122; // letter 'z'
return new Random().ints(leftLimit, rightLimit + 1)
.limit(length)
.collect(StringBuilder::new, StringBuilder::appendCodePoint, StringBuilder::append)
.toString();
}
}
20 changes: 20 additions & 0 deletions webserver/tests/websocket/src/test/resources/application.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#
# Copyright (c) 2023 Oracle and/or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

server:
protocols:
websocket:
max-frame-length: 10000
Original file line number Diff line number Diff line change
Expand Up @@ -56,4 +56,13 @@ default String type() {
@ConfiguredOption(WsUpgradeProvider.CONFIG_NAME)
@Override
String name();

/**
* Max WebSocket frame size supported by the server on a read operation.
* Default is 1 MB.
*
* @return max frame size to read
*/
@ConfiguredOption(WsConnection.MAX_FRAME_LENGTH)
int maxFrameLength();
}
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,14 @@
public class WsConnection implements ServerConnection, WsSession {
private static final System.Logger LOGGER = System.getLogger(WsConnection.class.getName());

static final String MAX_FRAME_LENGTH = "1048576";

private final ConnectionContext ctx;
private final HttpPrologue prologue;
private final Headers upgradeHeaders;
private final String wsKey;
private final WsListener listener;
private final WsConfig wsConfig;

private final BufferData sendBuffer = BufferData.growing(1024);
private final DataReader dataReader;
Expand All @@ -75,6 +78,13 @@ private WsConnection(ConnectionContext ctx,
this.listener = wsRoute.listener();
this.dataReader = ctx.dataReader();
this.lastRequestTimestamp = DateTime.timestamp();
this.wsConfig = (WsConfig) ctx.listenerContext()
.config()
.protocols()
.stream()
.filter(p -> p instanceof WsConfig)
.findFirst()
.orElse(null);
}

/**
Expand Down Expand Up @@ -243,8 +253,9 @@ private boolean processFrame(ClientWsFrame frame) {

private ClientWsFrame readFrame() {
try {
// TODO check may payload size, danger of oom
return ClientWsFrame.read(ctx, dataReader, Integer.MAX_VALUE);
int maxFrameLength = wsConfig != null ? wsConfig.maxFrameLength()
: Integer.parseInt(MAX_FRAME_LENGTH);
return ClientWsFrame.read(ctx, dataReader, maxFrameLength);
} catch (DataReader.InsufficientDataAvailableException e) {
throw new CloseConnectionException("Socket closed by the other side", e);
} catch (WsCloseException e) {
Expand Down Expand Up @@ -276,9 +287,18 @@ private WsSession send(ServerWsFrame frame) {
opCodeFull |= usedCode.code();
sendBuffer.write(opCodeFull);

if (frame.payloadLength() < 126) {
sendBuffer.write((int) frame.payloadLength());
// TODO finish other options (payload longer than 126 bytes)
long length = frame.payloadLength();
if (length < 126) {
sendBuffer.write((int) length);
} else if (length < 1 << 16) {
sendBuffer.write(126);
sendBuffer.write((int) (length >>> 8));
sendBuffer.write((int) (length & 0xFF));
} else {
sendBuffer.write(127);
for (int i = 56; i >= 0; i -= 8){
sendBuffer.write((int) (length >>> i) & 0xFF);
}
}
sendBuffer.write(frame.payloadData());
ctx.dataWriter().writeNow(sendBuffer);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
*/
public class WsUpgradeProvider implements Http1UpgradeProvider<WsConfig> {
/**
* HTTP/2 server connection provider configuration node name.
* WebSocket server connection provider configuration node name.
*/
protected static final String CONFIG_NAME = "websocket";

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,12 +96,12 @@ public class WsUpgrader implements Http1Upgrader {
private static final Base64.Decoder B64_DECODER = Base64.getDecoder();
private static final Base64.Encoder B64_ENCODER = Base64.getEncoder();
private static final byte[] HEADERS_SEPARATOR = "\r\n".getBytes(US_ASCII);
private final Set<String> origins;
private final boolean anyOrigin;
private final WsConfig wsConfig;

protected WsUpgrader(WsConfig wsConfig) {
this.origins = wsConfig.origins();
this.anyOrigin = this.origins.isEmpty();
this.wsConfig = wsConfig;
this.anyOrigin = wsConfig.origins().isEmpty();
}

/**
Expand Down Expand Up @@ -199,7 +199,7 @@ protected boolean anyOrigin() {
}

protected Set<String> origins() {
return origins;
return wsConfig.origins();
}

protected String hash(ConnectionContext ctx, String wsKey) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ protected static FrameHeader readFrameHeader(DataReader reader, int maxFrameLeng
throw new WsCloseException("Payload too large", WsCloseCodes.TOO_BIG);
}

return new FrameHeader(opCode, fin, masked, length);
return new FrameHeader(opCode, fin, masked, (int) frameLength);
}

protected static BufferData readPayload(DataReader reader, FrameHeader header) {
Expand Down

0 comments on commit a6b1a3d

Please sign in to comment.