From 1401b0e3d2765d4210ac0113a26a3aa103aa59db Mon Sep 17 00:00:00 2001 From: peter-csala Date: Mon, 18 Nov 2024 12:20:04 +0100 Subject: [PATCH] Handle CancellationToken for Retry --- .../Retry/RetryResilienceStrategy.cs | 11 ++++- .../Retry/RetryResilienceStrategyTests.cs | 44 +++++++++++++++++-- 2 files changed, 50 insertions(+), 5 deletions(-) diff --git a/src/Polly.Core/Retry/RetryResilienceStrategy.cs b/src/Polly.Core/Retry/RetryResilienceStrategy.cs index 3b378b3a5a3..2ffe83ff2fc 100644 --- a/src/Polly.Core/Retry/RetryResilienceStrategy.cs +++ b/src/Polly.Core/Retry/RetryResilienceStrategy.cs @@ -53,6 +53,15 @@ protected internal override async ValueTask> ExecuteCore(Func { var startTimestamp = _timeProvider.GetTimestamp(); var outcome = await StrategyHelper.ExecuteCallbackSafeAsync(callback, context, state).ConfigureAwait(context.ContinueOnCapturedContext); + try + { + context.CancellationToken.ThrowIfCancellationRequested(); + } + catch (OperationCanceledException e) + { + outcome = Outcome.FromException(e); + } + var shouldRetryArgs = new RetryPredicateArguments(context, outcome, attempt); var handle = await ShouldHandle(shouldRetryArgs).ConfigureAwait(context.ContinueOnCapturedContext); var executionTime = _timeProvider.GetElapsedTime(startTimestamp); @@ -67,7 +76,7 @@ protected internal override async ValueTask> ExecuteCore(Func TelemetryUtil.ReportExecutionAttempt(_telemetry, context, outcome, attempt, executionTime, handle); } - if (context.CancellationToken.IsCancellationRequested || isLastAttempt || !handle) + if (isLastAttempt || !handle) { return outcome; } diff --git a/test/Polly.Core.Tests/Retry/RetryResilienceStrategyTests.cs b/test/Polly.Core.Tests/Retry/RetryResilienceStrategyTests.cs index c389351329b..c403116de93 100644 --- a/test/Polly.Core.Tests/Retry/RetryResilienceStrategyTests.cs +++ b/test/Polly.Core.Tests/Retry/RetryResilienceStrategyTests.cs @@ -31,11 +31,13 @@ public void ExecuteAsync_EnsureResultNotDisposed() } [Fact] - public async Task ExecuteAsync_CancellationRequested_EnsureNotRetried() + public async Task ExecuteAsync_CancellationRequestedBeforeCallback_EnsureNoAttempt() { SetupNoDelay(); - var sut = CreateSut(); using var cancellationToken = new CancellationTokenSource(); + _options.ShouldHandle = _ => PredicateResult.True(); + var sut = CreateSut(); + cancellationToken.Cancel(); var context = ResilienceContextPool.Shared.Get(); context.CancellationToken = cancellationToken.Token; @@ -47,10 +49,36 @@ public async Task ExecuteAsync_CancellationRequested_EnsureNotRetried() } [Fact] - public async Task ExecuteAsync_CancellationRequestedAfterCallback_EnsureNotRetried() + public async Task ExecuteAsync_CancellationRequestedDuringCallback_EnsureNotRetried() { + SetupNoDelay(); using var cancellationToken = new CancellationTokenSource(); + _options.ShouldHandle = _ => PredicateResult.True(); + var sut = CreateSut(); + var context = ResilienceContextPool.Shared.Get(); + context.CancellationToken = cancellationToken.Token; + var executed = false; + var attemptCounter = 0; + + var result = await sut.ExecuteOutcomeAsync((_, _) => + { + executed = true; + ++attemptCounter; + cancellationToken.Cancel(); + return Outcome.FromResultAsValueTask("dummy"); + }, context, "state"); + + result.Exception.Should().BeOfType(); + executed.Should().BeTrue(); + attemptCounter.Should().Be(1); + } + + [Fact] + public async Task ExecuteAsync_CancellationRequestedAfterCallback_EnsureNotRetried() + { + SetupNoDelay(); + using var cancellationToken = new CancellationTokenSource(); _options.ShouldHandle = _ => PredicateResult.True(); _options.OnRetry = _ => { @@ -62,10 +90,18 @@ public async Task ExecuteAsync_CancellationRequestedAfterCallback_EnsureNotRetri var context = ResilienceContextPool.Shared.Get(); context.CancellationToken = cancellationToken.Token; var executed = false; + var attemptCounter = 0; + + var result = await sut.ExecuteOutcomeAsync((_, _) => + { + executed = true; + ++attemptCounter; + return Outcome.FromResultAsValueTask("dummy"); + }, context, "state"); - var result = await sut.ExecuteOutcomeAsync((_, _) => { executed = true; return Outcome.FromResultAsValueTask("dummy"); }, context, "state"); result.Exception.Should().BeOfType(); executed.Should().BeTrue(); + attemptCounter.Should().Be(1); } [Fact]