diff --git a/check_nodes.go b/check_nodes.go index 285eacd..a501434 100644 --- a/check_nodes.go +++ b/check_nodes.go @@ -87,7 +87,11 @@ func (nodes groupedCheckedNodes) Alive() []Node { return res } -func checkNodes(ctx context.Context, nodes []Node, checker NodeChecker, tracer Tracer) AliveNodes { +type checkExecutorFunc func(ctx context.Context, node Node) (bool, time.Duration, error) + +// checkNodes takes slice of nodes, checks them in parallel and returns the alive ones. +// Accepts customizable executor which enables time-independent tests for node sorting based on 'latency'. +func checkNodes(ctx context.Context, nodes []Node, executor checkExecutorFunc, tracer Tracer) AliveNodes { checkedNodes := groupedCheckedNodes{ Primaries: make(checkedNodesList, 0, len(nodes)), Standbys: make(checkedNodesList, 0, len(nodes)), @@ -100,9 +104,7 @@ func checkNodes(ctx context.Context, nodes []Node, checker NodeChecker, tracer T go func(node Node, wg *sync.WaitGroup) { defer wg.Done() - ts := time.Now() - primary, err := checkNode(ctx, node, checker) - d := time.Since(ts) + primary, duration, err := executor(ctx, node) if err != nil { if tracer.NodeDead != nil { tracer.NodeDead(node, err) @@ -115,7 +117,7 @@ func checkNodes(ctx context.Context, nodes []Node, checker NodeChecker, tracer T tracer.NodeAlive(node) } - nl := checkedNode{Node: node, Latency: d} + nl := checkedNode{Node: node, Latency: duration} mu.Lock() defer mu.Unlock() @@ -137,3 +139,17 @@ func checkNodes(ctx context.Context, nodes []Node, checker NodeChecker, tracer T Standbys: checkedNodes.Standbys.Nodes(), } } + +// checkExecutor returns checkExecutorFunc which can execute supplied check. +func checkExecutor(checker NodeChecker) checkExecutorFunc { + return func(ctx context.Context, node Node) (bool, time.Duration, error) { + ts := time.Now() + primary, err := checker(ctx, node.DB()) + d := time.Since(ts) + if err != nil { + return false, d, err + } + + return primary, d, nil + } +} diff --git a/check_nodes_test.go b/check_nodes_test.go index 4f3ce9c..bcda000 100644 --- a/check_nodes_test.go +++ b/check_nodes_test.go @@ -18,7 +18,6 @@ package hasql import ( "context" - "database/sql" "errors" "math/rand" "testing" @@ -31,7 +30,7 @@ import ( ) func TestCheckNodes(t *testing.T) { - const count = 10 + const count = 100 var nodes []Node expected := AliveNodes{Alive: make([]Node, count)} for i := 0; i < count; i++ { @@ -56,8 +55,8 @@ func TestCheckNodes(t *testing.T) { require.Len(t, expected.Alive, count) // Fill primaries and standbys - for _, node := range expected.Alive { - if rand.Intn(2) == 0 { + for i, node := range expected.Alive { + if i%2 == 0 { expected.Primaries = append(expected.Primaries, node) } else { expected.Standbys = append(expected.Standbys, node) @@ -68,30 +67,33 @@ func TestCheckNodes(t *testing.T) { require.NotEmpty(t, expected.Standbys) require.Equal(t, count, len(expected.Primaries)+len(expected.Standbys)) - checker := func(_ context.Context, db *sql.DB) (bool, error) { - for i, node := range expected.Alive { - if node.DB() == db { - // TODO: make test time-independent - time.Sleep(100 * time.Duration(i) * time.Millisecond) + executor := func(ctx context.Context, node Node) (bool, time.Duration, error) { + // Alive nodes set the expected 'order' (latency) of all available nodes. + // Return duration based on that order. + var duration time.Duration + for i, alive := range expected.Alive { + if alive == node { + duration = time.Duration(i) * time.Nanosecond + break } } - for _, node := range expected.Primaries { - if node.DB() == db { - return true, nil + for _, primary := range expected.Primaries { + if primary == node { + return true, duration, nil } } - for _, node := range expected.Standbys { - if node.DB() == db { - return false, nil + for _, standby := range expected.Standbys { + if standby == node { + return false, duration, nil } } - return false, errors.New("node not found") + return false, 0, errors.New("node not found") } - alive := checkNodes(context.Background(), nodes, checker, Tracer{}) + alive := checkNodes(context.Background(), nodes, executor, Tracer{}) assert.Equal(t, expected.Primaries, alive.Primaries) assert.Equal(t, expected.Standbys, alive.Standbys) assert.Equal(t, expected.Alive, alive.Alive) diff --git a/cluster.go b/cluster.go index f45753b..b7ad40a 100644 --- a/cluster.go +++ b/cluster.go @@ -318,7 +318,7 @@ func (cl *Cluster) updateNodes() { ctx, cancel := context.WithTimeout(context.Background(), cl.updateTimeout) defer cancel() - alive := checkNodes(ctx, cl.nodes, cl.checker, cl.tracer) + alive := checkNodes(ctx, cl.nodes, checkExecutor(cl.checker), cl.tracer) cl.aliveNodes.Store(alive) if cl.tracer.UpdatedNodes != nil { diff --git a/node.go b/node.go index 35822c7..a7b02ae 100644 --- a/node.go +++ b/node.go @@ -57,16 +57,6 @@ func (n *sqlNode) String() string { return n.addr } -// checkNode checks if the node is alive and whether it is primary or not -func checkNode(ctx context.Context, node Node, checker NodeChecker) (bool, error) { - primary, err := checker(ctx, node.DB()) - if err != nil { - return false, err - } - - return primary, nil -} - // NodeStateCriteria for choosing a node type NodeStateCriteria int