Skip to content

Commit

Permalink
Merge pull request #102 from grddev/limit-response-poll
Browse files Browse the repository at this point in the history
Use ControlledPoll for PollForResponse and PollForErrorResponse
  • Loading branch information
ethanf authored Jun 5, 2024
2 parents f6902ab + 4ef3a85 commit 6932b0c
Show file tree
Hide file tree
Showing 3 changed files with 264 additions and 11 deletions.
4 changes: 3 additions & 1 deletion archive/archive.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"github.com/lirm/aeron-go/aeron"
"github.com/lirm/aeron-go/aeron/atomic"
"github.com/lirm/aeron-go/aeron/logbuffer"
"github.com/lirm/aeron-go/aeron/logbuffer/term"
"github.com/lirm/aeron-go/aeron/logging"
"github.com/lirm/aeron-go/archive/codecs"
)
Expand Down Expand Up @@ -318,8 +319,9 @@ func NewArchive(options *Options, context *aeron.Context) (*Archive, error) {

for archive.Control.State.state != ControlStateConnected && archive.Control.State.err == nil {
fragments := archive.Control.poll(
func(buf *atomic.Buffer, offset int32, length int32, header *logbuffer.Header) {
func(buf *atomic.Buffer, offset int32, length int32, header *logbuffer.Header) term.ControlledPollAction {
ConnectionControlFragmentHandler(&pollContext, buf, offset, length, header)
return term.ControlledPollActionContinue
}, 1)
if fragments > 0 {
logger.Debugf("Read %d fragment(s)", fragments)
Expand Down
41 changes: 31 additions & 10 deletions archive/control.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,13 +113,19 @@ func init() {
codecIds.recordingStopped = recordingStopped.SbeTemplateId()
}

func controlFragmentHandler(context interface{}, buffer *atomic.Buffer, offset int32, length int32, header *logbuffer.Header) {
func controlFragmentHandler(context interface{}, buffer *atomic.Buffer, offset int32, length int32, header *logbuffer.Header) (action term.ControlledPollAction) {
action = term.ControlledPollActionContinue

pollContext, ok := context.(*PollContext)
if !ok {
logger.Errorf("context conversion failed")
return
}

if pollContext.control.Results.IsPollComplete {
return term.ControlledPollActionAbort
}

logger.Debugf("controlFragmentHandler: correlationID:%d offset:%d length:%d header:%#v", pollContext.correlationID, offset, length, header)
var hdr codecs.SbeGoMessageHeader

Expand Down Expand Up @@ -170,6 +176,8 @@ func controlFragmentHandler(context interface{}, buffer *atomic.Buffer, offset i
logger.Debugf("controlFragmentHandler/controlResponse: received for sessionID:%d, correlationID:%d", controlResponse.ControlSessionId, controlResponse.CorrelationId)
control.Results.ControlResponse = controlResponse
control.Results.IsPollComplete = true

return term.ControlledPollActionBreak
} else {
logger.Debugf("controlFragmentHandler/controlResponse ignoring sessionID:%d, correlationID:%d", controlResponse.ControlSessionId, controlResponse.CorrelationId)
}
Expand Down Expand Up @@ -198,6 +206,7 @@ func controlFragmentHandler(context interface{}, buffer *atomic.Buffer, offset i
// This can happen when testing/adding new functionality
fmt.Printf("controlFragmentHandler: Unexpected message type %d\n", hdr.TemplateId)
}
return
}

// ConnectionControlFragmentHandler is the connection handling specific fragment handler.
Expand Down Expand Up @@ -349,12 +358,14 @@ func (control *Control) PollForErrorResponse() (int, error) {
context := PollContext{control, 0}
received := 0

control.Results.ErrorResponse = nil

// Poll for async events, errors etc until the queue is drained
for {
ret := control.poll(
func(buf *atomic.Buffer, offset int32, length int32, header *logbuffer.Header) {
errorResponseFragmentHandler(&context, buf, offset, length, header)
}, 1)
func(buf *atomic.Buffer, offset int32, length int32, header *logbuffer.Header) term.ControlledPollAction {
return errorResponseFragmentHandler(&context, buf, offset, length, header)
}, 10)
received += ret

// If we received a response with an error then return it
Expand All @@ -377,13 +388,19 @@ func (control *Control) PollForErrorResponse() (int, error) {
// ignore messages not on our session ID
// process recordingSignalEvents
// Log a warning if we have interrupted a synchronous event
func errorResponseFragmentHandler(context interface{}, buffer *atomic.Buffer, offset int32, length int32, header *logbuffer.Header) {
func errorResponseFragmentHandler(context interface{}, buffer *atomic.Buffer, offset int32, length int32, header *logbuffer.Header) (action term.ControlledPollAction) {
action = term.ControlledPollActionContinue

pollContext, ok := context.(*PollContext)
if !ok {
logger.Errorf("context conversion failed")
return
}

if pollContext.control.Results.ErrorResponse != nil {
return term.ControlledPollActionAbort
}

logger.Debugf("errorResponseFragmentHandler: offset:%d length: %d", offset, length)

var hdr codecs.SbeGoMessageHeader
Expand Down Expand Up @@ -419,6 +436,7 @@ func errorResponseFragmentHandler(context interface{}, buffer *atomic.Buffer, of
if controlResponse.ControlSessionId == pollContext.control.archive.SessionID {
if controlResponse.Code == codecs.ControlResponseCode.ERROR {
pollContext.control.Results.ErrorResponse = fmt.Errorf("PollForErrorResponse received a ControlResponse (correlationId:%d Code:ERROR error=\"%s\"", controlResponse.CorrelationId, controlResponse.ErrorMessage)
return term.ControlledPollActionBreak
}
}
return
Expand Down Expand Up @@ -500,19 +518,21 @@ func errorResponseFragmentHandler(context interface{}, buffer *atomic.Buffer, of
default:
fmt.Printf("errorResponseFragmentHandler: Insert decoder for type: %d", hdr.TemplateId)
}

return
}

// poll provides the control response poller using local state to pass
// back data from the underlying subscription
func (control *Control) poll(handler term.FragmentHandler, fragmentLimit int) int {
func (control *Control) poll(handler term.ControlledFragmentHandler, fragmentLimit int) int {

// Update our globals in case they've changed so we use the current state in our callback
rangeChecking = control.archive.Options.RangeChecking

control.Results.ControlResponse = nil // Clear old results
control.Results.IsPollComplete = false // Clear completion flag

return control.Subscription.Poll(handler, fragmentLimit)
return control.Subscription.ControlledPoll(handler, fragmentLimit)
}

// Poll for control response events. Returns number of fragments read during the operation.
Expand Down Expand Up @@ -605,8 +625,8 @@ func (control *Control) PollForResponse(correlationID int64, sessionID int64) (i
start := time.Now()
context := PollContext{control, correlationID}

handler := aeron.NewFragmentAssembler(func(buf *atomic.Buffer, offset int32, length int32, header *logbuffer.Header) {
controlFragmentHandler(&context, buf, offset, length, header)
handler := aeron.NewControlledFragmentAssembler(func(buf *atomic.Buffer, offset int32, length int32, header *logbuffer.Header) term.ControlledPollAction {
return controlFragmentHandler(&context, buf, offset, length, header)
}, aeron.DefaultFragmentAssemblyBufferLength)
for {
ret := control.poll(handler.OnFragment, 10)
Expand Down Expand Up @@ -797,8 +817,9 @@ func (control *Control) PollForDescriptors(correlationID int64, sessionID int64,
for !control.Results.IsPollComplete {
logger.Debugf("PollForDescriptors(%d:%d, %d)", correlationID, sessionID, int(fragmentsWanted)-descriptorCount)
fragments := control.poll(
func(buf *atomic.Buffer, offset int32, length int32, header *logbuffer.Header) {
func(buf *atomic.Buffer, offset int32, length int32, header *logbuffer.Header) term.ControlledPollAction {
DescriptorFragmentHandler(&pollContext, buf, offset, length, header)
return term.ControlledPollActionContinue
}, int(fragmentsWanted)-descriptorCount)
logger.Debugf("Poll(%d:%d) returned %d fragments", correlationID, sessionID, fragments)
descriptorCount = len(control.Results.RecordingDescriptors) + len(control.Results.RecordingSubscriptionDescriptors)
Expand Down
230 changes: 230 additions & 0 deletions archive/control_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@
package archive

import (
"bytes"
"io"
"reflect"
"testing"
"time"
"unsafe"

"github.com/lirm/aeron-go/aeron"
"github.com/lirm/aeron-go/aeron/atomic"
"github.com/lirm/aeron-go/aeron/idlestrategy"
"github.com/lirm/aeron-go/aeron/logbuffer"
"github.com/lirm/aeron-go/aeron/logbuffer/term"
"github.com/lirm/aeron-go/archive/codecs"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
)

func TestControl_PollForResponse(t *testing.T) {
t.Run("times out when nothing to poll", func(t *testing.T) {
control, image := newTestControl(t)
mockPollResponses(t, image)
correlations.Store(int64(1), control)
id, err := control.PollForResponse(1, 0)
assert.Zero(t, id)
assert.EqualError(t, err, `timeout waiting for correlationID 1`)
})

t.Run("returns the error response", func(t *testing.T) {
control, image := newTestControl(t)
mockPollResponses(t, image,
&codecs.ControlResponse{
Code: codecs.ControlResponseCode.ERROR,
CorrelationId: 1,
ErrorMessage: []byte(`b0rk`),
RelevantId: 3,
},
)
correlations.Store(int64(1), control)
id, err := control.PollForResponse(1, 0)
assert.EqualValues(t, 3, id)
assert.EqualError(t, err, `Control Response failure: b0rk`)
})

t.Run("discards all responses preceding the first error", func(t *testing.T) {
control, image := newTestControl(t)
mockPollResponses(t, image,
&codecs.ControlResponse{Code: codecs.ControlResponseCode.OK},
&codecs.ControlResponse{Code: codecs.ControlResponseCode.OK},
&codecs.ControlResponse{
Code: codecs.ControlResponseCode.ERROR,
CorrelationId: 1,
ErrorMessage: []byte(`b0rk`),
RelevantId: 3,
},
)
correlations.Store(int64(1), control)
id, err := control.PollForResponse(1, 0)
assert.EqualValues(t, 3, id)
assert.EqualError(t, err, `Control Response failure: b0rk`)
})

t.Run("does not process messages after result", func(t *testing.T) {
control, image := newTestControl(t)
mockPollResponses(t, image,
&codecs.ControlResponse{
Code: codecs.ControlResponseCode.ERROR,
CorrelationId: 1,
ErrorMessage: []byte(`b0rk`),
RelevantId: 3,
},
&codecs.ControlResponse{Code: codecs.ControlResponseCode.OK},
)
correlations.Store(int64(1), control)
id, err := control.PollForResponse(1, 0)
assert.EqualValues(t, 3, id)
assert.EqualError(t, err, `Control Response failure: b0rk`)
fragments := image.Poll(func(buffer *atomic.Buffer, offset, length int32, header *logbuffer.Header) {}, 1)
assert.EqualValues(t, 1, fragments)
})
}

func TestControl_PollForErrorResponse(t *testing.T) {
t.Run("returns zero when nothing to poll", func(t *testing.T) {
control, image := newTestControl(t)
mockPollResponses(t, image)
cnt, err := control.PollForErrorResponse()
assert.Zero(t, cnt)
assert.NoError(t, err)
})

t.Run("returns the error response", func(t *testing.T) {
control, image := newTestControl(t)
mockPollResponses(t, image,
&codecs.ControlResponse{
Code: codecs.ControlResponseCode.ERROR,
ErrorMessage: []byte(`b0rk`),
},
)
cnt, err := control.PollForErrorResponse()
assert.EqualValues(t, 1, cnt)
assert.EqualError(t, err, `PollForErrorResponse received a ControlResponse (correlationId:0 Code:ERROR error="b0rk"`)
})

t.Run("discards all queued responses unless error", func(t *testing.T) {
control, image := newTestControl(t)
mockPollResponses(t, image,
&codecs.ControlResponse{Code: codecs.ControlResponseCode.OK},
&codecs.ControlResponse{Code: codecs.ControlResponseCode.OK},
)
cnt, err := control.PollForErrorResponse()
assert.EqualValues(t, 2, cnt)
assert.NoError(t, err)
})

t.Run("does not process messages after error", func(t *testing.T) {
control, image := newTestControl(t)
mockPollResponses(t, image,
&codecs.ControlResponse{
Code: codecs.ControlResponseCode.ERROR,
ErrorMessage: []byte(`b0rk`),
},
&codecs.ControlResponse{Code: codecs.ControlResponseCode.OK},
)
cnt, err := control.PollForErrorResponse()
assert.EqualValues(t, 1, cnt)
assert.Error(t, err)
cnt, err = control.PollForErrorResponse()
assert.EqualValues(t, 1, cnt)
assert.NoError(t, err)
})
}

func mockPollResponses(t *testing.T, image *aeron.MockImage, responses ...encodable) {
poll := image.On("Poll", mock.Anything, mock.Anything)
poll.Maybe()
poll.Run(func(args mock.Arguments) {
handler := args.Get(0).(term.FragmentHandler)
fragmentCount := args.Get(1).(int)
count := image.ControlledPoll(func(buffer *atomic.Buffer, offset, length int32, header *logbuffer.Header) term.ControlledPollAction {
handler(buffer, offset, length, header)
return term.ControlledPollActionContinue
}, fragmentCount)
poll.Return(count)
})
controlledPoll := image.On("ControlledPoll", mock.Anything, mock.Anything)
controlledPoll.Run(func(args mock.Arguments) {
fragmentCount := args.Get(1).(int)
count := 0
for count < fragmentCount {
if len(responses) == 0 {
break
}
buffer := encode(t, responses[0])
action := args.Get(0).(term.ControlledFragmentHandler)(buffer, 0, buffer.Capacity(), newTestHeader())
if action == term.ControlledPollActionAbort {
break
}
responses = responses[1:]
count++
if action == term.ControlledPollActionBreak {
break
}
}
controlledPoll.Return(count)
})
}

type encodable interface {
SbeBlockLength() uint16
SbeTemplateId() uint16
SbeSchemaId() uint16
SbeSchemaVersion() uint16
Encode(*codecs.SbeGoMarshaller, io.Writer, bool) error
}

func encode(t *testing.T, data encodable) *atomic.Buffer {
m := codecs.NewSbeGoMarshaller()
buf := new(bytes.Buffer)
header := codecs.MessageHeader{
BlockLength: data.SbeBlockLength(),
TemplateId: data.SbeTemplateId(),
SchemaId: data.SbeSchemaId(),
Version: data.SbeSchemaVersion(),
}
if !assert.NoError(t, header.Encode(m, buf)) {
return nil
}
if !assert.NoError(t, data.Encode(m, buf, false)) {
return nil
}
return atomic.MakeBuffer(buf.Bytes())
}

func newTestControl(t *testing.T) (*Control, *aeron.MockImage) {
image := aeron.NewMockImage(t)
c := &Control{
Subscription: newTestSub(image),
}
c.archive = &Archive{
Listeners: &ArchiveListeners{},
Options: &Options{
Timeout: 100*time.Millisecond,
IdleStrategy: idlestrategy.Yielding{},
},
}
c.fragmentAssembler = aeron.NewControlledFragmentAssembler(
c.onFragment, aeron.DefaultFragmentAssemblyBufferLength,
)
return c, image
}

func newTestSub(image aeron.Image) *aeron.Subscription {
images := aeron.NewImageList()
images.Set([]aeron.Image{image})
sub := aeron.NewSubscription(nil, "", 1, 2, 3, nil, nil)
rsub := reflect.ValueOf(sub)
rfimages := rsub.Elem().FieldByName("images")
rfimages = reflect.NewAt(rfimages.Type(), unsafe.Pointer(rfimages.UnsafeAddr())).Elem()
rfimages.Set(reflect.ValueOf(images))
return sub
}

func newTestHeader() *logbuffer.Header {
buffer := atomic.MakeBuffer(make([]byte, logbuffer.DataFrameHeader.Length))
buffer.PutUInt8(logbuffer.DataFrameHeader.FlagsFieldOffset, 0xc0) // unfragmented
return new(logbuffer.Header).Wrap(buffer.Ptr(), buffer.Capacity())
}

0 comments on commit 6932b0c

Please sign in to comment.