diff --git a/Base/src/main/java/io/deephaven/base/FileUtils.java b/Base/src/main/java/io/deephaven/base/FileUtils.java index 818b652862a..721ae77b21b 100644 --- a/Base/src/main/java/io/deephaven/base/FileUtils.java +++ b/Base/src/main/java/io/deephaven/base/FileUtils.java @@ -39,6 +39,8 @@ public boolean accept(File dir, String name) { public static final Pattern REPEATED_URI_SEPARATOR_PATTERN = Pattern.compile("//+"); + public static final String FILE_URI_SCHEME = "file"; + /** * Cleans the specified path. All files and subdirectories in the path will be deleted. (ie you'll be left with an * empty directory). @@ -282,21 +284,36 @@ public static URI convertToURI(final String source, final boolean isDirectory) { URI uri; try { uri = new URI(source); + if (uri.getScheme() == null) { + // Convert to a "file" URI + return convertToURI(new File(source), isDirectory); + } + if (uri.getScheme().equals(FILE_URI_SCHEME)) { + return convertToURI(new File(uri), isDirectory); + } + String path = uri.getPath(); + final boolean endsWithSlash = path.charAt(path.length() - 1) == URI_SEPARATOR_CHAR; + if (!isDirectory && endsWithSlash) { + throw new IllegalArgumentException("Non-directory URI should not end with a slash: " + uri); + } + boolean isUpdated = false; + if (isDirectory && !endsWithSlash) { + path = path + URI_SEPARATOR_CHAR; + isUpdated = true; + } // Replace two or more consecutive slashes in the path with a single slash - final String path = uri.getPath(); if (path.contains(REPEATED_URI_SEPARATOR)) { - final String canonicalizedPath = REPEATED_URI_SEPARATOR_PATTERN.matcher(path).replaceAll(URI_SEPARATOR); - uri = new URI(uri.getScheme(), uri.getUserInfo(), uri.getHost(), uri.getPort(), canonicalizedPath, - uri.getQuery(), uri.getFragment()); + path = REPEATED_URI_SEPARATOR_PATTERN.matcher(path).replaceAll(URI_SEPARATOR); + isUpdated = true; + } + if (isUpdated) { + uri = new URI(uri.getScheme(), uri.getUserInfo(), uri.getHost(), uri.getPort(), path, uri.getQuery(), + uri.getFragment()); } } catch (final URISyntaxException e) { // If the URI is invalid, assume it's a file path return convertToURI(new File(source), isDirectory); } - if (uri.getScheme() == null) { - // Convert to a "file" URI - return convertToURI(new File(source), isDirectory); - } return uri; } @@ -314,17 +331,11 @@ public static URI convertToURI(final File file, final boolean isDirectory) { if (File.separatorChar != URI_SEPARATOR_CHAR) { absPath = absPath.replace(File.separatorChar, URI_SEPARATOR_CHAR); } - if (absPath.charAt(0) != URI_SEPARATOR_CHAR) { - absPath = URI_SEPARATOR_CHAR + absPath; - } if (isDirectory && absPath.charAt(absPath.length() - 1) != URI_SEPARATOR_CHAR) { absPath = absPath + URI_SEPARATOR_CHAR; } - if (absPath.startsWith(REPEATED_URI_SEPARATOR)) { - absPath = REPEATED_URI_SEPARATOR + absPath; - } try { - return new URI("file", null, absPath, null); + return new URI(FILE_URI_SCHEME, null, absPath, null); } catch (final URISyntaxException e) { throw new IllegalStateException("Failed to convert file to URI: " + file, e); } diff --git a/Base/src/test/java/io/deephaven/base/FileUtilsTest.java b/Base/src/test/java/io/deephaven/base/FileUtilsTest.java new file mode 100644 index 00000000000..fc00e5802cb --- /dev/null +++ b/Base/src/test/java/io/deephaven/base/FileUtilsTest.java @@ -0,0 +1,68 @@ +// +// Copyright (c) 2016-2024 Deephaven Data Labs and Patent Pending +// +package io.deephaven.base; + +import java.io.File; +import java.io.IOException; +import java.net.URISyntaxException; +import java.nio.file.Path; + +import junit.framework.TestCase; +import org.junit.Assert; + +public class FileUtilsTest extends TestCase { + + public void testConvertToFileURI() throws IOException { + final File currentDir = new File("").getAbsoluteFile(); + fileUriTestHelper(currentDir.toString(), true, currentDir.toURI().toString()); + + final File someFile = new File(currentDir, "tempFile"); + fileUriTestHelper(someFile.getPath(), false, someFile.toURI().toString()); + + // Check if trailing slash gets added for a directory + final String expectedDirURI = "file:" + currentDir.getPath() + "/path/to/directory/"; + fileUriTestHelper(currentDir.getPath() + "/path/to/directory", true, expectedDirURI); + + // Check if multiple slashes get normalized + fileUriTestHelper(currentDir.getPath() + "////path///to////directory////", true, expectedDirURI); + + // Check if multiple slashes in the beginning get normalized + fileUriTestHelper("////" + currentDir.getPath() + "/path/to/directory", true, expectedDirURI); + + // Check for bad inputs for files with trailing slashes + final String expectedFileURI = someFile.toURI().toString(); + fileUriTestHelper(someFile.getPath() + "/", false, expectedFileURI); + Assert.assertEquals(expectedFileURI, + FileUtils.convertToURI("file:" + someFile.getPath() + "/", false).toString()); + } + + private static void fileUriTestHelper(final String filePath, final boolean isDirectory, + final String expectedURIString) { + Assert.assertEquals(expectedURIString, FileUtils.convertToURI(filePath, isDirectory).toString()); + Assert.assertEquals(expectedURIString, FileUtils.convertToURI(new File(filePath), isDirectory).toString()); + Assert.assertEquals(expectedURIString, FileUtils.convertToURI(Path.of(filePath), isDirectory).toString()); + } + + public void testConvertToS3URI() throws URISyntaxException { + Assert.assertEquals("s3://bucket/key", FileUtils.convertToURI("s3://bucket/key", false).toString()); + + // Check if trailing slash gets added for a directory + Assert.assertEquals("s3://bucket/key/".toString(), FileUtils.convertToURI("s3://bucket/key", true).toString()); + + // Check if multiple slashes get normalized + Assert.assertEquals("s3://bucket/key/", FileUtils.convertToURI("s3://bucket///key///", true).toString()); + + try { + FileUtils.convertToURI("", false); + Assert.fail("Expected IllegalArgumentException"); + } catch (IllegalArgumentException expected) { + } + + try { + FileUtils.convertToURI("s3://bucket/key/", false); + Assert.fail("Expected IllegalArgumentException"); + } catch (IllegalArgumentException expected) { + } + } +} diff --git a/ClientSupport/src/main/java/io/deephaven/clientsupport/plotdownsampling/BucketState.java b/ClientSupport/src/main/java/io/deephaven/clientsupport/plotdownsampling/BucketState.java index f29276dca14..420e859215b 100644 --- a/ClientSupport/src/main/java/io/deephaven/clientsupport/plotdownsampling/BucketState.java +++ b/ClientSupport/src/main/java/io/deephaven/clientsupport/plotdownsampling/BucketState.java @@ -9,6 +9,8 @@ import io.deephaven.engine.rowset.RowSetFactory; import io.deephaven.engine.rowset.impl.RowSetUtils; import io.deephaven.engine.rowset.chunkattributes.OrderedRowKeys; +import io.deephaven.internal.log.LoggerFactory; +import io.deephaven.io.logger.Logger; import io.deephaven.util.QueryConstants; import io.deephaven.chunk.Chunk; import io.deephaven.chunk.LongChunk; @@ -26,6 +28,8 @@ * its own offset in those arrays. */ public class BucketState { + private static final Logger log = LoggerFactory.getLogger(BucketState.class); + private final WritableRowSet rowSet = RowSetFactory.empty(); private RowSet cachedRowSet; @@ -310,10 +314,11 @@ public void validate(final boolean usePrev, final DownsampleChunkContext context values[columnIndex].validate(offset, keyChunk.get(indexInChunk), valueChunks[columnIndex], indexInChunk, trackNulls ? nulls[columnIndex] : null); } catch (final RuntimeException e) { - System.out.println(rowSet); final String msg = "Bad data! indexInChunk=" + indexInChunk + ", col=" + columnIndex + ", usePrev=" - + usePrev + ", offset=" + offset + ", rowSet=" + keyChunk.get(indexInChunk); + + usePrev + ", offset=" + offset + ", indexInChunk=" + + keyChunk.get(indexInChunk); + log.error().append(msg).append(", rowSet=").append(rowSet).endl(); throw new IllegalStateException(msg, e); } } @@ -321,11 +326,4 @@ public void validate(final boolean usePrev, final DownsampleChunkContext context } Assert.eqTrue(makeRowSet().subsetOf(rowSet), "makeRowSet().subsetOf(rowSet)"); } - - public void close() { - if (cachedRowSet != null) { - cachedRowSet.close(); - } - rowSet.close(); - } } diff --git a/ClientSupport/src/main/java/io/deephaven/clientsupport/plotdownsampling/RunChartDownsample.java b/ClientSupport/src/main/java/io/deephaven/clientsupport/plotdownsampling/RunChartDownsample.java index c0fadd7f2c4..3636333d8e7 100644 --- a/ClientSupport/src/main/java/io/deephaven/clientsupport/plotdownsampling/RunChartDownsample.java +++ b/ClientSupport/src/main/java/io/deephaven/clientsupport/plotdownsampling/RunChartDownsample.java @@ -317,12 +317,6 @@ private DownsamplerListener( allYColumnIndexes = IntStream.range(0, key.yColumnNames.length).toArray(); } - @Override - protected void destroy() { - super.destroy(); - states.values().forEach(BucketState::close); - } - @Override public void onUpdate(final TableUpdate upstream) { try (final DownsampleChunkContext context = @@ -684,7 +678,6 @@ private void performRescans(final DownsampleChunkContext context) { // if it has no keys at all, remove it so we quit checking it iterator.remove(); releasePosition(bucket.getOffset()); - bucket.close(); } else { bucket.rescanIfNeeded(context); } diff --git a/engine/table/src/main/java/io/deephaven/engine/table/impl/preview/DisplayWrapper.java b/engine/table/src/main/java/io/deephaven/engine/table/impl/preview/DisplayWrapper.java index 913471b2f13..00b91d700ad 100644 --- a/engine/table/src/main/java/io/deephaven/engine/table/impl/preview/DisplayWrapper.java +++ b/engine/table/src/main/java/io/deephaven/engine/table/impl/preview/DisplayWrapper.java @@ -50,4 +50,9 @@ public boolean equals(Object obj) { return false; } + + @Override + public int hashCode() { + return displayString.hashCode(); + } } diff --git a/extensions/parquet/base/src/main/java/io/deephaven/parquet/base/ParquetFileWriter.java b/extensions/parquet/base/src/main/java/io/deephaven/parquet/base/ParquetFileWriter.java index f050c119dce..81dc13a4430 100644 --- a/extensions/parquet/base/src/main/java/io/deephaven/parquet/base/ParquetFileWriter.java +++ b/extensions/parquet/base/src/main/java/io/deephaven/parquet/base/ParquetFileWriter.java @@ -18,7 +18,6 @@ import org.apache.parquet.schema.MessageType; import org.jetbrains.annotations.NotNull; -import java.io.File; import java.io.IOException; import java.util.ArrayList; import java.util.HashMap; diff --git a/extensions/trackedfile/src/main/java/io/deephaven/extensions/trackedfile/TrackedSeekableChannelsProvider.java b/extensions/trackedfile/src/main/java/io/deephaven/extensions/trackedfile/TrackedSeekableChannelsProvider.java index 38c17439e6b..4aec474721d 100644 --- a/extensions/trackedfile/src/main/java/io/deephaven/extensions/trackedfile/TrackedSeekableChannelsProvider.java +++ b/extensions/trackedfile/src/main/java/io/deephaven/extensions/trackedfile/TrackedSeekableChannelsProvider.java @@ -27,7 +27,7 @@ import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; import java.util.stream.Stream; -import static io.deephaven.extensions.trackedfile.TrackedSeekableChannelsProviderPlugin.FILE_URI_SCHEME; +import static io.deephaven.base.FileUtils.FILE_URI_SCHEME; /** * {@link SeekableChannelsProvider} implementation that is constrained by a Deephaven {@link TrackedFileHandleFactory}. diff --git a/extensions/trackedfile/src/main/java/io/deephaven/extensions/trackedfile/TrackedSeekableChannelsProviderPlugin.java b/extensions/trackedfile/src/main/java/io/deephaven/extensions/trackedfile/TrackedSeekableChannelsProviderPlugin.java index f8098f4c717..32ec7a3564a 100644 --- a/extensions/trackedfile/src/main/java/io/deephaven/extensions/trackedfile/TrackedSeekableChannelsProviderPlugin.java +++ b/extensions/trackedfile/src/main/java/io/deephaven/extensions/trackedfile/TrackedSeekableChannelsProviderPlugin.java @@ -12,14 +12,14 @@ import java.net.URI; +import static io.deephaven.base.FileUtils.FILE_URI_SCHEME; + /** * {@link SeekableChannelsProviderPlugin} implementation used for reading files from local disk. */ @AutoService(SeekableChannelsProviderPlugin.class) public final class TrackedSeekableChannelsProviderPlugin implements SeekableChannelsProviderPlugin { - static final String FILE_URI_SCHEME = "file"; - @Override public boolean isCompatible(@NotNull final URI uri, @Nullable final Object object) { return FILE_URI_SCHEME.equals(uri.getScheme()); diff --git a/server/src/main/java/io/deephaven/server/console/ConsoleServiceGrpcImpl.java b/server/src/main/java/io/deephaven/server/console/ConsoleServiceGrpcImpl.java index 25578fe057c..0c8709e0487 100644 --- a/server/src/main/java/io/deephaven/server/console/ConsoleServiceGrpcImpl.java +++ b/server/src/main/java/io/deephaven/server/console/ConsoleServiceGrpcImpl.java @@ -157,6 +157,8 @@ public void subscribeToLogs( GrpcUtil.safelyError(responseObserver, Code.FAILED_PRECONDITION, "Remote console disabled"); return; } + // Session close logic implicitly handled in + // io.deephaven.server.session.SessionServiceGrpcImpl.SessionServiceInterceptor final LogsClient client = new LogsClient(request, (ServerCallStreamObserver) responseObserver); client.start(); diff --git a/server/src/main/java/io/deephaven/server/object/ObjectServiceGrpcImpl.java b/server/src/main/java/io/deephaven/server/object/ObjectServiceGrpcImpl.java index 052de6b3475..f2901dee6ff 100644 --- a/server/src/main/java/io/deephaven/server/object/ObjectServiceGrpcImpl.java +++ b/server/src/main/java/io/deephaven/server/object/ObjectServiceGrpcImpl.java @@ -323,6 +323,8 @@ public void onCompleted() { @Override public StreamObserver messageStream(StreamObserver responseObserver) { SessionState session = sessionService.getCurrentSession(); + // Session close logic implicitly handled in + // io.deephaven.server.session.SessionServiceGrpcImpl.SessionServiceInterceptor return new SendMessageObserver(session, responseObserver); } diff --git a/server/src/main/java/io/deephaven/server/session/SessionServiceGrpcImpl.java b/server/src/main/java/io/deephaven/server/session/SessionServiceGrpcImpl.java index 2e64738457d..f13bb33e55d 100644 --- a/server/src/main/java/io/deephaven/server/session/SessionServiceGrpcImpl.java +++ b/server/src/main/java/io/deephaven/server/session/SessionServiceGrpcImpl.java @@ -15,9 +15,9 @@ import io.deephaven.internal.log.LoggerFactory; import io.deephaven.io.logger.Logger; import io.deephaven.proto.backplane.grpc.*; +import io.deephaven.proto.backplane.script.grpc.ConsoleServiceGrpc; import io.deephaven.proto.util.Exceptions; import io.deephaven.util.SafeCloseable; -import io.deephaven.util.function.ThrowingRunnable; import io.grpc.Context; import io.grpc.ForwardingServerCall.SimpleForwardingServerCall; import io.grpc.ForwardingServerCallListener; @@ -36,9 +36,11 @@ import javax.inject.Inject; import javax.inject.Singleton; +import java.io.Closeable; import java.lang.Object; import java.util.LinkedHashMap; import java.util.Map; +import java.util.Set; import java.util.UUID; public class SessionServiceGrpcImpl extends SessionServiceGrpc.SessionServiceImplBase { @@ -310,10 +312,19 @@ private void addHeaders(final Metadata md) { @Singleton public static class SessionServiceInterceptor implements ServerInterceptor { + private static final Status AUTHENTICATION_DETAILS_INVALID = + Status.UNAUTHENTICATED.withDescription("Authentication details invalid"); + + // We can't use just io.grpc.MethodDescriptor (unless we chose provide and inject the named method descriptors), + // some of our methods are overridden from stock gRPC; for example, + // io.deephaven.server.object.ObjectServiceGrpcBinding.bindService. + // The goal should be to migrate all of the existing RPC Session close management logic to here if possible. + private static final Set CANCEL_RPC_ON_SESSION_CLOSE = Set.of( + ConsoleServiceGrpc.getSubscribeToLogsMethod().getFullMethodName(), + ObjectServiceGrpc.getMessageStreamMethod().getFullMethodName()); + private final SessionService service; private final SessionService.ErrorTransformer errorTransformer; - private static final Status authenticationDetailsInvalid = - Status.UNAUTHENTICATED.withDescription("Authentication details invalid"); @Inject public SessionServiceInterceptor( @@ -344,12 +355,8 @@ public ServerCall.Listener interceptCall(final ServerCall() {}; } } @@ -363,33 +370,61 @@ public ServerCall.Listener interceptCall(final ServerCall> listener = new MutableObject<>(); rpcWrapper(serverCall, context, finalSession, errorTransformer, () -> listener.setValue( - new SessionServiceCallListener<>(serverCallHandler.startCall(serverCall, metadata), serverCall, - context, finalSession, errorTransformer))); + listener(serverCall, metadata, serverCallHandler, context, finalSession))); if (listener.getValue() == null) { return new ServerCall.Listener<>() {}; } return listener.getValue(); } + + private @NotNull SessionServiceCallListener listener( + InterceptedCall serverCall, + Metadata metadata, + ServerCallHandler serverCallHandler, + Context context, + SessionState session) { + return new SessionServiceCallListener<>( + serverCallHandler.startCall(serverCall, metadata), + serverCall, + context, + session, + errorTransformer, + CANCEL_RPC_ON_SESSION_CLOSE.contains(serverCall.getMethodDescriptor().getFullMethodName())); + } } private static class SessionServiceCallListener extends - ForwardingServerCallListener.SimpleForwardingServerCallListener { + ForwardingServerCallListener.SimpleForwardingServerCallListener implements Closeable { + private static final Status SESSION_CLOSED = Status.CANCELLED.withDescription("Session closed"); + private final ServerCall call; private final Context context; private final SessionState session; private final SessionService.ErrorTransformer errorTransformer; + private final boolean autoCancelOnSessionClose; - public SessionServiceCallListener( + SessionServiceCallListener( ServerCall.Listener delegate, ServerCall call, Context context, SessionState session, - SessionService.ErrorTransformer errorTransformer) { + SessionService.ErrorTransformer errorTransformer, + boolean autoCancelOnSessionClose) { super(delegate); this.call = call; this.context = context; this.session = session; this.errorTransformer = errorTransformer; + this.autoCancelOnSessionClose = autoCancelOnSessionClose; + if (autoCancelOnSessionClose && session != null) { + session.addOnCloseCallback(this); + } + } + + @Override + public void close() { + // session.addOnCloseCallback + safeClose(call, SESSION_CLOSED, new Metadata(), false); } @Override @@ -405,11 +440,17 @@ public void onHalfClose() { @Override public void onCancel() { rpcWrapper(call, context, session, errorTransformer, super::onCancel); + if (autoCancelOnSessionClose && session != null) { + session.removeOnCloseCallback(this); + } } @Override public void onComplete() { rpcWrapper(call, context, session, errorTransformer, super::onComplete); + if (autoCancelOnSessionClose && session != null) { + session.removeOnCloseCallback(this); + } } @Override @@ -432,34 +473,44 @@ private static void rpcWrapper( @NotNull final Context context, @Nullable final SessionState session, @NotNull final SessionService.ErrorTransformer errorTransformer, - @NotNull final ThrowingRunnable lambda) { + @NotNull final Runnable lambda) { Context previous = context.attach(); // note: we'll open the execution context here so that it may be used by the error transformer try (final SafeCloseable ignored1 = session == null ? null : session.getExecutionContext().open()) { try (final SafeCloseable ignored2 = LivenessScopeStack.open()) { lambda.run(); - } catch (final InterruptedException err) { - Thread.currentThread().interrupt(); - closeWithError(call, errorTransformer.transform(err)); - } catch (final Throwable err) { - closeWithError(call, errorTransformer.transform(err)); + } catch (final RuntimeException err) { + safeClose(call, errorTransformer.transform(err)); + } catch (final Error error) { + // Indicates a very serious failure; debateable whether we should even try to send close. + safeClose(call, Status.INTERNAL, new Metadata(), false); + throw error; } finally { context.detach(previous); } } } - private static void closeWithError( - @NotNull final ServerCall call, + private static void safeClose( + @NotNull final ServerCall call, @NotNull final StatusRuntimeException err) { + Metadata metadata = Status.trailersFromThrowable(err); + if (metadata == null) { + metadata = new Metadata(); + } + safeClose(call, Status.fromThrowable(err), metadata, true); + } + + private static void safeClose(ServerCall call, Status status, Metadata trailers, boolean logOnError) { try { - Metadata metadata = Status.trailersFromThrowable(err); - if (metadata == null) { - metadata = new Metadata(); + call.close(status, trailers); + } catch (IllegalStateException e) { + // IllegalStateException is explicitly documented as thrown if the call is already closed. It might be nice + // if there was a more explicit exception type, but this should suffice. We _could_ try and check the text + // "call already closed", but that is an undocumented implementation detail we should probably not rely on. + if (logOnError && log.isDebugEnabled()) { + log.debug().append("call.close error: ").append(e).endl(); } - call.close(Status.fromThrowable(err), metadata); - } catch (final Exception unexpectedErr) { - log.debug().append("Unanticipated gRPC Error: ").append(unexpectedErr).endl(); } } } diff --git a/server/src/main/java/io/deephaven/server/session/SessionState.java b/server/src/main/java/io/deephaven/server/session/SessionState.java index dbc06a60435..74b5aafc304 100644 --- a/server/src/main/java/io/deephaven/server/session/SessionState.java +++ b/server/src/main/java/io/deephaven/server/session/SessionState.java @@ -1141,6 +1141,7 @@ protected synchronized void destroy() { if (!(caughtException instanceof StatusRuntimeException)) { caughtException = null; } + queryPerformanceRecorder = null; } /** diff --git a/server/src/test/java/io/deephaven/server/session/SessionServiceCloseTest.java b/server/src/test/java/io/deephaven/server/session/SessionServiceCloseTest.java new file mode 100644 index 00000000000..a910b73a147 --- /dev/null +++ b/server/src/test/java/io/deephaven/server/session/SessionServiceCloseTest.java @@ -0,0 +1,219 @@ +// +// Copyright (c) 2016-2024 Deephaven Data Labs and Patent Pending +// +package io.deephaven.server.session; + +import io.deephaven.proto.backplane.grpc.ExportNotification; +import io.deephaven.proto.backplane.grpc.ExportNotificationRequest; +import io.deephaven.proto.backplane.grpc.ExportedTableUpdateMessage; +import io.deephaven.proto.backplane.grpc.ExportedTableUpdatesRequest; +import io.deephaven.proto.backplane.grpc.FieldsChangeUpdate; +import io.deephaven.proto.backplane.grpc.ListFieldsRequest; +import io.deephaven.proto.backplane.grpc.StreamRequest; +import io.deephaven.proto.backplane.grpc.StreamResponse; +import io.deephaven.proto.backplane.grpc.TerminationNotificationRequest; +import io.deephaven.proto.backplane.grpc.TerminationNotificationResponse; +import io.deephaven.proto.backplane.script.grpc.AutoCompleteRequest; +import io.deephaven.proto.backplane.script.grpc.AutoCompleteResponse; +import io.deephaven.proto.backplane.script.grpc.LogSubscriptionData; +import io.deephaven.proto.backplane.script.grpc.LogSubscriptionRequest; +import io.deephaven.server.runner.DeephavenApiServerSingleAuthenticatedBase; +import io.deephaven.server.runner.RpcServerStateInterceptor.RpcServerState; +import io.grpc.ClientInterceptor; +import io.grpc.StatusRuntimeException; +import io.grpc.stub.ClientCallStreamObserver; +import io.grpc.stub.ClientResponseObserver; +import org.junit.Test; + +import java.time.Duration; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +import static org.assertj.core.api.Assertions.assertThat; + +public class SessionServiceCloseTest extends DeephavenApiServerSingleAuthenticatedBase { + + @Test + public void listFields() throws InterruptedException, TimeoutException { + new ListFieldsObserver().doTest(StatusRuntimeException.class, "CANCELLED: subscription cancelled"); + } + + @Test + public void exportNotifications() throws InterruptedException, TimeoutException { + new ExportNotificationsObserver().doTestToCompleted(); + } + + @Test + public void exportedTableUpdates() throws InterruptedException, TimeoutException { + new ExportedTableUpdatesObserver().doTestToCompleted(); + } + + @Test + public void messageStream() throws InterruptedException, TimeoutException { + new MessageStreamObserver().doTest(StatusRuntimeException.class, "CANCELLED: Session closed"); + } + + @Test + public void subscribeToLogs() throws InterruptedException, TimeoutException { + new SubscribeToLogsObserver().doTest(StatusRuntimeException.class, "CANCELLED: Session closed"); + } + + @Test + public void autoCompleteStream() throws InterruptedException, TimeoutException { + new AutoCompleteStreamObserver().doTestToCompleted(); + } + + @Test + public void terminationNotification() throws InterruptedException, TimeoutException { + new TerminationObserver().doTest(StatusRuntimeException.class, "UNAUTHENTICATED: Session has ended"); + } + + final class ListFieldsObserver extends CloseSessionObserverBase { + + @Override + void sendImpl(ClientInterceptor clientInterceptor) { + channel() + .application() + .withInterceptors(clientInterceptor) + .listFields(ListFieldsRequest.getDefaultInstance(), this); + } + } + + final class ExportNotificationsObserver + extends CloseSessionObserverBase { + + @Override + void sendImpl(ClientInterceptor clientInterceptor) { + channel() + .session() + .withInterceptors(clientInterceptor) + .exportNotifications(ExportNotificationRequest.getDefaultInstance(), this); + } + } + + final class ExportedTableUpdatesObserver + extends CloseSessionObserverBase { + + @Override + void sendImpl(ClientInterceptor clientInterceptor) { + channel() + .table() + .withInterceptors(clientInterceptor) + .exportedTableUpdates(ExportedTableUpdatesRequest.getDefaultInstance(), this); + } + } + + final class MessageStreamObserver extends CloseSessionObserverBase { + + @Override + void sendImpl(ClientInterceptor clientInterceptor) { + channel() + .object() + .withInterceptors(clientInterceptor) + .messageStream(this); + } + } + + final class SubscribeToLogsObserver extends CloseSessionObserverBase { + + @Override + void sendImpl(ClientInterceptor clientInterceptor) { + channel() + .console() + .withInterceptors(clientInterceptor) + .subscribeToLogs(LogSubscriptionRequest.getDefaultInstance(), this); + } + } + + final class AutoCompleteStreamObserver extends CloseSessionObserverBase { + + @Override + void sendImpl(ClientInterceptor clientInterceptor) { + channel() + .console() + .withInterceptors(clientInterceptor) + .autoCompleteStream(this); + } + } + + final class TerminationObserver + extends CloseSessionObserverBase { + + @Override + void sendImpl(ClientInterceptor clientInterceptor) { + channel() + .session() + .withInterceptors(clientInterceptor) + .terminationNotification(TerminationNotificationRequest.getDefaultInstance(), this); + } + } + + abstract class CloseSessionObserverBase implements ClientResponseObserver { + ClientCallStreamObserver observer; + Throwable t; + boolean onCompleted; + final CountDownLatch onDone = new CountDownLatch(1); + + public void doTest(Class exceptionType, String message) + throws InterruptedException, TimeoutException { + sendImplCloseSessionAndWait(); + assertError(exceptionType, message); + } + + public void doTestToCompleted() throws InterruptedException, TimeoutException { + sendImplCloseSessionAndWait(); + assertCompleted(); + } + + private void sendImplCloseSessionAndWait() throws InterruptedException, TimeoutException { + final RpcServerState serverState = serverStateInterceptor().newRpcServerState(); + sendImpl(serverState.clientInterceptor()); + serverState.awaitServerInvokeFinished(Duration.ofSeconds(3)); + closeSession(); + awaitOnDone(Duration.ofSeconds(3)); + } + + abstract void sendImpl(ClientInterceptor clientInterceptor) throws InterruptedException, TimeoutException; + + @Override + public final void beforeStart(ClientCallStreamObserver requestStream) { + this.observer = requestStream; + } + + @Override + public final void onNext(RespT value) { + // ignore + } + + @Override + public final void onError(Throwable t) { + this.t = t; + onDone.countDown(); + } + + @Override + public final void onCompleted() { + onCompleted = true; + onDone.countDown(); + } + + final void awaitOnDone(Duration duration) throws InterruptedException, TimeoutException { + if (!onDone.await(duration.toNanos(), TimeUnit.NANOSECONDS)) { + throw new TimeoutException(); + } + } + + final void assertCompleted() { + assertThat(onCompleted).isTrue(); + assertThat(t).isNull(); + } + + final void assertError(Class exceptionType, String message) { + assertThat(onCompleted).isFalse(); + assertThat(t).isNotNull(); + assertThat(t).isInstanceOf(exceptionType); + assertThat(t).hasMessage(message); + } + } +} diff --git a/server/test-utils/src/main/java/io/deephaven/server/runner/DeephavenApiServerSingleAuthenticatedBase.java b/server/test-utils/src/main/java/io/deephaven/server/runner/DeephavenApiServerSingleAuthenticatedBase.java index bca764b2d6d..35229d62889 100644 --- a/server/test-utils/src/main/java/io/deephaven/server/runner/DeephavenApiServerSingleAuthenticatedBase.java +++ b/server/test-utils/src/main/java/io/deephaven/server/runner/DeephavenApiServerSingleAuthenticatedBase.java @@ -6,6 +6,7 @@ import io.deephaven.UncheckedDeephavenException; import io.deephaven.auth.AuthenticationException; import io.deephaven.proto.DeephavenChannel; +import io.deephaven.proto.backplane.grpc.CloseSessionResponse; import io.deephaven.proto.backplane.grpc.HandshakeRequest; import io.deephaven.proto.backplane.grpc.HandshakeResponse; import io.deephaven.server.session.SessionState; @@ -49,4 +50,8 @@ public SessionState authenticatedSessionState() { public DeephavenChannel channel() { return channel; } + + public CloseSessionResponse closeSession() { + return channel.sessionBlocking().closeSession(HandshakeRequest.getDefaultInstance()); + } } diff --git a/server/test-utils/src/main/java/io/deephaven/server/runner/DeephavenApiServerTestBase.java b/server/test-utils/src/main/java/io/deephaven/server/runner/DeephavenApiServerTestBase.java index add05a92452..bc5af70f588 100644 --- a/server/test-utils/src/main/java/io/deephaven/server/runner/DeephavenApiServerTestBase.java +++ b/server/test-utils/src/main/java/io/deephaven/server/runner/DeephavenApiServerTestBase.java @@ -53,6 +53,7 @@ public abstract class DeephavenApiServerTestBase { LogModule.class, NoConsoleSessionModule.class, ServerBuilderInProcessModule.class, + RpcServerStateInterceptor.Module.class, ExecutionContextUnitTestModule.class, ClientDefaultsModule.class, ObfuscatingErrorTransformerModule.class, @@ -105,6 +106,9 @@ interface Builder { @Inject Provider> managedChannelBuilderProvider; + @Inject + RpcServerStateInterceptor serverStateInterceptor; + @Before public void setUp() throws Exception { logBuffer = new LogBuffer(128); @@ -179,6 +183,10 @@ public ExecutionContext getExecutionContext() { return executionContext; } + public RpcServerStateInterceptor serverStateInterceptor() { + return serverStateInterceptor; + } + /** * The session token expiration * diff --git a/server/test-utils/src/main/java/io/deephaven/server/runner/RpcServerStateInterceptor.java b/server/test-utils/src/main/java/io/deephaven/server/runner/RpcServerStateInterceptor.java new file mode 100644 index 00000000000..749cd2eb9c8 --- /dev/null +++ b/server/test-utils/src/main/java/io/deephaven/server/runner/RpcServerStateInterceptor.java @@ -0,0 +1,162 @@ +// +// Copyright (c) 2016-2024 Deephaven Data Labs and Patent Pending +// +package io.deephaven.server.runner; + +import dagger.Binds; +import dagger.multibindings.IntoSet; +import io.grpc.ClientInterceptor; +import io.grpc.Context; +import io.grpc.Contexts; +import io.grpc.ForwardingServerCallListener.SimpleForwardingServerCallListener; +import io.grpc.Metadata; +import io.grpc.Metadata.Key; +import io.grpc.MethodDescriptor; +import io.grpc.ServerCall; +import io.grpc.ServerCall.Listener; +import io.grpc.ServerCallHandler; +import io.grpc.ServerInterceptor; +import io.grpc.stub.AbstractStub; +import io.grpc.stub.MetadataUtils; + +import javax.inject.Inject; +import javax.inject.Singleton; +import java.time.Duration; +import java.util.Map; +import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import java.util.concurrent.atomic.AtomicReference; + +/** + * This interceptor allows testing clients to hook into server-side RPC state. + */ +@Singleton +public final class RpcServerStateInterceptor implements ServerInterceptor { + + @dagger.Module + interface Module { + @Binds + @IntoSet + ServerInterceptor bindsInterceptor(RpcServerStateInterceptor interceptor); + } + + private static final Key KEY = + Key.of(RpcServerStateInterceptor.class.getSimpleName(), Metadata.ASCII_STRING_MARSHALLER); + + private final Map map; + + @Inject + public RpcServerStateInterceptor() { + map = new ConcurrentHashMap<>(); + } + + /** + * Creates a new {@link RpcServerState}. + */ + public RpcServerState newRpcServerState() { + final String id = UUID.randomUUID().toString(); + final RpcServerState rpcContext = new RpcServerState(id); + map.put(id, rpcContext); + return rpcContext; + } + + @Override + public Listener interceptCall(ServerCall call, Metadata headers, + ServerCallHandler next) { + final String id = headers.get(KEY); + if (id == null) { + // No RpcServerState requested, bypass. + return next.startCall(call, headers); + } + final RpcServerState state = map.remove(id); + if (state == null) { + throw new IllegalStateException(String.format( + "Re-use error for id='%s'. The test is probably re-using RpcServerState#clientInterceptor for multiple RPCs which is not allowed.", + id)); + } + return state.intercept(call, headers, next); + } + + public static final class RpcServerState { + private final CountDownLatch startCall; + private final CountDownLatch onHalfClosed; + private final AtomicReference clientInterceptor; + + private MethodDescriptor methodDescriptor; + + RpcServerState(String id) { + this.startCall = new CountDownLatch(1); + this.onHalfClosed = new CountDownLatch(1); + final Metadata metadata = new Metadata(); + metadata.put(KEY, id); + this.clientInterceptor = new AtomicReference<>(MetadataUtils.newAttachHeadersInterceptor(metadata)); + } + + /** + * The necessary, additional logic to pass along {@code this} state to the server. Callers should use this + * method exactly once in combination with a single RPC via + * {@link AbstractStub#withInterceptors(ClientInterceptor...)}. + */ + public ClientInterceptor clientInterceptor() { + final ClientInterceptor clientInterceptor = this.clientInterceptor.getAndSet(null); + if (clientInterceptor == null) { + throw new IllegalStateException("Tests should call clientInterceptor at most once"); + } + return clientInterceptor; + } + + /** + * Waits for the initial server-side invoke to finish. + */ + public void awaitServerInvokeFinished(Duration timeout) throws InterruptedException, TimeoutException { + if (clientInterceptor.get() != null) { + throw new IllegalStateException("Tests should call clientInterceptor() before waiting"); + } + if (!startCall.await(timeout.toNanos(), TimeUnit.NANOSECONDS)) { + throw new TimeoutException(); + } + // We could be more a bit more efficient here and have the testing client pass in the MethodDescriptor, but + // that would increase the complexity for the testing client. + if (methodDescriptor.getType().clientSendsOneMessage()) { + // In the case where we know the client only sends one message, we're going to wait for the server to + // finish the client half-close handling. This matches the GRPC implementation in + // io.grpc.stub.ServerCalls.UnaryServerCallHandler.UnaryServerCallListener. Even if the GRPC impl + // becomes more aggressive in the future (ie, actually invokes the server during the onMessage), we can + // still be safe here with this more conservative approach. + if (!onHalfClosed.await(timeout.toNanos(), TimeUnit.NANOSECONDS)) { + throw new TimeoutException(); + } + } else { + // In the case where the client is streaming (either one way client streaming or bidir), the server gets + // invoked during the startCall; in which case, we've already waited for the startCall and the server + // implementation has already been invoked. + } + } + + Listener intercept(ServerCall call, Metadata headers, + ServerCallHandler next) { + final Context context = Context.current(); + final Listener listener = Contexts.interceptCall(context, call, headers, next); + this.methodDescriptor = call.getMethodDescriptor(); + // We may find unit-testing use-cases where we'd like to make further information available to the testing + // client, in which case we might end up saving call, headers, context, or listener. + // this.call = call; + // this.headers = headers; + // this.context = context; + // this.listener = listener; + // startCall happens in interceptCall + startCall.countDown(); + return new SimpleForwardingServerCallListener<>(listener) { + + @Override + public void onHalfClose() { + super.onHalfClose(); + onHalfClosed.countDown(); + } + }; + } + } +}