diff --git a/README.md b/README.md index ce9c227..0d1dee8 100644 --- a/README.md +++ b/README.md @@ -46,7 +46,9 @@ cl, err := hasql.NewCluster( if err != nil { ... } node := cl.Primary() -if node == nil { ... } +if node == nil { + err := cl.Err() // most recent errors for all nodes in the cluster +} // Do anything you like fmt.Println("Node address", node.Addr) diff --git a/check_nodes.go b/check_nodes.go index b1c4357..22fe219 100644 --- a/check_nodes.go +++ b/check_nodes.go @@ -93,7 +93,7 @@ type checkExecutorFunc func(ctx context.Context, node Node) (bool, time.Duration // checkNodes takes slice of nodes, checks them in parallel and returns the alive ones. // Accepts customizable executor which enables time-independent tests for node sorting based on 'latency'. -func checkNodes(ctx context.Context, nodes []Node, executor checkExecutorFunc, tracer Tracer) AliveNodes { +func checkNodes(ctx context.Context, nodes []Node, executor checkExecutorFunc, tracer Tracer, errCollector *errorsCollector) AliveNodes { checkedNodes := groupedCheckedNodes{ Primaries: make(checkedNodesList, 0, len(nodes)), Standbys: make(checkedNodesList, 0, len(nodes)), @@ -111,9 +111,14 @@ func checkNodes(ctx context.Context, nodes []Node, executor checkExecutorFunc, t if tracer.NodeDead != nil { tracer.NodeDead(node, err) } - + if errCollector != nil { + errCollector.Add(node.Addr(), err, time.Now()) + } return } + if errCollector != nil { + errCollector.Remove(node.Addr()) + } if tracer.NodeAlive != nil { tracer.NodeAlive(node) diff --git a/check_nodes_test.go b/check_nodes_test.go index 1fec932..0293f24 100644 --- a/check_nodes_test.go +++ b/check_nodes_test.go @@ -145,8 +145,60 @@ func TestCheckNodes(t *testing.T) { return false, 0, errors.New("node not found") } - alive := checkNodes(context.Background(), nodes, executor, Tracer{}) + errCollector := newErrorsCollector() + alive := checkNodes(context.Background(), nodes, executor, Tracer{}, &errCollector) + + assert.NoError(t, errCollector.Err()) assert.Equal(t, expected.Primaries, alive.Primaries) assert.Equal(t, expected.Standbys, alive.Standbys) assert.Equal(t, expected.Alive, alive.Alive) } + +func TestCheckNodesWithErrors(t *testing.T) { + const count = 5 + var nodes []Node + for i := 0; i < count; i++ { + db, _, err := sqlmock.New() + require.NoError(t, err) + require.NotNil(t, db) + nodes = append(nodes, NewNode(uuid.Must(uuid.NewV4()).String(), db)) + } + + executor := func(ctx context.Context, node Node) (bool, time.Duration, error) { + return false, 0, errors.New("node not found") + } + + errCollector := newErrorsCollector() + checkNodes(context.Background(), nodes, executor, Tracer{}, &errCollector) + + err := errCollector.Err() + for i := 0; i < count; i++ { + assert.ErrorContains(t, err, fmt.Sprintf("%q node error occurred at", nodes[i].Addr())) + } + assert.ErrorContains(t, err, "node not found") +} + +func TestCheckNodesWithErrorsWhenNodesBecameAlive(t *testing.T) { + const count = 5 + var nodes []Node + for i := 0; i < count; i++ { + db, _, err := sqlmock.New() + require.NoError(t, err) + require.NotNil(t, db) + nodes = append(nodes, NewNode(uuid.Must(uuid.NewV4()).String(), db)) + } + + executor := func(ctx context.Context, node Node) (bool, time.Duration, error) { + return false, 0, errors.New("node not found") + } + + errCollector := newErrorsCollector() + checkNodes(context.Background(), nodes, executor, Tracer{}, &errCollector) + require.Error(t, errCollector.Err()) + + executor = func(ctx context.Context, node Node) (bool, time.Duration, error) { + return true, 1, nil + } + checkNodes(context.Background(), nodes, executor, Tracer{}, &errCollector) + require.NoError(t, errCollector.Err()) +} diff --git a/cluster.go b/cluster.go index b7ad40a..3b4e0ae 100644 --- a/cluster.go +++ b/cluster.go @@ -58,6 +58,7 @@ type Cluster struct { updateStopper chan struct{} aliveNodes atomic.Value nodes []Node + errCollector errorsCollector // Notification muWaiters sync.Mutex @@ -89,6 +90,7 @@ func NewCluster(nodes []Node, checker NodeChecker, opts ...ClusterOption) (*Clus checker: checker, picker: PickNodeRandom(), nodes: nodes, + errCollector: newErrorsCollector(), } // Apply options @@ -291,6 +293,12 @@ func (cl *Cluster) node(nodes AliveNodes, criteria NodeStateCriteria) Node { } } +// Err returns the combined error including most recent errors for all nodes. +// This error is CollectedErrors or nil. +func (cl *Cluster) Err() error { + return cl.errCollector.Err() +} + // backgroundNodesUpdate periodically updates list of live db nodes func (cl *Cluster) backgroundNodesUpdate() { // Initial update @@ -318,7 +326,7 @@ func (cl *Cluster) updateNodes() { ctx, cancel := context.WithTimeout(context.Background(), cl.updateTimeout) defer cancel() - alive := checkNodes(ctx, cl.nodes, checkExecutor(cl.checker), cl.tracer) + alive := checkNodes(ctx, cl.nodes, checkExecutor(cl.checker), cl.tracer, &cl.errCollector) cl.aliveNodes.Store(alive) if cl.tracer.UpdatedNodes != nil { diff --git a/cluster_test.go b/cluster_test.go index 10c22eb..923f166 100644 --- a/cluster_test.go +++ b/cluster_test.go @@ -543,6 +543,76 @@ func TestCluster_WaitForStandbyPreferred(t *testing.T) { } } +func TestCluster_Err(t *testing.T) { + inputs := []struct { + Name string + Fixture *fixture + Test func(t *testing.T, f *fixture, o *nodeUpdateObserver, cl *Cluster) + }{ + { + Name: "AllAlive", + Fixture: newFixture(t, 2), + Test: func(t *testing.T, f *fixture, o *nodeUpdateObserver, cl *Cluster) { + f.Nodes[0].setStatus(nodeStatusStandby) + f.Nodes[1].setStatus(nodeStatusPrimary) + waitForNode(t, o, cl.WaitForPrimary, f.Nodes[1].Node) + + require.NoError(t, cl.Err()) + }, + }, + { + Name: "AllDead", + Fixture: newFixture(t, 2), + Test: func(t *testing.T, f *fixture, o *nodeUpdateObserver, cl *Cluster) { + waitForNode(t, o, cl.WaitForPrimary, nil) + + err := cl.Err() + require.Error(t, err) + assert.ErrorContains(t, err, fmt.Sprintf("%q node error occurred at", f.Nodes[0].Node.Addr())) + assert.ErrorContains(t, err, fmt.Sprintf("%q node error occurred at", f.Nodes[1].Node.Addr())) + }, + }, + { + Name: "PrimaryAliveOtherDead", + Fixture: newFixture(t, 2), + Test: func(t *testing.T, f *fixture, o *nodeUpdateObserver, cl *Cluster) { + f.Nodes[1].setStatus(nodeStatusPrimary) + waitForNode(t, o, cl.WaitForPrimary, f.Nodes[1].Node) + + err := cl.Err() + require.Error(t, err) + assert.ErrorContains(t, err, fmt.Sprintf("%q node error occurred at", f.Nodes[0].Node.Addr())) + assert.NotContains(t, err.Error(), fmt.Sprintf("%q node error occurred at", f.Nodes[1].Node.Addr())) + }, + }, + { + Name: "PrimaryDeadOtherAlive", + Fixture: newFixture(t, 2), + Test: func(t *testing.T, f *fixture, o *nodeUpdateObserver, cl *Cluster) { + f.Nodes[0].setStatus(nodeStatusStandby) + waitForNode(t, o, cl.WaitForPrimary, nil) + + err := cl.Err() + require.Error(t, err) + assert.NotContains(t, err.Error(), fmt.Sprintf("%q node error occurred at", f.Nodes[0].Node.Addr())) + assert.ErrorContains(t, err, fmt.Sprintf("%q node error occurred at", f.Nodes[1].Node.Addr())) + }, + }, + } + + for _, input := range inputs { + t.Run(input.Name, func(t *testing.T) { + defer input.Fixture.AssertExpectations(t) + + var o nodeUpdateObserver + cl := setupCluster(t, input.Fixture, o.Tracer()) + defer func() { require.NoError(t, cl.Close()) }() + + input.Test(t, input.Fixture, &o, cl) + }) + } +} + type nodeStatus int64 const ( diff --git a/errors_collector.go b/errors_collector.go new file mode 100644 index 0000000..def511e --- /dev/null +++ b/errors_collector.go @@ -0,0 +1,105 @@ +/* + Copyright 2020 YANDEX LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package hasql + +import ( + "fmt" + "sort" + "strings" + "sync" + "time" +) + +// CollectedErrors are errors collected when checking node statuses +type CollectedErrors struct { + Errors []NodeError +} + +func (e *CollectedErrors) Error() string { + if len(e.Errors) == 1 { + return e.Errors[0].Error() + } + + errs := make([]string, len(e.Errors)) + for i, ne := range e.Errors { + errs[i] = ne.Error() + } + /* + I don't believe there exist 'best join separator' that fit all cases (cli output, JSON, .. etc), + so we use newline as error.Join did it. + In difficult cases (as suggested in https://github.com/yandex/go-hasql/pull/14), + the user should be able to receive "raw" errors and format them as it suits him. + */ + return strings.Join(errs, "\n") +} + +// NodeError is error that background goroutine got while check given node +type NodeError struct { + Addr string + Err error + OccurredAt time.Time +} + +func (e *NodeError) Error() string { + // 'foo.db' node error occurred at '2009-11-10..': FATAL: terminating connection due to ... + return fmt.Sprintf("%q node error occurred at %q: %s", e.Addr, e.OccurredAt, e.Err) +} + +type errorsCollector struct { + store map[string]NodeError + mu sync.Mutex +} + +func newErrorsCollector() errorsCollector { + return errorsCollector{store: make(map[string]NodeError)} +} + +func (e *errorsCollector) Add(addr string, err error, occurredAt time.Time) { + e.mu.Lock() + defer e.mu.Unlock() + + e.store[addr] = NodeError{ + Addr: addr, + Err: err, + OccurredAt: occurredAt, + } +} + +func (e *errorsCollector) Remove(addr string) { + e.mu.Lock() + defer e.mu.Unlock() + + delete(e.store, addr) +} + +func (e *errorsCollector) Err() error { + e.mu.Lock() + errList := make([]NodeError, 0, len(e.store)) + for _, nErr := range e.store { + errList = append(errList, nErr) + } + e.mu.Unlock() + + if len(errList) == 0 { + return nil + } + + sort.Slice(errList, func(i, j int) bool { + return errList[i].OccurredAt.Before(errList[j].OccurredAt) + }) + return &CollectedErrors{Errors: errList} +} diff --git a/errors_collector_test.go b/errors_collector_test.go new file mode 100644 index 0000000..4124e6a --- /dev/null +++ b/errors_collector_test.go @@ -0,0 +1,74 @@ +/* + Copyright 2020 YANDEX LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package hasql + +import ( + "errors" + "fmt" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestErrorsCollector(t *testing.T) { + nodesCount := 10 + errCollector := newErrorsCollector() + require.NoError(t, errCollector.Err()) + + connErr := errors.New("node connection error") + occurredAt := time.Now() + + var wg sync.WaitGroup + wg.Add(nodesCount) + for i := 1; i <= nodesCount; i++ { + go func(i int) { + defer wg.Done() + errCollector.Add( + fmt.Sprintf("node-%d", i), + connErr, + occurredAt, + ) + }(i) + } + + errCollectDone := make(chan struct{}) + go func() { + for { + select { + case <-errCollectDone: + return + default: + // there are no assertions here, because that logic expected to run with -race, + // otherwise it doesn't test anything, just eat CPU. + _ = errCollector.Err() + } + } + }() + + wg.Wait() + close(errCollectDone) + + err := errCollector.Err() + for i := 1; i <= nodesCount; i++ { + assert.ErrorContains(t, err, fmt.Sprintf("\"node-%d\" node error occurred at", i)) + } + assert.ErrorContains(t, err, connErr.Error()) + +}