diff --git a/server/src/main/java/io/deephaven/server/console/ConsoleServiceGrpcImpl.java b/server/src/main/java/io/deephaven/server/console/ConsoleServiceGrpcImpl.java index 25578fe057c..0c8709e0487 100644 --- a/server/src/main/java/io/deephaven/server/console/ConsoleServiceGrpcImpl.java +++ b/server/src/main/java/io/deephaven/server/console/ConsoleServiceGrpcImpl.java @@ -157,6 +157,8 @@ public void subscribeToLogs( GrpcUtil.safelyError(responseObserver, Code.FAILED_PRECONDITION, "Remote console disabled"); return; } + // Session close logic implicitly handled in + // io.deephaven.server.session.SessionServiceGrpcImpl.SessionServiceInterceptor final LogsClient client = new LogsClient(request, (ServerCallStreamObserver) responseObserver); client.start(); diff --git a/server/src/main/java/io/deephaven/server/object/ObjectServiceGrpcImpl.java b/server/src/main/java/io/deephaven/server/object/ObjectServiceGrpcImpl.java index 052de6b3475..f2901dee6ff 100644 --- a/server/src/main/java/io/deephaven/server/object/ObjectServiceGrpcImpl.java +++ b/server/src/main/java/io/deephaven/server/object/ObjectServiceGrpcImpl.java @@ -323,6 +323,8 @@ public void onCompleted() { @Override public StreamObserver messageStream(StreamObserver responseObserver) { SessionState session = sessionService.getCurrentSession(); + // Session close logic implicitly handled in + // io.deephaven.server.session.SessionServiceGrpcImpl.SessionServiceInterceptor return new SendMessageObserver(session, responseObserver); } diff --git a/server/src/main/java/io/deephaven/server/session/SessionServiceGrpcImpl.java b/server/src/main/java/io/deephaven/server/session/SessionServiceGrpcImpl.java index 2e64738457d..f13bb33e55d 100644 --- a/server/src/main/java/io/deephaven/server/session/SessionServiceGrpcImpl.java +++ b/server/src/main/java/io/deephaven/server/session/SessionServiceGrpcImpl.java @@ -15,9 +15,9 @@ import io.deephaven.internal.log.LoggerFactory; import io.deephaven.io.logger.Logger; import io.deephaven.proto.backplane.grpc.*; +import io.deephaven.proto.backplane.script.grpc.ConsoleServiceGrpc; import io.deephaven.proto.util.Exceptions; import io.deephaven.util.SafeCloseable; -import io.deephaven.util.function.ThrowingRunnable; import io.grpc.Context; import io.grpc.ForwardingServerCall.SimpleForwardingServerCall; import io.grpc.ForwardingServerCallListener; @@ -36,9 +36,11 @@ import javax.inject.Inject; import javax.inject.Singleton; +import java.io.Closeable; import java.lang.Object; import java.util.LinkedHashMap; import java.util.Map; +import java.util.Set; import java.util.UUID; public class SessionServiceGrpcImpl extends SessionServiceGrpc.SessionServiceImplBase { @@ -310,10 +312,19 @@ private void addHeaders(final Metadata md) { @Singleton public static class SessionServiceInterceptor implements ServerInterceptor { + private static final Status AUTHENTICATION_DETAILS_INVALID = + Status.UNAUTHENTICATED.withDescription("Authentication details invalid"); + + // We can't use just io.grpc.MethodDescriptor (unless we chose provide and inject the named method descriptors), + // some of our methods are overridden from stock gRPC; for example, + // io.deephaven.server.object.ObjectServiceGrpcBinding.bindService. + // The goal should be to migrate all of the existing RPC Session close management logic to here if possible. + private static final Set CANCEL_RPC_ON_SESSION_CLOSE = Set.of( + ConsoleServiceGrpc.getSubscribeToLogsMethod().getFullMethodName(), + ObjectServiceGrpc.getMessageStreamMethod().getFullMethodName()); + private final SessionService service; private final SessionService.ErrorTransformer errorTransformer; - private static final Status authenticationDetailsInvalid = - Status.UNAUTHENTICATED.withDescription("Authentication details invalid"); @Inject public SessionServiceInterceptor( @@ -344,12 +355,8 @@ public ServerCall.Listener interceptCall(final ServerCall() {}; } } @@ -363,33 +370,61 @@ public ServerCall.Listener interceptCall(final ServerCall> listener = new MutableObject<>(); rpcWrapper(serverCall, context, finalSession, errorTransformer, () -> listener.setValue( - new SessionServiceCallListener<>(serverCallHandler.startCall(serverCall, metadata), serverCall, - context, finalSession, errorTransformer))); + listener(serverCall, metadata, serverCallHandler, context, finalSession))); if (listener.getValue() == null) { return new ServerCall.Listener<>() {}; } return listener.getValue(); } + + private @NotNull SessionServiceCallListener listener( + InterceptedCall serverCall, + Metadata metadata, + ServerCallHandler serverCallHandler, + Context context, + SessionState session) { + return new SessionServiceCallListener<>( + serverCallHandler.startCall(serverCall, metadata), + serverCall, + context, + session, + errorTransformer, + CANCEL_RPC_ON_SESSION_CLOSE.contains(serverCall.getMethodDescriptor().getFullMethodName())); + } } private static class SessionServiceCallListener extends - ForwardingServerCallListener.SimpleForwardingServerCallListener { + ForwardingServerCallListener.SimpleForwardingServerCallListener implements Closeable { + private static final Status SESSION_CLOSED = Status.CANCELLED.withDescription("Session closed"); + private final ServerCall call; private final Context context; private final SessionState session; private final SessionService.ErrorTransformer errorTransformer; + private final boolean autoCancelOnSessionClose; - public SessionServiceCallListener( + SessionServiceCallListener( ServerCall.Listener delegate, ServerCall call, Context context, SessionState session, - SessionService.ErrorTransformer errorTransformer) { + SessionService.ErrorTransformer errorTransformer, + boolean autoCancelOnSessionClose) { super(delegate); this.call = call; this.context = context; this.session = session; this.errorTransformer = errorTransformer; + this.autoCancelOnSessionClose = autoCancelOnSessionClose; + if (autoCancelOnSessionClose && session != null) { + session.addOnCloseCallback(this); + } + } + + @Override + public void close() { + // session.addOnCloseCallback + safeClose(call, SESSION_CLOSED, new Metadata(), false); } @Override @@ -405,11 +440,17 @@ public void onHalfClose() { @Override public void onCancel() { rpcWrapper(call, context, session, errorTransformer, super::onCancel); + if (autoCancelOnSessionClose && session != null) { + session.removeOnCloseCallback(this); + } } @Override public void onComplete() { rpcWrapper(call, context, session, errorTransformer, super::onComplete); + if (autoCancelOnSessionClose && session != null) { + session.removeOnCloseCallback(this); + } } @Override @@ -432,34 +473,44 @@ private static void rpcWrapper( @NotNull final Context context, @Nullable final SessionState session, @NotNull final SessionService.ErrorTransformer errorTransformer, - @NotNull final ThrowingRunnable lambda) { + @NotNull final Runnable lambda) { Context previous = context.attach(); // note: we'll open the execution context here so that it may be used by the error transformer try (final SafeCloseable ignored1 = session == null ? null : session.getExecutionContext().open()) { try (final SafeCloseable ignored2 = LivenessScopeStack.open()) { lambda.run(); - } catch (final InterruptedException err) { - Thread.currentThread().interrupt(); - closeWithError(call, errorTransformer.transform(err)); - } catch (final Throwable err) { - closeWithError(call, errorTransformer.transform(err)); + } catch (final RuntimeException err) { + safeClose(call, errorTransformer.transform(err)); + } catch (final Error error) { + // Indicates a very serious failure; debateable whether we should even try to send close. + safeClose(call, Status.INTERNAL, new Metadata(), false); + throw error; } finally { context.detach(previous); } } } - private static void closeWithError( - @NotNull final ServerCall call, + private static void safeClose( + @NotNull final ServerCall call, @NotNull final StatusRuntimeException err) { + Metadata metadata = Status.trailersFromThrowable(err); + if (metadata == null) { + metadata = new Metadata(); + } + safeClose(call, Status.fromThrowable(err), metadata, true); + } + + private static void safeClose(ServerCall call, Status status, Metadata trailers, boolean logOnError) { try { - Metadata metadata = Status.trailersFromThrowable(err); - if (metadata == null) { - metadata = new Metadata(); + call.close(status, trailers); + } catch (IllegalStateException e) { + // IllegalStateException is explicitly documented as thrown if the call is already closed. It might be nice + // if there was a more explicit exception type, but this should suffice. We _could_ try and check the text + // "call already closed", but that is an undocumented implementation detail we should probably not rely on. + if (logOnError && log.isDebugEnabled()) { + log.debug().append("call.close error: ").append(e).endl(); } - call.close(Status.fromThrowable(err), metadata); - } catch (final Exception unexpectedErr) { - log.debug().append("Unanticipated gRPC Error: ").append(unexpectedErr).endl(); } } } diff --git a/server/src/test/java/io/deephaven/server/session/SessionServiceCloseTest.java b/server/src/test/java/io/deephaven/server/session/SessionServiceCloseTest.java new file mode 100644 index 00000000000..a910b73a147 --- /dev/null +++ b/server/src/test/java/io/deephaven/server/session/SessionServiceCloseTest.java @@ -0,0 +1,219 @@ +// +// Copyright (c) 2016-2024 Deephaven Data Labs and Patent Pending +// +package io.deephaven.server.session; + +import io.deephaven.proto.backplane.grpc.ExportNotification; +import io.deephaven.proto.backplane.grpc.ExportNotificationRequest; +import io.deephaven.proto.backplane.grpc.ExportedTableUpdateMessage; +import io.deephaven.proto.backplane.grpc.ExportedTableUpdatesRequest; +import io.deephaven.proto.backplane.grpc.FieldsChangeUpdate; +import io.deephaven.proto.backplane.grpc.ListFieldsRequest; +import io.deephaven.proto.backplane.grpc.StreamRequest; +import io.deephaven.proto.backplane.grpc.StreamResponse; +import io.deephaven.proto.backplane.grpc.TerminationNotificationRequest; +import io.deephaven.proto.backplane.grpc.TerminationNotificationResponse; +import io.deephaven.proto.backplane.script.grpc.AutoCompleteRequest; +import io.deephaven.proto.backplane.script.grpc.AutoCompleteResponse; +import io.deephaven.proto.backplane.script.grpc.LogSubscriptionData; +import io.deephaven.proto.backplane.script.grpc.LogSubscriptionRequest; +import io.deephaven.server.runner.DeephavenApiServerSingleAuthenticatedBase; +import io.deephaven.server.runner.RpcServerStateInterceptor.RpcServerState; +import io.grpc.ClientInterceptor; +import io.grpc.StatusRuntimeException; +import io.grpc.stub.ClientCallStreamObserver; +import io.grpc.stub.ClientResponseObserver; +import org.junit.Test; + +import java.time.Duration; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +import static org.assertj.core.api.Assertions.assertThat; + +public class SessionServiceCloseTest extends DeephavenApiServerSingleAuthenticatedBase { + + @Test + public void listFields() throws InterruptedException, TimeoutException { + new ListFieldsObserver().doTest(StatusRuntimeException.class, "CANCELLED: subscription cancelled"); + } + + @Test + public void exportNotifications() throws InterruptedException, TimeoutException { + new ExportNotificationsObserver().doTestToCompleted(); + } + + @Test + public void exportedTableUpdates() throws InterruptedException, TimeoutException { + new ExportedTableUpdatesObserver().doTestToCompleted(); + } + + @Test + public void messageStream() throws InterruptedException, TimeoutException { + new MessageStreamObserver().doTest(StatusRuntimeException.class, "CANCELLED: Session closed"); + } + + @Test + public void subscribeToLogs() throws InterruptedException, TimeoutException { + new SubscribeToLogsObserver().doTest(StatusRuntimeException.class, "CANCELLED: Session closed"); + } + + @Test + public void autoCompleteStream() throws InterruptedException, TimeoutException { + new AutoCompleteStreamObserver().doTestToCompleted(); + } + + @Test + public void terminationNotification() throws InterruptedException, TimeoutException { + new TerminationObserver().doTest(StatusRuntimeException.class, "UNAUTHENTICATED: Session has ended"); + } + + final class ListFieldsObserver extends CloseSessionObserverBase { + + @Override + void sendImpl(ClientInterceptor clientInterceptor) { + channel() + .application() + .withInterceptors(clientInterceptor) + .listFields(ListFieldsRequest.getDefaultInstance(), this); + } + } + + final class ExportNotificationsObserver + extends CloseSessionObserverBase { + + @Override + void sendImpl(ClientInterceptor clientInterceptor) { + channel() + .session() + .withInterceptors(clientInterceptor) + .exportNotifications(ExportNotificationRequest.getDefaultInstance(), this); + } + } + + final class ExportedTableUpdatesObserver + extends CloseSessionObserverBase { + + @Override + void sendImpl(ClientInterceptor clientInterceptor) { + channel() + .table() + .withInterceptors(clientInterceptor) + .exportedTableUpdates(ExportedTableUpdatesRequest.getDefaultInstance(), this); + } + } + + final class MessageStreamObserver extends CloseSessionObserverBase { + + @Override + void sendImpl(ClientInterceptor clientInterceptor) { + channel() + .object() + .withInterceptors(clientInterceptor) + .messageStream(this); + } + } + + final class SubscribeToLogsObserver extends CloseSessionObserverBase { + + @Override + void sendImpl(ClientInterceptor clientInterceptor) { + channel() + .console() + .withInterceptors(clientInterceptor) + .subscribeToLogs(LogSubscriptionRequest.getDefaultInstance(), this); + } + } + + final class AutoCompleteStreamObserver extends CloseSessionObserverBase { + + @Override + void sendImpl(ClientInterceptor clientInterceptor) { + channel() + .console() + .withInterceptors(clientInterceptor) + .autoCompleteStream(this); + } + } + + final class TerminationObserver + extends CloseSessionObserverBase { + + @Override + void sendImpl(ClientInterceptor clientInterceptor) { + channel() + .session() + .withInterceptors(clientInterceptor) + .terminationNotification(TerminationNotificationRequest.getDefaultInstance(), this); + } + } + + abstract class CloseSessionObserverBase implements ClientResponseObserver { + ClientCallStreamObserver observer; + Throwable t; + boolean onCompleted; + final CountDownLatch onDone = new CountDownLatch(1); + + public void doTest(Class exceptionType, String message) + throws InterruptedException, TimeoutException { + sendImplCloseSessionAndWait(); + assertError(exceptionType, message); + } + + public void doTestToCompleted() throws InterruptedException, TimeoutException { + sendImplCloseSessionAndWait(); + assertCompleted(); + } + + private void sendImplCloseSessionAndWait() throws InterruptedException, TimeoutException { + final RpcServerState serverState = serverStateInterceptor().newRpcServerState(); + sendImpl(serverState.clientInterceptor()); + serverState.awaitServerInvokeFinished(Duration.ofSeconds(3)); + closeSession(); + awaitOnDone(Duration.ofSeconds(3)); + } + + abstract void sendImpl(ClientInterceptor clientInterceptor) throws InterruptedException, TimeoutException; + + @Override + public final void beforeStart(ClientCallStreamObserver requestStream) { + this.observer = requestStream; + } + + @Override + public final void onNext(RespT value) { + // ignore + } + + @Override + public final void onError(Throwable t) { + this.t = t; + onDone.countDown(); + } + + @Override + public final void onCompleted() { + onCompleted = true; + onDone.countDown(); + } + + final void awaitOnDone(Duration duration) throws InterruptedException, TimeoutException { + if (!onDone.await(duration.toNanos(), TimeUnit.NANOSECONDS)) { + throw new TimeoutException(); + } + } + + final void assertCompleted() { + assertThat(onCompleted).isTrue(); + assertThat(t).isNull(); + } + + final void assertError(Class exceptionType, String message) { + assertThat(onCompleted).isFalse(); + assertThat(t).isNotNull(); + assertThat(t).isInstanceOf(exceptionType); + assertThat(t).hasMessage(message); + } + } +} diff --git a/server/test-utils/src/main/java/io/deephaven/server/runner/DeephavenApiServerSingleAuthenticatedBase.java b/server/test-utils/src/main/java/io/deephaven/server/runner/DeephavenApiServerSingleAuthenticatedBase.java index bca764b2d6d..35229d62889 100644 --- a/server/test-utils/src/main/java/io/deephaven/server/runner/DeephavenApiServerSingleAuthenticatedBase.java +++ b/server/test-utils/src/main/java/io/deephaven/server/runner/DeephavenApiServerSingleAuthenticatedBase.java @@ -6,6 +6,7 @@ import io.deephaven.UncheckedDeephavenException; import io.deephaven.auth.AuthenticationException; import io.deephaven.proto.DeephavenChannel; +import io.deephaven.proto.backplane.grpc.CloseSessionResponse; import io.deephaven.proto.backplane.grpc.HandshakeRequest; import io.deephaven.proto.backplane.grpc.HandshakeResponse; import io.deephaven.server.session.SessionState; @@ -49,4 +50,8 @@ public SessionState authenticatedSessionState() { public DeephavenChannel channel() { return channel; } + + public CloseSessionResponse closeSession() { + return channel.sessionBlocking().closeSession(HandshakeRequest.getDefaultInstance()); + } } diff --git a/server/test-utils/src/main/java/io/deephaven/server/runner/DeephavenApiServerTestBase.java b/server/test-utils/src/main/java/io/deephaven/server/runner/DeephavenApiServerTestBase.java index add05a92452..bc5af70f588 100644 --- a/server/test-utils/src/main/java/io/deephaven/server/runner/DeephavenApiServerTestBase.java +++ b/server/test-utils/src/main/java/io/deephaven/server/runner/DeephavenApiServerTestBase.java @@ -53,6 +53,7 @@ public abstract class DeephavenApiServerTestBase { LogModule.class, NoConsoleSessionModule.class, ServerBuilderInProcessModule.class, + RpcServerStateInterceptor.Module.class, ExecutionContextUnitTestModule.class, ClientDefaultsModule.class, ObfuscatingErrorTransformerModule.class, @@ -105,6 +106,9 @@ interface Builder { @Inject Provider> managedChannelBuilderProvider; + @Inject + RpcServerStateInterceptor serverStateInterceptor; + @Before public void setUp() throws Exception { logBuffer = new LogBuffer(128); @@ -179,6 +183,10 @@ public ExecutionContext getExecutionContext() { return executionContext; } + public RpcServerStateInterceptor serverStateInterceptor() { + return serverStateInterceptor; + } + /** * The session token expiration * diff --git a/server/test-utils/src/main/java/io/deephaven/server/runner/RpcServerStateInterceptor.java b/server/test-utils/src/main/java/io/deephaven/server/runner/RpcServerStateInterceptor.java new file mode 100644 index 00000000000..749cd2eb9c8 --- /dev/null +++ b/server/test-utils/src/main/java/io/deephaven/server/runner/RpcServerStateInterceptor.java @@ -0,0 +1,162 @@ +// +// Copyright (c) 2016-2024 Deephaven Data Labs and Patent Pending +// +package io.deephaven.server.runner; + +import dagger.Binds; +import dagger.multibindings.IntoSet; +import io.grpc.ClientInterceptor; +import io.grpc.Context; +import io.grpc.Contexts; +import io.grpc.ForwardingServerCallListener.SimpleForwardingServerCallListener; +import io.grpc.Metadata; +import io.grpc.Metadata.Key; +import io.grpc.MethodDescriptor; +import io.grpc.ServerCall; +import io.grpc.ServerCall.Listener; +import io.grpc.ServerCallHandler; +import io.grpc.ServerInterceptor; +import io.grpc.stub.AbstractStub; +import io.grpc.stub.MetadataUtils; + +import javax.inject.Inject; +import javax.inject.Singleton; +import java.time.Duration; +import java.util.Map; +import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import java.util.concurrent.atomic.AtomicReference; + +/** + * This interceptor allows testing clients to hook into server-side RPC state. + */ +@Singleton +public final class RpcServerStateInterceptor implements ServerInterceptor { + + @dagger.Module + interface Module { + @Binds + @IntoSet + ServerInterceptor bindsInterceptor(RpcServerStateInterceptor interceptor); + } + + private static final Key KEY = + Key.of(RpcServerStateInterceptor.class.getSimpleName(), Metadata.ASCII_STRING_MARSHALLER); + + private final Map map; + + @Inject + public RpcServerStateInterceptor() { + map = new ConcurrentHashMap<>(); + } + + /** + * Creates a new {@link RpcServerState}. + */ + public RpcServerState newRpcServerState() { + final String id = UUID.randomUUID().toString(); + final RpcServerState rpcContext = new RpcServerState(id); + map.put(id, rpcContext); + return rpcContext; + } + + @Override + public Listener interceptCall(ServerCall call, Metadata headers, + ServerCallHandler next) { + final String id = headers.get(KEY); + if (id == null) { + // No RpcServerState requested, bypass. + return next.startCall(call, headers); + } + final RpcServerState state = map.remove(id); + if (state == null) { + throw new IllegalStateException(String.format( + "Re-use error for id='%s'. The test is probably re-using RpcServerState#clientInterceptor for multiple RPCs which is not allowed.", + id)); + } + return state.intercept(call, headers, next); + } + + public static final class RpcServerState { + private final CountDownLatch startCall; + private final CountDownLatch onHalfClosed; + private final AtomicReference clientInterceptor; + + private MethodDescriptor methodDescriptor; + + RpcServerState(String id) { + this.startCall = new CountDownLatch(1); + this.onHalfClosed = new CountDownLatch(1); + final Metadata metadata = new Metadata(); + metadata.put(KEY, id); + this.clientInterceptor = new AtomicReference<>(MetadataUtils.newAttachHeadersInterceptor(metadata)); + } + + /** + * The necessary, additional logic to pass along {@code this} state to the server. Callers should use this + * method exactly once in combination with a single RPC via + * {@link AbstractStub#withInterceptors(ClientInterceptor...)}. + */ + public ClientInterceptor clientInterceptor() { + final ClientInterceptor clientInterceptor = this.clientInterceptor.getAndSet(null); + if (clientInterceptor == null) { + throw new IllegalStateException("Tests should call clientInterceptor at most once"); + } + return clientInterceptor; + } + + /** + * Waits for the initial server-side invoke to finish. + */ + public void awaitServerInvokeFinished(Duration timeout) throws InterruptedException, TimeoutException { + if (clientInterceptor.get() != null) { + throw new IllegalStateException("Tests should call clientInterceptor() before waiting"); + } + if (!startCall.await(timeout.toNanos(), TimeUnit.NANOSECONDS)) { + throw new TimeoutException(); + } + // We could be more a bit more efficient here and have the testing client pass in the MethodDescriptor, but + // that would increase the complexity for the testing client. + if (methodDescriptor.getType().clientSendsOneMessage()) { + // In the case where we know the client only sends one message, we're going to wait for the server to + // finish the client half-close handling. This matches the GRPC implementation in + // io.grpc.stub.ServerCalls.UnaryServerCallHandler.UnaryServerCallListener. Even if the GRPC impl + // becomes more aggressive in the future (ie, actually invokes the server during the onMessage), we can + // still be safe here with this more conservative approach. + if (!onHalfClosed.await(timeout.toNanos(), TimeUnit.NANOSECONDS)) { + throw new TimeoutException(); + } + } else { + // In the case where the client is streaming (either one way client streaming or bidir), the server gets + // invoked during the startCall; in which case, we've already waited for the startCall and the server + // implementation has already been invoked. + } + } + + Listener intercept(ServerCall call, Metadata headers, + ServerCallHandler next) { + final Context context = Context.current(); + final Listener listener = Contexts.interceptCall(context, call, headers, next); + this.methodDescriptor = call.getMethodDescriptor(); + // We may find unit-testing use-cases where we'd like to make further information available to the testing + // client, in which case we might end up saving call, headers, context, or listener. + // this.call = call; + // this.headers = headers; + // this.context = context; + // this.listener = listener; + // startCall happens in interceptCall + startCall.countDown(); + return new SimpleForwardingServerCallListener<>(listener) { + + @Override + public void onHalfClose() { + super.onHalfClose(); + onHalfClosed.countDown(); + } + }; + } + } +}