Skip to content

Commit

Permalink
Client and server should each support both ws mechanisms
Browse files Browse the repository at this point in the history
  • Loading branch information
niloc132 committed Jul 11, 2022
1 parent b7ddeac commit 1cc40d8
Show file tree
Hide file tree
Showing 6 changed files with 233 additions and 81 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package io.grpc.servlet.web.websocket;

import jakarta.websocket.CloseReason;
import jakarta.websocket.Endpoint;
import jakarta.websocket.EndpointConfig;
import jakarta.websocket.Session;

import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.HashMap;
import java.util.Map;
import java.util.function.Supplier;

/**
* Supports both grpc-websockets and grpc-websockets-multiplex subprotocols, and delegates to the correct implementation
* after protocol negotiation.
*/
public class GrpcWebsocket extends Endpoint {
private final Map<String, Supplier<Endpoint>> endpointFactories = new HashMap<>();
private Endpoint endpoint;

public GrpcWebsocket(Map<String, Supplier<Endpoint>> endpoints) {
endpointFactories.putAll(endpoints);
}

public void onOpen(Session session, EndpointConfig endpointConfig) {
Supplier<Endpoint> supplier = endpointFactories.get(session.getNegotiatedSubprotocol());
if (supplier == null) {
try {
session.close(new CloseReason(CloseReason.CloseCodes.PROTOCOL_ERROR, "Unsupported subprotocol"));
} catch (IOException e) {
throw new UncheckedIOException(e);
}
return;
}

endpoint = supplier.get();
endpoint.onOpen(session, endpointConfig);
}

@Override
public void onClose(Session session, CloseReason closeReason) {
endpoint.onClose(session, closeReason);
}

@Override
public void onError(Session session, Throwable thr) {
endpoint.onError(session, thr);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,12 @@
import io.grpc.internal.ServerTransportListener;
import io.grpc.internal.StatsTraceContext;
import jakarta.websocket.CloseReason;
import jakarta.websocket.Endpoint;
import jakarta.websocket.EndpointConfig;
import jakarta.websocket.OnError;
import jakarta.websocket.OnMessage;
import jakarta.websocket.OnOpen;
import jakarta.websocket.Session;
import jakarta.websocket.server.ServerEndpoint;

import java.io.IOException;
import java.io.UncheckedIOException;
import java.nio.ByteBuffer;
import java.nio.channels.ClosedChannelException;
import java.nio.charset.StandardCharsets;
Expand All @@ -29,7 +27,6 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.logging.Level;
import java.util.logging.Logger;

Expand All @@ -46,8 +43,7 @@
* JSR356 websockets always handle their incoming messages in a serial manner, so we don't need to worry here about
* runOnTransportThread while in onMessage, as we're already in the transport thread.
*/
@ServerEndpoint(value = "/grpc-websocket", subprotocols = "grpc-websockets-multiplex")
public class MultiplexedWebSocketServerStream {
public class MultiplexedWebSocketServerStream extends Endpoint {
private static final Logger logger = Logger.getLogger(MultiplexedWebSocketServerStream.class.getName());
public static final Metadata.Key<String> PATH =
Metadata.Key.of("grpc-websockets-path", Metadata.ASCII_STRING_MARSHALLER);
Expand Down Expand Up @@ -76,17 +72,25 @@ public MultiplexedWebSocketServerStream(ServerTransportListener transportListene
this.attributes = attributes;
}

@OnOpen
@Override
public void onOpen(Session websocketSession, EndpointConfig config) {
this.websocketSession = websocketSession;

websocketSession.addMessageHandler(String.class, this::onMessage);
websocketSession.addMessageHandler(ByteBuffer.class, message -> {
try {
onMessage(message);
} catch (IOException e) {
throw new UncheckedIOException(e);
}
});

// Configure defaults present in some servlet containers to avoid some confusing limits. Subclasses
// can override this method to control those defaults on their own.
websocketSession.setMaxIdleTimeout(0);
websocketSession.setMaxBinaryMessageBufferSize(Integer.MAX_VALUE);
}

@OnMessage
public void onMessage(String message) {
for (MultiplexedWebsocketStreamImpl stream : streams.values()) {
// This means the stream opened correctly, then sent a text payload, which doesn't make sense.
Expand All @@ -102,7 +106,6 @@ public void onMessage(String message) {
}
}

@OnMessage
public void onMessage(ByteBuffer message) throws IOException {
// Each message starts with an int, to indicate stream id. If that int is negative, the other end has performed
// a half close (and this is the final message).
Expand Down Expand Up @@ -160,8 +163,8 @@ public void onMessage(ByteBuffer message) throws IOException {
stream.inboundDataReceived(ReadableBuffers.wrap(message), false);
}

@OnError
public void onError(Throwable error) {
@Override
public void onError(Session session, Throwable error) {
for (MultiplexedWebsocketStreamImpl stream : streams.values()) {
stream.transportReportStatus(Status.UNKNOWN);// transport failure of some kind
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,12 @@
import io.grpc.internal.ServerTransportListener;
import io.grpc.internal.StatsTraceContext;
import jakarta.websocket.CloseReason;
import jakarta.websocket.Endpoint;
import jakarta.websocket.EndpointConfig;
import jakarta.websocket.OnError;
import jakarta.websocket.OnMessage;
import jakarta.websocket.OnOpen;
import jakarta.websocket.Session;
import jakarta.websocket.server.ServerEndpoint;

import java.io.IOException;
import java.io.UncheckedIOException;
import java.nio.ByteBuffer;
import java.nio.channels.ClosedChannelException;
import java.nio.charset.StandardCharsets;
Expand All @@ -35,8 +33,7 @@
* JSR356 websockets always handle their incoming messages in a serial manner, so we don't need to worry here about
* runOnTransportThread while in onMessage, as we're already in the transport thread.
*/
@ServerEndpoint(value = "/{service}/{method}", subprotocols = "grpc-websockets")
public class WebSocketServerStream {
public class WebSocketServerStream extends Endpoint {
private static final Logger logger = Logger.getLogger(WebSocketServerStream.class.getName());

private final ServerTransportListener transportListener;
Expand All @@ -63,17 +60,25 @@ public WebSocketServerStream(ServerTransportListener transportListener,
this.attributes = attributes;
}

@OnOpen
@Override
public void onOpen(Session websocketSession, EndpointConfig config) {
this.websocketSession = websocketSession;

websocketSession.addMessageHandler(String.class, this::onMessage);
websocketSession.addMessageHandler(ByteBuffer.class, message -> {
try {
onMessage(message);
} catch (IOException e) {
throw new UncheckedIOException(e);
}
});

// Configure defaults present in some servlet containers to avoid some confusing limits. Subclasses
// can override this method to control those defaults on their own.
websocketSession.setMaxIdleTimeout(0);
websocketSession.setMaxBinaryMessageBufferSize(Integer.MAX_VALUE);
}

@OnMessage
public void onMessage(String message) {
if (stream != null) {
// This means the stream opened correctly, then sent a text payload, which doesn't make sense.
Expand All @@ -88,7 +93,6 @@ public void onMessage(String message) {
}
}

@OnMessage
public void onMessage(ByteBuffer message) throws IOException {
if (message.remaining() == 0) {
// message is empty (no control flow, no data), error
Expand Down Expand Up @@ -128,8 +132,8 @@ public void onMessage(ByteBuffer message) throws IOException {
stream.inboundDataReceived(ReadableBuffers.wrap(message), false);
}

@OnError
public void onError(Throwable error) {
@Override
public void onError(Session session, Throwable error) {
stream.transportReportStatus(Status.UNKNOWN);// transport failure of some kind
// onClose will be called automatically
if (error instanceof ClosedChannelException) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@
import io.deephaven.ssl.config.SSLConfig;
import io.deephaven.ssl.config.TrustJdk;
import io.deephaven.ssl.config.impl.KickstartUtils;
import io.grpc.servlet.web.websocket.GrpcWebsocket;
import io.grpc.servlet.web.websocket.MultiplexedWebSocketServerStream;
import io.grpc.servlet.web.websocket.WebSocketServerStream;
import jakarta.servlet.DispatcherType;
import jakarta.websocket.Endpoint;
import jakarta.websocket.server.ServerEndpointConfig;
import nl.altindag.ssl.SSLFactory;
import nl.altindag.ssl.util.JettySslUtils;
Expand All @@ -38,8 +40,13 @@
import java.io.IOException;
import java.io.UncheckedIOException;
import java.net.URL;
import java.util.Arrays;
import java.util.Collections;
import java.util.EnumSet;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.TimeUnit;
import java.util.function.Supplier;

import static org.eclipse.jetty.servlet.ServletContextHandler.SESSIONS;

Expand Down Expand Up @@ -82,29 +89,26 @@ public JettyBackedGrpcServer(
// Wire up the provided grpc filter
context.addFilter(new FilterHolder(filter), "/*", EnumSet.noneOf(DispatcherType.class));

// Set up websocket for grpc-web
// Set up websockets for grpc-web - we register both in case we encounter a client using "vanilla"
// grpc-websocket,
// that can't multiplex all streams on a single socket
if (config.websockets()) {
JakartaWebSocketServletContainerInitializer.configure(context, (servletContext, container) -> {
// container.addEndpoint(
// ServerEndpointConfig.Builder.create(WebSocketServerStream.class, "/{service}/{method}")
// .configurator(new ServerEndpointConfig.Configurator() {
// @Override
// public <T> T getEndpointInstance(Class<T> endpointClass)
// throws InstantiationException {
// return (T) filter.create(WebSocketServerStream::new);
// }
// })
// .build());
container.addEndpoint(
ServerEndpointConfig.Builder.create(MultiplexedWebSocketServerStream.class, "/grpc-websocket")
.configurator(new ServerEndpointConfig.Configurator() {
@Override
public <T> T getEndpointInstance(Class<T> endpointClass)
throws InstantiationException {
return (T) filter.create(MultiplexedWebSocketServerStream::new);
}
})
.build());
Map<String, Supplier<Endpoint>> endpoints = Map.of(
"grpc-websockets", () -> filter.create(WebSocketServerStream::new),
"grpc-websockets-multiplex", () -> filter.create(MultiplexedWebSocketServerStream::new));
container.addEndpoint(ServerEndpointConfig.Builder.create(GrpcWebsocket.class, "/{service}/{method}")
.configurator(new ServerEndpointConfig.Configurator() {
@Override
public <T> T getEndpointInstance(Class<T> endpointClass) throws InstantiationException {
// noinspection unchecked
return (T) new GrpcWebsocket(endpoints);
}
})
.subprotocols(Arrays.asList("grpc-websockets", "grpc-websockets-multiplex"))
.build()

);
});
}
jetty.setHandler(context);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,9 @@ public class WorkerConnection {
// TODO configurable, let us support this even when ssl?
if (DomGlobal.window.location.protocol.equals("http:")) {
useWebsockets = true;
Grpc.setDefaultTransport.onInvoke(MultiplexedWebsocketTransport::new);
Grpc.setDefaultTransport.onInvoke(options -> new MultiplexedWebsocketTransport(options, () -> {
Grpc.setDefaultTransport.onInvoke(Grpc.WebsocketTransport.onInvoke());
}));
} else {
useWebsockets = false;
}
Expand Down
Loading

0 comments on commit 1cc40d8

Please sign in to comment.