From e1025b8195c41fc73bfe319230022a2fd061b14e Mon Sep 17 00:00:00 2001 From: lihan Date: Thu, 20 Apr 2023 15:54:45 +0800 Subject: [PATCH] Give a chance to allow the listener to finish when the container stops. --- ...DefaultStreamMessageListenerContainer.java | 24 ++++++++++++-- .../StreamMessageListenerContainer.java | 31 +++++++++++++++++-- ...sageListenerContainerIntegrationTests.java | 22 ++++++++++++- 3 files changed, 71 insertions(+), 6 deletions(-) diff --git a/src/main/java/org/springframework/data/redis/stream/DefaultStreamMessageListenerContainer.java b/src/main/java/org/springframework/data/redis/stream/DefaultStreamMessageListenerContainer.java index 3520367261..b63e14156f 100644 --- a/src/main/java/org/springframework/data/redis/stream/DefaultStreamMessageListenerContainer.java +++ b/src/main/java/org/springframework/data/redis/stream/DefaultStreamMessageListenerContainer.java @@ -19,6 +19,10 @@ import java.util.ArrayList; import java.util.List; import java.util.concurrent.Executor; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeoutException; import java.util.function.Function; import org.apache.commons.logging.Log; @@ -51,6 +55,7 @@ * * @author Mark Paluch * @author Christoph Strobl + * @author Han Li * @since 2.2 */ class DefaultStreamMessageListenerContainer> implements StreamMessageListenerContainer { @@ -160,9 +165,22 @@ public void stop() { synchronized (lifecycleMonitor) { if (this.running) { - - subscriptions.forEach(Cancelable::cancel); - + subscriptions.stream() + .map(subscription -> CompletableFuture.runAsync(() -> { + subscription.cancel(); + while (subscription.isActive()) { + // NO-OP + } + }, taskExecutor)) + .forEach(f -> { + try { + f.get(this.containerOptions.getShutdownTimeout().toNanos(), TimeUnit.NANOSECONDS); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } catch (ExecutionException | TimeoutException e) { + // ignore + } + }); running = false; } } diff --git a/src/main/java/org/springframework/data/redis/stream/StreamMessageListenerContainer.java b/src/main/java/org/springframework/data/redis/stream/StreamMessageListenerContainer.java index 14a66cf5e3..a743da89c6 100644 --- a/src/main/java/org/springframework/data/redis/stream/StreamMessageListenerContainer.java +++ b/src/main/java/org/springframework/data/redis/stream/StreamMessageListenerContainer.java @@ -494,12 +494,13 @@ class StreamMessageListenerContainerOptions> { private final @Nullable HashMapper hashMapper; private final ErrorHandler errorHandler; private final Executor executor; + private final Duration shutdownTimeout; @SuppressWarnings("unchecked") private StreamMessageListenerContainerOptions(Duration pollTimeout, @Nullable Integer batchSize, RedisSerializer keySerializer, RedisSerializer hashKeySerializer, RedisSerializer hashValueSerializer, @Nullable Class targetType, - @Nullable HashMapper hashMapper, ErrorHandler errorHandler, Executor executor) { + @Nullable HashMapper hashMapper, ErrorHandler errorHandler, Executor executor, Duration shutdownTimeout) { this.pollTimeout = pollTimeout; this.batchSize = batchSize; this.keySerializer = keySerializer; @@ -509,6 +510,7 @@ private StreamMessageListenerContainerOptions(Duration pollTimeout, @Nullable In this.hashMapper = (HashMapper) hashMapper; this.errorHandler = errorHandler; this.executor = executor; + this.shutdownTimeout = shutdownTimeout; } /** @@ -589,6 +591,15 @@ public Executor getExecutor() { return executor; } + /** + * Timeout for shutdown container. + * + * @return the timeout. + */ + public Duration getShutdownTimeout() { + return shutdownTimeout; + } + } /** @@ -609,6 +620,7 @@ class StreamMessageListenerContainerOptionsBuilder> { private @Nullable Class targetType; private ErrorHandler errorHandler = LoggingErrorHandler.INSTANCE; private Executor executor = new SimpleAsyncTaskExecutor(); + private Duration shutdownTimeout = Duration.ofSeconds(1); private StreamMessageListenerContainerOptionsBuilder() {} @@ -627,6 +639,21 @@ public StreamMessageListenerContainerOptionsBuilder pollTimeout(Duration p return this; } + /** + * Configure a timeout for shutdown container. + * + * @param shutdownTimeout must not be {@literal null} or negative. + * @return {@code this} {@link StreamMessageListenerContainerOptionsBuilder}. + */ + public StreamMessageListenerContainerOptionsBuilder shutdownTimeout(Duration shutdownTimeout) { + + Assert.notNull(shutdownTimeout, "Shutdown timeout must not be null"); + Assert.isTrue(!shutdownTimeout.isNegative(), "Shutdown timeout must not be negative"); + + this.shutdownTimeout = shutdownTimeout; + return this; + } + /** * Configure a batch size for the {@code COUNT} option during reading. * @@ -777,7 +804,7 @@ public StreamMessageListenerContainerOptionsBuilder> */ public StreamMessageListenerContainerOptions build() { return new StreamMessageListenerContainerOptions<>(pollTimeout, batchSize, keySerializer, hashKeySerializer, - hashValueSerializer, targetType, hashMapper, errorHandler, executor); + hashValueSerializer, targetType, hashMapper, errorHandler, executor, shutdownTimeout); } } } diff --git a/src/test/java/org/springframework/data/redis/stream/AbstractStreamMessageListenerContainerIntegrationTests.java b/src/test/java/org/springframework/data/redis/stream/AbstractStreamMessageListenerContainerIntegrationTests.java index 6f8fcf8a6a..7ade357026 100644 --- a/src/test/java/org/springframework/data/redis/stream/AbstractStreamMessageListenerContainerIntegrationTests.java +++ b/src/test/java/org/springframework/data/redis/stream/AbstractStreamMessageListenerContainerIntegrationTests.java @@ -67,7 +67,7 @@ abstract class AbstractStreamMessageListenerContainerIntegrationTests { private final RedisConnectionFactory connectionFactory; private final StringRedisTemplate redisTemplate; private final StreamMessageListenerContainerOptions> containerOptions = StreamMessageListenerContainerOptions - .builder().pollTimeout(Duration.ofMillis(100)).build(); + .builder().pollTimeout(Duration.ofMillis(100)).shutdownTimeout(Duration.ofMillis(2000)).build(); AbstractStreamMessageListenerContainerIntegrationTests(RedisConnectionFactory connectionFactory) { this.connectionFactory = connectionFactory; @@ -383,6 +383,26 @@ void containerRestartShouldRestartSubscription() throws InterruptedException { cancelAwait(subscription); } + @Test // GH-2261 + void containerShouldStopGracefully() throws InterruptedException { + StreamMessageListenerContainer> container = StreamMessageListenerContainer + .create(connectionFactory, containerOptions); + + BlockingQueue> queue = new LinkedBlockingQueue<>(); + container.start(); + Subscription subscription = container.receive(StreamOffset.create("my-stream", ReadOffset.from("0-0")), r -> { + try { + Thread.sleep(1500); + } catch (InterruptedException e) { + // ignore + } + queue.add(r); + }); + redisTemplate.opsForStream().add("my-stream", Collections.singletonMap("key", "value1")); + subscription.await(DEFAULT_TIMEOUT); + container.stop(); + assertThat(queue.poll(500, TimeUnit.MILLISECONDS)).isNotNull(); + } private static void cancelAwait(Subscription subscription) {