diff --git a/grpc-java/grpc-servlet-websocket-jakarta/src/main/java/io/grpc/servlet/web/websocket/GrpcWebsocket.java b/grpc-java/grpc-servlet-websocket-jakarta/src/main/java/io/grpc/servlet/web/websocket/GrpcWebsocket.java new file mode 100644 index 00000000000..ae747dc85ba --- /dev/null +++ b/grpc-java/grpc-servlet-websocket-jakarta/src/main/java/io/grpc/servlet/web/websocket/GrpcWebsocket.java @@ -0,0 +1,50 @@ +package io.grpc.servlet.web.websocket; + +import jakarta.websocket.CloseReason; +import jakarta.websocket.Endpoint; +import jakarta.websocket.EndpointConfig; +import jakarta.websocket.Session; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.HashMap; +import java.util.Map; +import java.util.function.Supplier; + +/** + * Supports both grpc-websockets and grpc-websockets-multiplex subprotocols, and delegates to the correct implementation + * after protocol negotiation. + */ +public class GrpcWebsocket extends Endpoint { + private final Map> endpointFactories = new HashMap<>(); + private Endpoint endpoint; + + public GrpcWebsocket(Map> endpoints) { + endpointFactories.putAll(endpoints); + } + + public void onOpen(Session session, EndpointConfig endpointConfig) { + Supplier supplier = endpointFactories.get(session.getNegotiatedSubprotocol()); + if (supplier == null) { + try { + session.close(new CloseReason(CloseReason.CloseCodes.PROTOCOL_ERROR, "Unsupported subprotocol")); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + return; + } + + endpoint = supplier.get(); + endpoint.onOpen(session, endpointConfig); + } + + @Override + public void onClose(Session session, CloseReason closeReason) { + endpoint.onClose(session, closeReason); + } + + @Override + public void onError(Session session, Throwable thr) { + endpoint.onError(session, thr); + } +} diff --git a/grpc-java/grpc-servlet-websocket-jakarta/src/main/java/io/grpc/servlet/web/websocket/MultiplexedWebSocketServerStream.java b/grpc-java/grpc-servlet-websocket-jakarta/src/main/java/io/grpc/servlet/web/websocket/MultiplexedWebSocketServerStream.java index acced6c0c49..6cf71eeb8c0 100644 --- a/grpc-java/grpc-servlet-websocket-jakarta/src/main/java/io/grpc/servlet/web/websocket/MultiplexedWebSocketServerStream.java +++ b/grpc-java/grpc-servlet-websocket-jakarta/src/main/java/io/grpc/servlet/web/websocket/MultiplexedWebSocketServerStream.java @@ -13,14 +13,12 @@ import io.grpc.internal.ServerTransportListener; import io.grpc.internal.StatsTraceContext; import jakarta.websocket.CloseReason; +import jakarta.websocket.Endpoint; import jakarta.websocket.EndpointConfig; -import jakarta.websocket.OnError; -import jakarta.websocket.OnMessage; -import jakarta.websocket.OnOpen; import jakarta.websocket.Session; -import jakarta.websocket.server.ServerEndpoint; import java.io.IOException; +import java.io.UncheckedIOException; import java.nio.ByteBuffer; import java.nio.channels.ClosedChannelException; import java.nio.charset.StandardCharsets; @@ -29,7 +27,6 @@ import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.concurrent.ConcurrentHashMap; import java.util.logging.Level; import java.util.logging.Logger; @@ -46,8 +43,7 @@ * JSR356 websockets always handle their incoming messages in a serial manner, so we don't need to worry here about * runOnTransportThread while in onMessage, as we're already in the transport thread. */ -@ServerEndpoint(value = "/grpc-websocket", subprotocols = "grpc-websockets-multiplex") -public class MultiplexedWebSocketServerStream { +public class MultiplexedWebSocketServerStream extends Endpoint { private static final Logger logger = Logger.getLogger(MultiplexedWebSocketServerStream.class.getName()); public static final Metadata.Key PATH = Metadata.Key.of("grpc-websockets-path", Metadata.ASCII_STRING_MARSHALLER); @@ -76,17 +72,25 @@ public MultiplexedWebSocketServerStream(ServerTransportListener transportListene this.attributes = attributes; } - @OnOpen + @Override public void onOpen(Session websocketSession, EndpointConfig config) { this.websocketSession = websocketSession; + websocketSession.addMessageHandler(String.class, this::onMessage); + websocketSession.addMessageHandler(ByteBuffer.class, message -> { + try { + onMessage(message); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + }); + // Configure defaults present in some servlet containers to avoid some confusing limits. Subclasses // can override this method to control those defaults on their own. websocketSession.setMaxIdleTimeout(0); websocketSession.setMaxBinaryMessageBufferSize(Integer.MAX_VALUE); } - @OnMessage public void onMessage(String message) { for (MultiplexedWebsocketStreamImpl stream : streams.values()) { // This means the stream opened correctly, then sent a text payload, which doesn't make sense. @@ -102,7 +106,6 @@ public void onMessage(String message) { } } - @OnMessage public void onMessage(ByteBuffer message) throws IOException { // Each message starts with an int, to indicate stream id. If that int is negative, the other end has performed // a half close (and this is the final message). @@ -160,8 +163,8 @@ public void onMessage(ByteBuffer message) throws IOException { stream.inboundDataReceived(ReadableBuffers.wrap(message), false); } - @OnError - public void onError(Throwable error) { + @Override + public void onError(Session session, Throwable error) { for (MultiplexedWebsocketStreamImpl stream : streams.values()) { stream.transportReportStatus(Status.UNKNOWN);// transport failure of some kind } diff --git a/grpc-java/grpc-servlet-websocket-jakarta/src/main/java/io/grpc/servlet/web/websocket/WebSocketServerStream.java b/grpc-java/grpc-servlet-websocket-jakarta/src/main/java/io/grpc/servlet/web/websocket/WebSocketServerStream.java index c8adc2b53f7..c5c6c09e641 100644 --- a/grpc-java/grpc-servlet-websocket-jakarta/src/main/java/io/grpc/servlet/web/websocket/WebSocketServerStream.java +++ b/grpc-java/grpc-servlet-websocket-jakarta/src/main/java/io/grpc/servlet/web/websocket/WebSocketServerStream.java @@ -10,14 +10,12 @@ import io.grpc.internal.ServerTransportListener; import io.grpc.internal.StatsTraceContext; import jakarta.websocket.CloseReason; +import jakarta.websocket.Endpoint; import jakarta.websocket.EndpointConfig; -import jakarta.websocket.OnError; -import jakarta.websocket.OnMessage; -import jakarta.websocket.OnOpen; import jakarta.websocket.Session; -import jakarta.websocket.server.ServerEndpoint; import java.io.IOException; +import java.io.UncheckedIOException; import java.nio.ByteBuffer; import java.nio.channels.ClosedChannelException; import java.nio.charset.StandardCharsets; @@ -35,8 +33,7 @@ * JSR356 websockets always handle their incoming messages in a serial manner, so we don't need to worry here about * runOnTransportThread while in onMessage, as we're already in the transport thread. */ -@ServerEndpoint(value = "/{service}/{method}", subprotocols = "grpc-websockets") -public class WebSocketServerStream { +public class WebSocketServerStream extends Endpoint { private static final Logger logger = Logger.getLogger(WebSocketServerStream.class.getName()); private final ServerTransportListener transportListener; @@ -63,17 +60,25 @@ public WebSocketServerStream(ServerTransportListener transportListener, this.attributes = attributes; } - @OnOpen + @Override public void onOpen(Session websocketSession, EndpointConfig config) { this.websocketSession = websocketSession; + websocketSession.addMessageHandler(String.class, this::onMessage); + websocketSession.addMessageHandler(ByteBuffer.class, message -> { + try { + onMessage(message); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + }); + // Configure defaults present in some servlet containers to avoid some confusing limits. Subclasses // can override this method to control those defaults on their own. websocketSession.setMaxIdleTimeout(0); websocketSession.setMaxBinaryMessageBufferSize(Integer.MAX_VALUE); } - @OnMessage public void onMessage(String message) { if (stream != null) { // This means the stream opened correctly, then sent a text payload, which doesn't make sense. @@ -88,7 +93,6 @@ public void onMessage(String message) { } } - @OnMessage public void onMessage(ByteBuffer message) throws IOException { if (message.remaining() == 0) { // message is empty (no control flow, no data), error @@ -128,8 +132,8 @@ public void onMessage(ByteBuffer message) throws IOException { stream.inboundDataReceived(ReadableBuffers.wrap(message), false); } - @OnError - public void onError(Throwable error) { + @Override + public void onError(Session session, Throwable error) { stream.transportReportStatus(Status.UNKNOWN);// transport failure of some kind // onClose will be called automatically if (error instanceof ClosedChannelException) { diff --git a/server/jetty/src/main/java/io/deephaven/server/jetty/JettyBackedGrpcServer.java b/server/jetty/src/main/java/io/deephaven/server/jetty/JettyBackedGrpcServer.java index bac45ec55be..7a0c3a74b46 100644 --- a/server/jetty/src/main/java/io/deephaven/server/jetty/JettyBackedGrpcServer.java +++ b/server/jetty/src/main/java/io/deephaven/server/jetty/JettyBackedGrpcServer.java @@ -10,9 +10,11 @@ import io.deephaven.ssl.config.SSLConfig; import io.deephaven.ssl.config.TrustJdk; import io.deephaven.ssl.config.impl.KickstartUtils; +import io.grpc.servlet.web.websocket.GrpcWebsocket; import io.grpc.servlet.web.websocket.MultiplexedWebSocketServerStream; import io.grpc.servlet.web.websocket.WebSocketServerStream; import jakarta.servlet.DispatcherType; +import jakarta.websocket.Endpoint; import jakarta.websocket.server.ServerEndpointConfig; import nl.altindag.ssl.SSLFactory; import nl.altindag.ssl.util.JettySslUtils; @@ -38,8 +40,13 @@ import java.io.IOException; import java.io.UncheckedIOException; import java.net.URL; +import java.util.Arrays; +import java.util.Collections; import java.util.EnumSet; +import java.util.HashMap; +import java.util.Map; import java.util.concurrent.TimeUnit; +import java.util.function.Supplier; import static org.eclipse.jetty.servlet.ServletContextHandler.SESSIONS; @@ -82,29 +89,26 @@ public JettyBackedGrpcServer( // Wire up the provided grpc filter context.addFilter(new FilterHolder(filter), "/*", EnumSet.noneOf(DispatcherType.class)); - // Set up websocket for grpc-web + // Set up websockets for grpc-web - we register both in case we encounter a client using "vanilla" + // grpc-websocket, + // that can't multiplex all streams on a single socket if (config.websockets()) { JakartaWebSocketServletContainerInitializer.configure(context, (servletContext, container) -> { - // container.addEndpoint( - // ServerEndpointConfig.Builder.create(WebSocketServerStream.class, "/{service}/{method}") - // .configurator(new ServerEndpointConfig.Configurator() { - // @Override - // public T getEndpointInstance(Class endpointClass) - // throws InstantiationException { - // return (T) filter.create(WebSocketServerStream::new); - // } - // }) - // .build()); - container.addEndpoint( - ServerEndpointConfig.Builder.create(MultiplexedWebSocketServerStream.class, "/grpc-websocket") - .configurator(new ServerEndpointConfig.Configurator() { - @Override - public T getEndpointInstance(Class endpointClass) - throws InstantiationException { - return (T) filter.create(MultiplexedWebSocketServerStream::new); - } - }) - .build()); + Map> endpoints = Map.of( + "grpc-websockets", () -> filter.create(WebSocketServerStream::new), + "grpc-websockets-multiplex", () -> filter.create(MultiplexedWebSocketServerStream::new)); + container.addEndpoint(ServerEndpointConfig.Builder.create(GrpcWebsocket.class, "/{service}/{method}") + .configurator(new ServerEndpointConfig.Configurator() { + @Override + public T getEndpointInstance(Class endpointClass) throws InstantiationException { + // noinspection unchecked + return (T) new GrpcWebsocket(endpoints); + } + }) + .subprotocols(Arrays.asList("grpc-websockets", "grpc-websockets-multiplex")) + .build() + + ); }); } jetty.setHandler(context); diff --git a/web/client-api/src/main/java/io/deephaven/web/client/api/WorkerConnection.java b/web/client-api/src/main/java/io/deephaven/web/client/api/WorkerConnection.java index 009b46a29e0..17128f0d22e 100644 --- a/web/client-api/src/main/java/io/deephaven/web/client/api/WorkerConnection.java +++ b/web/client-api/src/main/java/io/deephaven/web/client/api/WorkerConnection.java @@ -132,7 +132,9 @@ public class WorkerConnection { // TODO configurable, let us support this even when ssl? if (DomGlobal.window.location.protocol.equals("http:")) { useWebsockets = true; - Grpc.setDefaultTransport.onInvoke(MultiplexedWebsocketTransport::new); + Grpc.setDefaultTransport.onInvoke(options -> new MultiplexedWebsocketTransport(options, () -> { + Grpc.setDefaultTransport.onInvoke(Grpc.WebsocketTransport.onInvoke()); + })); } else { useWebsockets = false; } diff --git a/web/client-api/src/main/java/io/deephaven/web/client/api/grpc/MultiplexedWebsocketTransport.java b/web/client-api/src/main/java/io/deephaven/web/client/api/grpc/MultiplexedWebsocketTransport.java index 36a4b2961e4..cff5f72826d 100644 --- a/web/client-api/src/main/java/io/deephaven/web/client/api/grpc/MultiplexedWebsocketTransport.java +++ b/web/client-api/src/main/java/io/deephaven/web/client/api/grpc/MultiplexedWebsocketTransport.java @@ -5,24 +5,37 @@ import elemental2.core.ArrayBuffer; import elemental2.core.DataView; -import elemental2.core.Int32Array; import elemental2.core.Int8Array; +import elemental2.core.JsError; import elemental2.core.Uint8Array; +import elemental2.dom.CloseEvent; import elemental2.dom.Event; import elemental2.dom.MessageEvent; import elemental2.dom.URL; import elemental2.dom.WebSocket; import io.deephaven.javascript.proto.dhinternal.browserheaders.BrowserHeaders; +import io.deephaven.javascript.proto.dhinternal.grpcweb.Grpc; import io.deephaven.javascript.proto.dhinternal.grpcweb.transports.transport.Transport; import io.deephaven.javascript.proto.dhinternal.grpcweb.transports.transport.TransportOptions; +import io.deephaven.web.client.api.JsLazy; +import io.deephaven.web.shared.fu.JsRunnable; +import jsinterop.base.Js; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; +/** + * Custom replacement for grpc-websockets transport that handles multiple grpc streams in a single websocket. All else + * equal, this transport should be preferred to the default grpc-websockets transport, and in turn the fetch based + * transport is usually superior to this. + */ public class MultiplexedWebsocketTransport implements Transport { + public static final String MULTIPLEX_PROTOCOL = "grpc-websockets-multiplex"; + public static final String SOCKET_PER_STREAM_PROTOCOL = "grpc-websockets"; + private static Uint8Array encodeASCII(String str) { Uint8Array encoded = new Uint8Array(str.length()); for (int i = 0; i < str.length(); i++) { @@ -35,27 +48,40 @@ private static Uint8Array encodeASCII(String str) { private interface QueuedEntry { void send(WebSocket webSocket, int streamId); + + void sendFallback(Transport transport); } + public static class HeaderFrame implements QueuedEntry { - private final Uint8Array headerBytes; + private final String path; + private final BrowserHeaders metadata; public HeaderFrame(String path, BrowserHeaders metadata) { - StringBuilder str = new StringBuilder(); + this.path = path; + this.metadata = metadata; + } + + @Override + public void send(WebSocket webSocket, int streamId) { + final Uint8Array headerBytes; + final StringBuilder str = new StringBuilder(); metadata.append("grpc-websockets-path", path); metadata.forEach((key, value) -> { str.append(key).append(": ").append(value.join(", ")).append("\r\n"); }); headerBytes = encodeASCII(str.toString()); - } - - @Override - public void send(WebSocket webSocket, int streamId) { Int8Array payload = new Int8Array(headerBytes.byteLength + 4); new DataView(payload.buffer).setInt32(0, streamId); payload.set(headerBytes, 4); webSocket.send(payload); } + + @Override + public void sendFallback(Transport transport) { + transport.start(metadata); + } } + private static class GrpcMessageFrame implements QueuedEntry { private final Uint8Array msgBytes; @@ -72,9 +98,13 @@ public void send(WebSocket webSocket, int streamId) { webSocket.send(payload); } + @Override + public void sendFallback(Transport transport) { + transport.sendMessage(msgBytes); + } } - private static class WebsocketFinishSignal implements QueuedEntry { + private static class WebsocketFinishSignal implements QueuedEntry { @Override public void send(WebSocket webSocket, int streamId) { Uint8Array data = new Uint8Array(new double[] {0, 0, 0, 0, 1}); @@ -82,6 +112,11 @@ public void send(WebSocket webSocket, int streamId) { new DataView(data.buffer).setInt32(0, streamId); webSocket.send(data); } + + @Override + public void sendFallback(Transport transport) { + transport.finishSend(); + } } private static int nextStreamId = 0; @@ -93,7 +128,9 @@ public void send(WebSocket webSocket, int streamId) { private final TransportOptions options; private final String path; - public MultiplexedWebsocketTransport(TransportOptions options) { + private final JsLazy alternativeTransport; + + public MultiplexedWebsocketTransport(TransportOptions options, JsRunnable avoidMultiplexCallback) { this.options = options; String url = options.getUrl(); URL urlWrapper = new URL(url); @@ -102,29 +139,34 @@ public MultiplexedWebsocketTransport(TransportOptions options) { } else { urlWrapper.protocol = "wss:"; } + // preserve the path to send as metadata, but still talk to the server with that path path = urlWrapper.pathname.substring(1); - urlWrapper.pathname = "/grpc-websocket"; - url = urlWrapper.toString(); + String actualUrl = urlWrapper.toString(); + urlWrapper.pathname = "/"; + String key = urlWrapper.toString(); - webSocket = activeSockets.computeIfAbsent(url, u -> { - WebSocket ws = new WebSocket(u, "grpc-websockets-multiplex"); + // note that we connect to the actual url so the server can inform us via subprotocols that it isn't supported, + // but the global map removes the path as the key for each websocket + webSocket = activeSockets.computeIfAbsent(key, ignore -> { + WebSocket ws = new WebSocket(actualUrl, new String[] {MULTIPLEX_PROTOCOL, SOCKET_PER_STREAM_PROTOCOL}); ws.binaryType = "arraybuffer"; return ws; }); - } - @Override - public void sendMessage(Uint8Array msgBytes) { - sendOrEnqueue(new GrpcMessageFrame(msgBytes)); - } - - @Override - public void finishSend() { - sendOrEnqueue(new WebsocketFinishSignal()); + // prepare a fallback + alternativeTransport = new JsLazy<>(() -> { + avoidMultiplexCallback.run(); + return Grpc.WebsocketTransport.onInvoke().onInvoke(options); + }); } @Override public void start(BrowserHeaders metadata) { + if (alternativeTransport.isAvailable()) { + alternativeTransport.get().start(metadata); + return; + } + if (webSocket.readyState == WebSocket.CONNECTING) { // if the socket isn't open already, wait until the socket is // open, then flush the queue, otherwise everything will be @@ -138,9 +180,69 @@ public void start(BrowserHeaders metadata) { webSocket.addEventListener("message", this::onMessage); } + private void onOpen(Event event) { + Object protocol = Js.asPropertyMap(webSocket).get("protocol"); + if (protocol.equals(SOCKET_PER_STREAM_PROTOCOL)) { + // delegate to plain websocket impl, try to dissuade future users of this server + Transport transport = alternativeTransport.get(); + + // close our own websocket + webSocket.close(); + + // flush the queued items, which are now the new transport's problems - we'll forward all future work there + // as well automatically + for (int i = 0; i < sendQueue.size(); i++) { + sendQueue.get(i).sendFallback(transport); + } + sendQueue.clear(); + return; + } else if (!protocol.equals(MULTIPLEX_PROTOCOL)) { + // give up, no way to handle this + // TODO throw so the user can see this + return; + } + for (int i = 0; i < sendQueue.size(); i++) { + sendQueue.get(i).send(webSocket, streamId); + } + sendQueue.clear(); + } + + @Override + public void sendMessage(Uint8Array msgBytes) { + if (alternativeTransport.isAvailable()) { + alternativeTransport.get().sendMessage(msgBytes); + return; + } + + sendOrEnqueue(new GrpcMessageFrame(msgBytes)); + } + + @Override + public void finishSend() { + if (alternativeTransport.isAvailable()) { + alternativeTransport.get().finishSend(); + return; + } + + sendOrEnqueue(new WebsocketFinishSignal()); + } + + @Override + public void cancel() { + if (alternativeTransport.isAvailable()) { + alternativeTransport.get().cancel(); + return; + } + // TODO remove handlers, and close if we're the last one out + } private void onClose(Event event) { - options.getOnEnd().onInvoke(null); + if (alternativeTransport.isAvailable()) { + // must be downgrading to fallback + return; + } + // each grpc transport will handle this as an error + options.getOnEnd().onInvoke(new JsError("Unexpectedly closed " + ((CloseEvent) event).reason)); } private void onError(Event event) { @@ -173,17 +275,4 @@ private void sendOrEnqueue(QueuedEntry e) { e.send(webSocket, streamId); } } - - private void onOpen(Event event) { - for (int i = 0; i < sendQueue.size(); i++) { - QueuedEntry queuedEntry = sendQueue.get(i); - queuedEntry.send(webSocket, streamId); - } - sendQueue.clear(); - } - - @Override - public void cancel() { - // TODO remove handlers, and close if we're the last one out - } }