diff --git a/pkg/retry/retry.go b/pkg/retry/retry.go index 171fcc87b..fd17eefc4 100644 --- a/pkg/retry/retry.go +++ b/pkg/retry/retry.go @@ -23,8 +23,10 @@ type IsRetryable func(error) bool // Settings aggregates optional settings for WithBackoff. type Settings struct { - // If >0, Timeout lets WithBackoff stop retrying once elapsed, - // but allows the RetryableFunc its execution time and **doesn't abort** it if it exceeds Timeout. + // If >0, Timeout lets WithBackoff stop retrying gracefully once elapsed based on the following criteria: + // * If the execution of RetryableFunc has taken longer than Timeout, no further attempts are made. + // * If Timeout elapses during the sleep phase between retries, one final retry is attempted. + // * RetryableFunc is always granted its full execution time and is not canceled if it exceeds Timeout. // This means that WithBackoff may not stop exactly after Timeout expires, // or may not retry at all if the first execution of RetryableFunc already takes longer than Timeout. Timeout time.Duration @@ -50,6 +52,7 @@ func WithBackoff( } start := time.Now() + timedOut := false for attempt := uint64(1); ; /* true */ attempt++ { prevErr := err @@ -80,6 +83,19 @@ func WithBackoff( return } + select { + case <-timeout: + // Stop retrying immediately if executing the retryable function took longer than the timeout. + timedOut = true + default: + } + + if timedOut { + err = errors.Wrap(err, "retry deadline exceeded") + + return + } + if settings.OnRetryableError != nil { settings.OnRetryableError(time.Since(start), attempt, err, prevErr) } @@ -87,9 +103,10 @@ func WithBackoff( select { case <-time.After(b(attempt)): case <-timeout: - err = errors.Wrap(err, "retry deadline exceeded") - - return + // Do not stop retrying immediately, but start one last attempt to mitigate timing issues where + // the timeout expires while waiting for the next attempt and + // therefore no retries have happened during this possibly long period. + timedOut = true case <-ctx.Done(): err = errors.Wrap(ctx.Err(), err.Error())