Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Share code with the existing implementation
Browse files Browse the repository at this point in the history
niloc132 committed Jul 11, 2022
1 parent 1cc40d8 commit ba25c9b
Showing 7 changed files with 310 additions and 488 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
package io.grpc.servlet.web.websocket;

import io.grpc.Attributes;
import io.grpc.InternalMetadata;
import io.grpc.Metadata;
import io.grpc.ServerStreamTracer;
import io.grpc.internal.ServerTransportListener;
import jakarta.websocket.Endpoint;
import jakarta.websocket.EndpointConfig;
import jakarta.websocket.Session;

import java.io.IOException;
import java.io.UncheckedIOException;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

public abstract class AbstractWebSocketServerStream extends Endpoint {
protected final ServerTransportListener transportListener;
protected final List<? extends ServerStreamTracer.Factory> streamTracerFactories;
protected final int maxInboundMessageSize;
protected final Attributes attributes;

// assigned on open, always available
protected Session websocketSession;

protected AbstractWebSocketServerStream(ServerTransportListener transportListener,
List<? extends ServerStreamTracer.Factory> streamTracerFactories, int maxInboundMessageSize,
Attributes attributes) {
this.transportListener = transportListener;
this.streamTracerFactories = streamTracerFactories;
this.maxInboundMessageSize = maxInboundMessageSize;
this.attributes = attributes;
}

protected static Metadata readHeaders(ByteBuffer headerPayload) {
// Headers are passed as ascii (browsers don't support binary), ":"-separated key/value pairs, separated on
// "\r\n". The client implementation shows that values might be comma-separated, but we'll pass that through
// directly as a plain string.
//
List<byte[]> byteArrays = new ArrayList<>();
while (headerPayload.hasRemaining()) {
int nameStart = headerPayload.position();
while (headerPayload.hasRemaining() && headerPayload.get() != ':');
int nameEnd = headerPayload.position() - 1;
int valueStart = headerPayload.position() + 1;// assumes that the colon is followed by a space

while (headerPayload.hasRemaining() && headerPayload.get() != '\n');
int valueEnd = headerPayload.position() - 2;// assumes that \n is preceded by a \r, this isnt generally
// safe?
if (valueEnd < valueStart) {
valueEnd = valueStart;
}
int endOfLinePosition = headerPayload.position();

byte[] headerBytes = new byte[nameEnd - nameStart];
headerPayload.position(nameStart);
headerPayload.get(headerBytes);

byteArrays.add(headerBytes);
if (Arrays.equals(headerBytes, "content-type".getBytes(StandardCharsets.US_ASCII))) {
// rewrite grpc-web content type to matching grpc content type
byteArrays.add("grpc+proto".getBytes(StandardCharsets.US_ASCII));
// TODO support other formats like text, non-proto
headerPayload.position(valueEnd);
continue;
}

// TODO check for binary header suffix
// if (headerBytes.endsWith(Metadata.BINARY_HEADER_SUFFIX)) {
//
// } else {
byte[] valueBytes = new byte[valueEnd - valueStart];
headerPayload.position(valueStart);
headerPayload.get(valueBytes);
byteArrays.add(valueBytes);
// }

headerPayload.position(endOfLinePosition);
}

// add a te:trailers, as gRPC will expect it
byteArrays.add("te".getBytes(StandardCharsets.US_ASCII));
byteArrays.add("trailers".getBytes(StandardCharsets.US_ASCII));

// TODO to support text encoding

return InternalMetadata.newMetadata(byteArrays.toArray(new byte[][] {}));
}

@Override
public void onOpen(Session websocketSession, EndpointConfig config) {
this.websocketSession = websocketSession;

websocketSession.addMessageHandler(String.class, this::onMessage);
websocketSession.addMessageHandler(ByteBuffer.class, message -> {
try {
onMessage(message);
} catch (IOException e) {
throw new UncheckedIOException(e);
}
});

// Configure defaults present in some servlet containers to avoid some confusing limits. Subclasses
// can override this method to control those defaults on their own.
websocketSession.setMaxIdleTimeout(0);
websocketSession.setMaxBinaryMessageBufferSize(Integer.MAX_VALUE);
}

public abstract void onMessage(String message);

public abstract void onMessage(ByteBuffer message) throws IOException;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
package io.grpc.servlet.web.websocket;

import com.google.common.util.concurrent.MoreExecutors;
import io.grpc.Attributes;
import io.grpc.InternalLogId;
import io.grpc.Metadata;
import io.grpc.Status;
import io.grpc.internal.AbstractServerStream;
import io.grpc.internal.ReadableBuffer;
import io.grpc.internal.SerializingExecutor;
import io.grpc.internal.ServerTransportListener;
import io.grpc.internal.StatsTraceContext;
import io.grpc.internal.TransportTracer;
import io.grpc.internal.WritableBufferAllocator;
import jakarta.websocket.Session;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.logging.Level;
import java.util.logging.Logger;

public abstract class AbstractWebsocketStreamImpl extends AbstractServerStream {
public final class WebsocketTransportState extends TransportState {

private final SerializingExecutor transportThreadExecutor =
new SerializingExecutor(MoreExecutors.directExecutor());
private final Logger logger;

private WebsocketTransportState(int maxMessageSize, StatsTraceContext statsTraceCtx,
TransportTracer transportTracer, Logger logger) {
super(maxMessageSize, statsTraceCtx, transportTracer);
this.logger = logger;
}

@Override
public void runOnTransportThread(Runnable r) {
transportThreadExecutor.execute(r);
}

@Override
public void bytesRead(int numBytes) {
// no-op, no flow-control yet
}

@Override
public void deframeFailed(Throwable cause) {
if (logger.isLoggable(Level.FINE)) {
logger.log(Level.FINE, String.format("[{%s}] Exception processing message", logId), cause);
}
cancel(Status.fromThrowable(cause));
}
}

protected final TransportState transportState;
protected final Session websocketSession;
protected final InternalLogId logId;
protected final Attributes attributes;

public AbstractWebsocketStreamImpl(WritableBufferAllocator bufferAllocator, StatsTraceContext statsTraceCtx,
int maxInboundMessageSize, Session websocketSession, InternalLogId logId, Attributes attributes,
Logger logger) {
super(bufferAllocator, statsTraceCtx);
transportState =
new WebsocketTransportState(maxInboundMessageSize, statsTraceCtx, new TransportTracer(), logger);
this.websocketSession = websocketSession;
this.logId = logId;
this.attributes = attributes;
}

protected static void writeAsciiHeadersToMessage(byte[][] serializedHeaders, ByteBuffer message) {
for (int i = 0; i < serializedHeaders.length; i += 2) {
message.put(serializedHeaders[i]);
message.put((byte) ':');
message.put((byte) ' ');
message.put(serializedHeaders[i + 1]);
message.put((byte) '\r');
message.put((byte) '\n');
}
}

@Override
public int streamId() {
return -1;
}

@Override
public Attributes getAttributes() {
return attributes;
}

public void createStream(ServerTransportListener transportListener, String methodName, Metadata headers) {
transportListener.streamCreated(this, methodName, headers);
transportState().onStreamAllocated();
}

public void inboundDataReceived(ReadableBuffer message, boolean endOfStream) {
transportState().inboundDataReceived(message, endOfStream);
}

public void transportReportStatus(Status status) {
transportState().transportReportStatus(status);
}

@Override
public TransportState transportState() {
return transportState;
}

protected void cancelSink(Status status) {
if (!websocketSession.isOpen() && Status.Code.DEADLINE_EXCEEDED == status.getCode()) {
return;
}
transportState.runOnTransportThread(() -> transportState.transportReportStatus(status));
// There is no way to RST_STREAM with CANCEL code, so write trailers instead
close(Status.CANCELLED.withCause(status.asRuntimeException()), new Metadata());
CountDownLatch countDownLatch = new CountDownLatch(1);
transportState.runOnTransportThread(() -> {
try {
websocketSession.close();
} catch (IOException ioException) {
// already closing, ignore
}
countDownLatch.countDown();
});
try {
countDownLatch.await(5, TimeUnit.SECONDS);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
package io.grpc.servlet.web.websocket;

import io.grpc.internal.WritableBuffer;

import static java.lang.Math.max;
import static java.lang.Math.min;

final class ByteArrayWritableBuffer implements WritableBuffer {

private final int capacity;
final byte[] bytes;
private int index;

ByteArrayWritableBuffer(int capacityHint) {
this.bytes = new byte[min(1024 * 1024, max(4096, capacityHint))];
this.capacity = bytes.length;
}

@Override
public void write(byte[] src, int srcIndex, int length) {
System.arraycopy(src, srcIndex, bytes, index, length);
index += length;
}

@Override
public void write(byte b) {
bytes[index++] = b;
}

@Override
public int writableBytes() {
return capacity - index;
}

@Override
public int readableBytes() {
return index;
}

@Override
public void release() {}
}
Loading

0 comments on commit ba25c9b

Please sign in to comment.