diff --git a/multinode/ctx.go b/multinode/ctx.go index ed9bc32..0684e09 100644 --- a/multinode/ctx.go +++ b/multinode/ctx.go @@ -12,6 +12,6 @@ func CtxAddHealthCheckFlag(ctx context.Context) context.Context { return context.WithValue(ctx, contextKeyHeathCheckRequest, struct{}{}) } -func CtxIsHeathCheckRequest(ctx context.Context) bool { +func CtxIsHealthCheckRequest(ctx context.Context) bool { return ctx.Value(contextKeyHeathCheckRequest) != nil } diff --git a/multinode/ctx_test.go b/multinode/ctx_test.go index c8d46e9..8466325 100644 --- a/multinode/ctx_test.go +++ b/multinode/ctx_test.go @@ -10,7 +10,7 @@ import ( func TestContext(t *testing.T) { ctx := tests.Context(t) - assert.False(t, CtxIsHeathCheckRequest(ctx), "expected false for test context") + assert.False(t, CtxIsHealthCheckRequest(ctx), "expected false for test context") ctx = CtxAddHealthCheckFlag(ctx) - assert.True(t, CtxIsHeathCheckRequest(ctx), "expected context to contain the healthcheck flag") + assert.True(t, CtxIsHealthCheckRequest(ctx), "expected context to contain the healthcheck flag") } diff --git a/multinode/mock_head_test.go b/multinode/mock_head_test.go index cf83998..bd3d414 100644 --- a/multinode/mock_head_test.go +++ b/multinode/mock_head_test.go @@ -113,6 +113,53 @@ func (_c *mockHead_BlockNumber_Call) RunAndReturn(run func() int64) *mockHead_Bl return _c } +// GetTotalDifficulty provides a mock function with given fields: +func (_m *mockHead) GetTotalDifficulty() *big.Int { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for GetTotalDifficulty") + } + + var r0 *big.Int + if rf, ok := ret.Get(0).(func() *big.Int); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*big.Int) + } + } + + return r0 +} + +// mockHead_GetTotalDifficulty_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetTotalDifficulty' +type mockHead_GetTotalDifficulty_Call struct { + *mock.Call +} + +// GetTotalDifficulty is a helper method to define mock.On call +func (_e *mockHead_Expecter) GetTotalDifficulty() *mockHead_GetTotalDifficulty_Call { + return &mockHead_GetTotalDifficulty_Call{Call: _e.mock.On("GetTotalDifficulty")} +} + +func (_c *mockHead_GetTotalDifficulty_Call) Run(run func()) *mockHead_GetTotalDifficulty_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *mockHead_GetTotalDifficulty_Call) Return(_a0 *big.Int) *mockHead_GetTotalDifficulty_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *mockHead_GetTotalDifficulty_Call) RunAndReturn(run func() *big.Int) *mockHead_GetTotalDifficulty_Call { + _c.Call.Return(run) + return _c +} + // IsValid provides a mock function with given fields: func (_m *mockHead) IsValid() bool { ret := _m.Called() diff --git a/multinode/node_lifecycle_test.go b/multinode/node_lifecycle_test.go index bb289e0..86a0ddc 100644 --- a/multinode/node_lifecycle_test.go +++ b/multinode/node_lifecycle_test.go @@ -590,6 +590,7 @@ func (h head) ToMockHead(t *testing.T) *mockHead { m := newMockHead(t) m.On("BlockNumber").Return(h.BlockNumber).Maybe() m.On("BlockDifficulty").Return(h.BlockDifficulty).Maybe() + m.On("GetTotalDifficulty").Return(h.BlockDifficulty).Maybe() m.On("IsValid").Return(true).Maybe() return m } diff --git a/multinode/rpc_client_base.go b/multinode/rpc_client_base.go new file mode 100644 index 0000000..b4a886c --- /dev/null +++ b/multinode/rpc_client_base.go @@ -0,0 +1,298 @@ +package multinode + +import ( + "context" + "errors" + "fmt" + "sync" + "time" + + "github.com/smartcontractkit/chainlink-common/pkg/logger" + "github.com/smartcontractkit/chainlink-common/pkg/services" +) + +type RPCClientBaseConfig interface { + NewHeadsPollInterval() time.Duration + FinalizedBlockPollInterval() time.Duration +} + +// RPCClientBase is used to integrate multinode into chain-specific clients. +// For new MultiNode integrations, we wrap the RPC client and inherit from the RPCClientBase +// to get the required RPCClient methods and enable the use of MultiNode. +// +// The RPCClientBase provides chain-agnostic functionality such as head and finalized head +// subscriptions, which are required in each Node lifecycle to execute various +// health checks. +type RPCClientBase[HEAD Head] struct { + cfg RPCClientBaseConfig + log logger.Logger + ctxTimeout time.Duration + subsMu sync.RWMutex + subs map[Subscription]struct{} + + latestBlock func(ctx context.Context) (HEAD, error) + latestFinalizedBlock func(ctx context.Context) (HEAD, error) + + // lifeCycleCh can be closed to immediately cancel all in-flight requests on + // this RPC. Closing and replacing should be serialized through + // lifeCycleMu since it can happen on state transitions as well as RPCClientBase Close. + // Also closed when RPC is declared unhealthy. + lifeCycleMu sync.RWMutex + lifeCycleCh chan struct{} + + // chainInfoLock protects highestUserObservations and latestChainInfo + chainInfoLock sync.RWMutex + // intercepted values seen by callers of the RPCClientBase excluding health check calls. Need to ensure MultiNode provides repeatable read guarantee + highestUserObservations ChainInfo + // most recent chain info observed during current lifecycle + latestChainInfo ChainInfo +} + +func NewRPCClientBase[HEAD Head]( + cfg RPCClientBaseConfig, ctxTimeout time.Duration, log logger.Logger, + latestBlock func(ctx context.Context) (HEAD, error), + latestFinalizedBlock func(ctx context.Context) (HEAD, error), +) *RPCClientBase[HEAD] { + return &RPCClientBase[HEAD]{ + cfg: cfg, + log: log, + ctxTimeout: ctxTimeout, + latestBlock: latestBlock, + latestFinalizedBlock: latestFinalizedBlock, + subs: make(map[Subscription]struct{}), + lifeCycleCh: make(chan struct{}), + } +} + +func (m *RPCClientBase[HEAD]) lenSubs() int { + m.subsMu.RLock() + defer m.subsMu.RUnlock() + return len(m.subs) +} + +// RegisterSub adds the sub to the RPCClientBase list and returns a managed sub which is removed on unsubscribe +func (m *RPCClientBase[HEAD]) RegisterSub(sub Subscription, lifeCycleCh chan struct{}) (*ManagedSubscription, error) { + // ensure that the `sub` belongs to current life cycle of the `RPCClientBase` and it should not be killed due to + // previous `DisconnectAll` call. + select { + case <-lifeCycleCh: + sub.Unsubscribe() + return nil, fmt.Errorf("failed to register subscription - all in-flight requests were canceled") + default: + } + m.subsMu.Lock() + defer m.subsMu.Unlock() + managedSub := &ManagedSubscription{ + sub, + m.removeSub, + } + m.subs[managedSub] = struct{}{} + return managedSub, nil +} + +func (m *RPCClientBase[HEAD]) removeSub(sub Subscription) { + m.subsMu.Lock() + defer m.subsMu.Unlock() + delete(m.subs, sub) +} + +func (m *RPCClientBase[HEAD]) SubscribeToHeads(ctx context.Context) (<-chan HEAD, Subscription, error) { + ctx, cancel, lifeCycleCh := m.AcquireQueryCtx(ctx, m.ctxTimeout) + defer cancel() + + pollInterval := m.cfg.NewHeadsPollInterval() + if pollInterval == 0 { + return nil, nil, errors.New("PollInterval is 0") + } + timeout := pollInterval + poller, channel := NewPoller[HEAD](pollInterval, func(pollRequestCtx context.Context) (HEAD, error) { + if CtxIsHealthCheckRequest(ctx) { + pollRequestCtx = CtxAddHealthCheckFlag(pollRequestCtx) + } + return m.LatestBlock(pollRequestCtx) + }, timeout, m.log) + + if err := poller.Start(ctx); err != nil { + return nil, nil, err + } + + sub, err := m.RegisterSub(&poller, lifeCycleCh) + if err != nil { + sub.Unsubscribe() + return nil, nil, err + } + return channel, sub, nil +} + +func (m *RPCClientBase[HEAD]) SubscribeToFinalizedHeads(ctx context.Context) (<-chan HEAD, Subscription, error) { + ctx, cancel, lifeCycleCh := m.AcquireQueryCtx(ctx, m.ctxTimeout) + defer cancel() + + finalizedBlockPollInterval := m.cfg.FinalizedBlockPollInterval() + if finalizedBlockPollInterval == 0 { + return nil, nil, errors.New("FinalizedBlockPollInterval is 0") + } + timeout := finalizedBlockPollInterval + poller, channel := NewPoller[HEAD](finalizedBlockPollInterval, func(pollRequestCtx context.Context) (HEAD, error) { + if CtxIsHealthCheckRequest(ctx) { + pollRequestCtx = CtxAddHealthCheckFlag(pollRequestCtx) + } + return m.LatestFinalizedBlock(pollRequestCtx) + }, timeout, m.log) + if err := poller.Start(ctx); err != nil { + return nil, nil, err + } + + sub, err := m.RegisterSub(&poller, lifeCycleCh) + if err != nil { + poller.Unsubscribe() + return nil, nil, err + } + return channel, sub, nil +} + +func (m *RPCClientBase[HEAD]) LatestBlock(ctx context.Context) (HEAD, error) { + // capture lifeCycleCh to ensure we are not updating chainInfo with observations related to previous life cycle + ctx, cancel, lifeCycleCh := m.AcquireQueryCtx(ctx, m.ctxTimeout) + defer cancel() + + head, err := m.latestBlock(ctx) + if err != nil { + return head, err + } + + if !head.IsValid() { + return head, errors.New("invalid head") + } + + m.OnNewHead(ctx, lifeCycleCh, head) + return head, nil +} + +func (m *RPCClientBase[HEAD]) LatestFinalizedBlock(ctx context.Context) (HEAD, error) { + ctx, cancel, lifeCycleCh := m.AcquireQueryCtx(ctx, m.ctxTimeout) + defer cancel() + + head, err := m.latestFinalizedBlock(ctx) + if err != nil { + return head, err + } + + if !head.IsValid() { + return head, errors.New("invalid head") + } + + m.OnNewFinalizedHead(ctx, lifeCycleCh, head) + return head, nil +} + +func (m *RPCClientBase[HEAD]) OnNewHead(ctx context.Context, requestCh <-chan struct{}, head HEAD) { + if !head.IsValid() { + return + } + + m.chainInfoLock.Lock() + defer m.chainInfoLock.Unlock() + blockNumber := head.BlockNumber() + totalDifficulty := head.GetTotalDifficulty() + if !CtxIsHealthCheckRequest(ctx) { + m.highestUserObservations.BlockNumber = max(m.highestUserObservations.BlockNumber, blockNumber) + m.highestUserObservations.TotalDifficulty = MaxTotalDifficulty(m.highestUserObservations.TotalDifficulty, totalDifficulty) + } + select { + case <-requestCh: // no need to update latestChainInfo, as RPCClientBase already started new life cycle + return + default: + m.latestChainInfo.BlockNumber = blockNumber + m.latestChainInfo.TotalDifficulty = totalDifficulty + } +} + +func (m *RPCClientBase[HEAD]) OnNewFinalizedHead(ctx context.Context, requestCh <-chan struct{}, head HEAD) { + if !head.IsValid() { + return + } + + m.chainInfoLock.Lock() + defer m.chainInfoLock.Unlock() + if !CtxIsHealthCheckRequest(ctx) { + m.highestUserObservations.FinalizedBlockNumber = max(m.highestUserObservations.FinalizedBlockNumber, head.BlockNumber()) + } + select { + case <-requestCh: // no need to update latestChainInfo, as RPCClientBase already started new life cycle + return + default: + m.latestChainInfo.FinalizedBlockNumber = head.BlockNumber() + } +} + +// makeQueryCtx returns a context that cancels if: +// 1. Passed in ctx cancels +// 2. Passed in channel is closed +// 3. Default timeout is reached (queryTimeout) +func makeQueryCtx(ctx context.Context, ch services.StopChan, timeout time.Duration) (context.Context, context.CancelFunc) { + var chCancel, timeoutCancel context.CancelFunc + ctx, chCancel = ch.Ctx(ctx) + ctx, timeoutCancel = context.WithTimeout(ctx, timeout) + cancel := func() { + chCancel() + timeoutCancel() + } + return ctx, cancel +} + +func (m *RPCClientBase[HEAD]) AcquireQueryCtx(parentCtx context.Context, timeout time.Duration) (ctx context.Context, cancel context.CancelFunc, + lifeCycleCh chan struct{}) { + // Need to wrap in mutex because state transition can cancel and replace context + m.lifeCycleMu.RLock() + lifeCycleCh = m.lifeCycleCh + m.lifeCycleMu.RUnlock() + ctx, cancel = makeQueryCtx(parentCtx, lifeCycleCh, timeout) + return +} + +func (m *RPCClientBase[HEAD]) UnsubscribeAllExcept(subs ...Subscription) { + m.subsMu.Lock() + keepSubs := map[Subscription]struct{}{} + for _, sub := range subs { + keepSubs[sub] = struct{}{} + } + + var unsubs []Subscription + for sub := range m.subs { + if _, keep := keepSubs[sub]; !keep { + unsubs = append(unsubs, sub) + } + } + m.subsMu.Unlock() + + for _, sub := range unsubs { + sub.Unsubscribe() + } +} + +// CancelLifeCycle closes and replaces the lifeCycleCh +func (m *RPCClientBase[HEAD]) CancelLifeCycle() { + m.lifeCycleMu.Lock() + defer m.lifeCycleMu.Unlock() + close(m.lifeCycleCh) + m.lifeCycleCh = make(chan struct{}) +} + +func (m *RPCClientBase[HEAD]) resetLatestChainInfo() { + m.chainInfoLock.Lock() + m.latestChainInfo = ChainInfo{} + m.chainInfoLock.Unlock() +} + +func (m *RPCClientBase[HEAD]) Close() { + m.CancelLifeCycle() + m.UnsubscribeAllExcept() + m.resetLatestChainInfo() +} + +func (m *RPCClientBase[HEAD]) GetInterceptedChainInfo() (latest, highestUserObservations ChainInfo) { + m.chainInfoLock.RLock() + defer m.chainInfoLock.RUnlock() + return m.latestChainInfo, m.highestUserObservations +} diff --git a/multinode/rpc_client_base_test.go b/multinode/rpc_client_base_test.go new file mode 100644 index 0000000..25afd4d --- /dev/null +++ b/multinode/rpc_client_base_test.go @@ -0,0 +1,300 @@ +package multinode + +import ( + "context" + "math/big" + "testing" + "time" + + "github.com/stretchr/testify/require" + + common "github.com/smartcontractkit/chainlink-common/pkg/config" + "github.com/smartcontractkit/chainlink-common/pkg/logger" + "github.com/smartcontractkit/chainlink-common/pkg/utils/tests" + "github.com/smartcontractkit/chainlink-framework/multinode/config" +) + +type testRPC struct { + *RPCClientBase[*testHead] + latestBlockNum int64 +} + +// latestBlock simulates a chain-specific latestBlock function +func (rpc *testRPC) latestBlock(ctx context.Context) (*testHead, error) { + rpc.latestBlockNum++ + return &testHead{rpc.latestBlockNum}, nil +} + +func (rpc *testRPC) Close() { + rpc.RPCClientBase.Close() +} + +type testHead struct { + blockNumber int64 +} + +func (t *testHead) BlockNumber() int64 { return t.blockNumber } +func (t *testHead) BlockDifficulty() *big.Int { return nil } +func (t *testHead) GetTotalDifficulty() *big.Int { return nil } +func (t *testHead) IsValid() bool { return true } + +func ptr[T any](t T) *T { + return &t +} + +func newTestRPC(t *testing.T) *testRPC { + requestTimeout := 5 * time.Second + lggr := logger.Test(t) + cfg := &config.MultiNodeConfig{ + MultiNode: config.MultiNode{ + Enabled: ptr(true), + PollFailureThreshold: ptr(uint32(5)), + PollInterval: common.MustNewDuration(15 * time.Second), + SelectionMode: ptr(NodeSelectionModePriorityLevel), + SyncThreshold: ptr(uint32(10)), + LeaseDuration: common.MustNewDuration(time.Minute), + NodeIsSyncingEnabled: ptr(false), + NewHeadsPollInterval: common.MustNewDuration(5 * time.Second), + FinalizedBlockPollInterval: common.MustNewDuration(5 * time.Second), + EnforceRepeatableRead: ptr(true), + DeathDeclarationDelay: common.MustNewDuration(20 * time.Second), + NodeNoNewHeadsThreshold: common.MustNewDuration(20 * time.Second), + NoNewFinalizedHeadsThreshold: common.MustNewDuration(20 * time.Second), + FinalityTagEnabled: ptr(true), + FinalityDepth: ptr(uint32(0)), + FinalizedBlockOffset: ptr(uint32(50)), + }, + } + + rpc := &testRPC{} + rpc.RPCClientBase = NewRPCClientBase[*testHead](cfg, requestTimeout, lggr, rpc.latestBlock, rpc.latestBlock) + t.Cleanup(rpc.Close) + return rpc +} + +func TestAdapter_LatestBlock(t *testing.T) { + t.Run("LatestBlock", func(t *testing.T) { + rpc := newTestRPC(t) + latestChainInfo, highestChainInfo := rpc.GetInterceptedChainInfo() + require.Equal(t, int64(0), latestChainInfo.BlockNumber) + require.Equal(t, int64(0), highestChainInfo.BlockNumber) + head, err := rpc.LatestBlock(tests.Context(t)) + require.NoError(t, err) + require.True(t, head.IsValid()) + latestChainInfo, highestChainInfo = rpc.GetInterceptedChainInfo() + require.Equal(t, int64(1), latestChainInfo.BlockNumber) + require.Equal(t, int64(1), highestChainInfo.BlockNumber) + }) + + t.Run("LatestFinalizedBlock", func(t *testing.T) { + rpc := newTestRPC(t) + latestChainInfo, highestChainInfo := rpc.GetInterceptedChainInfo() + require.Equal(t, int64(0), latestChainInfo.FinalizedBlockNumber) + require.Equal(t, int64(0), highestChainInfo.FinalizedBlockNumber) + finalizedHead, err := rpc.LatestFinalizedBlock(tests.Context(t)) + require.NoError(t, err) + require.True(t, finalizedHead.IsValid()) + latestChainInfo, highestChainInfo = rpc.GetInterceptedChainInfo() + require.Equal(t, int64(1), latestChainInfo.FinalizedBlockNumber) + require.Equal(t, int64(1), highestChainInfo.FinalizedBlockNumber) + }) +} + +func TestAdapter_OnNewHeadFunctions(t *testing.T) { + timeout := 10 * time.Second + t.Run("OnNewHead and OnNewFinalizedHead updates chain info", func(t *testing.T) { + rpc := newTestRPC(t) + latestChainInfo, highestChainInfo := rpc.GetInterceptedChainInfo() + require.Equal(t, int64(0), latestChainInfo.BlockNumber) + require.Equal(t, int64(0), latestChainInfo.FinalizedBlockNumber) + require.Equal(t, int64(0), highestChainInfo.BlockNumber) + require.Equal(t, int64(0), highestChainInfo.FinalizedBlockNumber) + + ctx, cancel, lifeCycleCh := rpc.AcquireQueryCtx(tests.Context(t), timeout) + defer cancel() + rpc.OnNewHead(ctx, lifeCycleCh, &testHead{blockNumber: 10}) + rpc.OnNewFinalizedHead(ctx, lifeCycleCh, &testHead{blockNumber: 3}) + rpc.OnNewHead(ctx, lifeCycleCh, &testHead{blockNumber: 5}) + rpc.OnNewFinalizedHead(ctx, lifeCycleCh, &testHead{blockNumber: 1}) + + latestChainInfo, highestChainInfo = rpc.GetInterceptedChainInfo() + require.Equal(t, int64(5), latestChainInfo.BlockNumber) + require.Equal(t, int64(1), latestChainInfo.FinalizedBlockNumber) + require.Equal(t, int64(10), highestChainInfo.BlockNumber) + require.Equal(t, int64(3), highestChainInfo.FinalizedBlockNumber) + }) + + t.Run("OnNewHead respects HealthCheckCtx", func(t *testing.T) { + rpc := newTestRPC(t) + latestChainInfo, highestChainInfo := rpc.GetInterceptedChainInfo() + require.Equal(t, int64(0), latestChainInfo.BlockNumber) + require.Equal(t, int64(0), latestChainInfo.FinalizedBlockNumber) + require.Equal(t, int64(0), highestChainInfo.BlockNumber) + require.Equal(t, int64(0), highestChainInfo.FinalizedBlockNumber) + + healthCheckCtx := CtxAddHealthCheckFlag(tests.Context(t)) + + ctx, cancel, lifeCycleCh := rpc.AcquireQueryCtx(healthCheckCtx, timeout) + defer cancel() + rpc.OnNewHead(ctx, lifeCycleCh, &testHead{blockNumber: 10}) + rpc.OnNewFinalizedHead(ctx, lifeCycleCh, &testHead{blockNumber: 3}) + rpc.OnNewHead(ctx, lifeCycleCh, &testHead{blockNumber: 5}) + rpc.OnNewFinalizedHead(ctx, lifeCycleCh, &testHead{blockNumber: 1}) + + latestChainInfo, highestChainInfo = rpc.GetInterceptedChainInfo() + require.Equal(t, int64(5), latestChainInfo.BlockNumber) + require.Equal(t, int64(1), latestChainInfo.FinalizedBlockNumber) + + // Highest chain info should not be set on health check requests + require.Equal(t, int64(0), highestChainInfo.BlockNumber) + require.Equal(t, int64(0), highestChainInfo.FinalizedBlockNumber) + }) + + t.Run("OnNewHead and OnNewFinalizedHead respects closure of requestCh", func(t *testing.T) { + rpc := newTestRPC(t) + latestChainInfo, highestChainInfo := rpc.GetInterceptedChainInfo() + require.Equal(t, int64(0), latestChainInfo.BlockNumber) + require.Equal(t, int64(0), latestChainInfo.FinalizedBlockNumber) + require.Equal(t, int64(0), highestChainInfo.BlockNumber) + require.Equal(t, int64(0), highestChainInfo.FinalizedBlockNumber) + + ctx, cancel, lifeCycleCh := rpc.AcquireQueryCtx(tests.Context(t), timeout) + defer cancel() + rpc.CancelLifeCycle() + + rpc.OnNewHead(ctx, lifeCycleCh, &testHead{blockNumber: 10}) + rpc.OnNewFinalizedHead(ctx, lifeCycleCh, &testHead{blockNumber: 3}) + rpc.OnNewHead(ctx, lifeCycleCh, &testHead{blockNumber: 5}) + rpc.OnNewFinalizedHead(ctx, lifeCycleCh, &testHead{blockNumber: 1}) + + // Latest chain info should not be set if life cycle is cancelled + latestChainInfo, highestChainInfo = rpc.GetInterceptedChainInfo() + require.Equal(t, int64(0), latestChainInfo.BlockNumber) + require.Equal(t, int64(0), latestChainInfo.FinalizedBlockNumber) + + require.Equal(t, int64(10), highestChainInfo.BlockNumber) + require.Equal(t, int64(3), highestChainInfo.FinalizedBlockNumber) + }) +} + +func TestAdapter_HeadSubscriptions(t *testing.T) { + t.Run("SubscribeToHeads", func(t *testing.T) { + rpc := newTestRPC(t) + ch, sub, err := rpc.SubscribeToHeads(tests.Context(t)) + require.NoError(t, err) + defer sub.Unsubscribe() + + ctx, cancel := context.WithTimeout(tests.Context(t), time.Minute) + defer cancel() + select { + case head := <-ch: + latest, _ := rpc.GetInterceptedChainInfo() + require.Equal(t, head.BlockNumber(), latest.BlockNumber) + case <-ctx.Done(): + t.Fatal("failed to receive head: ", ctx.Err()) + } + }) + + t.Run("SubscribeToFinalizedHeads", func(t *testing.T) { + rpc := newTestRPC(t) + finalizedCh, finalizedSub, err := rpc.SubscribeToFinalizedHeads(tests.Context(t)) + require.NoError(t, err) + defer finalizedSub.Unsubscribe() + + ctx, cancel := context.WithTimeout(tests.Context(t), time.Minute) + defer cancel() + select { + case finalizedHead := <-finalizedCh: + latest, _ := rpc.GetInterceptedChainInfo() + require.Equal(t, finalizedHead.BlockNumber(), latest.FinalizedBlockNumber) + case <-ctx.Done(): + t.Fatal("failed to receive finalized head: ", ctx.Err()) + } + }) + + t.Run("Remove Subscription on Unsubscribe", func(t *testing.T) { + rpc := newTestRPC(t) + _, sub1, err := rpc.SubscribeToHeads(tests.Context(t)) + require.NoError(t, err) + require.Equal(t, 1, rpc.lenSubs()) + _, sub2, err := rpc.SubscribeToFinalizedHeads(tests.Context(t)) + require.NoError(t, err) + require.Equal(t, 2, rpc.lenSubs()) + + sub1.Unsubscribe() + require.Equal(t, 1, rpc.lenSubs()) + sub2.Unsubscribe() + require.Equal(t, 0, rpc.lenSubs()) + }) + + t.Run("Ensure no deadlock on UnsubscribeAll", func(t *testing.T) { + rpc := newTestRPC(t) + _, _, err := rpc.SubscribeToHeads(tests.Context(t)) + require.NoError(t, err) + require.Equal(t, 1, rpc.lenSubs()) + _, _, err = rpc.SubscribeToFinalizedHeads(tests.Context(t)) + require.NoError(t, err) + require.Equal(t, 2, rpc.lenSubs()) + rpc.UnsubscribeAllExcept() + require.Equal(t, 0, rpc.lenSubs()) + }) +} + +type mockSub struct { + unsubscribed bool +} + +func newMockSub() *mockSub { + return &mockSub{unsubscribed: false} +} + +func (s *mockSub) Unsubscribe() { + s.unsubscribed = true +} +func (s *mockSub) Err() <-chan error { + return nil +} + +func TestMultiNodeClient_RegisterSubs(t *testing.T) { + t.Run("RegisterSub", func(t *testing.T) { + rpc := newTestRPC(t) + mockSub := newMockSub() + sub, err := rpc.RegisterSub(mockSub, make(chan struct{})) + require.NoError(t, err) + require.NotNil(t, sub) + require.Equal(t, 1, rpc.lenSubs()) + rpc.UnsubscribeAllExcept() + }) + + t.Run("lifeCycleCh returns error and unsubscribes", func(t *testing.T) { + rpc := newTestRPC(t) + chStopInFlight := make(chan struct{}) + close(chStopInFlight) + mockSub := newMockSub() + _, err := rpc.RegisterSub(mockSub, chStopInFlight) + require.Error(t, err) + require.True(t, mockSub.unsubscribed) + }) + + t.Run("UnsubscribeAllExcept", func(t *testing.T) { + rpc := newTestRPC(t) + chStopInFlight := make(chan struct{}) + mockSub1 := newMockSub() + mockSub2 := newMockSub() + sub1, err := rpc.RegisterSub(mockSub1, chStopInFlight) + require.NoError(t, err) + _, err = rpc.RegisterSub(mockSub2, chStopInFlight) + require.NoError(t, err) + require.Equal(t, 2, rpc.lenSubs()) + + // Ensure passed sub is not removed + rpc.UnsubscribeAllExcept(sub1) + require.Equal(t, 1, rpc.lenSubs()) + require.True(t, mockSub2.unsubscribed) + require.False(t, mockSub1.unsubscribed) + + rpc.UnsubscribeAllExcept() + require.Equal(t, 0, rpc.lenSubs()) + require.True(t, mockSub1.unsubscribed) + }) +} diff --git a/multinode/types.go b/multinode/types.go index d26c25e..f70c51c 100644 --- a/multinode/types.go +++ b/multinode/types.go @@ -32,6 +32,19 @@ type Subscription interface { Err() <-chan error } +// ManagedSubscription is a Subscription which contains an onUnsubscribe callback for cleanup +type ManagedSubscription struct { + Subscription + onUnsubscribe func(sub Subscription) +} + +func (w *ManagedSubscription) Unsubscribe() { + w.Subscription.Unsubscribe() + if w.onUnsubscribe != nil { + w.onUnsubscribe(w) + } +} + // RPCClient includes all the necessary generalized RPC methods used by Node to perform health checks type RPCClient[ CHAIN_ID ID, @@ -54,8 +67,8 @@ type RPCClient[ // Close - closes all subscriptions and aborts all RPC calls Close() // GetInterceptedChainInfo - returns latest and highest observed by application layer ChainInfo. - // latest ChainInfo is the most recent value received within a NodeClient's current lifecycle between Dial and DisconnectAll. - // highestUserObservations ChainInfo is the highest ChainInfo observed excluding health checks calls. + // latestChainInfo is the most recent value received within a NodeClient's current lifecycle between Dial and DisconnectAll. + // highestUserObservations is the highest ChainInfo observed excluding health checks calls. // Its values must not be reset. // The results of corresponding calls, to get the most recent head and the latest finalized head, must be // intercepted and reflected in ChainInfo before being returned to a caller. Otherwise, MultiNode is not able to @@ -70,6 +83,7 @@ type RPCClient[ type Head interface { BlockNumber() int64 BlockDifficulty() *big.Int + GetTotalDifficulty() *big.Int IsValid() bool }