diff --git a/docs/features/wait/walk.md b/docs/features/wait/walk.md new file mode 100644 index 0000000000..d69c942631 --- /dev/null +++ b/docs/features/wait/walk.md @@ -0,0 +1,12 @@ +# Walk + +Walk walks the strategies tree and calls the visit function for each node. + +If visit function returns `wait.VisitStop`, the walk stops. +If visit function returns `wait.VisitRemove`, the current node is removed. + +## Walk removing entries + + +[Walk a strategy and remove the all FileStrategy entries found](../../../wait/walk_test.go) inside_block:walkRemoveFileStrategy + diff --git a/wait/walk.go b/wait/walk.go new file mode 100644 index 0000000000..4685e50088 --- /dev/null +++ b/wait/walk.go @@ -0,0 +1,74 @@ +package wait + +import ( + "errors" +) + +var ( + // VisitStop is used as a return value from [VisitFunc] to stop the walk. + // It is not returned as an error by any function. + VisitStop = errors.New("stop the walk") + + // VisitRemove is used as a return value from [VisitFunc] to have the current node removed. + // It is not returned as an error by any function. + VisitRemove = errors.New("remove this strategy") +) + +// VisitFunc is a function that visits a strategy node. +// If it returns [VisitStop], the walk stops. +// If it returns [VisitRemove], the current node is removed. +type VisitFunc func(root Strategy) error + +// Walk walks the strategies tree and calls the visit function for each node. +func Walk(root *Strategy, visit VisitFunc) error { + if root == nil { + return errors.New("root strategy is nil") + } + + if err := walk(root, visit); err != nil { + if errors.Is(err, VisitRemove) || errors.Is(err, VisitStop) { + return nil + } + return err + } + + return nil +} + +// walk walks the strategies tree and calls the visit function for each node. +// It returns an error if the visit function returns an error. +func walk(root *Strategy, visit VisitFunc) error { + if *root == nil { + // No strategy. + return nil + } + + // Allow the visit function to customize the behaviour of the walk before visiting the children. + if err := visit(*root); err != nil { + if errors.Is(err, VisitRemove) { + *root = nil + } + + return err + } + + if s, ok := (*root).(*MultiStrategy); ok { + var i int + for range s.Strategies { + if err := walk(&s.Strategies[i], visit); err != nil { + if errors.Is(err, VisitRemove) { + s.Strategies = append(s.Strategies[:i], s.Strategies[i+1:]...) + if errors.Is(err, VisitStop) { + return VisitStop + } + continue + } + + return err + } + i++ + } + } + + return nil +} diff --git a/wait/walk_test.go b/wait/walk_test.go new file mode 100644 index 0000000000..e8f8df2f2b --- /dev/null +++ b/wait/walk_test.go @@ -0,0 +1,127 @@ +package wait_test + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/testcontainers/testcontainers-go" + "github.com/testcontainers/testcontainers-go/wait" +) + +func TestWalk(t *testing.T) { + req := testcontainers.ContainerRequest{ + WaitingFor: wait.ForAll( + wait.ForFile("/tmp/file"), + wait.ForHTTP("/health"), + wait.ForAll( + wait.ForFile("/tmp/other"), + ), + ), + } + + t.Run("walk", func(t *testing.T) { + var count int + err := wait.Walk(&req.WaitingFor, func(_ wait.Strategy) error { + count++ + return nil + }) + require.NoError(t, err) + require.Equal(t, 5, count) + }) + + t.Run("stop", func(t *testing.T) { + var count int + err := wait.Walk(&req.WaitingFor, func(_ wait.Strategy) error { + count++ + return wait.VisitStop + }) + require.NoError(t, err) + require.Equal(t, 1, count) + }) + + t.Run("remove", func(t *testing.T) { + // walkRemoveFileStrategy { + var count, matched int + err := wait.Walk(&req.WaitingFor, func(s wait.Strategy) error { + count++ + if _, ok := s.(*wait.FileStrategy); ok { + matched++ + return wait.VisitRemove + } + + return nil + }) + // } + require.NoError(t, err) + require.Equal(t, 5, count) + require.Equal(t, 2, matched) + + count = 0 + matched = 0 + err = wait.Walk(&req.WaitingFor, func(s wait.Strategy) error { + count++ + if _, ok := s.(*wait.FileStrategy); ok { + matched++ + } + return nil + }) + require.NoError(t, err) + require.Equal(t, 3, count) + require.Zero(t, matched) + }) + + t.Run("remove-stop", func(t *testing.T) { + req := testcontainers.ContainerRequest{ + WaitingFor: wait.ForAll( + wait.ForFile("/tmp/file"), + wait.ForHTTP("/health"), + ), + } + var count int + err := wait.Walk(&req.WaitingFor, func(_ wait.Strategy) error { + count++ + return errors.Join(wait.VisitRemove, wait.VisitStop) + }) + require.NoError(t, err) + require.Equal(t, 1, count) + require.Nil(t, req.WaitingFor) + }) + + t.Run("nil-root", func(t *testing.T) { + err := wait.Walk(nil, func(_ wait.Strategy) error { + return nil + }) + require.EqualError(t, err, "root strategy is nil") + }) + + t.Run("direct-single", func(t *testing.T) { + req := testcontainers.ContainerRequest{ + WaitingFor: wait.ForFile("/tmp/file"), + } + requireVisits(t, req, 1) + }) + + t.Run("for-all-single", func(t *testing.T) { + req := testcontainers.ContainerRequest{ + WaitingFor: wait.ForAll( + wait.ForFile("/tmp/file"), + ), + } + requireVisits(t, req, 2) + }) +} + +// requireVisits validates the number of visits for a given request. +func requireVisits(t *testing.T, req testcontainers.ContainerRequest, expected int) { + t.Helper() + + var count int + err := wait.Walk(&req.WaitingFor, func(_ wait.Strategy) error { + count++ + return nil + }) + require.NoError(t, err) + require.Equal(t, expected, count) +}