diff --git a/archive/control_test.go b/archive/control_test.go index fb95fc65..1bf995c5 100644 --- a/archive/control_test.go +++ b/archive/control_test.go @@ -21,7 +21,7 @@ import ( func TestControl_PollForResponse(t *testing.T) { t.Run("times out when nothing to poll", func(t *testing.T) { control, image := newTestControl(t) - image.On("Poll", mock.Anything, mock.Anything).Return(0) + mockPollResponses(t, image) correlations.Store(int64(1), control) id, err := control.PollForResponse(1, 0) assert.Zero(t, id) @@ -85,7 +85,7 @@ func TestControl_PollForResponse(t *testing.T) { func TestControl_PollForErrorResponse(t *testing.T) { t.Run("returns zero when nothing to poll", func(t *testing.T) { control, image := newTestControl(t) - image.On("Poll", mock.Anything, mock.Anything).Return(0) + mockPollResponses(t, image) cnt, err := control.PollForErrorResponse() assert.Zero(t, cnt) assert.NoError(t, err) @@ -135,18 +135,36 @@ func TestControl_PollForErrorResponse(t *testing.T) { func mockPollResponses(t *testing.T, image *aeron.MockImage, responses ...encodable) { poll := image.On("Poll", mock.Anything, mock.Anything) + poll.Maybe() poll.Run(func(args mock.Arguments) { + handler := args.Get(0).(term.FragmentHandler) + fragmentCount := args.Get(1).(int) + count := image.ControlledPoll(func(buffer *atomic.Buffer, offset, length int32, header *logbuffer.Header) term.ControlledPollAction { + handler(buffer, offset, length, header) + return term.ControlledPollActionContinue + }, fragmentCount) + poll.Return(count) + }) + controlledPoll := image.On("ControlledPoll", mock.Anything, mock.Anything) + controlledPoll.Run(func(args mock.Arguments) { fragmentCount := args.Get(1).(int) count := 0 - for ; count < fragmentCount; count++ { + for count < fragmentCount { if len(responses) == 0 { break } buffer := encode(t, responses[0]) + action := args.Get(0).(term.ControlledFragmentHandler)(buffer, 0, buffer.Capacity(), newTestHeader()) + if action == term.ControlledPollActionAbort { + break + } responses = responses[1:] - args.Get(0).(term.FragmentHandler)(buffer, 0, buffer.Capacity(), newTestHeader()) + count++ + if action == term.ControlledPollActionBreak { + break + } } - poll.Return(count) + controlledPoll.Return(count) }) }