diff --git a/go.mod b/go.mod index df918aa..5f3f073 100644 --- a/go.mod +++ b/go.mod @@ -8,7 +8,6 @@ require ( github.com/eapache/go-xerial-snappy v0.0.0-20180814174437-776d5712da21 // indirect github.com/eapache/queue v1.1.0 // indirect github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db // indirect - github.com/magiconair/properties v1.8.0 github.com/msales/pkg/v3 v3.1.0 github.com/pierrec/lz4 v0.0.0-20181005164709-635575b42742 // indirect github.com/pkg/errors v0.8.0 diff --git a/mocks_test.go b/mocks_test.go index 0d32570..9c99ac8 100644 --- a/mocks_test.go +++ b/mocks_test.go @@ -215,3 +215,28 @@ func (s *MockSource) Close() error { args := s.Called() return args.Error(0) } + +type MockTask struct { + mock.Mock + + startCalled time.Time + onErrorCalled time.Time + closeCalled time.Time +} + +func (t *MockTask) Start() error { + t.startCalled = time.Now() + + return t.Called().Error(0) +} + +func (t *MockTask) OnError(fn streams.ErrorFunc) { + t.onErrorCalled = time.Now() + t.Called(fn) +} + +func (t *MockTask) Close() error { + t.closeCalled = time.Now() + + return t.Called().Error(0) +} diff --git a/task.go b/task.go index 1a3ed56..d360073 100644 --- a/task.go +++ b/task.go @@ -173,3 +173,59 @@ func (t *streamTask) handleError(err error) { func (t *streamTask) OnError(fn ErrorFunc) { t.errorFn = fn } + +// Tasks represents a slice of tasks. +// This is a utility type that makes it easier to work with multiple tasks. +type Tasks []Task + +// Start starts the streams processors. +func (tasks Tasks) Start() error { + err := tasks.each(func(t Task) error { + return t.Start() + }) + + return err +} + +// OnError sets the error handler. +func (tasks Tasks) OnError(fn ErrorFunc) { + _ = tasks.each(func(t Task) error { + t.OnError(fn) + return nil + }) +} + +// Close stops and closes the streams processors. +// This function operates on the tasks in the reversed order. +func (tasks Tasks) Close() error { + err := tasks.eachRev(func(t Task) error { + return t.Close() + }) + + return err + +} + +// each executes a passed function with every task in the slice. +func (tasks Tasks) each(fn func(t Task) error) error { + for _, t := range tasks { + err := fn(t) + if err != nil { + return err + } + } + + return nil +} + +// eachRev executes a passed function with every task in the slice, in the reversed order. +func (tasks Tasks) eachRev(fn func(t Task) error) error { + for i := len(tasks) - 1; i >= 0; i-- { + err := fn(tasks[i]) + if err != nil { + return err + } + } + + return nil +} diff --git a/task_test.go b/task_test.go index 5674bf8..6113b11 100644 --- a/task_test.go +++ b/task_test.go @@ -211,6 +211,90 @@ func TestStreamTask_HandleCloseWithSourceError(t *testing.T) { assert.Error(t, err) } +func TestTasks_Start(t *testing.T) { + t1, t2, t3 := new(MockTask), new(MockTask), new(MockTask) + t1.On("Start").Return(nil) + t2.On("Start").Return(nil) + t3.On("Start").Return(nil) + + tasks := streams.Tasks{t1, t2, t3} + + err := tasks.Start() + + assert.NoError(t, err) + t1.AssertExpectations(t) + t2.AssertExpectations(t) + t3.AssertExpectations(t) + assert.True(t, t1.startCalled.Before(t2.startCalled)) + assert.True(t, t2.startCalled.Before(t3.startCalled)) +} + +func TestTasks_Start_WithError(t *testing.T) { + t1, t2, t3 := new(MockTask), new(MockTask), new(MockTask) + t1.On("Start").Return(nil) + t2.On("Start").Return(errors.New("test error")) + + tasks := streams.Tasks{t1, t2, t3} + + err := tasks.Start() + + assert.Error(t, err) + t1.AssertExpectations(t) + t2.AssertExpectations(t) + t3.AssertNotCalled(t, "Start") +} + +func TestTasks_OnError(t *testing.T) { + fn := streams.ErrorFunc(func(_ error) {}) + t1, t2, t3 := new(MockTask), new(MockTask), new(MockTask) + t1.On("OnError", mock.AnythingOfType("streams.ErrorFunc")).Return() + t2.On("OnError", mock.AnythingOfType("streams.ErrorFunc")).Return() + t3.On("OnError", mock.AnythingOfType("streams.ErrorFunc")).Return() + + tasks := streams.Tasks{t1, t2, t3} + + tasks.OnError(fn) + + t1.AssertExpectations(t) + t2.AssertExpectations(t) + t3.AssertExpectations(t) + assert.True(t, t1.onErrorCalled.Before(t2.onErrorCalled)) + assert.True(t, t2.onErrorCalled.Before(t3.onErrorCalled)) +} + +func TestTasks_Close(t *testing.T) { + t1, t2, t3 := new(MockTask), new(MockTask), new(MockTask) + t1.On("Close").Return(nil) + t2.On("Close").Return(nil) + t3.On("Close").Return(nil) + + tasks := streams.Tasks{t1, t2, t3} + + err := tasks.Close() + + assert.NoError(t, err) + t1.AssertExpectations(t) + t2.AssertExpectations(t) + t3.AssertExpectations(t) + assert.True(t, t1.closeCalled.After(t2.closeCalled)) + assert.True(t, t2.closeCalled.After(t3.closeCalled)) +} + +func TestTasks_Close_WithError(t *testing.T) { + t1, t2, t3 := new(MockTask), new(MockTask), new(MockTask) + t2.On("Close").Return(errors.New("test error")) + t3.On("Close").Return(nil) + + tasks := streams.Tasks{t1, t2, t3} + + err := tasks.Close() + + assert.Error(t, err) + t1.AssertNotCalled(t, "Close") + t2.AssertExpectations(t) + t3.AssertExpectations(t) +} + type chanSource struct { msgs chan *streams.Message }