From 76e7cd8744f918963bfae993ecf661ebfd6199c9 Mon Sep 17 00:00:00 2001 From: Nicholas Wiersma Date: Sun, 10 Dec 2023 06:47:42 +0200 Subject: [PATCH] feat: add poll functions --- wait/wait.go | 25 ++++++++++++++++++------- wait/wait_test.go | 47 +++++++++++++++++++++++++++++++++++++++-------- 2 files changed, 57 insertions(+), 15 deletions(-) diff --git a/wait/wait.go b/wait/wait.go index 8007bdb..507d31c 100644 --- a/wait/wait.go +++ b/wait/wait.go @@ -12,12 +12,23 @@ type ConditionFunc func(context.Context) (done bool, err error) // PollUntil tries a condition until stopped by the context. func PollUntil(ctx context.Context, fn ConditionFunc, interval time.Duration) error { - done, err := fn(ctx) - if err != nil { - return err - } - if done { - return nil + return poll(ctx, false, fn, interval) +} + +// PollImmediateUntil tries a condition until stopped by the context. +func PollImmediateUntil(ctx context.Context, fn ConditionFunc, interval time.Duration) error { + return poll(ctx, true, fn, interval) +} + +func poll(ctx context.Context, immediate bool, fn ConditionFunc, interval time.Duration) error { + if immediate { + done, err := fn(ctx) + if err != nil { + return err + } + if done { + return nil + } } tick := time.NewTicker(interval) @@ -28,7 +39,7 @@ func PollUntil(ctx context.Context, fn ConditionFunc, interval time.Duration) er case <-ctx.Done(): return ctx.Err() case <-tick.C: - done, err = fn(ctx) + done, err := fn(ctx) if err != nil { return err } diff --git a/wait/wait_test.go b/wait/wait_test.go index ef1d7f3..dbfa1fc 100644 --- a/wait/wait_test.go +++ b/wait/wait_test.go @@ -18,18 +18,17 @@ func TestPollUntil(t *testing.T) { go func() { defer close(called) - err := wait.PollUntil(ctx, func(context.Context) (done bool, err error) { + err := wait.PollImmediateUntil(ctx, func(context.Context) (done bool, err error) { called <- struct{}{} return false, nil - }, time.Microsecond) + }, 100*time.Microsecond) assert.ErrorIs(t, err, context.Canceled) }() - // Wait for the initial condition call, and the first tick + // Wait for the first tick // condition call. <-called - <-called // Stop waiting. cancel() @@ -75,22 +74,54 @@ func TestPollUntil_HandlesDone(t *testing.T) { require.NoError(t, err) } -func TestPollUntil_HandlesImmediateError(t *testing.T) { +func TestPollImmediateUntil(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + + called := make(chan struct{}) + go func() { + defer close(called) + + err := wait.PollImmediateUntil(ctx, func(context.Context) (done bool, err error) { + called <- struct{}{} + return false, nil + }, 100*time.Microsecond) + + assert.ErrorIs(t, err, context.Canceled) + }() + + // Wait for the initial condition call, and the first tick + // condition call. + <-called + <-called + + // Stop waiting. + cancel() + + // Assert that the condition is not called more than once after + // canceling the context. + var calledCount int + for range called { + calledCount++ + } + assert.LessOrEqual(t, calledCount, 1) +} + +func TestPollImmediateUntil_HandlesImmediateError(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - err := wait.PollUntil(ctx, func(context.Context) (done bool, err error) { + err := wait.PollImmediateUntil(ctx, func(context.Context) (done bool, err error) { return false, errors.New("test") }, time.Microsecond) require.Error(t, err) } -func TestPollUntil_HandlesImmediateDone(t *testing.T) { +func TestPollImmediateUntil_HandlesImmediateDone(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - err := wait.PollUntil(ctx, func(context.Context) (done bool, err error) { + err := wait.PollImmediateUntil(ctx, func(context.Context) (done bool, err error) { return true, nil }, time.Microsecond)