Skip to content

Commit

Permalink
Fixes in this commit:
Browse files Browse the repository at this point in the history
1. Support for writing WS frames of more than 125 bytes, i.e. with lengths that fit in two bytes and eight bytes.
2. Configuration of max-frame-length to limit the size of a frame read by the server. The JDK WS client splits a long payload into 16K frames.
3. Some new tests for (1) and (2).
  • Loading branch information
spericas committed Nov 17, 2023
1 parent ea20930 commit efef095
Show file tree
Hide file tree
Showing 8 changed files with 137 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,8 @@ private static DataReader reader(ArrayBlockingQueue<byte[]> queue) {

void start() {
if (serverStarted.compareAndSet(false, true)) {
WsConnection serverConnection = WsConnection.create(ctx, prologue, WritableHeaders.create(), "", serverRoute);
WsConnection serverConnection = WsConnection.create(ctx, prologue, WritableHeaders.create(), "",
serverRoute, null);

ClientWsConnection clientConnection = ClientWsConnection.create(new DirectConnect(clientReader, clientWriter),
clientListener);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import java.time.Duration;
import java.util.LinkedList;
import java.util.List;
import java.util.Random;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage;
import java.util.concurrent.ExecutionException;
Expand Down Expand Up @@ -50,6 +51,8 @@ class WebSocketTest {
private final int port;
private final HttpClient client;

private volatile boolean isNormalClose = true;

WebSocketTest(WebServer server) {
port = server.port();
client = HttpClient.newBuilder()
Expand All @@ -60,7 +63,8 @@ class WebSocketTest {
@SetUpRoute
static void router(Router.RouterBuilder<?> router) {
service = new EchoService();
router.addRouting(WsRouting.builder().endpoint("/echo", service));
router.addRouting(WsRouting.builder()
.endpoint("/echo", service));
}

@BeforeEach
Expand All @@ -69,11 +73,13 @@ void resetClosed() {
}

@AfterEach
void checkClosed() {
EchoService.CloseInfo closeInfo = service.closeInfo();
assertThat(closeInfo, notNullValue());
assertThat(closeInfo.status(), is(WsCloseCodes.NORMAL_CLOSE));
assertThat(closeInfo.reason(), is("normal"));
void checkNormalClose() {
if (isNormalClose) {
EchoService.CloseInfo closeInfo = service.closeInfo();
assertThat(closeInfo, notNullValue());
assertThat(closeInfo.status(), is(WsCloseCodes.NORMAL_CLOSE));
assertThat(closeInfo.reason(), is("normal"));
}
}

@Test
Expand All @@ -91,7 +97,7 @@ void testOnce() throws Exception {
ws.sendText("Hello", true).get(5, TimeUnit.SECONDS);
ws.sendClose(WsCloseCodes.NORMAL_CLOSE, "normal").get(5, TimeUnit.SECONDS);

List<String> results = listener.getResults();
List<String> results = listener.results().received;
assertThat(results, contains("Hello"));
}

Expand All @@ -107,7 +113,7 @@ void testMulti() throws Exception {
ws.sendText("First", true).get(5, TimeUnit.SECONDS);
ws.sendText("Second", true).get(5, TimeUnit.SECONDS);
ws.sendClose(WsCloseCodes.NORMAL_CLOSE, "normal").get(5, TimeUnit.SECONDS);
assertThat(listener.getResults(), contains("First", "Second"));
assertThat(listener.results().received, contains("First", "Second"));
}

@Test
Expand All @@ -124,13 +130,63 @@ void testFragmentedAndMulti() throws Exception {
ws.sendText("Third", true).get(5, TimeUnit.SECONDS);
ws.sendClose(WsCloseCodes.NORMAL_CLOSE, "normal").get(5, TimeUnit.SECONDS);

assertThat(listener.getResults(), contains("FirstSecond", "Third"));
assertThat(listener.results().received, contains("FirstSecond", "Third"));
}

/**
* Tests sending long text messages. Note that any message longer than 16K
* will be chunked into 16K pieces by the JDK client.
*
* @throws Exception if an error occurs
*/
@Test
void testLongTextMessages() throws Exception {
TestListener listener = new TestListener();

java.net.http.WebSocket ws = client.newWebSocketBuilder()
.buildAsync(URI.create("ws://localhost:" + port + "/echo"), listener)
.get(5, TimeUnit.SECONDS);
ws.request(10);

String s100 = randomString(100); // less than one byte
ws.sendText(s100, true).get(5, TimeUnit.SECONDS);
String s10000 = randomString(10000); // less than two bytes
ws.sendText(s10000, true).get(5, TimeUnit.SECONDS);
ws.sendClose(WsCloseCodes.NORMAL_CLOSE, "normal").get(5, TimeUnit.SECONDS);

assertThat(listener.results().received, contains(s100, s10000));
}

/**
* Test sending a single text message that will fit into a single JDK client frame
* of 16K but exceeds max-frame-length set in application.yaml for the server.
*
* @throws Exception if an error occurs
*/
@Test
void testTooLongTextMessage() throws Exception {
TestListener listener = new TestListener();

java.net.http.WebSocket ws = client.newWebSocketBuilder()
.buildAsync(URI.create("ws://localhost:" + port + "/echo"), listener)
.get(5, TimeUnit.SECONDS);
ws.request(10);

String s10001 = randomString(10001); // over the limit of 10000
ws.sendText(s10001, true).get(5, TimeUnit.SECONDS);
assertThat(listener.results().statusCode, is(1009));
assertThat(listener.results().reason, is("Payload too large"));
isNormalClose = false;
}

private static class TestListener implements java.net.http.WebSocket.Listener {

record Results(int statusCode, String reason, List<String> received) {
}

final List<String> received = new LinkedList<>();
final List<String> buffered = new LinkedList<>();
private final CompletableFuture<List<String>> response = new CompletableFuture<>();
private final CompletableFuture<Results> response = new CompletableFuture<>();

@Override
public void onOpen(java.net.http.WebSocket webSocket) {
Expand All @@ -151,12 +207,21 @@ public CompletionStage<?> onText(java.net.http.WebSocket webSocket, CharSequence

@Override
public CompletionStage<?> onClose(java.net.http.WebSocket webSocket, int statusCode, String reason) {
response.complete(received);
response.complete(new Results(statusCode, reason, received));
return null;
}

List<String> getResults() throws ExecutionException, InterruptedException, TimeoutException {
Results results() throws ExecutionException, InterruptedException, TimeoutException {
return response.get(10, TimeUnit.SECONDS);
}
}

private static String randomString(int length) {
int leftLimit = 97; // letter 'a'
int rightLimit = 122; // letter 'z'
return new Random().ints(leftLimit, rightLimit + 1)
.limit(length)
.collect(StringBuilder::new, StringBuilder::appendCodePoint, StringBuilder::append)
.toString();
}
}
20 changes: 20 additions & 0 deletions webserver/tests/websocket/src/test/resources/application.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#
# Copyright (c) 2023 Oracle and/or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

server:
protocols:
websocket:
max-frame-length: 10000
Original file line number Diff line number Diff line change
Expand Up @@ -56,4 +56,13 @@ default String type() {
@ConfiguredOption(WsUpgradeProvider.CONFIG_NAME)
@Override
String name();

/**
* Max WebSocket frame size supported by the server on a read operation.
* Default is 1 MB.
*
* @return max frame size to read
*/
@ConfiguredOption("1048576")
int maxFrameLength();
}
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ public class WsConnection implements ServerConnection, WsSession {
private final Headers upgradeHeaders;
private final String wsKey;
private final WsListener listener;
private final WsConfig wsConfig;

private final BufferData sendBuffer = BufferData.growing(1024);
private final DataReader dataReader;
Expand All @@ -67,14 +68,16 @@ private WsConnection(ConnectionContext ctx,
HttpPrologue prologue,
Headers upgradeHeaders,
String wsKey,
WsRoute wsRoute) {
WsRoute wsRoute,
WsConfig wsConfig) {
this.ctx = ctx;
this.prologue = prologue;
this.upgradeHeaders = upgradeHeaders;
this.wsKey = wsKey;
this.listener = wsRoute.listener();
this.dataReader = ctx.dataReader();
this.lastRequestTimestamp = DateTime.timestamp();
this.wsConfig = wsConfig;
}

/**
Expand All @@ -85,14 +88,16 @@ private WsConnection(ConnectionContext ctx,
* @param upgradeHeaders headers for
* @param wsKey ws key
* @param wsRoute route to use
* @param wsConfig websocket config
* @return a new connection
*/
public static WsConnection create(ConnectionContext ctx,
HttpPrologue prologue,
Headers upgradeHeaders,
String wsKey,
WsRoute wsRoute) {
return new WsConnection(ctx, prologue, upgradeHeaders, wsKey, wsRoute);
WsRoute wsRoute,
WsConfig wsConfig) {
return new WsConnection(ctx, prologue, upgradeHeaders, wsKey, wsRoute, wsConfig);
}

@Override
Expand Down Expand Up @@ -243,8 +248,8 @@ private boolean processFrame(ClientWsFrame frame) {

private ClientWsFrame readFrame() {
try {
// TODO check may payload size, danger of oom
return ClientWsFrame.read(ctx, dataReader, Integer.MAX_VALUE);
int maxFrameLength = wsConfig != null ? wsConfig.maxFrameLength() : Integer.MAX_VALUE;
return ClientWsFrame.read(ctx, dataReader, maxFrameLength);
} catch (DataReader.InsufficientDataAvailableException e) {
throw new CloseConnectionException("Socket closed by the other side", e);
} catch (WsCloseException e) {
Expand Down Expand Up @@ -276,9 +281,18 @@ private WsSession send(ServerWsFrame frame) {
opCodeFull |= usedCode.code();
sendBuffer.write(opCodeFull);

if (frame.payloadLength() < 126) {
sendBuffer.write((int) frame.payloadLength());
// TODO finish other options (payload longer than 126 bytes)
long length = frame.payloadLength();
if (length < 126) {
sendBuffer.write((int) length);
} else if (length < 1 << 16) {
sendBuffer.write(126);
sendBuffer.write((int) (length >>> 8));
sendBuffer.write((int) (length & 0xFF));
} else {
sendBuffer.write(127);
for (int i = 56; i >= 0; i -= 8){
sendBuffer.write((int) (length >>> i) & 0xFF);
}
}
sendBuffer.write(frame.payloadData());
ctx.dataWriter().writeNow(sendBuffer);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
*/
public class WsUpgradeProvider implements Http1UpgradeProvider<WsConfig> {
/**
* HTTP/2 server connection provider configuration node name.
* WebSocket server connection provider configuration node name.
*/
protected static final String CONFIG_NAME = "websocket";

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,12 +96,12 @@ public class WsUpgrader implements Http1Upgrader {
private static final Base64.Decoder B64_DECODER = Base64.getDecoder();
private static final Base64.Encoder B64_ENCODER = Base64.getEncoder();
private static final byte[] HEADERS_SEPARATOR = "\r\n".getBytes(US_ASCII);
private final Set<String> origins;
private final boolean anyOrigin;
private final WsConfig wsConfig;

protected WsUpgrader(WsConfig wsConfig) {
this.origins = wsConfig.origins();
this.anyOrigin = this.origins.isEmpty();
this.wsConfig = wsConfig;
this.anyOrigin = wsConfig.origins().isEmpty();
}

/**
Expand Down Expand Up @@ -191,15 +191,15 @@ public ServerConnection upgrade(ConnectionContext ctx, HttpPrologue prologue, Wr
LOGGER.log(Level.TRACE, "Upgraded to websocket version " + version);
}

return WsConnection.create(ctx, prologue, upgradeHeaders.orElse(EMPTY_HEADERS), wsKey, route);
return WsConnection.create(ctx, prologue, upgradeHeaders.orElse(EMPTY_HEADERS), wsKey, route, wsConfig);
}

protected boolean anyOrigin() {
return anyOrigin;
}

protected Set<String> origins() {
return origins;
return wsConfig.origins();
}

protected String hash(ConnectionContext ctx, String wsKey) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ protected static FrameHeader readFrameHeader(DataReader reader, int maxFrameLeng
throw new WsCloseException("Payload too large", WsCloseCodes.TOO_BIG);
}

return new FrameHeader(opCode, fin, masked, length);
return new FrameHeader(opCode, fin, masked, (int) frameLength);
}

protected static BufferData readPayload(DataReader reader, FrameHeader header) {
Expand Down

0 comments on commit efef095

Please sign in to comment.