diff --git a/archive/archive.go b/archive/archive.go index 9bba9161..278f3c45 100644 --- a/archive/archive.go +++ b/archive/archive.go @@ -25,6 +25,7 @@ import ( "github.com/lirm/aeron-go/aeron" "github.com/lirm/aeron-go/aeron/atomic" "github.com/lirm/aeron-go/aeron/logbuffer" + "github.com/lirm/aeron-go/aeron/logbuffer/term" "github.com/lirm/aeron-go/aeron/logging" "github.com/lirm/aeron-go/archive/codecs" ) @@ -318,8 +319,9 @@ func NewArchive(options *Options, context *aeron.Context) (*Archive, error) { for archive.Control.State.state != ControlStateConnected && archive.Control.State.err == nil { fragments := archive.Control.poll( - func(buf *atomic.Buffer, offset int32, length int32, header *logbuffer.Header) { + func(buf *atomic.Buffer, offset int32, length int32, header *logbuffer.Header) term.ControlledPollAction { ConnectionControlFragmentHandler(&pollContext, buf, offset, length, header) + return term.ControlledPollActionContinue }, 1) if fragments > 0 { logger.Debugf("Read %d fragment(s)", fragments) diff --git a/archive/control.go b/archive/control.go index 64f0b059..27b9380b 100644 --- a/archive/control.go +++ b/archive/control.go @@ -113,13 +113,19 @@ func init() { codecIds.recordingStopped = recordingStopped.SbeTemplateId() } -func controlFragmentHandler(context interface{}, buffer *atomic.Buffer, offset int32, length int32, header *logbuffer.Header) { +func controlFragmentHandler(context interface{}, buffer *atomic.Buffer, offset int32, length int32, header *logbuffer.Header) (action term.ControlledPollAction) { + action = term.ControlledPollActionContinue + pollContext, ok := context.(*PollContext) if !ok { logger.Errorf("context conversion failed") return } + if pollContext.control.Results.IsPollComplete { + return term.ControlledPollActionAbort + } + logger.Debugf("controlFragmentHandler: correlationID:%d offset:%d length:%d header:%#v", pollContext.correlationID, offset, length, header) var hdr codecs.SbeGoMessageHeader @@ -170,6 +176,8 @@ func controlFragmentHandler(context interface{}, buffer *atomic.Buffer, offset i logger.Debugf("controlFragmentHandler/controlResponse: received for sessionID:%d, correlationID:%d", controlResponse.ControlSessionId, controlResponse.CorrelationId) control.Results.ControlResponse = controlResponse control.Results.IsPollComplete = true + + return term.ControlledPollActionBreak } else { logger.Debugf("controlFragmentHandler/controlResponse ignoring sessionID:%d, correlationID:%d", controlResponse.ControlSessionId, controlResponse.CorrelationId) } @@ -198,6 +206,7 @@ func controlFragmentHandler(context interface{}, buffer *atomic.Buffer, offset i // This can happen when testing/adding new functionality fmt.Printf("controlFragmentHandler: Unexpected message type %d\n", hdr.TemplateId) } + return } // ConnectionControlFragmentHandler is the connection handling specific fragment handler. @@ -349,12 +358,14 @@ func (control *Control) PollForErrorResponse() (int, error) { context := PollContext{control, 0} received := 0 + control.Results.ErrorResponse = nil + // Poll for async events, errors etc until the queue is drained for { ret := control.poll( - func(buf *atomic.Buffer, offset int32, length int32, header *logbuffer.Header) { - errorResponseFragmentHandler(&context, buf, offset, length, header) - }, 1) + func(buf *atomic.Buffer, offset int32, length int32, header *logbuffer.Header) term.ControlledPollAction { + return errorResponseFragmentHandler(&context, buf, offset, length, header) + }, 10) received += ret // If we received a response with an error then return it @@ -377,13 +388,19 @@ func (control *Control) PollForErrorResponse() (int, error) { // ignore messages not on our session ID // process recordingSignalEvents // Log a warning if we have interrupted a synchronous event -func errorResponseFragmentHandler(context interface{}, buffer *atomic.Buffer, offset int32, length int32, header *logbuffer.Header) { +func errorResponseFragmentHandler(context interface{}, buffer *atomic.Buffer, offset int32, length int32, header *logbuffer.Header) (action term.ControlledPollAction) { + action = term.ControlledPollActionContinue + pollContext, ok := context.(*PollContext) if !ok { logger.Errorf("context conversion failed") return } + if pollContext.control.Results.ErrorResponse != nil { + return term.ControlledPollActionAbort + } + logger.Debugf("errorResponseFragmentHandler: offset:%d length: %d", offset, length) var hdr codecs.SbeGoMessageHeader @@ -419,6 +436,7 @@ func errorResponseFragmentHandler(context interface{}, buffer *atomic.Buffer, of if controlResponse.ControlSessionId == pollContext.control.archive.SessionID { if controlResponse.Code == codecs.ControlResponseCode.ERROR { pollContext.control.Results.ErrorResponse = fmt.Errorf("PollForErrorResponse received a ControlResponse (correlationId:%d Code:ERROR error=\"%s\"", controlResponse.CorrelationId, controlResponse.ErrorMessage) + return term.ControlledPollActionBreak } } return @@ -500,11 +518,13 @@ func errorResponseFragmentHandler(context interface{}, buffer *atomic.Buffer, of default: fmt.Printf("errorResponseFragmentHandler: Insert decoder for type: %d", hdr.TemplateId) } + + return } // poll provides the control response poller using local state to pass // back data from the underlying subscription -func (control *Control) poll(handler term.FragmentHandler, fragmentLimit int) int { +func (control *Control) poll(handler term.ControlledFragmentHandler, fragmentLimit int) int { // Update our globals in case they've changed so we use the current state in our callback rangeChecking = control.archive.Options.RangeChecking @@ -512,7 +532,7 @@ func (control *Control) poll(handler term.FragmentHandler, fragmentLimit int) in control.Results.ControlResponse = nil // Clear old results control.Results.IsPollComplete = false // Clear completion flag - return control.Subscription.Poll(handler, fragmentLimit) + return control.Subscription.ControlledPoll(handler, fragmentLimit) } // Poll for control response events. Returns number of fragments read during the operation. @@ -605,8 +625,8 @@ func (control *Control) PollForResponse(correlationID int64, sessionID int64) (i start := time.Now() context := PollContext{control, correlationID} - handler := aeron.NewFragmentAssembler(func(buf *atomic.Buffer, offset int32, length int32, header *logbuffer.Header) { - controlFragmentHandler(&context, buf, offset, length, header) + handler := aeron.NewControlledFragmentAssembler(func(buf *atomic.Buffer, offset int32, length int32, header *logbuffer.Header) term.ControlledPollAction { + return controlFragmentHandler(&context, buf, offset, length, header) }, aeron.DefaultFragmentAssemblyBufferLength) for { ret := control.poll(handler.OnFragment, 10) @@ -797,8 +817,9 @@ func (control *Control) PollForDescriptors(correlationID int64, sessionID int64, for !control.Results.IsPollComplete { logger.Debugf("PollForDescriptors(%d:%d, %d)", correlationID, sessionID, int(fragmentsWanted)-descriptorCount) fragments := control.poll( - func(buf *atomic.Buffer, offset int32, length int32, header *logbuffer.Header) { + func(buf *atomic.Buffer, offset int32, length int32, header *logbuffer.Header) term.ControlledPollAction { DescriptorFragmentHandler(&pollContext, buf, offset, length, header) + return term.ControlledPollActionContinue }, int(fragmentsWanted)-descriptorCount) logger.Debugf("Poll(%d:%d) returned %d fragments", correlationID, sessionID, fragments) descriptorCount = len(control.Results.RecordingDescriptors) + len(control.Results.RecordingSubscriptionDescriptors) diff --git a/archive/control_test.go b/archive/control_test.go new file mode 100644 index 00000000..24f90a4a --- /dev/null +++ b/archive/control_test.go @@ -0,0 +1,230 @@ +package archive + +import ( + "bytes" + "io" + "reflect" + "testing" + "time" + "unsafe" + + "github.com/lirm/aeron-go/aeron" + "github.com/lirm/aeron-go/aeron/atomic" + "github.com/lirm/aeron-go/aeron/idlestrategy" + "github.com/lirm/aeron-go/aeron/logbuffer" + "github.com/lirm/aeron-go/aeron/logbuffer/term" + "github.com/lirm/aeron-go/archive/codecs" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +func TestControl_PollForResponse(t *testing.T) { + t.Run("times out when nothing to poll", func(t *testing.T) { + control, image := newTestControl(t) + mockPollResponses(t, image) + correlations.Store(int64(1), control) + id, err := control.PollForResponse(1, 0) + assert.Zero(t, id) + assert.EqualError(t, err, `timeout waiting for correlationID 1`) + }) + + t.Run("returns the error response", func(t *testing.T) { + control, image := newTestControl(t) + mockPollResponses(t, image, + &codecs.ControlResponse{ + Code: codecs.ControlResponseCode.ERROR, + CorrelationId: 1, + ErrorMessage: []byte(`b0rk`), + RelevantId: 3, + }, + ) + correlations.Store(int64(1), control) + id, err := control.PollForResponse(1, 0) + assert.EqualValues(t, 3, id) + assert.EqualError(t, err, `Control Response failure: b0rk`) + }) + + t.Run("discards all responses preceding the first error", func(t *testing.T) { + control, image := newTestControl(t) + mockPollResponses(t, image, + &codecs.ControlResponse{Code: codecs.ControlResponseCode.OK}, + &codecs.ControlResponse{Code: codecs.ControlResponseCode.OK}, + &codecs.ControlResponse{ + Code: codecs.ControlResponseCode.ERROR, + CorrelationId: 1, + ErrorMessage: []byte(`b0rk`), + RelevantId: 3, + }, + ) + correlations.Store(int64(1), control) + id, err := control.PollForResponse(1, 0) + assert.EqualValues(t, 3, id) + assert.EqualError(t, err, `Control Response failure: b0rk`) + }) + + t.Run("does not process messages after result", func(t *testing.T) { + control, image := newTestControl(t) + mockPollResponses(t, image, + &codecs.ControlResponse{ + Code: codecs.ControlResponseCode.ERROR, + CorrelationId: 1, + ErrorMessage: []byte(`b0rk`), + RelevantId: 3, + }, + &codecs.ControlResponse{Code: codecs.ControlResponseCode.OK}, + ) + correlations.Store(int64(1), control) + id, err := control.PollForResponse(1, 0) + assert.EqualValues(t, 3, id) + assert.EqualError(t, err, `Control Response failure: b0rk`) + fragments := image.Poll(func(buffer *atomic.Buffer, offset, length int32, header *logbuffer.Header) {}, 1) + assert.EqualValues(t, 1, fragments) + }) +} + +func TestControl_PollForErrorResponse(t *testing.T) { + t.Run("returns zero when nothing to poll", func(t *testing.T) { + control, image := newTestControl(t) + mockPollResponses(t, image) + cnt, err := control.PollForErrorResponse() + assert.Zero(t, cnt) + assert.NoError(t, err) + }) + + t.Run("returns the error response", func(t *testing.T) { + control, image := newTestControl(t) + mockPollResponses(t, image, + &codecs.ControlResponse{ + Code: codecs.ControlResponseCode.ERROR, + ErrorMessage: []byte(`b0rk`), + }, + ) + cnt, err := control.PollForErrorResponse() + assert.EqualValues(t, 1, cnt) + assert.EqualError(t, err, `PollForErrorResponse received a ControlResponse (correlationId:0 Code:ERROR error="b0rk"`) + }) + + t.Run("discards all queued responses unless error", func(t *testing.T) { + control, image := newTestControl(t) + mockPollResponses(t, image, + &codecs.ControlResponse{Code: codecs.ControlResponseCode.OK}, + &codecs.ControlResponse{Code: codecs.ControlResponseCode.OK}, + ) + cnt, err := control.PollForErrorResponse() + assert.EqualValues(t, 2, cnt) + assert.NoError(t, err) + }) + + t.Run("does not process messages after error", func(t *testing.T) { + control, image := newTestControl(t) + mockPollResponses(t, image, + &codecs.ControlResponse{ + Code: codecs.ControlResponseCode.ERROR, + ErrorMessage: []byte(`b0rk`), + }, + &codecs.ControlResponse{Code: codecs.ControlResponseCode.OK}, + ) + cnt, err := control.PollForErrorResponse() + assert.EqualValues(t, 1, cnt) + assert.Error(t, err) + cnt, err = control.PollForErrorResponse() + assert.EqualValues(t, 1, cnt) + assert.NoError(t, err) + }) +} + +func mockPollResponses(t *testing.T, image *aeron.MockImage, responses ...encodable) { + poll := image.On("Poll", mock.Anything, mock.Anything) + poll.Maybe() + poll.Run(func(args mock.Arguments) { + handler := args.Get(0).(term.FragmentHandler) + fragmentCount := args.Get(1).(int) + count := image.ControlledPoll(func(buffer *atomic.Buffer, offset, length int32, header *logbuffer.Header) term.ControlledPollAction { + handler(buffer, offset, length, header) + return term.ControlledPollActionContinue + }, fragmentCount) + poll.Return(count) + }) + controlledPoll := image.On("ControlledPoll", mock.Anything, mock.Anything) + controlledPoll.Run(func(args mock.Arguments) { + fragmentCount := args.Get(1).(int) + count := 0 + for count < fragmentCount { + if len(responses) == 0 { + break + } + buffer := encode(t, responses[0]) + action := args.Get(0).(term.ControlledFragmentHandler)(buffer, 0, buffer.Capacity(), newTestHeader()) + if action == term.ControlledPollActionAbort { + break + } + responses = responses[1:] + count++ + if action == term.ControlledPollActionBreak { + break + } + } + controlledPoll.Return(count) + }) +} + +type encodable interface { + SbeBlockLength() uint16 + SbeTemplateId() uint16 + SbeSchemaId() uint16 + SbeSchemaVersion() uint16 + Encode(*codecs.SbeGoMarshaller, io.Writer, bool) error +} + +func encode(t *testing.T, data encodable) *atomic.Buffer { + m := codecs.NewSbeGoMarshaller() + buf := new(bytes.Buffer) + header := codecs.MessageHeader{ + BlockLength: data.SbeBlockLength(), + TemplateId: data.SbeTemplateId(), + SchemaId: data.SbeSchemaId(), + Version: data.SbeSchemaVersion(), + } + if !assert.NoError(t, header.Encode(m, buf)) { + return nil + } + if !assert.NoError(t, data.Encode(m, buf, false)) { + return nil + } + return atomic.MakeBuffer(buf.Bytes()) +} + +func newTestControl(t *testing.T) (*Control, *aeron.MockImage) { + image := aeron.NewMockImage(t) + c := &Control{ + Subscription: newTestSub(image), + } + c.archive = &Archive{ + Listeners: &ArchiveListeners{}, + Options: &Options{ + Timeout: 100*time.Millisecond, + IdleStrategy: idlestrategy.Yielding{}, + }, + } + c.fragmentAssembler = aeron.NewControlledFragmentAssembler( + c.onFragment, aeron.DefaultFragmentAssemblyBufferLength, + ) + return c, image +} + +func newTestSub(image aeron.Image) *aeron.Subscription { + images := aeron.NewImageList() + images.Set([]aeron.Image{image}) + sub := aeron.NewSubscription(nil, "", 1, 2, 3, nil, nil) + rsub := reflect.ValueOf(sub) + rfimages := rsub.Elem().FieldByName("images") + rfimages = reflect.NewAt(rfimages.Type(), unsafe.Pointer(rfimages.UnsafeAddr())).Elem() + rfimages.Set(reflect.ValueOf(images)) + return sub +} + +func newTestHeader() *logbuffer.Header { + buffer := atomic.MakeBuffer(make([]byte, logbuffer.DataFrameHeader.Length)) + buffer.PutUInt8(logbuffer.DataFrameHeader.FlagsFieldOffset, 0xc0) // unfragmented + return new(logbuffer.Header).Wrap(buffer.Ptr(), buffer.Capacity()) +}