Skip to content

Commit

Permalink
Add managed subscriptions
Browse files Browse the repository at this point in the history
  • Loading branch information
DylanTinianov committed Jan 10, 2025
1 parent e196d79 commit d3fcab1
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 24 deletions.
50 changes: 34 additions & 16 deletions multinode/adaptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,7 @@ func (m *MultiNodeAdapter[RPC, HEAD]) LenSubs() int {
}

// registerSub adds the sub to the rpcMultiNodeAdapter list
func (m *MultiNodeAdapter[RPC, HEAD]) registerSub(sub Subscription, stopInFLightCh chan struct{}) error {
m.subsSliceMu.Lock()
defer m.subsSliceMu.Unlock()
func (m *MultiNodeAdapter[RPC, HEAD]) registerSub(sub *ManagedSubscription, stopInFLightCh chan struct{}) error {
// ensure that the `sub` belongs to current life cycle of the `rpcMultiNodeAdapter` and it should not be killed due to
// previous `DisconnectAll` call.
select {
Expand All @@ -72,11 +70,18 @@ func (m *MultiNodeAdapter[RPC, HEAD]) registerSub(sub Subscription, stopInFLight
return fmt.Errorf("failed to register subscription - all in-flight requests were canceled")
default:
}
// TODO: BCI-3358 - delete sub when caller unsubscribes.
m.subsSliceMu.Lock()
defer m.subsSliceMu.Unlock()
m.subs[sub] = struct{}{}
return nil
}

func (m *MultiNodeAdapter[RPC, HEAD]) removeSub(sub Subscription) {
m.subsSliceMu.Lock()
defer m.subsSliceMu.Unlock()
delete(m.subs, sub)
}

func (m *MultiNodeAdapter[RPC, HEAD]) LatestBlock(ctx context.Context) (HEAD, error) {
// capture chStopInFlight to ensure we are not updating chainInfo with observations related to previous life cycle
ctx, cancel, chStopInFlight, rpc := m.AcquireQueryCtx(ctx, m.ctxTimeout)
Expand Down Expand Up @@ -108,7 +113,7 @@ func (m *MultiNodeAdapter[RPC, HEAD]) LatestFinalizedBlock(ctx context.Context)
return head, errors.New("invalid head")
}

m.OnNewFinalizedHead(ctx, chStopInFlight, head)
m.onNewFinalizedHead(ctx, chStopInFlight, head)
return head, nil
}

Expand All @@ -133,13 +138,18 @@ func (m *MultiNodeAdapter[RPC, HEAD]) SubscribeToHeads(ctx context.Context) (<-c
return nil, nil, err
}

err := m.registerSub(&poller, chStopInFlight)
sub := &ManagedSubscription{
Subscription: &poller,
onUnsubscribe: m.removeSub,
}

err := m.registerSub(sub, chStopInFlight)
if err != nil {
poller.Unsubscribe()
sub.Unsubscribe()
return nil, nil, err
}

return channel, &poller, nil
return channel, sub, nil
}

func (m *MultiNodeAdapter[RPC, HEAD]) SubscribeToFinalizedHeads(ctx context.Context) (<-chan HEAD, Subscription, error) {
Expand All @@ -161,13 +171,18 @@ func (m *MultiNodeAdapter[RPC, HEAD]) SubscribeToFinalizedHeads(ctx context.Cont
return nil, nil, err
}

err := m.registerSub(&poller, chStopInFlight)
sub := &ManagedSubscription{
Subscription: &poller,
onUnsubscribe: m.removeSub,
}

err := m.registerSub(sub, chStopInFlight)
if err != nil {
poller.Unsubscribe()
sub.Unsubscribe()
return nil, nil, err
}

return channel, &poller, nil
return channel, sub, nil
}

func (m *MultiNodeAdapter[RPC, HEAD]) onNewHead(ctx context.Context, requestCh <-chan struct{}, head HEAD) {
Expand All @@ -188,7 +203,7 @@ func (m *MultiNodeAdapter[RPC, HEAD]) onNewHead(ctx context.Context, requestCh <
}
}

func (m *MultiNodeAdapter[RPC, HEAD]) OnNewFinalizedHead(ctx context.Context, requestCh <-chan struct{}, head HEAD) {
func (m *MultiNodeAdapter[RPC, HEAD]) onNewFinalizedHead(ctx context.Context, requestCh <-chan struct{}, head HEAD) {
if !head.IsValid() {
return
}
Expand Down Expand Up @@ -235,19 +250,22 @@ func (m *MultiNodeAdapter[RPC, HEAD]) AcquireQueryCtx(parentCtx context.Context,

func (m *MultiNodeAdapter[RPC, HEAD]) UnsubscribeAllExcept(subs ...Subscription) {
m.subsSliceMu.Lock()
defer m.subsSliceMu.Unlock()

keepSubs := map[Subscription]struct{}{}
for _, sub := range subs {
keepSubs[sub] = struct{}{}
}

var unsubs []Subscription
for sub := range m.subs {
if _, keep := keepSubs[sub]; !keep {
sub.Unsubscribe()
delete(m.subs, sub)
unsubs = append(unsubs, sub)
}
}
m.subsSliceMu.Unlock()

for _, sub := range unsubs {
sub.Unsubscribe()
}
}

// cancelInflightRequests closes and replaces the chStopInFlight
Expand Down
60 changes: 52 additions & 8 deletions multinode/adaptor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,11 @@ func (s *mockSub) Err() <-chan error {
func TestMultiNodeClient_RegisterSubs(t *testing.T) {
t.Run("registerSub", func(t *testing.T) {
c := newTestClient(t)
sub := newMockSub()
mockSub := newMockSub()
sub := &ManagedSubscription{
Subscription: mockSub,
onUnsubscribe: c.removeSub,
}
err := c.registerSub(sub, make(chan struct{}))
require.NoError(t, err)
require.Equal(t, 1, c.LenSubs())
Expand All @@ -144,30 +148,70 @@ func TestMultiNodeClient_RegisterSubs(t *testing.T) {
c := newTestClient(t)
chStopInFlight := make(chan struct{})
close(chStopInFlight)
sub := newMockSub()
mockSub := newMockSub()
sub := &ManagedSubscription{
Subscription: mockSub,
onUnsubscribe: c.removeSub,
}
err := c.registerSub(sub, chStopInFlight)
require.Error(t, err)
require.Equal(t, true, sub.unsubscribed)
require.Equal(t, true, mockSub.unsubscribed)

Check failure on line 158 in multinode/adaptor_test.go

View workflow job for this annotation

GitHub Actions / golangci-lint

bool-compare: use require.True (testifylint)
})

t.Run("UnsubscribeAllExcept", func(t *testing.T) {
c := newTestClient(t)
chStopInFlight := make(chan struct{})
sub1 := newMockSub()
sub2 := newMockSub()
mockSub1 := newMockSub()
sub1 := &ManagedSubscription{
Subscription: mockSub1,
onUnsubscribe: c.removeSub,
}
mockSub2 := newMockSub()
sub2 := &ManagedSubscription{
Subscription: mockSub2,
onUnsubscribe: c.removeSub,
}
err := c.registerSub(sub1, chStopInFlight)
require.NoError(t, err)
err = c.registerSub(sub2, chStopInFlight)
require.NoError(t, err)
require.Equal(t, 2, c.LenSubs())

// Ensure passed sub is not removed
c.UnsubscribeAllExcept(sub1)
require.Equal(t, 1, c.LenSubs())
require.Equal(t, true, sub2.unsubscribed)
require.Equal(t, false, sub1.unsubscribed)
require.Equal(t, true, mockSub2.unsubscribed)
require.Equal(t, false, mockSub1.unsubscribed)

Check failure on line 184 in multinode/adaptor_test.go

View workflow job for this annotation

GitHub Actions / golangci-lint

bool-compare: use require.False (testifylint)

c.UnsubscribeAllExcept()
require.Equal(t, 0, c.LenSubs())
require.Equal(t, true, mockSub1.unsubscribed)
})

t.Run("Remove Subscription on Unsubscribe", func(t *testing.T) {
c := newTestClient(t)
_, sub1, err := c.SubscribeToHeads(tests.Context(t))
require.NoError(t, err)
require.Equal(t, 1, c.LenSubs())
_, sub2, err := c.SubscribeToFinalizedHeads(tests.Context(t))
require.NoError(t, err)
require.Equal(t, 2, c.LenSubs())

sub1.Unsubscribe()
require.Equal(t, 1, c.LenSubs())
sub2.Unsubscribe()
require.Equal(t, 0, c.LenSubs())
})

t.Run("Ensure no deadlock on UnsubscribeAll", func(t *testing.T) {
c := newTestClient(t)
_, _, err := c.SubscribeToHeads(tests.Context(t))
require.NoError(t, err)
require.Equal(t, 1, c.LenSubs())
_, _, err = c.SubscribeToFinalizedHeads(tests.Context(t))
require.NoError(t, err)
require.Equal(t, 2, c.LenSubs())
c.UnsubscribeAllExcept()
require.Equal(t, 0, c.LenSubs())
require.Equal(t, true, sub1.unsubscribed)
})
}
13 changes: 13 additions & 0 deletions multinode/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,19 @@ type Subscription interface {
Err() <-chan error
}

// ManagedSubscription is a Subscription which contains an onUnsubscribe callback for cleanup
type ManagedSubscription struct {
Subscription
onUnsubscribe func(sub Subscription)
}

func (w *ManagedSubscription) Unsubscribe() {
w.Subscription.Unsubscribe()
if w.onUnsubscribe != nil {
w.onUnsubscribe(w)
}
}

// RPCClient includes all the necessary generalized RPC methods used by Node to perform health checks
type RPCClient[
CHAIN_ID ID,
Expand Down

0 comments on commit d3fcab1

Please sign in to comment.