Skip to content

Commit

Permalink
common/client: convert MultiNode to use *services.Engine
Browse files Browse the repository at this point in the history
  • Loading branch information
jmank88 committed Nov 14, 2024
1 parent c015da8 commit 5f4c1e6
Show file tree
Hide file tree
Showing 7 changed files with 101 additions and 142 deletions.
123 changes: 52 additions & 71 deletions common/client/multi_node.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package client

import (
"context"
"errors"
"fmt"
"math/big"
"sync"
Expand Down Expand Up @@ -32,7 +31,9 @@ type MultiNode[
CHAIN_ID types.ID,
RPC any,
] struct {
services.StateMachine
services.Service
eng *services.Engine

primaryNodes []Node[CHAIN_ID, RPC]
sendOnlyNodes []SendOnlyNode[CHAIN_ID, RPC]
chainID CHAIN_ID
Expand All @@ -47,9 +48,6 @@ type MultiNode[

activeMu sync.RWMutex
activeNode Node[CHAIN_ID, RPC]

chStop services.StopChan
wg sync.WaitGroup
}

func NewMultiNode[
Expand All @@ -73,15 +71,19 @@ func NewMultiNode[
primaryNodes: primaryNodes,
sendOnlyNodes: sendOnlyNodes,
chainID: chainID,
lggr: logger.Sugared(lggr).Named("MultiNode").With("chainID", chainID.String()),
selectionMode: selectionMode,
nodeSelector: nodeSelector,
chStop: make(services.StopChan),
leaseDuration: leaseDuration,
chainFamily: chainFamily,
reportInterval: reportInterval,
deathDeclarationDelay: deathDeclarationDelay,
}
c.Service, c.eng = services.Config{
Name: "MultiNode",
Start: c.start,
Close: c.close,
}.NewServiceEngine(logger.With(lggr, "chainID", chainID.String()))
c.lggr = c.eng.SugaredLogger

c.lggr.Debugf("The MultiNode is configured to use NodeSelectionMode: %s", selectionMode)

Expand All @@ -93,16 +95,12 @@ func (c *MultiNode[CHAIN_ID, RPC]) ChainID() CHAIN_ID {
}

func (c *MultiNode[CHAIN_ID, RPC]) DoAll(ctx context.Context, do func(ctx context.Context, rpc RPC, isSendOnly bool)) error {
var err error
ok := c.IfNotStopped(func() {
ctx, _ = c.chStop.Ctx(ctx)

return c.eng.IfNotStopped(func() error {
callsCompleted := 0
for _, n := range c.primaryNodes {
select {
case <-ctx.Done():
err = ctx.Err()
return
return ctx.Err()
default:
if n.State() != nodeStateAlive {
continue
Expand All @@ -111,27 +109,23 @@ func (c *MultiNode[CHAIN_ID, RPC]) DoAll(ctx context.Context, do func(ctx contex
callsCompleted++
}
}
if callsCompleted == 0 {
err = ErroringNodeError
}

for _, n := range c.sendOnlyNodes {
select {
case <-ctx.Done():
err = ctx.Err()
return
return ctx.Err()
default:
if n.State() != nodeStateAlive {
continue
}
do(ctx, n.RPC(), true)
}
}
if callsCompleted == 0 {
return ErroringNodeError
}
return nil
})
if !ok {
return errors.New("MultiNode is stopped")
}
return err
}

func (c *MultiNode[CHAIN_ID, RPC]) NodeStates() map[string]string {
Expand All @@ -149,53 +143,44 @@ func (c *MultiNode[CHAIN_ID, RPC]) NodeStates() map[string]string {
//
// Nodes handle their own redialing and runloops, so this function does not
// return any error if the nodes aren't available
func (c *MultiNode[CHAIN_ID, RPC]) Start(ctx context.Context) error {
return c.StartOnce("MultiNode", func() (merr error) {
if len(c.primaryNodes) == 0 {
return fmt.Errorf("no available nodes for chain %s", c.chainID.String())
func (c *MultiNode[CHAIN_ID, RPC]) start(ctx context.Context) error {
if len(c.primaryNodes) == 0 {
return fmt.Errorf("no available nodes for chain %s", c.chainID.String())
}
var ms services.MultiStart
for _, n := range c.primaryNodes {
if n.ConfiguredChainID().String() != c.chainID.String() {
return ms.CloseBecause(fmt.Errorf("node %s has configured chain ID %s which does not match multinode configured chain ID of %s", n.String(), n.ConfiguredChainID().String(), c.chainID.String()))
}
var ms services.MultiStart
for _, n := range c.primaryNodes {
if n.ConfiguredChainID().String() != c.chainID.String() {
return ms.CloseBecause(fmt.Errorf("node %s has configured chain ID %s which does not match multinode configured chain ID of %s", n.String(), n.ConfiguredChainID().String(), c.chainID.String()))
}
n.SetPoolChainInfoProvider(c)
// node will handle its own redialing and automatic recovery
if err := ms.Start(ctx, n); err != nil {
return err
}
n.SetPoolChainInfoProvider(c)
// node will handle its own redialing and automatic recovery
if err := ms.Start(ctx, n); err != nil {
return err
}
for _, s := range c.sendOnlyNodes {
if s.ConfiguredChainID().String() != c.chainID.String() {
return ms.CloseBecause(fmt.Errorf("sendonly node %s has configured chain ID %s which does not match multinode configured chain ID of %s", s.String(), s.ConfiguredChainID().String(), c.chainID.String()))
}
if err := ms.Start(ctx, s); err != nil {
return err
}
}
for _, s := range c.sendOnlyNodes {
if s.ConfiguredChainID().String() != c.chainID.String() {
return ms.CloseBecause(fmt.Errorf("sendonly node %s has configured chain ID %s which does not match multinode configured chain ID of %s", s.String(), s.ConfiguredChainID().String(), c.chainID.String()))
}
c.wg.Add(1)
go c.runLoop()

if c.leaseDuration.Seconds() > 0 && c.selectionMode != NodeSelectionModeRoundRobin {
c.lggr.Infof("The MultiNode will switch to best node every %s", c.leaseDuration.String())
c.wg.Add(1)
go c.checkLeaseLoop()
} else {
c.lggr.Info("Best node switching is disabled")
if err := ms.Start(ctx, s); err != nil {
return err
}
}
c.eng.Go(c.runLoop)

return nil
})
if c.leaseDuration.Seconds() > 0 && c.selectionMode != NodeSelectionModeRoundRobin {
c.lggr.Infof("The MultiNode will switch to best node every %s", c.leaseDuration.String())
c.eng.Go(c.checkLeaseLoop)
} else {
c.lggr.Info("Best node switching is disabled")
}

return nil
}

// Close tears down the MultiNode and closes all nodes
func (c *MultiNode[CHAIN_ID, RPC]) Close() error {
return c.StopOnce("MultiNode", func() error {
close(c.chStop)
c.wg.Wait()

return services.CloseAll(services.MultiCloser(c.primaryNodes), services.MultiCloser(c.sendOnlyNodes))
})
func (c *MultiNode[CHAIN_ID, RPC]) close() error {
return services.CloseAll(services.MultiCloser(c.primaryNodes), services.MultiCloser(c.sendOnlyNodes))
}

// SelectRPC returns an RPC of an active node. If there are no active nodes it returns an error.
Expand Down Expand Up @@ -233,8 +218,7 @@ func (c *MultiNode[CHAIN_ID, RPC]) selectNode() (node Node[CHAIN_ID, RPC], err e
c.activeNode = c.nodeSelector.Select()
if c.activeNode == nil {
c.lggr.Criticalw("No live RPC nodes available", "NodeSelectionMode", c.nodeSelector.Name())
errmsg := fmt.Errorf("no live nodes available for chain %s", c.chainID.String())
c.SvcErrBuffer.Append(errmsg)
c.eng.EmitHealthErr(fmt.Errorf("no live nodes available for chain %s", c.chainID.String()))
return nil, ErroringNodeError
}

Expand Down Expand Up @@ -296,24 +280,21 @@ func (c *MultiNode[CHAIN_ID, RPC]) checkLease() {
}
}

func (c *MultiNode[CHAIN_ID, RPC]) checkLeaseLoop() {
defer c.wg.Done()
func (c *MultiNode[CHAIN_ID, RPC]) checkLeaseLoop(ctx context.Context) {
c.leaseTicker = time.NewTicker(c.leaseDuration)
defer c.leaseTicker.Stop()

for {
select {
case <-c.leaseTicker.C:
c.checkLease()
case <-c.chStop:
case <-ctx.Done():
return
}
}
}

func (c *MultiNode[CHAIN_ID, RPC]) runLoop() {
defer c.wg.Done()

func (c *MultiNode[CHAIN_ID, RPC]) runLoop(ctx context.Context) {
nodeStates := make([]nodeWithState, len(c.primaryNodes))
for i, n := range c.primaryNodes {
nodeStates[i] = nodeWithState{
Expand All @@ -332,7 +313,7 @@ func (c *MultiNode[CHAIN_ID, RPC]) runLoop() {
select {
case <-monitor.C:
c.report(nodeStates)
case <-c.chStop:
case <-ctx.Done():
return
}
}
Expand Down Expand Up @@ -376,7 +357,7 @@ func (c *MultiNode[CHAIN_ID, RPC]) report(nodesStateInfo []nodeWithState) {
if total == dead {
rerr := fmt.Errorf("no primary nodes available: 0/%d nodes are alive", total)
c.lggr.Criticalw(rerr.Error(), "nodeStates", nodesStateInfo)
c.SvcErrBuffer.Append(rerr)
c.eng.EmitHealthErr(rerr)
} else if dead > 0 {
c.lggr.Errorw(fmt.Sprintf("At least one primary node is dead: %d/%d nodes are alive", live, total), "nodeStates", nodesStateInfo)
}
Expand Down
39 changes: 14 additions & 25 deletions common/client/multi_node_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"github.com/stretchr/testify/require"
"go.uber.org/zap"

"github.com/smartcontractkit/chainlink-common/pkg/services/servicetest"
"github.com/smartcontractkit/chainlink-common/pkg/utils/tests"

"github.com/smartcontractkit/chainlink-common/pkg/logger"
Expand Down Expand Up @@ -76,7 +77,7 @@ func TestMultiNode_Dial(t *testing.T) {
chainID: types.RandomID(),
})
err := mn.Start(tests.Context(t))
assert.EqualError(t, err, fmt.Sprintf("no available nodes for chain %s", mn.chainID.String()))
assert.ErrorContains(t, err, fmt.Sprintf("no available nodes for chain %s", mn.chainID))
})
t.Run("Fails with wrong node's chainID", func(t *testing.T) {
t.Parallel()
Expand All @@ -92,7 +93,7 @@ func TestMultiNode_Dial(t *testing.T) {
nodes: []Node[types.ID, multiNodeRPCClient]{node},
})
err := mn.Start(tests.Context(t))
assert.EqualError(t, err, fmt.Sprintf("node %s has configured chain ID %s which does not match multinode configured chain ID of %s", nodeName, nodeChainID, mn.chainID))
assert.ErrorContains(t, err, fmt.Sprintf("node %s has configured chain ID %s which does not match multinode configured chain ID of %s", nodeName, nodeChainID, mn.chainID))
})
t.Run("Fails if node fails", func(t *testing.T) {
t.Parallel()
Expand All @@ -108,7 +109,7 @@ func TestMultiNode_Dial(t *testing.T) {
nodes: []Node[types.ID, multiNodeRPCClient]{node},
})
err := mn.Start(tests.Context(t))
assert.EqualError(t, err, expectedError.Error())
assert.ErrorIs(t, err, expectedError)
})

t.Run("Closes started nodes on failure", func(t *testing.T) {
Expand All @@ -127,7 +128,7 @@ func TestMultiNode_Dial(t *testing.T) {
nodes: []Node[types.ID, multiNodeRPCClient]{node1, node2},
})
err := mn.Start(tests.Context(t))
assert.EqualError(t, err, expectedError.Error())
assert.ErrorIs(t, err, expectedError)
})
t.Run("Fails with wrong send only node's chainID", func(t *testing.T) {
t.Parallel()
Expand All @@ -146,7 +147,7 @@ func TestMultiNode_Dial(t *testing.T) {
sendonlys: []SendOnlyNode[types.ID, multiNodeRPCClient]{sendOnly},
})
err := mn.Start(tests.Context(t))
assert.EqualError(t, err, fmt.Sprintf("sendonly node %s has configured chain ID %s which does not match multinode configured chain ID of %s", sendOnlyName, sendOnlyChainID, mn.chainID))
assert.ErrorContains(t, err, fmt.Sprintf("sendonly node %s has configured chain ID %s which does not match multinode configured chain ID of %s", sendOnlyName, sendOnlyChainID, mn.chainID))
})

newHealthySendOnly := func(t *testing.T, chainID types.ID) *mockSendOnlyNode[types.ID, multiNodeRPCClient] {
Expand All @@ -173,7 +174,7 @@ func TestMultiNode_Dial(t *testing.T) {
sendonlys: []SendOnlyNode[types.ID, multiNodeRPCClient]{sendOnly1, sendOnly2},
})
err := mn.Start(tests.Context(t))
assert.EqualError(t, err, expectedError.Error())
assert.ErrorIs(t, err, expectedError)
})
t.Run("Starts successfully with healthy nodes", func(t *testing.T) {
t.Parallel()
Expand All @@ -185,9 +186,7 @@ func TestMultiNode_Dial(t *testing.T) {
nodes: []Node[types.ID, multiNodeRPCClient]{node},
sendonlys: []SendOnlyNode[types.ID, multiNodeRPCClient]{newHealthySendOnly(t, chainID)},
})
defer func() { assert.NoError(t, mn.Close()) }()
err := mn.Start(tests.Context(t))
require.NoError(t, err)
servicetest.Run(t, mn)
selectedNode, err := mn.selectNode()
require.NoError(t, err)
assert.Equal(t, node, selectedNode)
Expand All @@ -210,9 +209,7 @@ func TestMultiNode_Report(t *testing.T) {
})
mn.reportInterval = tests.TestInterval
mn.deathDeclarationDelay = tests.TestInterval
defer func() { assert.NoError(t, mn.Close()) }()
err := mn.Start(tests.Context(t))
require.NoError(t, err)
servicetest.Run(t, mn)
tests.AssertLogCountEventually(t, observedLogs, "At least one primary node is dead: 1/2 nodes are alive", 2)
})
t.Run("Report critical error on all node failure", func(t *testing.T) {
Expand All @@ -228,11 +225,9 @@ func TestMultiNode_Report(t *testing.T) {
})
mn.reportInterval = tests.TestInterval
mn.deathDeclarationDelay = tests.TestInterval
defer func() { assert.NoError(t, mn.Close()) }()
err := mn.Start(tests.Context(t))
require.NoError(t, err)
servicetest.Run(t, mn)
tests.AssertLogCountEventually(t, observedLogs, "no primary nodes available: 0/1 nodes are alive", 2)
err = mn.Healthy()
err := mn.HealthReport()["MultiNode"]
require.Error(t, err)
assert.Contains(t, err.Error(), "no primary nodes available: 0/1 nodes are alive")
})
Expand All @@ -251,9 +246,7 @@ func TestMultiNode_CheckLease(t *testing.T) {
logger: lggr,
nodes: []Node[types.ID, multiNodeRPCClient]{node},
})
defer func() { assert.NoError(t, mn.Close()) }()
err := mn.Start(tests.Context(t))
require.NoError(t, err)
servicetest.Run(t, mn)
tests.RequireLogMessage(t, observedLogs, "Best node switching is disabled")
})
t.Run("Misconfigured lease check period won't start", func(t *testing.T) {
Expand All @@ -268,9 +261,7 @@ func TestMultiNode_CheckLease(t *testing.T) {
nodes: []Node[types.ID, multiNodeRPCClient]{node},
leaseDuration: 0,
})
defer func() { assert.NoError(t, mn.Close()) }()
err := mn.Start(tests.Context(t))
require.NoError(t, err)
servicetest.Run(t, mn)
tests.RequireLogMessage(t, observedLogs, "Best node switching is disabled")
})
t.Run("Lease check updates active node", func(t *testing.T) {
Expand All @@ -289,10 +280,8 @@ func TestMultiNode_CheckLease(t *testing.T) {
nodes: []Node[types.ID, multiNodeRPCClient]{node, bestNode},
leaseDuration: tests.TestInterval,
})
defer func() { assert.NoError(t, mn.Close()) }()
mn.nodeSelector = nodeSelector
err := mn.Start(tests.Context(t))
require.NoError(t, err)
servicetest.Run(t, mn)
tests.AssertLogEventually(t, observedLogs, fmt.Sprintf("Switching to best node from %q to %q", node.String(), bestNode.String()))
tests.AssertEventually(t, func() bool {
mn.activeMu.RLock()
Expand Down
2 changes: 1 addition & 1 deletion common/client/node_lifecycle_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,7 @@ func TestUnit_NodeLifecycle_aliveLoop(t *testing.T) {
rpc.On("SubscribeToHeads", mock.Anything).Return(make(<-chan Head), sub, nil).Once()
expectedError := errors.New("failed to subscribe to finalized heads")
rpc.On("SubscribeToFinalizedHeads", mock.Anything).Return(nil, sub, expectedError).Once()
lggr, _ := logger.TestObserved(t, zap.DebugLevel)
lggr := logger.Test(t)
node := newDialedNode(t, testNodeOpts{
config: testNodeConfig{
finalizedBlockPollInterval: tests.TestInterval,
Expand Down
Loading

0 comments on commit 5f4c1e6

Please sign in to comment.