diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 6f0bb86..7dd4601 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -14,6 +14,6 @@ jobs: steps: - uses: actions/checkout@v2 - name: golangci-lint - uses: golangci/golangci-lint-action@v3 + uses: golangci/golangci-lint-action@v6 with: - version: v1.52.2 + version: v1.62.0 diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 02347d4..d9ea1f7 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -11,7 +11,7 @@ jobs: test: strategy: matrix: - go-version: [1.19.x, 1.20.x] + go-version: [1.22.x, 1.23.x] os: [ubuntu-latest, macos-latest, windows-latest] env: OS: ${{ matrix.os }} @@ -19,16 +19,16 @@ jobs: runs-on: ${{ matrix.os }} steps: - name: Install Go - uses: actions/setup-go@v2 + uses: actions/setup-go@v5 with: go-version: ${{ matrix.go-version }} - name: Checkout code - uses: actions/checkout@v2 + uses: actions/checkout@v4 - name: Test run: go test -race ./... -coverprofile=coverage.txt -covermode=atomic - name: Upload coverage to Codecov uses: codecov/codecov-action@v1 - if: matrix.os == 'ubuntu-latest' && matrix.go-version == '1.20.x' + if: matrix.os == 'ubuntu-latest' && matrix.go-version == '1.23.x' with: file: ./coverage.txt flags: unittests diff --git a/.golangci.yml b/.golangci.yml index d761778..4f7691a 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -100,25 +100,20 @@ linters-settings: linters: disable-all: true enable: - - deadcode - - depguard - errcheck - goconst - gofmt # On why gofmt when goimports is enabled - https://github.com/golang/go/issues/21476 - goimports - - golint - gosimple - govet - ineffassign - - maligned - misspell + - revive - staticcheck - - structcheck - typecheck - unconvert - unparam - unused - - varcheck issues: # List of regexps of issue texts to exclude, empty list by default. diff --git a/README.md b/README.md index 0d1dee8..464c42f 100644 --- a/README.md +++ b/README.md @@ -21,7 +21,7 @@ With [Go module](https://github.com/golang/go/wiki/Modules) support, simply add the following import ```go -import "golang.yandex/hasql" +import "golang.yandex/hasql/v2" ``` to your code, and then `go [build|run|test]` will automatically fetch the @@ -30,19 +30,22 @@ necessary dependencies. Otherwise, to install the `hasql` package, run the following command: ```console -$ go get -u golang.yandex/hasql +$ go get -u golang.yandex/hasql/v2 ``` ## How does it work -`hasql` operates using standard `database/sql` connection pool objects. User creates `*sql.DB` objects for each node of database cluster and passes them to constructor. Library keeps up to date information on state of each node by 'pinging' them periodically. User is provided with a set of interfaces to retrieve `*sql.DB` object suitable for required operation. +`hasql` operates using standard `database/sql` connection pool objects. User creates `*sql.DB`-compatible objects for each node of database cluster and passes them to constructor. Library keeps up to date information on state of each node by 'pinging' them periodically. User is provided with a set of interfaces to retrieve database node object suitable for required operation. ```go dbFoo, _ := sql.Open("pgx", "host=foo") dbBar, _ := sql.Open("pgx", "host=bar") -cl, err := hasql.NewCluster( - []hasql.Node{hasql.NewNode("foo", dbFoo), hasql.NewNode("bar", dbBar) }, - checkers.PostgreSQL, + +discoverer := NewStaticNodeDiscoverer( + NewNode("foo", dbFoo), + NewNode("bar", dbBar), ) + +cl, err := hasql.NewCluster(discoverer, hasql.PostgreSQLChecker) if err != nil { ... } node := cl.Primary() @@ -50,62 +53,71 @@ if node == nil { err := cl.Err() // most recent errors for all nodes in the cluster } -// Do anything you like -fmt.Println("Node address", node.Addr) +fmt.Printf("got node %s\n", node) ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() + if err = node.DB().PingContext(ctx); err != nil { ... } ``` `hasql` does not configure provided connection pools in any way. It is user's job to set them up properly. Library does handle their lifetime though - pools are closed when `Cluster` object is closed. +### Concepts and entities + +**Cluster** is a set of `database/sql`-compatible nodes that tracks their lifespan and provides access to each individual nodes. + +**Node** is a single database instance in high-availability cluster. + +**Node discoverer** provides nodes objects to cluster. This abstraction allows user to dynamically change set of cluster nodes, for example collect nodes list via Service Discovery (etcd, Consul). + +**Node checker** collects information about current state of individual node in cluster, such as: cluster role, network latency, replication lag, etc. + +**Node picker** picks node from cluster by given criterion using predefined algorithm: random, round-robin, lowest latency, etc. + ### Supported criteria -_Alive primary_|_Alive standby_|_Any alive_ node, or _none_ otherwise +_Alive primary_|_Alive standby_|_Any alive_ node or _none_ otherwise ```go -node := c.Primary() -if node == nil { ... } +node := c.Node(hasql.Alive) ``` -_Alive primary_|_Alive standby_, or _any alive_ node, or _none_ otherwise +_Alive primary_ or _none_ otherwise ```go -node := c.PreferPrimary() -if node == nil { ... } +node := c.Node(hasql.Primary) ``` -### Ways of accessing nodes -Any of _currently alive_ nodes _satisfying criteria_, or _none_ otherwise +_Alive standby_ or _none_ otherwise ```go -node := c.Primary() -if node == nil { ... } +node := c.Node(hasql.Standby) ``` -Any of _currently alive_ nodes _satisfying criteria_, or _wait_ for one to become _alive_ +_Alive primary_|_Alive standby_ or _none_ otherwise ```go -ctx, cancel := context.WithTimeout(context.Background(), time.Second) -defer cancel() -node, err := c.WaitForPrimary(ctx) -if err == nil { ... } +node := c.Node(hasql.PreferPrimary) +``` + +_Alive standby_|_Alive primary_ or _none_ otherwise +```go +node := c.Node(hasql.PreferStandby) ``` ### Node pickers When user asks `Cluster` object for a node a random one from a list of suitable nodes is returned. User can override this behavior by providing a custom node picker. -Library provides a couple of predefined pickers. For example if user wants 'closest' node (with lowest latency) `PickNodeClosest` picker should be used. +Library provides a couple of predefined pickers. For example if user wants 'closest' node (with lowest latency) `LatencyNodePicker` should be used. ```go cl, err := hasql.NewCluster( - []hasql.Node{hasql.NewNode("foo", dbFoo), hasql.NewNode("bar", dbBar) }, - checkers.PostgreSQL, - hasql.WithNodePicker(hasql.PickNodeClosest()), + hasql.NewStaticNodeDiscoverer(hasql.NewNode("foo", dbFoo), hasql.NewNode("bar", dbBar)), + hasql.PostgreSQLChecker, + hasql.WithNodePicker(new(hasql.LatencyNodePicker[*sql.DB])), ) -if err != nil { ... } ``` ## Supported databases -Since library works over standard `database/sql` it supports any database that has a `database/sql` driver. All it requires is a database-specific checker function that can tell if node is primary or standby. +Since library requires `Querier` interface, which describes a subset of `database/sql.DB` methods, it supports any database that has a `database/sql` driver. All it requires is a database-specific checker function that can provide node state info. -Check out `golang.yandex/hasql/checkers` package for more information. +Check out `node_checker.go` file for more information. ### Caveats Node's state is transient at any given time. If `Primary()` returns a node it does not mean that node is still primary when you execute statement on it. All it means is that it was primary when it was last checked. Nodes can change their state at a whim or even go offline and `hasql` can't control it in any manner. @@ -115,8 +127,3 @@ This is one of the reasons why nodes do not expose their perceived state to user ## Extensions ### Instrumentation You can add instrumentation via `Tracer` object similar to [httptrace](https://godoc.org/net/http/httptrace) in standard library. - -### sqlx -`hasql` can operate over `database/sql` pools wrapped with [sqlx](https://github.com/jmoiron/sqlx). It works the same as with standard library but requires user to import `golang.yandex/hasql/sqlx` instead. - -Refer to `golang.yandex/hasql/sqlx` package for more information. diff --git a/check_nodes.go b/check_nodes.go deleted file mode 100644 index 22fe219..0000000 --- a/check_nodes.go +++ /dev/null @@ -1,162 +0,0 @@ -/* - Copyright 2020 YANDEX LLC - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ - -package hasql - -import ( - "context" - "sort" - "sync" - "time" -) - -type checkedNode struct { - Node Node - Latency time.Duration -} - -type checkedNodesList []checkedNode - -var _ sort.Interface = checkedNodesList{} - -func (list checkedNodesList) Len() int { - return len(list) -} - -func (list checkedNodesList) Less(i, j int) bool { - return list[i].Latency < list[j].Latency -} - -func (list checkedNodesList) Swap(i, j int) { - list[i], list[j] = list[j], list[i] -} - -func (list checkedNodesList) Nodes() []Node { - res := make([]Node, 0, len(list)) - for _, node := range list { - res = append(res, node.Node) - } - - return res -} - -type groupedCheckedNodes struct { - Primaries checkedNodesList - Standbys checkedNodesList -} - -// Alive returns merged primaries and standbys sorted by latency. Primaries and standbys are expected to be -// sorted beforehand. -func (nodes groupedCheckedNodes) Alive() []Node { - res := make([]Node, len(nodes.Primaries)+len(nodes.Standbys)) - - var i int - for len(nodes.Primaries) > 0 && len(nodes.Standbys) > 0 { - if nodes.Primaries[0].Latency < nodes.Standbys[0].Latency { - res[i] = nodes.Primaries[0].Node - nodes.Primaries = nodes.Primaries[1:] - } else { - res[i] = nodes.Standbys[0].Node - nodes.Standbys = nodes.Standbys[1:] - } - - i++ - } - - for j := 0; j < len(nodes.Primaries); j++ { - res[i] = nodes.Primaries[j].Node - i++ - } - - for j := 0; j < len(nodes.Standbys); j++ { - res[i] = nodes.Standbys[j].Node - i++ - } - - return res -} - -type checkExecutorFunc func(ctx context.Context, node Node) (bool, time.Duration, error) - -// checkNodes takes slice of nodes, checks them in parallel and returns the alive ones. -// Accepts customizable executor which enables time-independent tests for node sorting based on 'latency'. -func checkNodes(ctx context.Context, nodes []Node, executor checkExecutorFunc, tracer Tracer, errCollector *errorsCollector) AliveNodes { - checkedNodes := groupedCheckedNodes{ - Primaries: make(checkedNodesList, 0, len(nodes)), - Standbys: make(checkedNodesList, 0, len(nodes)), - } - - var mu sync.Mutex - var wg sync.WaitGroup - wg.Add(len(nodes)) - for _, node := range nodes { - go func(node Node, wg *sync.WaitGroup) { - defer wg.Done() - - primary, duration, err := executor(ctx, node) - if err != nil { - if tracer.NodeDead != nil { - tracer.NodeDead(node, err) - } - if errCollector != nil { - errCollector.Add(node.Addr(), err, time.Now()) - } - return - } - if errCollector != nil { - errCollector.Remove(node.Addr()) - } - - if tracer.NodeAlive != nil { - tracer.NodeAlive(node) - } - - nl := checkedNode{Node: node, Latency: duration} - - mu.Lock() - defer mu.Unlock() - if primary { - checkedNodes.Primaries = append(checkedNodes.Primaries, nl) - } else { - checkedNodes.Standbys = append(checkedNodes.Standbys, nl) - } - }(node, &wg) - } - wg.Wait() - - sort.Sort(checkedNodes.Primaries) - sort.Sort(checkedNodes.Standbys) - - return AliveNodes{ - Alive: checkedNodes.Alive(), - Primaries: checkedNodes.Primaries.Nodes(), - Standbys: checkedNodes.Standbys.Nodes(), - } -} - -// checkExecutor returns checkExecutorFunc which can execute supplied check. -func checkExecutor(checker NodeChecker) checkExecutorFunc { - return func(ctx context.Context, node Node) (bool, time.Duration, error) { - ts := time.Now() - primary, err := checker(ctx, node.DB()) - d := time.Since(ts) - if err != nil { - return false, d, err - } - - return primary, d, nil - } -} diff --git a/check_nodes_test.go b/check_nodes_test.go deleted file mode 100644 index 0293f24..0000000 --- a/check_nodes_test.go +++ /dev/null @@ -1,204 +0,0 @@ -/* - Copyright 2020 YANDEX LLC - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ - -package hasql - -import ( - "context" - "errors" - "fmt" - "math/rand" - "sort" - "testing" - "time" - - "github.com/DATA-DOG/go-sqlmock" - "github.com/gofrs/uuid" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestCheckedNodesList_Len(t *testing.T) { - nodes := checkedNodesList{checkedNode{}, checkedNode{}, checkedNode{}} - require.Equal(t, 3, nodes.Len()) -} - -func TestCheckedNodesList_Less(t *testing.T) { - nodes := checkedNodesList{checkedNode{Latency: time.Nanosecond}, checkedNode{Latency: 2 * time.Nanosecond}} - require.True(t, nodes.Less(0, 1)) - require.False(t, nodes.Less(1, 0)) -} - -func TestCheckedNodesList_Swap(t *testing.T) { - nodes := checkedNodesList{checkedNode{Latency: time.Nanosecond}, checkedNode{Latency: 2 * time.Nanosecond}} - nodes.Swap(0, 1) - assert.Equal(t, 2*time.Nanosecond, nodes[0].Latency) - assert.Equal(t, time.Nanosecond, nodes[1].Latency) -} - -func TestCheckedNodesList_Sort(t *testing.T) { - nodes := checkedNodesList{checkedNode{Latency: 2 * time.Nanosecond}, checkedNode{Latency: 3 * time.Nanosecond}, checkedNode{Latency: time.Nanosecond}} - sort.Sort(nodes) - for i := range nodes { - assert.Equal(t, time.Duration(i+1)*time.Nanosecond, nodes[i].Latency) - } -} - -func TestGroupedCheckedNodes_Alive(t *testing.T) { - // TODO: this test does not cover all the cases but better than nothing - const count = 10 - var expected []Node - var input groupedCheckedNodes - for i := 0; i < count; i++ { - node := checkedNode{Node: NewNode(fmt.Sprintf("%d", i), nil), Latency: time.Duration(i+1) * time.Nanosecond} - expected = append(expected, node.Node) - if i%2 == 0 { - input.Primaries = append(input.Primaries, node) - } else { - input.Standbys = append(input.Standbys, node) - } - } - require.Len(t, expected, count) - require.NotEmpty(t, input.Primaries) - require.NotEmpty(t, input.Standbys) - require.Equal(t, count, len(input.Primaries)+len(input.Standbys)) - - alive := input.Alive() - require.Len(t, alive, count) - require.Equal(t, expected, alive) -} - -func TestCheckNodes(t *testing.T) { - const count = 100 - var nodes []Node - expected := AliveNodes{Alive: make([]Node, count)} - for i := 0; i < count; i++ { - db, _, err := sqlmock.New() - require.NoError(t, err) - require.NotNil(t, db) - - node := NewNode(uuid.Must(uuid.NewV4()).String(), db) - - for { - // Randomize 'order' (latency) - pos := rand.Intn(count) - if expected.Alive[pos] == nil { - expected.Alive[pos] = node - break - } - } - - nodes = append(nodes, node) - } - - require.Len(t, expected.Alive, count) - - // Fill primaries and standbys - for i, node := range expected.Alive { - if i%2 == 0 { - expected.Primaries = append(expected.Primaries, node) - } else { - expected.Standbys = append(expected.Standbys, node) - } - } - - require.NotEmpty(t, expected.Primaries) - require.NotEmpty(t, expected.Standbys) - require.Equal(t, count, len(expected.Primaries)+len(expected.Standbys)) - - executor := func(ctx context.Context, node Node) (bool, time.Duration, error) { - // Alive nodes set the expected 'order' (latency) of all available nodes. - // Return duration based on that order. - var duration time.Duration - for i, alive := range expected.Alive { - if alive == node { - duration = time.Duration(i) * time.Nanosecond - break - } - } - - for _, primary := range expected.Primaries { - if primary == node { - return true, duration, nil - } - } - - for _, standby := range expected.Standbys { - if standby == node { - return false, duration, nil - } - } - - return false, 0, errors.New("node not found") - } - - errCollector := newErrorsCollector() - alive := checkNodes(context.Background(), nodes, executor, Tracer{}, &errCollector) - - assert.NoError(t, errCollector.Err()) - assert.Equal(t, expected.Primaries, alive.Primaries) - assert.Equal(t, expected.Standbys, alive.Standbys) - assert.Equal(t, expected.Alive, alive.Alive) -} - -func TestCheckNodesWithErrors(t *testing.T) { - const count = 5 - var nodes []Node - for i := 0; i < count; i++ { - db, _, err := sqlmock.New() - require.NoError(t, err) - require.NotNil(t, db) - nodes = append(nodes, NewNode(uuid.Must(uuid.NewV4()).String(), db)) - } - - executor := func(ctx context.Context, node Node) (bool, time.Duration, error) { - return false, 0, errors.New("node not found") - } - - errCollector := newErrorsCollector() - checkNodes(context.Background(), nodes, executor, Tracer{}, &errCollector) - - err := errCollector.Err() - for i := 0; i < count; i++ { - assert.ErrorContains(t, err, fmt.Sprintf("%q node error occurred at", nodes[i].Addr())) - } - assert.ErrorContains(t, err, "node not found") -} - -func TestCheckNodesWithErrorsWhenNodesBecameAlive(t *testing.T) { - const count = 5 - var nodes []Node - for i := 0; i < count; i++ { - db, _, err := sqlmock.New() - require.NoError(t, err) - require.NotNil(t, db) - nodes = append(nodes, NewNode(uuid.Must(uuid.NewV4()).String(), db)) - } - - executor := func(ctx context.Context, node Node) (bool, time.Duration, error) { - return false, 0, errors.New("node not found") - } - - errCollector := newErrorsCollector() - checkNodes(context.Background(), nodes, executor, Tracer{}, &errCollector) - require.Error(t, errCollector.Err()) - - executor = func(ctx context.Context, node Node) (bool, time.Duration, error) { - return true, 1, nil - } - checkNodes(context.Background(), nodes, executor, Tracer{}, &errCollector) - require.NoError(t, errCollector.Err()) -} diff --git a/checked_node.go b/checked_node.go new file mode 100644 index 0000000..74af20d --- /dev/null +++ b/checked_node.go @@ -0,0 +1,199 @@ +/* + Copyright 2020 YANDEX LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package hasql + +import ( + "context" + "errors" + "fmt" + "slices" + "sync" +) + +// CheckedNodes holds references to all available cluster nodes +type CheckedNodes[T Querier] struct { + discovered []*Node[T] + alive []CheckedNode[T] + primaries []CheckedNode[T] + standbys []CheckedNode[T] + err error +} + +// Discovered returns a list of nodes discovered in cluster +func (c CheckedNodes[T]) Discovered() []*Node[T] { + return c.discovered +} + +// Alive returns a list of all successfully checked nodes irregarding their cluster role +func (c CheckedNodes[T]) Alive() []CheckedNode[T] { + return c.alive +} + +// Primaries returns list of all successfully checked nodes with primary role +func (c CheckedNodes[T]) Primaries() []CheckedNode[T] { + return c.primaries +} + +// Standbys returns list of all successfully checked nodes with standby role +func (c CheckedNodes[T]) Standbys() []CheckedNode[T] { + return c.standbys +} + +// Err holds information about cause of node check failure. +func (c CheckedNodes[T]) Err() error { + return c.err +} + +// CheckedNode contains most recent state of single cluster node +type CheckedNode[T Querier] struct { + Node *Node[T] + Info NodeInfoProvider +} + +// checkNodes takes slice of nodes, checks them in parallel and returns the alive ones +func checkNodes[T Querier](ctx context.Context, discoverer NodeDiscoverer[T], checkFn NodeChecker, compareFn func(a, b CheckedNode[T]) int, tracer Tracer[T]) CheckedNodes[T] { + discoveredNodes, err := discoverer.DiscoverNodes(ctx) + if err != nil { + // error discovering nodes + return CheckedNodes[T]{ + err: fmt.Errorf("cannot discover cluster nodes: %w", err), + } + } + + var mu sync.Mutex + checked := make([]CheckedNode[T], 0, len(discoveredNodes)) + var errs NodeCheckErrors[T] + + var wg sync.WaitGroup + wg.Add(len(discoveredNodes)) + for _, node := range discoveredNodes { + go func(node *Node[T]) { + defer wg.Done() + + // check single node state + info, err := checkFn(ctx, node.DB()) + if err != nil { + cerr := NodeCheckError[T]{ + node: node, + err: err, + } + + // node is dead - make trace call + if tracer.NodeDead != nil { + tracer.NodeDead(cerr) + } + + // store node check error + mu.Lock() + defer mu.Unlock() + errs = append(errs, cerr) + return + } + + cn := CheckedNode[T]{ + Node: node, + Info: info, + } + + // make trace call about alive node + if tracer.NodeAlive != nil { + tracer.NodeAlive(cn) + } + + // store checked alive node + mu.Lock() + defer mu.Unlock() + checked = append(checked, cn) + }(node) + } + + // wait for all nodes to be checked + wg.Wait() + + // sort checked nodes + slices.SortFunc(checked, compareFn) + + // split checked nodes by roles + alive := make([]CheckedNode[T], 0, len(checked)) + // in almost all cases there is only one primary node in cluster + primaries := make([]CheckedNode[T], 0, 1) + standbys := make([]CheckedNode[T], 0, len(checked)) + for _, cn := range checked { + switch cn.Info.Role() { + case NodeRolePrimary: + primaries = append(primaries, cn) + alive = append(alive, cn) + case NodeRoleStandby: + standbys = append(standbys, cn) + alive = append(alive, cn) + default: + // treat node with undetermined role as dead + cerr := NodeCheckError[T]{ + node: cn.Node, + err: errors.New("cannot determine node role"), + } + errs = append(errs, cerr) + + if tracer.NodeDead != nil { + tracer.NodeDead(cerr) + } + } + } + + res := CheckedNodes[T]{ + discovered: discoveredNodes, + alive: alive, + primaries: primaries, + standbys: standbys, + err: func() error { + if len(errs) != 0 { + return errs + } + return nil + }(), + } + + return res +} + +// pickNodeByCriterion is a helper function to pick a single node by given criterion +func pickNodeByCriterion[T Querier](nodes CheckedNodes[T], picker NodePicker[T], criterion NodeStateCriterion) *Node[T] { + var subset []CheckedNode[T] + + switch criterion { + case Alive: + subset = nodes.alive + case Primary: + subset = nodes.primaries + case Standby: + subset = nodes.standbys + case PreferPrimary: + if subset = nodes.primaries; len(subset) == 0 { + subset = nodes.standbys + } + case PreferStandby: + if subset = nodes.standbys; len(subset) == 0 { + subset = nodes.primaries + } + } + + if len(subset) == 0 { + return nil + } + + return picker.PickNode(subset).Node +} diff --git a/checked_node_test.go b/checked_node_test.go new file mode 100644 index 0000000..37ec45a --- /dev/null +++ b/checked_node_test.go @@ -0,0 +1,483 @@ +/* + Copyright 2020 YANDEX LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package hasql + +import ( + "context" + "database/sql" + "errors" + "io" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestCheckNodes(t *testing.T) { + t.Run("discovery_error", func(t *testing.T) { + discoverer := mockNodeDiscoverer[*sql.DB]{ + err: io.EOF, + } + + nodes := checkNodes(context.Background(), discoverer, nil, nil, Tracer[*sql.DB]{}) + assert.Empty(t, nodes.discovered) + assert.Empty(t, nodes.alive) + assert.Empty(t, nodes.primaries) + assert.Empty(t, nodes.standbys) + assert.ErrorIs(t, nodes.err, io.EOF) + }) + + t.Run("all_nodes_alive", func(t *testing.T) { + node1 := &Node[*mockQuerier]{ + name: "shimba", + db: &mockQuerier{name: "primary"}, + } + node2 := &Node[*mockQuerier]{ + name: "boomba", + db: &mockQuerier{name: "standby1"}, + } + node3 := &Node[*mockQuerier]{ + name: "looken", + db: &mockQuerier{name: "standby2"}, + } + + discoverer := mockNodeDiscoverer[*mockQuerier]{ + nodes: []*Node[*mockQuerier]{node1, node2, node3}, + } + + // mock node checker func + checkFn := func(_ context.Context, q Querier) (NodeInfoProvider, error) { + mq, ok := q.(*mockQuerier) + if !ok { + return NodeInfo{}, nil + } + + switch mq.name { + case node1.db.name: + return NodeInfo{ClusterRole: NodeRolePrimary, NetworkLatency: 100}, nil + case node2.db.name: + return NodeInfo{ClusterRole: NodeRoleStandby, NetworkLatency: 50}, nil + case node3.db.name: + return NodeInfo{ClusterRole: NodeRoleStandby, NetworkLatency: 70}, nil + default: + return NodeInfo{}, nil + } + } + + var picker LatencyNodePicker[*mockQuerier] + var tracer Tracer[*mockQuerier] + + checked := checkNodes(context.Background(), discoverer, checkFn, picker.CompareNodes, tracer) + + expected := CheckedNodes[*mockQuerier]{ + discovered: []*Node[*mockQuerier]{node1, node2, node3}, + alive: []CheckedNode[*mockQuerier]{ + {Node: node2, Info: NodeInfo{ClusterRole: NodeRoleStandby, NetworkLatency: 50}}, + {Node: node3, Info: NodeInfo{ClusterRole: NodeRoleStandby, NetworkLatency: 70}}, + {Node: node1, Info: NodeInfo{ClusterRole: NodeRolePrimary, NetworkLatency: 100}}, + }, + primaries: []CheckedNode[*mockQuerier]{ + {Node: node1, Info: NodeInfo{ClusterRole: NodeRolePrimary, NetworkLatency: 100}}, + }, + standbys: []CheckedNode[*mockQuerier]{ + {Node: node2, Info: NodeInfo{ClusterRole: NodeRoleStandby, NetworkLatency: 50}}, + {Node: node3, Info: NodeInfo{ClusterRole: NodeRoleStandby, NetworkLatency: 70}}, + }, + } + + assert.Equal(t, expected, checked) + }) + + t.Run("all_nodes_dead", func(t *testing.T) { + node1 := &Node[*mockQuerier]{ + name: "shimba", + db: &mockQuerier{name: "primary"}, + } + node2 := &Node[*mockQuerier]{ + name: "boomba", + db: &mockQuerier{name: "standby1"}, + } + node3 := &Node[*mockQuerier]{ + name: "looken", + db: &mockQuerier{name: "standby2"}, + } + + discoverer := mockNodeDiscoverer[*mockQuerier]{ + nodes: []*Node[*mockQuerier]{node1, node2, node3}, + } + + // mock node checker func + checkFn := func(_ context.Context, _ Querier) (NodeInfoProvider, error) { + return nil, io.EOF + } + + var picker LatencyNodePicker[*mockQuerier] + var tracer Tracer[*mockQuerier] + + checked := checkNodes(context.Background(), discoverer, checkFn, picker.CompareNodes, tracer) + + expectedDiscovered := []*Node[*mockQuerier]{node1, node2, node3} + assert.Equal(t, expectedDiscovered, checked.discovered) + + assert.Empty(t, checked.alive) + assert.Empty(t, checked.primaries) + assert.Empty(t, checked.standbys) + + var cerrs NodeCheckErrors[*mockQuerier] + assert.ErrorAs(t, checked.err, &cerrs) + assert.Len(t, cerrs, 3) + for _, cerr := range cerrs { + assert.ErrorIs(t, cerr, io.EOF) + } + }) + + t.Run("one_standby_is_dead", func(t *testing.T) { + node1 := &Node[*mockQuerier]{ + name: "shimba", + db: &mockQuerier{name: "primary"}, + } + node2 := &Node[*mockQuerier]{ + name: "boomba", + db: &mockQuerier{name: "standby1"}, + } + node3 := &Node[*mockQuerier]{ + name: "looken", + db: &mockQuerier{name: "standby2"}, + } + + discoverer := mockNodeDiscoverer[*mockQuerier]{ + nodes: []*Node[*mockQuerier]{node1, node2, node3}, + } + + // mock node checker func + checkFn := func(_ context.Context, q Querier) (NodeInfoProvider, error) { + mq, ok := q.(*mockQuerier) + if !ok { + return NodeInfo{}, nil + } + + switch mq.name { + case node1.db.name: + return NodeInfo{ClusterRole: NodeRolePrimary, NetworkLatency: 100}, nil + case node2.db.name: + return NodeInfo{ClusterRole: NodeRoleStandby, NetworkLatency: 50}, nil + case node3.db.name: + return nil, io.EOF + default: + return NodeInfo{}, nil + } + } + + var picker LatencyNodePicker[*mockQuerier] + var tracer Tracer[*mockQuerier] + + checked := checkNodes(context.Background(), discoverer, checkFn, picker.CompareNodes, tracer) + + expected := CheckedNodes[*mockQuerier]{ + discovered: []*Node[*mockQuerier]{node1, node2, node3}, + alive: []CheckedNode[*mockQuerier]{ + {Node: node2, Info: NodeInfo{ClusterRole: NodeRoleStandby, NetworkLatency: 50}}, + {Node: node1, Info: NodeInfo{ClusterRole: NodeRolePrimary, NetworkLatency: 100}}, + }, + primaries: []CheckedNode[*mockQuerier]{ + {Node: node1, Info: NodeInfo{ClusterRole: NodeRolePrimary, NetworkLatency: 100}}, + }, + standbys: []CheckedNode[*mockQuerier]{ + {Node: node2, Info: NodeInfo{ClusterRole: NodeRoleStandby, NetworkLatency: 50}}, + }, + err: NodeCheckErrors[*mockQuerier]{ + {node: node3, err: io.EOF}, + }, + } + + assert.Equal(t, expected, checked) + }) + + t.Run("primary_is_dead", func(t *testing.T) { + node1 := &Node[*mockQuerier]{ + name: "shimba", + db: &mockQuerier{name: "primary"}, + } + node2 := &Node[*mockQuerier]{ + name: "boomba", + db: &mockQuerier{name: "standby1"}, + } + node3 := &Node[*mockQuerier]{ + name: "looken", + db: &mockQuerier{name: "standby2"}, + } + + discoverer := mockNodeDiscoverer[*mockQuerier]{ + nodes: []*Node[*mockQuerier]{node1, node2, node3}, + } + + // mock node checker func + checkFn := func(_ context.Context, q Querier) (NodeInfoProvider, error) { + mq, ok := q.(*mockQuerier) + if !ok { + return NodeInfo{}, nil + } + + switch mq.name { + case node1.db.name: + return nil, io.EOF + case node2.db.name: + return NodeInfo{ClusterRole: NodeRoleStandby, NetworkLatency: 70}, nil + case node3.db.name: + return NodeInfo{ClusterRole: NodeRoleStandby, NetworkLatency: 50}, nil + default: + return NodeInfo{}, nil + } + } + + var picker LatencyNodePicker[*mockQuerier] + var tracer Tracer[*mockQuerier] + + checked := checkNodes(context.Background(), discoverer, checkFn, picker.CompareNodes, tracer) + + expected := CheckedNodes[*mockQuerier]{ + discovered: []*Node[*mockQuerier]{node1, node2, node3}, + alive: []CheckedNode[*mockQuerier]{ + {Node: node3, Info: NodeInfo{ClusterRole: NodeRoleStandby, NetworkLatency: 50}}, + {Node: node2, Info: NodeInfo{ClusterRole: NodeRoleStandby, NetworkLatency: 70}}, + }, + primaries: []CheckedNode[*mockQuerier]{}, + standbys: []CheckedNode[*mockQuerier]{ + {Node: node3, Info: NodeInfo{ClusterRole: NodeRoleStandby, NetworkLatency: 50}}, + {Node: node2, Info: NodeInfo{ClusterRole: NodeRoleStandby, NetworkLatency: 70}}, + }, + err: NodeCheckErrors[*mockQuerier]{ + {node: node1, err: io.EOF}, + }, + } + + assert.Equal(t, expected, checked) + }) + + t.Run("node_with_unknown_role", func(t *testing.T) { + node1 := &Node[*mockQuerier]{ + name: "shimba", + db: &mockQuerier{name: "unknown"}, + } + node2 := &Node[*mockQuerier]{ + name: "boomba", + db: &mockQuerier{name: "primary"}, + } + node3 := &Node[*mockQuerier]{ + name: "looken", + db: &mockQuerier{name: "standby2"}, + } + + discoverer := mockNodeDiscoverer[*mockQuerier]{ + nodes: []*Node[*mockQuerier]{node1, node2, node3}, + } + + // mock node checker func + checkFn := func(_ context.Context, q Querier) (NodeInfoProvider, error) { + mq, ok := q.(*mockQuerier) + if !ok { + return NodeInfo{}, nil + } + + switch mq.name { + case node1.db.name: + return NodeInfo{}, nil + case node2.db.name: + return NodeInfo{ClusterRole: NodeRolePrimary, NetworkLatency: 20}, nil + case node3.db.name: + return NodeInfo{ClusterRole: NodeRoleStandby, NetworkLatency: 50}, nil + default: + return NodeInfo{}, nil + } + } + + var picker LatencyNodePicker[*mockQuerier] + var tracer Tracer[*mockQuerier] + + checked := checkNodes(context.Background(), discoverer, checkFn, picker.CompareNodes, tracer) + + expected := CheckedNodes[*mockQuerier]{ + discovered: []*Node[*mockQuerier]{node1, node2, node3}, + alive: []CheckedNode[*mockQuerier]{ + {Node: node2, Info: NodeInfo{ClusterRole: NodeRolePrimary, NetworkLatency: 20}}, + {Node: node3, Info: NodeInfo{ClusterRole: NodeRoleStandby, NetworkLatency: 50}}, + }, + primaries: []CheckedNode[*mockQuerier]{ + {Node: node2, Info: NodeInfo{ClusterRole: NodeRolePrimary, NetworkLatency: 20}}, + }, + standbys: []CheckedNode[*mockQuerier]{ + {Node: node3, Info: NodeInfo{ClusterRole: NodeRoleStandby, NetworkLatency: 50}}, + }, + err: NodeCheckErrors[*mockQuerier]{ + {node: node1, err: errors.New("cannot determine node role")}, + }, + } + + assert.Equal(t, expected, checked) + }) +} + +func TestPickNodeByCriterion(t *testing.T) { + t.Run("no_nodes", func(t *testing.T) { + nodes := CheckedNodes[*sql.DB]{} + picker := new(RandomNodePicker[*sql.DB]) + + // all criteria must return nil node + for i := Alive; i < maxNodeCriterion; i++ { + node := pickNodeByCriterion(nodes, picker, i) + assert.Nil(t, node) + } + }) + + t.Run("alive", func(t *testing.T) { + node := &Node[*sql.DB]{ + name: "shimba", + db: new(sql.DB), + } + + nodes := CheckedNodes[*sql.DB]{ + alive: []CheckedNode[*sql.DB]{ + { + Node: node, + Info: NodeInfo{ + ClusterRole: NodeRolePrimary, + }, + }, + }, + } + picker := new(RandomNodePicker[*sql.DB]) + + assert.Equal(t, node, pickNodeByCriterion(nodes, picker, Alive)) + }) + + t.Run("primary", func(t *testing.T) { + node := &Node[*sql.DB]{ + name: "shimba", + db: new(sql.DB), + } + + nodes := CheckedNodes[*sql.DB]{ + primaries: []CheckedNode[*sql.DB]{ + { + Node: node, + Info: NodeInfo{ + ClusterRole: NodeRolePrimary, + }, + }, + }, + } + picker := new(RandomNodePicker[*sql.DB]) + + assert.Equal(t, node, pickNodeByCriterion(nodes, picker, Primary)) + // we will return node on Prefer* creterias also + assert.Equal(t, node, pickNodeByCriterion(nodes, picker, PreferPrimary)) + assert.Equal(t, node, pickNodeByCriterion(nodes, picker, PreferStandby)) + }) + + t.Run("standby", func(t *testing.T) { + node := &Node[*sql.DB]{ + name: "shimba", + db: new(sql.DB), + } + + nodes := CheckedNodes[*sql.DB]{ + standbys: []CheckedNode[*sql.DB]{ + { + Node: node, + Info: NodeInfo{ + ClusterRole: NodeRoleStandby, + }, + }, + }, + } + picker := new(RandomNodePicker[*sql.DB]) + + assert.Equal(t, node, pickNodeByCriterion(nodes, picker, Standby)) + // we will return node on Prefer* creterias also + assert.Equal(t, node, pickNodeByCriterion(nodes, picker, PreferPrimary)) + assert.Equal(t, node, pickNodeByCriterion(nodes, picker, PreferStandby)) + }) + + t.Run("prefer_primary", func(t *testing.T) { + node := &Node[*sql.DB]{ + name: "shimba", + db: new(sql.DB), + } + + // must pick from primaries + nodes := CheckedNodes[*sql.DB]{ + primaries: []CheckedNode[*sql.DB]{ + { + Node: node, + Info: NodeInfo{ + ClusterRole: NodeRolePrimary, + }, + }, + }, + } + picker := new(RandomNodePicker[*sql.DB]) + + assert.Equal(t, node, pickNodeByCriterion(nodes, picker, PreferPrimary)) + + // must pick from standbys + nodes = CheckedNodes[*sql.DB]{ + standbys: []CheckedNode[*sql.DB]{ + { + Node: node, + Info: NodeInfo{ + ClusterRole: NodeRoleStandby, + }, + }, + }, + } + assert.Equal(t, node, pickNodeByCriterion(nodes, picker, PreferPrimary)) + }) + + t.Run("prefer_standby", func(t *testing.T) { + node := &Node[*sql.DB]{ + name: "shimba", + db: new(sql.DB), + } + + // must pick from standbys + nodes := CheckedNodes[*sql.DB]{ + standbys: []CheckedNode[*sql.DB]{ + { + Node: node, + Info: NodeInfo{ + ClusterRole: NodeRoleStandby, + }, + }, + }, + } + picker := new(RandomNodePicker[*sql.DB]) + + assert.Equal(t, node, pickNodeByCriterion(nodes, picker, PreferStandby)) + + // must pick from primaries + nodes = CheckedNodes[*sql.DB]{ + primaries: []CheckedNode[*sql.DB]{ + { + Node: node, + Info: NodeInfo{ + ClusterRole: NodeRolePrimary, + }, + }, + }, + } + assert.Equal(t, node, pickNodeByCriterion(nodes, picker, PreferStandby)) + }) +} diff --git a/checkers/check.go b/checkers/check.go deleted file mode 100644 index 66d9b45..0000000 --- a/checkers/check.go +++ /dev/null @@ -1,34 +0,0 @@ -/* - Copyright 2020 YANDEX LLC - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ - -package checkers - -import ( - "context" - "database/sql" -) - -// Check executes specified query on specified database pool. Query must return single boolean -// value that signals if that pool is connected to primary or not. All errors are returned as is. -func Check(ctx context.Context, db *sql.DB, query string) (bool, error) { - row := db.QueryRowContext(ctx, query) - var primary bool - if err := row.Scan(&primary); err != nil { - return false, err - } - - return primary, nil -} diff --git a/checkers/mssql.go b/checkers/mssql.go deleted file mode 100644 index 7a50fe1..0000000 --- a/checkers/mssql.go +++ /dev/null @@ -1,35 +0,0 @@ -/* - Copyright 2020 YANDEX LLC - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ - -package checkers - -import ( - "context" - "database/sql" -) - -// MSSQL checks whether MSSQL server is primary or not. -func MSSQL(ctx context.Context, db *sql.DB) (bool, error) { - var status bool - - if err := db.QueryRowContext(ctx, "SELECT IIF(count(database_guid) = 0, 'TRUE', 'FALSE') AS STATUS "+ - "FROM sys.database_recovery_status"+ - " WHERE database_guid IS NULL", - ).Scan(&status); err != nil { - return false, err - } - return status, nil -} diff --git a/checkers/postgresql.go b/checkers/postgresql.go deleted file mode 100644 index aba09d0..0000000 --- a/checkers/postgresql.go +++ /dev/null @@ -1,27 +0,0 @@ -/* - Copyright 2020 YANDEX LLC - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ - -package checkers - -import ( - "context" - "database/sql" -) - -// PostgreSQL checks whether PostgreSQL server is primary or not. -func PostgreSQL(ctx context.Context, db *sql.DB) (bool, error) { - return Check(ctx, db, "SELECT NOT pg_is_in_recovery()") -} diff --git a/cluster.go b/cluster.go index 3b4e0ae..10c5714 100644 --- a/cluster.go +++ b/cluster.go @@ -14,169 +14,121 @@ limitations under the License. */ +// Package hasql provides simple and reliable way to access high-availability database setups with multiple hosts. package hasql import ( "context" "errors" - "fmt" + "io" "sync" "sync/atomic" "time" ) -// Default values for cluster config -const ( - DefaultUpdateInterval = time.Second * 5 - DefaultUpdateTimeout = time.Second -) - -type nodeWaiter struct { - Ch chan Node - StateCriteria NodeStateCriteria -} - -// AliveNodes of Cluster -type AliveNodes struct { - Alive []Node - Primaries []Node - Standbys []Node -} - // Cluster consists of number of 'nodes' of a single SQL database. -// Background goroutine periodically checks nodes and updates their status. -type Cluster struct { - tracer Tracer - - // Configuration +type Cluster[T Querier] struct { + // configuration updateInterval time.Duration updateTimeout time.Duration + discoverer NodeDiscoverer[T] checker NodeChecker - picker NodePicker + picker NodePicker[T] + tracer Tracer[T] - // Status - updateStopper chan struct{} - aliveNodes atomic.Value - nodes []Node - errCollector errorsCollector + // status + checkedNodes atomic.Value + stop context.CancelFunc - // Notification - muWaiters sync.Mutex - waiters []nodeWaiter + // broadcast + subscribersMu sync.Mutex + subscribers []updateSubscriber[T] } -// NewCluster constructs cluster object representing a single 'cluster' of SQL database. -// Close function must be called when cluster is not needed anymore. -func NewCluster(nodes []Node, checker NodeChecker, opts ...ClusterOption) (*Cluster, error) { - // Validate nodes - if len(nodes) == 0 { - return nil, errors.New("no nodes provided") +// NewCluster returns object representing a single 'cluster' of SQL databases +func NewCluster[T Querier](discoverer NodeDiscoverer[T], checker NodeChecker, opts ...ClusterOpt[T]) (*Cluster[T], error) { + if discoverer == nil { + return nil, errors.New("node discoverer required") } - for i, node := range nodes { - if node.Addr() == "" { - return nil, fmt.Errorf("node %d has no address", i) - } + // prepare internal 'stop' context + ctx, stopFn := context.WithCancel(context.Background()) - if node.DB() == nil { - return nil, fmt.Errorf("node %d (%q) has nil *sql.DB", i, node.Addr()) - } - } - - cl := &Cluster{ - updateStopper: make(chan struct{}), - updateInterval: DefaultUpdateInterval, - updateTimeout: DefaultUpdateTimeout, + cl := &Cluster[T]{ + updateInterval: 5 * time.Second, + updateTimeout: time.Second, + discoverer: discoverer, checker: checker, - picker: PickNodeRandom(), - nodes: nodes, - errCollector: newErrorsCollector(), + picker: new(RandomNodePicker[T]), + + stop: stopFn, } - // Apply options + // apply options for _, opt := range opts { opt(cl) } // Store initial nodes state - cl.aliveNodes.Store(AliveNodes{}) + cl.checkedNodes.Store(CheckedNodes[T]{}) // Start update routine - go cl.backgroundNodesUpdate() + go cl.backgroundNodesUpdate(ctx) return cl, nil } -// Close databases and stop node updates. -func (cl *Cluster) Close() error { - close(cl.updateStopper) - - var err error - for _, node := range cl.nodes { - if e := node.DB().Close(); e != nil { - // TODO: This is bad, we save only one error. Need multiple-error error package. - err = e +// Close stops node updates. +// Close function must be called when cluster is not needed anymore. +// It returns combined error if multiple nodes returned errors +func (cl *Cluster[T]) Close() (err error) { + cl.stop() + + // close all nodes underlying connection pools + discovered := cl.checkedNodes.Load().(CheckedNodes[T]).discovered + for _, node := range discovered { + if closer, ok := any(node.DB()).(io.Closer); ok { + err = errors.Join(err, closer.Close()) } } - return err -} - -// Nodes returns list of all nodes -func (cl *Cluster) Nodes() []Node { - return cl.nodes -} + // discard any collected state of nodes + cl.checkedNodes.Store(CheckedNodes[T]{}) -func (cl *Cluster) nodesAlive() AliveNodes { - return cl.aliveNodes.Load().(AliveNodes) -} - -func (cl *Cluster) addUpdateWaiter(criteria NodeStateCriteria) <-chan Node { - // Buffered channel is essential. - // Read WaitForNode function for more information. - ch := make(chan Node, 1) - cl.muWaiters.Lock() - defer cl.muWaiters.Unlock() - cl.waiters = append(cl.waiters, nodeWaiter{Ch: ch, StateCriteria: criteria}) - return ch -} - -// WaitForAlive node to appear or until context is canceled -func (cl *Cluster) WaitForAlive(ctx context.Context) (Node, error) { - return cl.WaitForNode(ctx, Alive) -} - -// WaitForPrimary node to appear or until context is canceled -func (cl *Cluster) WaitForPrimary(ctx context.Context) (Node, error) { - return cl.WaitForNode(ctx, Primary) -} - -// WaitForStandby node to appear or until context is canceled -func (cl *Cluster) WaitForStandby(ctx context.Context) (Node, error) { - return cl.WaitForNode(ctx, Standby) + return err } -// WaitForPrimaryPreferred node to appear or until context is canceled -func (cl *Cluster) WaitForPrimaryPreferred(ctx context.Context) (Node, error) { - return cl.WaitForNode(ctx, PreferPrimary) +// Err returns cause of nodes most recent check failures. +// In most cases error is a list of errors of type CheckNodeErrors, original errors +// could be extracted using `errors.As`. +// Example: +// +// var cerrs NodeCheckErrors +// if errors.As(err, &cerrs) { +// for _, cerr := range cerrs { +// fmt.Printf("node: %s, err: %s\n", cerr.Node(), cerr.Err()) +// } +// } +func (cl *Cluster[T]) Err() error { + return cl.checkedNodes.Load().(CheckedNodes[T]).Err() } -// WaitForStandbyPreferred node to appear or until context is canceled -func (cl *Cluster) WaitForStandbyPreferred(ctx context.Context) (Node, error) { - return cl.WaitForNode(ctx, PreferStandby) +// Node returns cluster node with specified status +func (cl *Cluster[T]) Node(criterion NodeStateCriterion) *Node[T] { + return pickNodeByCriterion(cl.checkedNodes.Load().(CheckedNodes[T]), cl.picker, criterion) } // WaitForNode with specified status to appear or until context is canceled -func (cl *Cluster) WaitForNode(ctx context.Context, criteria NodeStateCriteria) (Node, error) { +func (cl *Cluster[T]) WaitForNode(ctx context.Context, criterion NodeStateCriterion) (*Node[T], error) { // Node already exists? - node := cl.Node(criteria) + node := cl.Node(criterion) if node != nil { return node, nil } - ch := cl.addUpdateWaiter(criteria) + ch := cl.addUpdateSubscriber(criterion) // Node might have appeared while we were adding waiter, recheck - node = cl.Node(criteria) + node = cl.Node(criterion) if node != nil { return node, nil } @@ -203,165 +155,44 @@ func (cl *Cluster) WaitForNode(ctx context.Context, criteria NodeStateCriteria) } } -// Alive returns node that is considered alive -func (cl *Cluster) Alive() Node { - return cl.alive(cl.nodesAlive()) -} - -func (cl *Cluster) alive(nodes AliveNodes) Node { - if len(nodes.Alive) == 0 { - return nil - } - - return cl.picker(nodes.Alive) -} - -// Primary returns first available node that is considered alive and is primary (able to execute write operations) -func (cl *Cluster) Primary() Node { - return cl.primary(cl.nodesAlive()) -} - -func (cl *Cluster) primary(nodes AliveNodes) Node { - if len(nodes.Primaries) == 0 { - return nil - } - - return cl.picker(nodes.Primaries) -} - -// Standby returns node that is considered alive and is standby (unable to execute write operations) -func (cl *Cluster) Standby() Node { - return cl.standby(cl.nodesAlive()) -} - -func (cl *Cluster) standby(nodes AliveNodes) Node { - if len(nodes.Standbys) == 0 { - return nil - } - - // select one of standbys - return cl.picker(nodes.Standbys) -} - -// PrimaryPreferred returns primary node if possible, standby otherwise -func (cl *Cluster) PrimaryPreferred() Node { - return cl.primaryPreferred(cl.nodesAlive()) -} - -func (cl *Cluster) primaryPreferred(nodes AliveNodes) Node { - node := cl.primary(nodes) - if node == nil { - node = cl.standby(nodes) - } - - return node -} - -// StandbyPreferred returns standby node if possible, primary otherwise -func (cl *Cluster) StandbyPreferred() Node { - return cl.standbyPreferred(cl.nodesAlive()) -} - -func (cl *Cluster) standbyPreferred(nodes AliveNodes) Node { - node := cl.standby(nodes) - if node == nil { - node = cl.primary(nodes) - } - - return node -} - -// Node returns cluster node with specified status. -func (cl *Cluster) Node(criteria NodeStateCriteria) Node { - return cl.node(cl.nodesAlive(), criteria) -} - -func (cl *Cluster) node(nodes AliveNodes, criteria NodeStateCriteria) Node { - switch criteria { - case Alive: - return cl.alive(nodes) - case Primary: - return cl.primary(nodes) - case Standby: - return cl.standby(nodes) - case PreferPrimary: - return cl.primaryPreferred(nodes) - case PreferStandby: - return cl.standbyPreferred(nodes) - default: - panic(fmt.Sprintf("unknown node state criteria: %d", criteria)) - } -} - -// Err returns the combined error including most recent errors for all nodes. -// This error is CollectedErrors or nil. -func (cl *Cluster) Err() error { - return cl.errCollector.Err() -} - -// backgroundNodesUpdate periodically updates list of live db nodes -func (cl *Cluster) backgroundNodesUpdate() { - // Initial update - cl.updateNodes() +// backgroundNodesUpdate periodically checks list of registered nodes +func (cl *Cluster[T]) backgroundNodesUpdate(ctx context.Context) { + // initial update + cl.updateNodes(ctx) ticker := time.NewTicker(cl.updateInterval) defer ticker.Stop() for { select { - case <-cl.updateStopper: + case <-ctx.Done(): return case <-ticker.C: - cl.updateNodes() + cl.updateNodes(ctx) } } } -// updateNodes pings all db nodes and stores alive ones in a separate slice -func (cl *Cluster) updateNodes() { +// updateNodes performs a new round of cluster state check +// and notifies all subscribers afterwards +func (cl *Cluster[T]) updateNodes(ctx context.Context) { if cl.tracer.UpdateNodes != nil { cl.tracer.UpdateNodes() } - ctx, cancel := context.WithTimeout(context.Background(), cl.updateTimeout) + ctx, cancel := context.WithTimeout(ctx, cl.updateTimeout) defer cancel() - alive := checkNodes(ctx, cl.nodes, checkExecutor(cl.checker), cl.tracer, &cl.errCollector) - cl.aliveNodes.Store(alive) + checked := checkNodes(ctx, cl.discoverer, cl.checker, cl.picker.CompareNodes, cl.tracer) + cl.checkedNodes.Store(checked) - if cl.tracer.UpdatedNodes != nil { - cl.tracer.UpdatedNodes(alive) + if cl.tracer.NodesUpdated != nil { + cl.tracer.NodesUpdated(checked) } - cl.notifyWaiters(alive) + cl.notifyUpdateSubscribers(checked) - if cl.tracer.NotifiedWaiters != nil { - cl.tracer.NotifiedWaiters() + if cl.tracer.WaitersNotified != nil { + cl.tracer.WaitersNotified() } } - -func (cl *Cluster) notifyWaiters(nodes AliveNodes) { - cl.muWaiters.Lock() - defer cl.muWaiters.Unlock() - - if len(cl.waiters) == 0 { - return - } - - var nodelessWaiters []nodeWaiter - // Notify all waiters - for _, waiter := range cl.waiters { - node := cl.node(nodes, waiter.StateCriteria) - if node == nil { - // Put waiter back - nodelessWaiters = append(nodelessWaiters, waiter) - continue - } - - // We won't block here, read addUpdateWaiter function for more information - waiter.Ch <- node - // No need to close channel since we write only once and forget it so does the 'client' - } - - cl.waiters = nodelessWaiters -} diff --git a/cluster_opts.go b/cluster_opts.go index 706d75c..1b5641f 100644 --- a/cluster_opts.go +++ b/cluster_opts.go @@ -18,33 +18,33 @@ package hasql import "time" -// ClusterOption is a functional option type for Cluster constructor -type ClusterOption func(*Cluster) +// ClusterOpt is a functional option type for Cluster constructor +type ClusterOpt[T Querier] func(*Cluster[T]) -// WithUpdateInterval sets interval between cluster node updates -func WithUpdateInterval(d time.Duration) ClusterOption { - return func(cl *Cluster) { +// WithUpdateInterval sets interval between cluster state updates +func WithUpdateInterval[T Querier](d time.Duration) ClusterOpt[T] { + return func(cl *Cluster[T]) { cl.updateInterval = d } } -// WithUpdateTimeout sets ping timeout for update of each node in cluster -func WithUpdateTimeout(d time.Duration) ClusterOption { - return func(cl *Cluster) { +// WithUpdateTimeout sets timeout for update of each node in cluster +func WithUpdateTimeout[T Querier](d time.Duration) ClusterOpt[T] { + return func(cl *Cluster[T]) { cl.updateTimeout = d } } // WithNodePicker sets algorithm for node selection (e.g. random, round robin etc) -func WithNodePicker(picker NodePicker) ClusterOption { - return func(cl *Cluster) { +func WithNodePicker[T Querier](picker NodePicker[T]) ClusterOpt[T] { + return func(cl *Cluster[T]) { cl.picker = picker } } // WithTracer sets tracer for actions happening in the background -func WithTracer(tracer Tracer) ClusterOption { - return func(cl *Cluster) { +func WithTracer[T Querier](tracer Tracer[T]) ClusterOpt[T] { + return func(cl *Cluster[T]) { cl.tracer = tracer } } diff --git a/cluster_opts_test.go b/cluster_opts_test.go deleted file mode 100644 index 6559db4..0000000 --- a/cluster_opts_test.go +++ /dev/null @@ -1,86 +0,0 @@ -/* - Copyright 2020 YANDEX LLC - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ - -package hasql - -import ( - "sync/atomic" - "testing" - "time" - - "github.com/stretchr/testify/require" -) - -func TestClusterDefaults(t *testing.T) { - f := newFixture(t, 1) - c, err := NewCluster(f.ClusterNodes(), f.PrimaryChecker) - require.NoError(t, err) - defer func() { require.NoError(t, c.Close()) }() - - require.Equal(t, DefaultUpdateInterval, c.updateInterval) - require.Equal(t, DefaultUpdateTimeout, c.updateTimeout) -} - -func TestWithUpdateInterval(t *testing.T) { - f := newFixture(t, 1) - d := time.Hour - c, err := NewCluster(f.ClusterNodes(), f.PrimaryChecker, WithUpdateInterval(d)) - require.NoError(t, err) - defer func() { require.NoError(t, c.Close()) }() - - require.Equal(t, d, c.updateInterval) -} - -func TestWithUpdateTimeout(t *testing.T) { - f := newFixture(t, 1) - d := time.Hour - c, err := NewCluster(f.ClusterNodes(), f.PrimaryChecker, WithUpdateTimeout(d)) - require.NoError(t, err) - defer func() { require.NoError(t, c.Close()) }() - - require.Equal(t, d, c.updateTimeout) -} - -func TestWithNodePicker(t *testing.T) { - var called bool - picker := func([]Node) Node { - called = true - return nil - } - f := newFixture(t, 1) - c, err := NewCluster(f.ClusterNodes(), f.PrimaryChecker, WithNodePicker(picker)) - require.NoError(t, err) - defer func() { require.NoError(t, c.Close()) }() - - c.picker(nil) - require.True(t, called) -} - -func TestWithTracer(t *testing.T) { - var called int32 - tracer := Tracer{ - NotifiedWaiters: func() { - atomic.StoreInt32(&called, 1) - }, - } - f := newFixture(t, 1) - c, err := NewCluster(f.ClusterNodes(), f.PrimaryChecker, WithTracer(tracer)) - require.NoError(t, err) - defer func() { require.NoError(t, c.Close()) }() - - c.tracer.NotifiedWaiters() - require.Equal(t, int32(1), atomic.LoadInt32(&called)) -} diff --git a/cluster_test.go b/cluster_test.go index 923f166..da311cd 100644 --- a/cluster_test.go +++ b/cluster_test.go @@ -17,714 +17,266 @@ package hasql import ( - "context" "database/sql" - "errors" - "fmt" - "sync/atomic" + "io" "testing" - "time" "github.com/DATA-DOG/go-sqlmock" - "github.com/gofrs/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestNewCluster(t *testing.T) { - fakeDB, _, err := sqlmock.New(sqlmock.MonitorPingsOption(true)) - require.NoError(t, err) - require.NotNil(t, fakeDB) - - inputs := []struct { - Name string - Fixture *fixture - Err bool - }{ - { - Name: "no nodes", - Fixture: newFixture(t, 0), - Err: true, - }, - { - Name: "invalid node (no address)", - Fixture: &fixture{Nodes: []*mockedNode{{Node: NewNode("", fakeDB)}}}, - Err: true, - }, - { - Name: "invalid node (no db)", - Fixture: &fixture{Nodes: []*mockedNode{{Node: NewNode("fake.addr", nil)}}}, - Err: true, - }, - { - Name: "valid node", - Fixture: newFixture(t, 1), - }, - } - - for _, input := range inputs { - t.Run(input.Name, func(t *testing.T) { - defer input.Fixture.AssertExpectations(t) - - cl, err := NewCluster(input.Fixture.ClusterNodes(), input.Fixture.PrimaryChecker) - if input.Err { - require.Error(t, err) - require.Nil(t, cl) - return - } - - require.NoError(t, err) - require.NotNil(t, cl) - defer func() { require.NoError(t, cl.Close()) }() - - require.Len(t, cl.Nodes(), len(input.Fixture.Nodes)) - }) - } -} + t.Run("no_nodes", func(t *testing.T) { + cl, err := NewCluster[*sql.DB](nil, PostgreSQLChecker) + assert.Nil(t, cl) + assert.EqualError(t, err, "node discoverer required") + }) + + t.Run("success", func(t *testing.T) { + db, _, err := sqlmock.New() + require.NoError(t, err) -const ( - // How often nodes are updated in the background - updateInterval = time.Millisecond * 20 - // Timeout of one WaitFor* iteration - // Being half of updateInterval allows for hitting updates during wait and not hitting them at all - waitTimeout = updateInterval / 2 -) + node := NewNode("shimba", db) -func setupCluster(t *testing.T, f *fixture, tracer Tracer) *Cluster { - cl, err := NewCluster( - f.ClusterNodes(), - f.PrimaryChecker, - WithUpdateInterval(updateInterval), - WithTracer(tracer), - ) - require.NoError(t, err) - require.NotNil(t, cl) - require.Len(t, cl.Nodes(), len(f.Nodes)) - return cl + cl, err := NewCluster(NewStaticNodeDiscoverer(node), PostgreSQLChecker) + assert.NoError(t, err) + assert.NotNil(t, cl) + }) } -func waitForNode(t *testing.T, o *nodeUpdateObserver, wait func(ctx context.Context) (Node, error), expected Node) { - o.StartObservation() +func TestCluster_Close(t *testing.T) { + t.Run("no_errors", func(t *testing.T) { + db1, dbmock1, err := sqlmock.New() + require.NoError(t, err) - var node Node - var err error - for { - ctx, cancel := context.WithTimeout(context.Background(), waitTimeout) + db2, dbmock2, err := sqlmock.New() + require.NoError(t, err) - node, err = wait(ctx) - if o.ObservedUpdates() { - cancel() - break - } + // expect database client to be closed + dbmock1.ExpectClose() + dbmock2.ExpectClose() - cancel() - } + node1 := NewNode("shimba", db1) + node2 := NewNode("boomba", db2) - if expected != nil { + cl, err := NewCluster(NewStaticNodeDiscoverer(node1, node2), PostgreSQLChecker) require.NoError(t, err) - require.Equal(t, expected, node) - } else { - require.Error(t, err) - require.Nil(t, node) - } -} -func waitForOneOfNodes(t *testing.T, o *nodeUpdateObserver, wait func(ctx context.Context) (Node, error), expected []Node) { - o.StartObservation() - - var node Node - var err error - for { - ctx, cancel := context.WithTimeout(context.Background(), waitTimeout) + cl.checkedNodes.Store(CheckedNodes[*sql.DB]{ + discovered: []*Node[*sql.DB]{node1, node2}, + }) - node, err = wait(ctx) - if o.ObservedUpdates() { - cancel() - break - } + assert.NoError(t, cl.Close()) + }) - cancel() - } + t.Run("multiple_errors", func(t *testing.T) { + db1, dbmock1, err := sqlmock.New() + require.NoError(t, err) - require.NoError(t, err) - require.NotNil(t, node) + db2, dbmock2, err := sqlmock.New() + require.NoError(t, err) - for _, n := range expected { - if n == node { - return - } - } + // expect database client to be closed + dbmock1.ExpectClose().WillReturnError(io.EOF) + dbmock2.ExpectClose().WillReturnError(sql.ErrTxDone) - t.Fatalf("received node %+v but expected one of %+v", node, expected) -} + node1 := NewNode("shimba", db1) + node2 := NewNode("boomba", db2) -func TestCluster_WaitForAlive(t *testing.T) { - inputs := []struct { - Name string - Fixture *fixture - Test func(t *testing.T, f *fixture, o *nodeUpdateObserver, cl *Cluster, status nodeStatus) - }{ - { - Name: "Alive", - Fixture: newFixture(t, 1), - Test: func(t *testing.T, f *fixture, o *nodeUpdateObserver, cl *Cluster, status nodeStatus) { - f.Nodes[0].setStatus(status) - waitForNode(t, o, cl.WaitForAlive, f.Nodes[0].Node) - }, - }, - { - Name: "Dead", - Fixture: newFixture(t, 1), - Test: func(t *testing.T, f *fixture, o *nodeUpdateObserver, cl *Cluster, status nodeStatus) { - waitForNode(t, o, cl.WaitForAlive, nil) - }, - }, - { - Name: "AliveDeadAlive", - Fixture: newFixture(t, 1), - Test: func(t *testing.T, f *fixture, o *nodeUpdateObserver, cl *Cluster, status nodeStatus) { - node := f.Nodes[0] + cl, err := NewCluster(NewStaticNodeDiscoverer(node1, node2), PostgreSQLChecker) + require.NoError(t, err) - node.setStatus(status) - waitForNode(t, o, cl.WaitForAlive, node.Node) + cl.checkedNodes.Store(CheckedNodes[*sql.DB]{ + discovered: []*Node[*sql.DB]{node1, node2}, + }) - node.setStatus(nodeStatusUnknown) - waitForNode(t, o, cl.WaitForAlive, nil) + err = cl.Close() + assert.ErrorIs(t, err, io.EOF) + assert.ErrorIs(t, err, sql.ErrTxDone) + }) +} - node.setStatus(status) - waitForNode(t, o, cl.WaitForAlive, node.Node) +func TestCluster_Err(t *testing.T) { + t.Run("no_error", func(t *testing.T) { + cl := new(Cluster[*sql.DB]) + cl.checkedNodes.Store(CheckedNodes[*sql.DB]{}) + assert.NoError(t, cl.Err()) + }) + + t.Run("has_error", func(t *testing.T) { + checkedNodes := CheckedNodes[*sql.DB]{ + err: NodeCheckErrors[*sql.DB]{ + { + node: &Node[*sql.DB]{ + name: "shimba", + db: new(sql.DB), + }, + err: io.EOF, + }, }, - }, - { - Name: "DeadAliveDead", - Fixture: newFixture(t, 1), - Test: func(t *testing.T, f *fixture, o *nodeUpdateObserver, cl *Cluster, status nodeStatus) { - node := f.Nodes[0] - - waitForNode(t, o, cl.WaitForAlive, nil) + } - node.setStatus(status) - waitForNode(t, o, cl.WaitForAlive, node.Node) + cl := new(Cluster[*sql.DB]) + cl.checkedNodes.Store(checkedNodes) - node.setStatus(nodeStatusUnknown) - waitForNode(t, o, cl.WaitForAlive, nil) - }, - }, - { - Name: "AllAlive", - Fixture: newFixture(t, 3), - Test: func(t *testing.T, f *fixture, o *nodeUpdateObserver, cl *Cluster, status nodeStatus) { - f.Nodes[0].setStatus(status) - f.Nodes[1].setStatus(status) - f.Nodes[2].setStatus(status) - - waitForOneOfNodes(t, o, cl.WaitForAlive, []Node{f.Nodes[0].Node, f.Nodes[1].Node, f.Nodes[2].Node}) - }, - }, - { - Name: "AllDead", - Fixture: newFixture(t, 3), - Test: func(t *testing.T, f *fixture, o *nodeUpdateObserver, cl *Cluster, status nodeStatus) { - waitForNode(t, o, cl.WaitForAlive, nil) - }, - }, - } + assert.ErrorIs(t, cl.Err(), io.EOF) + }) +} - for _, status := range []nodeStatus{nodeStatusPrimary, nodeStatusStandby} { - for _, input := range inputs { - t.Run(fmt.Sprintf("%s status %d", input.Name, status), func(t *testing.T) { - defer input.Fixture.AssertExpectations(t) +func TestCluster_Node(t *testing.T) { + t.Run("no_nodes", func(t *testing.T) { + cl := new(Cluster[*sql.DB]) + cl.checkedNodes.Store(CheckedNodes[*sql.DB]{}) - var o nodeUpdateObserver - cl := setupCluster(t, input.Fixture, o.Tracer()) - defer func() { require.NoError(t, cl.Close()) }() + // all criteria must return nil node + for i := Alive; i < maxNodeCriterion; i++ { + node := cl.Node(i) + assert.Nil(t, node) + } + }) - input.Test(t, input.Fixture, &o, cl, status) - }) + t.Run("alive", func(t *testing.T) { + node := &Node[*sql.DB]{ + name: "shimba", + db: new(sql.DB), } - } -} -func TestCluster_WaitForPrimary(t *testing.T) { - inputs := []struct { - Name string - Fixture *fixture - Test func(t *testing.T, f *fixture, o *nodeUpdateObserver, cl *Cluster) - }{ - { - Name: "PrimaryAlive", - Fixture: newFixture(t, 1), - Test: func(t *testing.T, f *fixture, o *nodeUpdateObserver, cl *Cluster) { - f.Nodes[0].setStatus(nodeStatusPrimary) - waitForNode(t, o, cl.WaitForPrimary, f.Nodes[0].Node) - }, - }, - { - Name: "PrimaryDead", - Fixture: newFixture(t, 1), - Test: func(t *testing.T, f *fixture, o *nodeUpdateObserver, cl *Cluster) { - waitForNode(t, o, cl.WaitForPrimary, nil) + cl := new(Cluster[*sql.DB]) + cl.picker = new(RandomNodePicker[*sql.DB]) + cl.checkedNodes.Store(CheckedNodes[*sql.DB]{ + alive: []CheckedNode[*sql.DB]{ + { + Node: node, + Info: NodeInfo{ + ClusterRole: NodeRolePrimary, + }, + }, }, - }, - { - Name: "AllAlive", - Fixture: newFixture(t, 3), - Test: func(t *testing.T, f *fixture, o *nodeUpdateObserver, cl *Cluster) { - f.Nodes[0].setStatus(nodeStatusStandby) - f.Nodes[1].setStatus(nodeStatusPrimary) - f.Nodes[2].setStatus(nodeStatusStandby) - waitForNode(t, o, cl.WaitForPrimary, f.Nodes[1].Node) - }, - }, - { - Name: "AllDead", - Fixture: newFixture(t, 3), - Test: func(t *testing.T, f *fixture, o *nodeUpdateObserver, cl *Cluster) { - waitForNode(t, o, cl.WaitForPrimary, nil) - }, - }, - { - Name: "PrimaryAliveOtherDead", - Fixture: newFixture(t, 3), - Test: func(t *testing.T, f *fixture, o *nodeUpdateObserver, cl *Cluster) { - f.Nodes[1].setStatus(nodeStatusPrimary) - waitForNode(t, o, cl.WaitForPrimary, f.Nodes[1].Node) - }, - }, - { - Name: "PrimaryDeadOtherAlive", - Fixture: newFixture(t, 3), - Test: func(t *testing.T, f *fixture, o *nodeUpdateObserver, cl *Cluster) { - f.Nodes[0].setStatus(nodeStatusStandby) - f.Nodes[2].setStatus(nodeStatusStandby) - waitForNode(t, o, cl.WaitForPrimary, nil) - }, - }, - } - - for _, input := range inputs { - t.Run(input.Name, func(t *testing.T) { - defer input.Fixture.AssertExpectations(t) - - var o nodeUpdateObserver - cl := setupCluster(t, input.Fixture, o.Tracer()) - defer func() { require.NoError(t, cl.Close()) }() - - input.Test(t, input.Fixture, &o, cl) }) - } -} -func TestCluster_WaitForStandby(t *testing.T) { - inputs := []struct { - Name string - Fixture *fixture - Test func(t *testing.T, f *fixture, o *nodeUpdateObserver, cl *Cluster) - }{ - { - Name: "StandbyAlive", - Fixture: newFixture(t, 1), - Test: func(t *testing.T, f *fixture, o *nodeUpdateObserver, cl *Cluster) { - f.Nodes[0].setStatus(nodeStatusStandby) - waitForNode(t, o, cl.WaitForStandby, f.Nodes[0].Node) - }, - }, - { - Name: "StandbyDead", - Fixture: newFixture(t, 1), - Test: func(t *testing.T, f *fixture, o *nodeUpdateObserver, cl *Cluster) { - waitForNode(t, o, cl.WaitForStandby, nil) - }, - }, - { - Name: "AllAlive", - Fixture: newFixture(t, 3), - Test: func(t *testing.T, f *fixture, o *nodeUpdateObserver, cl *Cluster) { - f.Nodes[0].setStatus(nodeStatusStandby) - f.Nodes[1].setStatus(nodeStatusPrimary) - f.Nodes[2].setStatus(nodeStatusStandby) - waitForOneOfNodes(t, o, cl.WaitForStandby, []Node{f.Nodes[0].Node, f.Nodes[2].Node}) - }, - }, - { - Name: "AllDead", - Fixture: newFixture(t, 3), - Test: func(t *testing.T, f *fixture, o *nodeUpdateObserver, cl *Cluster) { - waitForNode(t, o, cl.WaitForStandby, nil) - }, - }, - { - Name: "StandbyAliveOtherDead", - Fixture: newFixture(t, 3), - Test: func(t *testing.T, f *fixture, o *nodeUpdateObserver, cl *Cluster) { - f.Nodes[1].setStatus(nodeStatusStandby) - waitForNode(t, o, cl.WaitForStandby, f.Nodes[1].Node) - }, - }, - { - Name: "StandbysAliveOtherDead", - Fixture: newFixture(t, 3), - Test: func(t *testing.T, f *fixture, o *nodeUpdateObserver, cl *Cluster) { - f.Nodes[1].setStatus(nodeStatusStandby) - f.Nodes[2].setStatus(nodeStatusStandby) - waitForOneOfNodes(t, o, cl.WaitForStandby, []Node{f.Nodes[1].Node, f.Nodes[2].Node}) - }, - }, - { - Name: "StandbyDeadOtherAlive", - Fixture: newFixture(t, 3), - Test: func(t *testing.T, f *fixture, o *nodeUpdateObserver, cl *Cluster) { - f.Nodes[0].setStatus(nodeStatusPrimary) - f.Nodes[2].setStatus(nodeStatusPrimary) - waitForNode(t, o, cl.WaitForStandby, nil) - }, - }, - } + assert.Equal(t, node, cl.Node(Alive)) + }) - for _, input := range inputs { - t.Run(input.Name, func(t *testing.T) { - defer input.Fixture.AssertExpectations(t) - - var o nodeUpdateObserver - cl := setupCluster(t, input.Fixture, o.Tracer()) - defer func() { require.NoError(t, cl.Close()) }() - - input.Test(t, input.Fixture, &o, cl) - }) - } -} + t.Run("primary", func(t *testing.T) { + node := &Node[*sql.DB]{ + name: "shimba", + db: new(sql.DB), + } -func TestCluster_WaitForPrimaryPreferred(t *testing.T) { - inputs := []struct { - Name string - Fixture *fixture - Test func(t *testing.T, f *fixture, o *nodeUpdateObserver, cl *Cluster) - }{ - { - Name: "PrimaryAlive", - Fixture: newFixture(t, 1), - Test: func(t *testing.T, f *fixture, o *nodeUpdateObserver, cl *Cluster) { - f.Nodes[0].setStatus(nodeStatusPrimary) - waitForNode(t, o, cl.WaitForPrimaryPreferred, f.Nodes[0].Node) - }, - }, - { - Name: "PrimaryDead", - Fixture: newFixture(t, 1), - Test: func(t *testing.T, f *fixture, o *nodeUpdateObserver, cl *Cluster) { - waitForNode(t, o, cl.WaitForPrimaryPreferred, nil) - }, - }, - { - Name: "AllAlive", - Fixture: newFixture(t, 3), - Test: func(t *testing.T, f *fixture, o *nodeUpdateObserver, cl *Cluster) { - f.Nodes[0].setStatus(nodeStatusStandby) - f.Nodes[1].setStatus(nodeStatusPrimary) - f.Nodes[2].setStatus(nodeStatusStandby) - waitForNode(t, o, cl.WaitForPrimaryPreferred, f.Nodes[1].Node) + cl := new(Cluster[*sql.DB]) + cl.picker = new(RandomNodePicker[*sql.DB]) + cl.checkedNodes.Store(CheckedNodes[*sql.DB]{ + primaries: []CheckedNode[*sql.DB]{ + { + Node: node, + Info: NodeInfo{ + ClusterRole: NodeRolePrimary, + }, + }, }, - }, - { - Name: "AllDead", - Fixture: newFixture(t, 3), - Test: func(t *testing.T, f *fixture, o *nodeUpdateObserver, cl *Cluster) { - waitForNode(t, o, cl.WaitForPrimary, nil) - }, - }, - { - Name: "PrimaryAliveOtherDead", - Fixture: newFixture(t, 3), - Test: func(t *testing.T, f *fixture, o *nodeUpdateObserver, cl *Cluster) { - f.Nodes[1].setStatus(nodeStatusPrimary) - waitForNode(t, o, cl.WaitForPrimaryPreferred, f.Nodes[1].Node) - }, - }, - { - Name: "PrimaryDeadOtherAlive", - Fixture: newFixture(t, 3), - Test: func(t *testing.T, f *fixture, o *nodeUpdateObserver, cl *Cluster) { - f.Nodes[0].setStatus(nodeStatusStandby) - f.Nodes[2].setStatus(nodeStatusStandby) - waitForOneOfNodes(t, o, cl.WaitForPrimaryPreferred, []Node{f.Nodes[0].Node, f.Nodes[2].Node}) - }, - }, - } - - for _, input := range inputs { - t.Run(input.Name, func(t *testing.T) { - defer input.Fixture.AssertExpectations(t) - - var o nodeUpdateObserver - cl := setupCluster(t, input.Fixture, o.Tracer()) - defer func() { require.NoError(t, cl.Close()) }() - - input.Test(t, input.Fixture, &o, cl) }) - } -} -func TestCluster_WaitForStandbyPreferred(t *testing.T) { - inputs := []struct { - Name string - Fixture *fixture - Test func(t *testing.T, f *fixture, o *nodeUpdateObserver, cl *Cluster) - }{ - { - Name: "StandbyAlive", - Fixture: newFixture(t, 1), - Test: func(t *testing.T, f *fixture, o *nodeUpdateObserver, cl *Cluster) { - f.Nodes[0].setStatus(nodeStatusStandby) - waitForNode(t, o, cl.WaitForStandbyPreferred, f.Nodes[0].Node) - }, - }, - { - Name: "StandbyDead", - Fixture: newFixture(t, 1), - Test: func(t *testing.T, f *fixture, o *nodeUpdateObserver, cl *Cluster) { - waitForNode(t, o, cl.WaitForStandbyPreferred, nil) - }, - }, - { - Name: "AllAlive", - Fixture: newFixture(t, 3), - Test: func(t *testing.T, f *fixture, o *nodeUpdateObserver, cl *Cluster) { - f.Nodes[0].setStatus(nodeStatusStandby) - f.Nodes[1].setStatus(nodeStatusPrimary) - f.Nodes[2].setStatus(nodeStatusStandby) - waitForOneOfNodes(t, o, cl.WaitForStandbyPreferred, []Node{f.Nodes[0].Node, f.Nodes[2].Node}) - }, - }, - { - Name: "AllDead", - Fixture: newFixture(t, 3), - Test: func(t *testing.T, f *fixture, o *nodeUpdateObserver, cl *Cluster) { - waitForNode(t, o, cl.WaitForStandbyPreferred, nil) - }, - }, - { - Name: "StandbyAliveOtherDead", - Fixture: newFixture(t, 3), - Test: func(t *testing.T, f *fixture, o *nodeUpdateObserver, cl *Cluster) { - f.Nodes[1].setStatus(nodeStatusStandby) - waitForNode(t, o, cl.WaitForStandbyPreferred, f.Nodes[1].Node) - }, - }, - { - Name: "StandbysAliveOtherDead", - Fixture: newFixture(t, 3), - Test: func(t *testing.T, f *fixture, o *nodeUpdateObserver, cl *Cluster) { - f.Nodes[1].setStatus(nodeStatusStandby) - f.Nodes[2].setStatus(nodeStatusStandby) - waitForOneOfNodes(t, o, cl.WaitForStandbyPreferred, []Node{f.Nodes[1].Node, f.Nodes[2].Node}) - }, - }, - { - Name: "StandbyDeadOtherAlive", - Fixture: newFixture(t, 3), - Test: func(t *testing.T, f *fixture, o *nodeUpdateObserver, cl *Cluster) { - f.Nodes[0].setStatus(nodeStatusPrimary) - f.Nodes[2].setStatus(nodeStatusPrimary) - waitForOneOfNodes(t, o, cl.WaitForStandbyPreferred, []Node{f.Nodes[0].Node, f.Nodes[2].Node}) - }, - }, - } - - for _, input := range inputs { - t.Run(input.Name, func(t *testing.T) { - defer input.Fixture.AssertExpectations(t) + assert.Equal(t, node, cl.Node(Primary)) + // we will return node on Prefer* creterias also + assert.Equal(t, node, cl.Node(PreferPrimary)) + assert.Equal(t, node, cl.Node(PreferStandby)) + }) - var o nodeUpdateObserver - cl := setupCluster(t, input.Fixture, o.Tracer()) - defer func() { require.NoError(t, cl.Close()) }() - - input.Test(t, input.Fixture, &o, cl) - }) - } -} + t.Run("standby", func(t *testing.T) { + node := &Node[*sql.DB]{ + name: "shimba", + db: new(sql.DB), + } -func TestCluster_Err(t *testing.T) { - inputs := []struct { - Name string - Fixture *fixture - Test func(t *testing.T, f *fixture, o *nodeUpdateObserver, cl *Cluster) - }{ - { - Name: "AllAlive", - Fixture: newFixture(t, 2), - Test: func(t *testing.T, f *fixture, o *nodeUpdateObserver, cl *Cluster) { - f.Nodes[0].setStatus(nodeStatusStandby) - f.Nodes[1].setStatus(nodeStatusPrimary) - waitForNode(t, o, cl.WaitForPrimary, f.Nodes[1].Node) - - require.NoError(t, cl.Err()) - }, - }, - { - Name: "AllDead", - Fixture: newFixture(t, 2), - Test: func(t *testing.T, f *fixture, o *nodeUpdateObserver, cl *Cluster) { - waitForNode(t, o, cl.WaitForPrimary, nil) - - err := cl.Err() - require.Error(t, err) - assert.ErrorContains(t, err, fmt.Sprintf("%q node error occurred at", f.Nodes[0].Node.Addr())) - assert.ErrorContains(t, err, fmt.Sprintf("%q node error occurred at", f.Nodes[1].Node.Addr())) - }, - }, - { - Name: "PrimaryAliveOtherDead", - Fixture: newFixture(t, 2), - Test: func(t *testing.T, f *fixture, o *nodeUpdateObserver, cl *Cluster) { - f.Nodes[1].setStatus(nodeStatusPrimary) - waitForNode(t, o, cl.WaitForPrimary, f.Nodes[1].Node) - - err := cl.Err() - require.Error(t, err) - assert.ErrorContains(t, err, fmt.Sprintf("%q node error occurred at", f.Nodes[0].Node.Addr())) - assert.NotContains(t, err.Error(), fmt.Sprintf("%q node error occurred at", f.Nodes[1].Node.Addr())) + cl := new(Cluster[*sql.DB]) + cl.picker = new(RandomNodePicker[*sql.DB]) + cl.checkedNodes.Store(CheckedNodes[*sql.DB]{ + standbys: []CheckedNode[*sql.DB]{ + { + Node: node, + Info: NodeInfo{ + ClusterRole: NodeRoleStandby, + }, + }, }, - }, - { - Name: "PrimaryDeadOtherAlive", - Fixture: newFixture(t, 2), - Test: func(t *testing.T, f *fixture, o *nodeUpdateObserver, cl *Cluster) { - f.Nodes[0].setStatus(nodeStatusStandby) - waitForNode(t, o, cl.WaitForPrimary, nil) - - err := cl.Err() - require.Error(t, err) - assert.NotContains(t, err.Error(), fmt.Sprintf("%q node error occurred at", f.Nodes[0].Node.Addr())) - assert.ErrorContains(t, err, fmt.Sprintf("%q node error occurred at", f.Nodes[1].Node.Addr())) - }, - }, - } - - for _, input := range inputs { - t.Run(input.Name, func(t *testing.T) { - defer input.Fixture.AssertExpectations(t) - - var o nodeUpdateObserver - cl := setupCluster(t, input.Fixture, o.Tracer()) - defer func() { require.NoError(t, cl.Close()) }() - - input.Test(t, input.Fixture, &o, cl) }) - } -} - -type nodeStatus int64 -const ( - nodeStatusUnknown nodeStatus = iota - nodeStatusPrimary - nodeStatusStandby -) + assert.Equal(t, node, cl.Node(Standby)) + // we will return node on Prefer* creterias also + assert.Equal(t, node, cl.Node(PreferPrimary)) + assert.Equal(t, node, cl.Node(PreferStandby)) + }) -type mockedNode struct { - Node Node - Mock sqlmock.Sqlmock - st int64 -} - -func (n *mockedNode) setStatus(s nodeStatus) { - atomic.StoreInt64(&n.st, int64(s)) -} - -func (n *mockedNode) status() nodeStatus { - return nodeStatus(atomic.LoadInt64(&n.st)) -} - -type nodeUpdateObserver struct { - updatedNodes int64 - - updatedNodesAtStart int64 - updatesObserved int64 -} - -func (o *nodeUpdateObserver) StartObservation() { - o.updatedNodesAtStart = atomic.LoadInt64(&o.updatedNodes) - o.updatesObserved = 0 -} - -func (o *nodeUpdateObserver) ObservedUpdates() bool { - updatedNodes := atomic.LoadInt64(&o.updatedNodes) - - // When we wait for a node, we are guaranteed to observe at least one update - // only when two actually happened - // TODO: its a mess, implement checker - if updatedNodes-o.updatedNodesAtStart >= 2*(o.updatesObserved+1) { - o.updatesObserved++ - } - - return o.updatesObserved >= 2 -} - -func (o *nodeUpdateObserver) Tracer() Tracer { - return Tracer{ - UpdatedNodes: func(_ AliveNodes) { atomic.AddInt64(&o.updatedNodes, 1) }, - } -} - -type fixture struct { - TraceCounter *nodeUpdateObserver - Nodes []*mockedNode -} - -func newFixture(t *testing.T, count int) *fixture { - var f fixture - for i := count; i > 0; i-- { - db, mock, err := sqlmock.New(sqlmock.MonitorPingsOption(true)) - require.NoError(t, err) - require.NotNil(t, db) - - mock.ExpectClose() - - node := &mockedNode{ - Node: NewNode(uuid.Must(uuid.NewV4()).String(), db), - Mock: mock, - st: int64(nodeStatusUnknown), + t.Run("prefer_primary", func(t *testing.T) { + node := &Node[*sql.DB]{ + name: "shimba", + db: new(sql.DB), } - require.NotNil(t, node.Node) - f.Nodes = append(f.Nodes, node) - } - - require.Len(t, f.Nodes, count) + cl := new(Cluster[*sql.DB]) + cl.picker = new(RandomNodePicker[*sql.DB]) - return &f -} - -func (f *fixture) ClusterNodes() []Node { - var nodes []Node - for _, node := range f.Nodes { - nodes = append(nodes, node.Node) - } - - return nodes -} + // must pick from primaries + cl.checkedNodes.Store(CheckedNodes[*sql.DB]{ + primaries: []CheckedNode[*sql.DB]{ + { + Node: node, + Info: NodeInfo{ + ClusterRole: NodeRolePrimary, + }, + }, + }, + }) + assert.Equal(t, node, cl.Node(PreferPrimary)) + + // must pick from standbys + cl.checkedNodes.Store(CheckedNodes[*sql.DB]{ + standbys: []CheckedNode[*sql.DB]{ + { + Node: node, + Info: NodeInfo{ + ClusterRole: NodeRoleStandby, + }, + }, + }, + }) + assert.Equal(t, node, cl.Node(PreferPrimary)) + }) -func (f *fixture) PrimaryChecker(_ context.Context, db *sql.DB) (bool, error) { - for _, node := range f.Nodes { - if node.Node.DB() == db { - switch node.status() { - case nodeStatusPrimary: - return true, nil - case nodeStatusStandby: - return false, nil - default: - return false, errors.New("node is dead") - } + t.Run("prefer_standby", func(t *testing.T) { + node := &Node[*sql.DB]{ + name: "shimba", + db: new(sql.DB), } - } - return false, errors.New("node not found in fixture") -} + cl := new(Cluster[*sql.DB]) + cl.picker = new(RandomNodePicker[*sql.DB]) -func (f *fixture) AssertExpectations(t *testing.T) { - for _, node := range f.Nodes { - if node.Mock != nil { // We can use 'incomplete' fixture to test invalid cases - assert.NoError(t, node.Mock.ExpectationsWereMet()) - } - } + // must pick from standbys + cl.checkedNodes.Store(CheckedNodes[*sql.DB]{ + standbys: []CheckedNode[*sql.DB]{ + { + Node: node, + Info: NodeInfo{ + ClusterRole: NodeRoleStandby, + }, + }, + }, + }) + assert.Equal(t, node, cl.Node(PreferStandby)) + + // must pick from primaries + cl.checkedNodes.Store(CheckedNodes[*sql.DB]{ + primaries: []CheckedNode[*sql.DB]{ + { + Node: node, + Info: NodeInfo{ + ClusterRole: NodeRolePrimary, + }, + }, + }, + }) + assert.Equal(t, node, cl.Node(PreferStandby)) + }) } diff --git a/e2e_test.go b/e2e_test.go new file mode 100644 index 0000000..bf1fa8d --- /dev/null +++ b/e2e_test.go @@ -0,0 +1,443 @@ +/* + Copyright 2020 YANDEX LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package hasql_test + +import ( + "context" + "database/sql" + "errors" + "io" + "slices" + "sync/atomic" + "testing" + "time" + + "github.com/DATA-DOG/go-sqlmock" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "golang.yandex/hasql/v2" +) + +// TestEnd2End_AliveCluster setups 3 node cluster, waits for at least one +// alive node, then picks primary and secondary node. Nodes are always alive. +func TestEnd2End_AliveCluster(t *testing.T) { + // create three database pools + db1, mock1, err := sqlmock.New() + require.NoError(t, err) + db2, mock2, err := sqlmock.New() + require.NoError(t, err) + db3, mock3, err := sqlmock.New() + require.NoError(t, err) + + // set db1 to be primary node + mock1.ExpectQuery(`SELECT.*pg_is_in_recovery`). + WillReturnRows(sqlmock. + NewRows([]string{"role", "lag"}). + AddRow(hasql.NodeRolePrimary, 0), + ) + + // set db2 and db3 to be standby nodes + mock2.ExpectQuery(`SELECT.*pg_is_in_recovery`). + WillReturnRows(sqlmock. + NewRows([]string{"role", "lag"}). + AddRow(hasql.NodeRoleStandby, 0), + ) + mock3.ExpectQuery(`SELECT.*pg_is_in_recovery`). + WillReturnRows(sqlmock. + NewRows([]string{"role", "lag"}). + AddRow(hasql.NodeRoleStandby, 10), + ) + + // all pools must be closed in the end + mock1.ExpectClose() + mock2.ExpectClose() + mock3.ExpectClose() + + // register pools as nodes + node1 := hasql.NewNode("ololo", db1) + node2 := hasql.NewNode("trololo", db2) + node3 := hasql.NewNode("shimba", db3) + discoverer := hasql.NewStaticNodeDiscoverer(node1, node2, node3) + + // create test cluster + cl, err := hasql.NewCluster(discoverer, hasql.PostgreSQLChecker, + hasql.WithUpdateInterval[*sql.DB](10*time.Millisecond), + ) + require.NoError(t, err) + + // close cluster and all underlying pools in the end + defer func() { + assert.NoError(t, cl.Close()) + }() + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + // wait for any alive node + waitNode, err := cl.WaitForNode(ctx, hasql.Alive) + assert.NoError(t, err) + assert.Contains(t, []*hasql.Node[*sql.DB]{node1, node2, node3}, waitNode) + + // pick primary node + primary := cl.Node(hasql.Primary) + assert.Same(t, node1, primary) + + // pick standby node + standby := cl.Node(hasql.Standby) + assert.Contains(t, []*hasql.Node[*sql.DB]{node2, node3}, standby) +} + +// TestEnd2End_SingleDeadNodeCluster setups 3 node cluster, waits for at least one +// alive node, then picks primary and secondary node. One node is always dead. +func TestEnd2End_SingleDeadNodeCluster(t *testing.T) { + // create three database pools + db1, mock1, err := sqlmock.New() + require.NoError(t, err) + db2, mock2, err := sqlmock.New() + require.NoError(t, err) + db3, mock3, err := sqlmock.New() + require.NoError(t, err) + + // set db1 to be primary node + mock1.ExpectQuery(`SELECT.*pg_is_in_recovery`). + WillReturnRows(sqlmock. + NewRows([]string{"role", "lag"}). + AddRow(hasql.NodeRolePrimary, 0), + ) + // set db2 to be standby node + mock2.ExpectQuery(`SELECT.*pg_is_in_recovery`). + WillReturnRows(sqlmock. + NewRows([]string{"role", "lag"}). + AddRow(hasql.NodeRoleStandby, 0), + ) + // db3 will be always dead + mock3.ExpectQuery(`SELECT.*pg_is_in_recovery`). + WillDelayFor(time.Second). + WillReturnError(io.EOF) + + // all pools must be closed in the end + mock1.ExpectClose() + mock2.ExpectClose() + mock3.ExpectClose() + + // register pools as nodes + node1 := hasql.NewNode("ololo", db1) + node2 := hasql.NewNode("trololo", db2) + node3 := hasql.NewNode("shimba", db3) + discoverer := hasql.NewStaticNodeDiscoverer(node1, node2, node3) + + // create test cluster. + cl, err := hasql.NewCluster(discoverer, hasql.PostgreSQLChecker, + hasql.WithUpdateInterval[*sql.DB](10*time.Millisecond), + hasql.WithUpdateTimeout[*sql.DB](50*time.Millisecond), + // set node picker to round robin to guarantee iteration across all nodes + hasql.WithNodePicker(new(hasql.RoundRobinNodePicker[*sql.DB])), + ) + require.NoError(t, err) + + // close cluster and all underlying pools in the end + defer func() { + assert.NoError(t, cl.Close()) + }() + + // Set context timeout to be greater than cluster update interval and timeout. + // If we set update timeout to be greater than wait context timeout + // we will always receive context.DeadlineExceeded error as current cycle of update + // will try to gather info about dead node (and thus update whole cluster state) + // longer that we are waiting for node + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + // wait for any alive node + waitNode, err := cl.WaitForNode(ctx, hasql.Alive) + assert.NoError(t, err) + assert.Contains(t, []*hasql.Node[*sql.DB]{node1, node2}, waitNode) + + // pick primary node + primary := cl.Node(hasql.Primary) + assert.Same(t, node1, primary) + + // pick standby node multiple times to ensure + // we always get alive standby node + for range 100 { + standby := cl.Node(hasql.Standby) + assert.Same(t, node2, standby) + } +} + +// TestEnd2End_NoPrimaryCluster setups 3 node cluster, waits for at least one +// alive node, then picks primary and secondary node. No alive primary nodes present. +func TestEnd2End_NoPrimaryCluster(t *testing.T) { + // create three database pools + db1, mock1, err := sqlmock.New() + require.NoError(t, err) + db2, mock2, err := sqlmock.New() + require.NoError(t, err) + db3, mock3, err := sqlmock.New() + require.NoError(t, err) + + // db1 is always dead + mock1.ExpectQuery(`SELECT.*pg_is_in_recovery`). + WillReturnError(io.EOF) + // set db2 to be standby node + mock2.ExpectQuery(`SELECT.*pg_is_in_recovery`). + WillReturnRows(sqlmock. + NewRows([]string{"role", "lag"}). + AddRow(hasql.NodeRoleStandby, 10), + ) + // set db3 to be standby node + mock3.ExpectQuery(`SELECT.*pg_is_in_recovery`). + WillReturnRows(sqlmock. + NewRows([]string{"role", "lag"}). + AddRow(hasql.NodeRoleStandby, 0), + ) + + // all pools must be closed in the end + mock1.ExpectClose() + mock2.ExpectClose() + mock3.ExpectClose() + + // register pools as nodes + node1 := hasql.NewNode("ololo", db1) + node2 := hasql.NewNode("trololo", db2) + node3 := hasql.NewNode("shimba", db3) + discoverer := hasql.NewStaticNodeDiscoverer(node1, node2, node3) + + // create test cluster. + cl, err := hasql.NewCluster(discoverer, hasql.PostgreSQLChecker, + hasql.WithUpdateInterval[*sql.DB](10*time.Millisecond), + ) + require.NoError(t, err) + + // close cluster and all underlying pools in the end + defer func() { + assert.NoError(t, cl.Close()) + }() + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + // wait for any alive node + waitNode, err := cl.WaitForNode(ctx, hasql.Alive) + assert.NoError(t, err) + assert.Contains(t, []*hasql.Node[*sql.DB]{node2, node3}, waitNode) + + // pick primary node + primary := cl.Node(hasql.Primary) + assert.Nil(t, primary) + + // pick standby node + standby := cl.Node(hasql.Standby) + assert.Contains(t, []*hasql.Node[*sql.DB]{node2, node3}, standby) + + // cluster must keep last check error + assert.ErrorIs(t, cl.Err(), io.EOF) +} + +// TestEnd2End_DeadCluster setups 3 node cluster. None node is alive. +func TestEnd2End_DeadCluster(t *testing.T) { + // create three database pools + db1, mock1, err := sqlmock.New() + require.NoError(t, err) + db2, mock2, err := sqlmock.New() + require.NoError(t, err) + db3, mock3, err := sqlmock.New() + require.NoError(t, err) + + // db1 is always dead + mock1.ExpectQuery(`SELECT.*pg_is_in_recovery`). + WillReturnError(io.EOF) + // set db2 to be standby node + mock2.ExpectQuery(`SELECT.*pg_is_in_recovery`). + WillReturnError(io.EOF) + // set db3 to be standby node + mock3.ExpectQuery(`SELECT.*pg_is_in_recovery`). + WillReturnError(io.EOF) + + // all pools must be closed in the end + mock1.ExpectClose() + mock2.ExpectClose() + mock3.ExpectClose() + + // register pools as nodes + node1 := hasql.NewNode("ololo", db1) + node2 := hasql.NewNode("trololo", db2) + node3 := hasql.NewNode("shimba", db3) + discoverer := hasql.NewStaticNodeDiscoverer(node1, node2, node3) + + // create test cluster. + cl, err := hasql.NewCluster(discoverer, hasql.PostgreSQLChecker, + hasql.WithUpdateInterval[*sql.DB](10*time.Millisecond), + hasql.WithUpdateTimeout[*sql.DB](50*time.Millisecond), + ) + require.NoError(t, err) + + // close cluster and all underlying pools in the end + defer func() { + assert.NoError(t, cl.Close()) + }() + + // set context expiration to be greater than cluster refresh interval and timeout + // to guarantee at least one cycle of state refresh + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + // wait for any alive node + waitNode, err := cl.WaitForNode(ctx, hasql.Alive) + assert.ErrorIs(t, err, context.DeadlineExceeded) + assert.Nil(t, waitNode) + + // pick primary node + primary := cl.Node(hasql.Primary) + assert.Nil(t, primary) + + // pick standby node + standby := cl.Node(hasql.Standby) + assert.Nil(t, standby) +} + +// TestEnd2End_FlakyCluster setups 3 node cluster, waits for at least one +// alive node, then picks primary and secondary node. +// One node fails to report it's state between refresh intervals. +func TestEnd2End_FlakyCluster(t *testing.T) { + errIsPrimary := errors.New("primary node") + errIsStandby := errors.New("standby node") + + sentinelErrChecker := func(ctx context.Context, q hasql.Querier) (hasql.NodeInfoProvider, error) { + _, err := q.QueryContext(ctx, "report node pls") + if errors.Is(err, errIsPrimary) { + return hasql.NodeInfo{ClusterRole: hasql.NodeRolePrimary}, nil + } + if errors.Is(err, errIsStandby) { + return hasql.NodeInfo{ClusterRole: hasql.NodeRoleStandby}, nil + } + return nil, err + } + + // set db1 to be primary node + // it will fail with error on every second attempt to query state + var attempts uint32 + db1 := &mockQuerier{ + queryFn: func(_ context.Context, _ string, _ ...any) (*sql.Rows, error) { + call := atomic.AddUint32(&attempts, 1) + if call%2 == 0 { + return nil, io.EOF + } + return nil, errIsPrimary + }, + } + + // set db2 and db3 to be standbys + db2 := &mockQuerier{ + queryFn: func(_ context.Context, _ string, _ ...any) (*sql.Rows, error) { + return nil, errIsStandby + }, + } + db3 := &mockQuerier{ + queryFn: func(_ context.Context, _ string, _ ...any) (*sql.Rows, error) { + return nil, errIsStandby + }, + } + + // register pools as nodes + node1 := hasql.NewNode("ololo", db1) + node2 := hasql.NewNode("trololo", db2) + node3 := hasql.NewNode("shimba", db3) + discoverer := hasql.NewStaticNodeDiscoverer(node1, node2, node3) + + // create test cluster + cl, err := hasql.NewCluster(discoverer, sentinelErrChecker, + hasql.WithUpdateInterval[*mockQuerier](50*time.Millisecond), + ) + require.NoError(t, err) + + // close cluster and all underlying pools in the end + defer func() { + assert.NoError(t, cl.Close()) + }() + + // wait for a long time + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + + // wait for any alive node + waitNode, err := cl.WaitForNode(ctx, hasql.Alive) + assert.NoError(t, err) + assert.Contains(t, []*hasql.Node[*mockQuerier]{node1, node2, node3}, waitNode) + + // fetch nodes often + ticker := time.NewTicker(10 * time.Millisecond) + defer ticker.Stop() + + var primaryStates []string + for { + select { + case <-ctx.Done(): + // check that primary node has changed its state at least once + expected := []string{"ololo", "", "ololo", "", "ololo", "", "ololo", "", "ololo", ""} + assert.Equal(t, expected, slices.Compact(primaryStates)) + // end test + return + case <-ticker.C: + // pick primary node + primary := cl.Node(hasql.Primary) + // store current state for further checks + var name string + if primary != nil { + name = primary.String() + } + primaryStates = append(primaryStates, name) + + // pick and check standby node + standby := cl.Node(hasql.Standby) + assert.Contains(t, []*hasql.Node[*mockQuerier]{node2, node3}, standby) + } + } + +} + +var _ hasql.Querier = (*mockQuerier)(nil) +var _ io.Closer = (*mockQuerier)(nil) + +// mockQuerier returns fake SQL results to tests +type mockQuerier struct { + queryFn func(ctx context.Context, query string, args ...any) (*sql.Rows, error) + queryRowFn func(ctx context.Context, query string, args ...any) *sql.Row + closeFn func() error +} + +func (m *mockQuerier) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) { + if m.queryFn != nil { + return m.queryFn(ctx, query, args...) + } + return nil, nil +} + +func (m *mockQuerier) QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row { + if m.queryRowFn != nil { + return m.queryRowFn(ctx, query, args...) + } + return nil +} + +func (m *mockQuerier) Close() error { + if m.closeFn != nil { + return m.closeFn() + } + return nil +} diff --git a/error.go b/error.go new file mode 100644 index 0000000..1a42ac4 --- /dev/null +++ b/error.go @@ -0,0 +1,69 @@ +/* + Copyright 2020 YANDEX LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package hasql + +import ( + "strings" +) + +// NodeCheckErrors is a set of checked nodes errors. +// This type can be used in errors.As/Is as it implements errors.Unwrap method +type NodeCheckErrors[T Querier] []NodeCheckError[T] + +func (n NodeCheckErrors[T]) Error() string { + var b strings.Builder + for i, err := range n { + if i > 0 { + b.WriteByte('\n') + } + b.WriteString(err.Error()) + } + return b.String() +} + +// Unwrap is a helper for errors.Is/errors.As functions +func (n NodeCheckErrors[T]) Unwrap() []error { + errs := make([]error, len(n)) + for i, err := range n { + errs[i] = err + } + return errs +} + +// NodeCheckError implements `error` and contains information about unsuccessful node check +type NodeCheckError[T Querier] struct { + node *Node[T] + err error +} + +// Node returns dead node instance +func (n NodeCheckError[T]) Node() *Node[T] { + return n.node +} + +// Error implements `error` interface +func (n NodeCheckError[T]) Error() string { + if n.err == nil { + return "" + } + return n.err.Error() +} + +// Unwrap returns underlying error +func (n NodeCheckError[T]) Unwrap() error { + return n.err +} diff --git a/errors_collector.go b/errors_collector.go deleted file mode 100644 index def511e..0000000 --- a/errors_collector.go +++ /dev/null @@ -1,105 +0,0 @@ -/* - Copyright 2020 YANDEX LLC - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ - -package hasql - -import ( - "fmt" - "sort" - "strings" - "sync" - "time" -) - -// CollectedErrors are errors collected when checking node statuses -type CollectedErrors struct { - Errors []NodeError -} - -func (e *CollectedErrors) Error() string { - if len(e.Errors) == 1 { - return e.Errors[0].Error() - } - - errs := make([]string, len(e.Errors)) - for i, ne := range e.Errors { - errs[i] = ne.Error() - } - /* - I don't believe there exist 'best join separator' that fit all cases (cli output, JSON, .. etc), - so we use newline as error.Join did it. - In difficult cases (as suggested in https://github.com/yandex/go-hasql/pull/14), - the user should be able to receive "raw" errors and format them as it suits him. - */ - return strings.Join(errs, "\n") -} - -// NodeError is error that background goroutine got while check given node -type NodeError struct { - Addr string - Err error - OccurredAt time.Time -} - -func (e *NodeError) Error() string { - // 'foo.db' node error occurred at '2009-11-10..': FATAL: terminating connection due to ... - return fmt.Sprintf("%q node error occurred at %q: %s", e.Addr, e.OccurredAt, e.Err) -} - -type errorsCollector struct { - store map[string]NodeError - mu sync.Mutex -} - -func newErrorsCollector() errorsCollector { - return errorsCollector{store: make(map[string]NodeError)} -} - -func (e *errorsCollector) Add(addr string, err error, occurredAt time.Time) { - e.mu.Lock() - defer e.mu.Unlock() - - e.store[addr] = NodeError{ - Addr: addr, - Err: err, - OccurredAt: occurredAt, - } -} - -func (e *errorsCollector) Remove(addr string) { - e.mu.Lock() - defer e.mu.Unlock() - - delete(e.store, addr) -} - -func (e *errorsCollector) Err() error { - e.mu.Lock() - errList := make([]NodeError, 0, len(e.store)) - for _, nErr := range e.store { - errList = append(errList, nErr) - } - e.mu.Unlock() - - if len(errList) == 0 { - return nil - } - - sort.Slice(errList, func(i, j int) bool { - return errList[i].OccurredAt.Before(errList[j].OccurredAt) - }) - return &CollectedErrors{Errors: errList} -} diff --git a/errors_collector_test.go b/errors_collector_test.go deleted file mode 100644 index 4124e6a..0000000 --- a/errors_collector_test.go +++ /dev/null @@ -1,74 +0,0 @@ -/* - Copyright 2020 YANDEX LLC - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ - -package hasql - -import ( - "errors" - "fmt" - "sync" - "testing" - "time" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestErrorsCollector(t *testing.T) { - nodesCount := 10 - errCollector := newErrorsCollector() - require.NoError(t, errCollector.Err()) - - connErr := errors.New("node connection error") - occurredAt := time.Now() - - var wg sync.WaitGroup - wg.Add(nodesCount) - for i := 1; i <= nodesCount; i++ { - go func(i int) { - defer wg.Done() - errCollector.Add( - fmt.Sprintf("node-%d", i), - connErr, - occurredAt, - ) - }(i) - } - - errCollectDone := make(chan struct{}) - go func() { - for { - select { - case <-errCollectDone: - return - default: - // there are no assertions here, because that logic expected to run with -race, - // otherwise it doesn't test anything, just eat CPU. - _ = errCollector.Err() - } - } - }() - - wg.Wait() - close(errCollectDone) - - err := errCollector.Err() - for i := 1; i <= nodesCount; i++ { - assert.ErrorContains(t, err, fmt.Sprintf("\"node-%d\" node error occurred at", i)) - } - assert.ErrorContains(t, err, connErr.Error()) - -} diff --git a/example_cluster_test.go b/example_cluster_test.go deleted file mode 100644 index b6f6303..0000000 --- a/example_cluster_test.go +++ /dev/null @@ -1,133 +0,0 @@ -/* - Copyright 2020 YANDEX LLC - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ - -package hasql_test - -import ( - "context" - "database/sql" - "time" - - // This example assumes you use pgx driver, but you can use anything that supports database/sql - // _ "github.com/jackc/pgx/v4/stdlib" - - "golang.yandex/hasql" - "golang.yandex/hasql/checkers" -) - -func ExampleNewCluster() { - // cluster hosts - hosts := []struct { - Addr string - Connstring string - }{ - { - Addr: "host1.example.com", - Connstring: "host=host1.example.com", - }, - { - Addr: "host2.example.com", - Connstring: "host=host2.example.com", - }, - { - Addr: "host3.example.com", - Connstring: "host=host3.example.com", - }, - } - - // Construct cluster nodes - nodes := make([]hasql.Node, 0, len(hosts)) - for _, host := range hosts { - // Create database pools for each node - db, err := sql.Open("pgx", host.Connstring) - if err != nil { - panic(err) - } - nodes = append(nodes, hasql.NewNode(host.Addr, db)) - } - - // Use options to fine-tune cluster behavior - opts := []hasql.ClusterOption{ - hasql.WithUpdateInterval(2 * time.Second), // set custom update interval - hasql.WithNodePicker(hasql.PickNodeRoundRobin()), // set desired nodes selection algorithm - } - - // Create cluster handler - c, err := hasql.NewCluster(nodes, checkers.PostgreSQL, opts...) - if err != nil { - panic(err) - } - defer func() { _ = c.Close() }() // close cluster when it is not needed - - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - - // Wait for current primary - node, err := c.WaitForPrimary(ctx) - if err != nil { - panic(err) - } - - // Wait for any alive standby - node, err = c.WaitForStandby(ctx) - if err != nil { - panic(err) - } - - // Wait for any alive node - node, err = c.WaitForAlive(ctx) - if err != nil { - panic(err) - } - // Wait for secondary node if possible, primary otherwise - node, err = c.WaitForNode(ctx, hasql.PreferStandby) - if err != nil { - panic(err) - } - - // Retrieve current primary - node = c.Primary() - if node == nil { - panic("no primary") - } - // Retrieve any alive standby - node = c.Standby() - if node == nil { - panic("no standby") - } - // Retrieve any alive node - node = c.Alive() - if node == nil { - panic("everything is dead") - } - - // Retrieve primary node if possible, secondary otherwise - node = c.Node(hasql.PreferPrimary) - if node == nil { - panic("no primary nor secondary") - } - - // Retrieve secondary node if possible, primary otherwise - node = c.Node(hasql.PreferStandby) - if node == nil { - panic("no primary nor secondary") - } - - // Do something on retrieved node - if err = node.DB().PingContext(ctx); err != nil { - panic(err) - } -} diff --git a/example_test.go b/example_test.go new file mode 100644 index 0000000..c5b2130 --- /dev/null +++ b/example_test.go @@ -0,0 +1,93 @@ +/* + Copyright 2020 YANDEX LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package hasql_test + +import ( + "context" + "database/sql" + "time" + + "golang.yandex/hasql/v2" +) + +// ExampleCluster shows how to setup basic hasql cluster with some suctom settings +func ExampleCluster() { + // open connections to database instances + db1, err := sql.Open("pgx", "host1.example.com") + if err != nil { + panic(err) + } + db2, err := sql.Open("pgx", "host2.example.com") + if err != nil { + panic(err) + } + + // register connections as nodes with some additional information + nodes := []*hasql.Node[*sql.DB]{ + hasql.NewNode("bear", db1), + hasql.NewNode("battlestar galactica", db2), + } + + // create NodeDiscoverer instance + // here we use built-in StaticNodeDiscoverer which always returns all registered nodes + discoverer := hasql.NewStaticNodeDiscoverer(nodes...) + // use checker suitable for your database + checker := hasql.PostgreSQLChecker + // change default RandomNodePicker to RoundRobinNodePicker here + picker := new(hasql.RoundRobinNodePicker[*sql.DB]) + + // create cluster instance using previously created discoverer, checker + // and some additional options + cl, err := hasql.NewCluster(discoverer, checker, + // set custom picker via funcopt + hasql.WithNodePicker(picker), + // change interval of cluster state check + hasql.WithUpdateInterval[*sql.DB](500*time.Millisecond), + // change cluster check timeout value + hasql.WithUpdateTimeout[*sql.DB](time.Second), + ) + if err != nil { + panic(err) + } + + // create context with timeout to wait for at least one alive node in cluster + // note that context timeout value must be greater than cluster update interval + update timeout + // otherwise you will always receive `context.DeadlineExceeded` error if one of cluster node is dead on startup + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + // wait for any alive node to guarantee that your application + // does not start without database dependency + // this step usually performed on application startup + node, err := cl.WaitForNode(ctx, hasql.Alive) + if err != nil { + panic(err) + } + + // pick standby node to perform query + // always check node for nilness to avoid nil pointer dereference error + // node object can be nil if no alive nodes for given criterion has been found + node = cl.Node(hasql.Standby) + if node == nil { + panic("no alive standby available") + } + + // get connection from node and perform desired action + if err := node.DB().PingContext(ctx); err != nil { + panic(err) + } +} diff --git a/example_trace_test.go b/example_trace_test.go deleted file mode 100644 index 05e042f..0000000 --- a/example_trace_test.go +++ /dev/null @@ -1,72 +0,0 @@ -/* - Copyright 2020 YANDEX LLC - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ - -package hasql_test - -import ( - "context" - "database/sql" - "fmt" - "time" - - // This example assumes you use pgx driver, but you can use anything that supports database/sql - // _ "github.com/jackc/pgx/v4/stdlib" - - "golang.yandex/hasql" - "golang.yandex/hasql/checkers" -) - -func ExampleTracer() { - const hostname = "host=host1.example.com" - db, err := sql.Open("pgx", "host="+hostname) - if err != nil { - panic(err) - } - - nodes := []hasql.Node{hasql.NewNode(hostname, db)} - - tracer := hasql.Tracer{ - UpdateNodes: func() { - fmt.Println("Started updating nodes") - }, - UpdatedNodes: func(nodes hasql.AliveNodes) { - fmt.Printf("Finished updating nodes: %+v\n", nodes) - }, - NodeDead: func(node hasql.Node, err error) { - fmt.Printf("Node %q is dead: %s", node, err) - }, - NodeAlive: func(node hasql.Node) { - fmt.Printf("Node %q is alive", node) - }, - NotifiedWaiters: func() { - fmt.Println("Notified all waiters") - }, - } - - c, err := hasql.NewCluster(nodes, checkers.PostgreSQL, hasql.WithTracer(tracer)) - if err != nil { - panic(err) - } - defer func() { _ = c.Close() }() - - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - - _, err = c.WaitForPrimary(ctx) - if err != nil { - panic(err) - } -} diff --git a/go.mod b/go.mod index 94a627d..f48f66e 100644 --- a/go.mod +++ b/go.mod @@ -1,11 +1,9 @@ -module golang.yandex/hasql +module golang.yandex/hasql/v2 -go 1.18 +go 1.22 require ( - github.com/DATA-DOG/go-sqlmock v1.5.0 - github.com/gofrs/uuid v4.2.0+incompatible - github.com/jmoiron/sqlx v1.3.5 + github.com/DATA-DOG/go-sqlmock v1.5.2 github.com/stretchr/testify v1.7.2 ) diff --git a/go.sum b/go.sum index 2d99413..54f288b 100644 --- a/go.sum +++ b/go.sum @@ -1,17 +1,8 @@ -github.com/DATA-DOG/go-sqlmock v1.5.0 h1:Shsta01QNfFxHCfpW6YH2STWB0MudeXXEWMr20OEh60= -github.com/DATA-DOG/go-sqlmock v1.5.0/go.mod h1:f/Ixk793poVmq4qj/V1dPUg2JEAKC73Q5eFN3EC/SaM= +github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU= +github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU= github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/go-sql-driver/mysql v1.6.0 h1:BCTh4TKNUYmOmMUcQ3IipzF5prigylS7XXjEkfCHuOE= -github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= -github.com/gofrs/uuid v4.2.0+incompatible h1:yyYWMnhkhrKwwr8gAOcOCYxOOscHgDS9yZgBrnJfGa0= -github.com/gofrs/uuid v4.2.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= -github.com/jmoiron/sqlx v1.3.5 h1:vFFPA71p1o5gAeqtEAwLU4dnX2napprKtHr7PYIcN3g= -github.com/jmoiron/sqlx v1.3.5/go.mod h1:nRVWtLre0KfCLJvgxzCsLVMogSvQ1zNJtpYr2Ccp0mQ= -github.com/lib/pq v1.2.0 h1:LXpIM/LZ5xGFhOpXAQUIMM1HdyqzVYM13zNdjCEEcA0= -github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= -github.com/mattn/go-sqlite3 v1.14.6 h1:dNPt6NO46WmLVt2DLNpwczCmdV5boIZ6g/tlDrlRUbg= -github.com/mattn/go-sqlite3 v1.14.6/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU= +github.com/kisielk/sqlstruct v0.0.0-20201105191214-5f3e10d3ab46/go.mod h1:yyMNCyc/Ib3bDTKd379tNMpB/7/H5TjM2Y9QJ5THLbE= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= diff --git a/mocks_test.go b/mocks_test.go new file mode 100644 index 0000000..dd5e41a --- /dev/null +++ b/mocks_test.go @@ -0,0 +1,68 @@ +/* + Copyright 2020 YANDEX LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package hasql + +import ( + "context" + "database/sql" + "io" + "slices" +) + +var _ NodeDiscoverer[*mockQuerier] = (*mockNodeDiscoverer[*mockQuerier])(nil) + +// mockNodeDiscoverer returns stored results to tests +type mockNodeDiscoverer[T Querier] struct { + nodes []*Node[T] + err error +} + +func (e mockNodeDiscoverer[T]) DiscoverNodes(_ context.Context) ([]*Node[T], error) { + return slices.Clone(e.nodes), e.err +} + +var _ Querier = (*mockQuerier)(nil) +var _ io.Closer = (*mockQuerier)(nil) + +// mockQuerier returns fake SQL results to tests +type mockQuerier struct { + name string + queryFn func(ctx context.Context, query string, args ...any) (*sql.Rows, error) + queryRowFn func(ctx context.Context, query string, args ...any) *sql.Row + closeFn func() error +} + +func (m *mockQuerier) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) { + if m.queryFn != nil { + return m.queryFn(ctx, query, args...) + } + return nil, nil +} + +func (m *mockQuerier) QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row { + if m.queryRowFn != nil { + return m.queryRowFn(ctx, query, args...) + } + return nil +} + +func (m *mockQuerier) Close() error { + if m.closeFn != nil { + return m.closeFn() + } + return nil +} diff --git a/node.go b/node.go index c763934..2bbc2de 100644 --- a/node.go +++ b/node.go @@ -16,69 +16,44 @@ package hasql -import ( - "context" - "database/sql" - "fmt" -) - -// Node of single cluster -type Node interface { - fmt.Stringer - - Addr() string - DB() *sql.DB -} +// NodeStateCriterion represents a node selection criterion +type NodeStateCriterion uint8 -type sqlNode struct { - addr string - db *sql.DB -} +const ( + // Alive is a criterion to choose any alive node + Alive NodeStateCriterion = iota + 1 + // Primary is a criterion to choose primary node + Primary + // Standby is a criterion to choose standby node + Standby + // PreferPrimary is a criterion to choose primary or any alive node + PreferPrimary + // PreferStandby is a criterion to choose standby or any alive node + PreferStandby -var _ Node = &sqlNode{} + // maxNodeCriterion is for testing purposes only + // all new criteria must be added above this constant + maxNodeCriterion +) -// NewNode constructs node from database/sql DB -func NewNode(addr string, db *sql.DB) Node { - return &sqlNode{addr: addr, db: db} +// Node holds reference to database connection pool with some additional data +type Node[T Querier] struct { + name string + db T } -// Addr returns node's address -func (n *sqlNode) Addr() string { - return n.addr +// NewNode constructs node with given SQL querier +func NewNode[T Querier](name string, db T) *Node[T] { + return &Node[T]{name: name, db: db} } -// DB returns node's database/sql DB -func (n *sqlNode) DB() *sql.DB { +// DB returns node's database connection +func (n *Node[T]) DB() T { return n.db } -// String implements Stringer -func (n *sqlNode) String() string { - return n.addr +// String implements Stringer. +// It uses name provided at construction to uniquely identify a single node +func (n *Node[T]) String() string { + return n.name } - -// NodeStateCriteria for choosing a node -type NodeStateCriteria int - -const ( - // Alive for choosing any alive node - Alive NodeStateCriteria = iota + 1 - // Primary for choosing primary node - Primary - // Standby for choosing standby node - Standby - // PreferPrimary for choosing primary or any alive node - PreferPrimary - // PreferStandby for choosing standby or any alive node - PreferStandby -) - -// NodeChecker is a signature for functions that check if specific node is alive and is primary. -// Returns true for primary and false if not. If error is returned, node is considered dead. -// Check function can be used to perform a query returning single boolean value that signals -// if node is primary or not. -type NodeChecker func(ctx context.Context, db *sql.DB) (bool, error) - -// NodePicker is a signature for functions that determine how to pick single node from set of nodes. -// Nodes passed to the picker function are sorted according to latency (from lowest to greatest). -type NodePicker func(nodes []Node) Node diff --git a/node_checker.go b/node_checker.go new file mode 100644 index 0000000..f536784 --- /dev/null +++ b/node_checker.go @@ -0,0 +1,189 @@ +/* + Copyright 2020 YANDEX LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package hasql + +import ( + "context" + "math" + "time" +) + +// NodeRole represents role of node in SQL cluster (usually primary/standby) +type NodeRole uint8 + +const ( + // NodeRoleUnknown used to report node with unconvetional role in cluster + NodeRoleUnknown NodeRole = iota + // NodeRolePrimary used to report node with primary role in cluster + NodeRolePrimary + // NodeRoleStandby used to report node with standby role in cluster + NodeRoleStandby +) + +// NodeInfoProvider information about single cluster node +type NodeInfoProvider interface { + // Role reports role of node in cluster. + // For SQL servers it is usually either primary or standby + Role() NodeRole +} + +// NodeInfo implements NodeInfoProvider with additional useful information +var _ NodeInfoProvider = NodeInfo{} + +// NodeInfo contains various information about single cluster node +type NodeInfo struct { + // Role contains determined node's role in cluster + ClusterRole NodeRole + // Latency stores time that has been spent to send check request + // and receive response from server + NetworkLatency time.Duration + // ReplicaLag represents how far behind is data on standby + // in comparison to primary. As determination of real replication + // lag is a tricky task and value type vary from one DBMS to another + // (e.g. bytes count lag, time delta lag etc.) this field contains + // abstract value for sorting purposes only + ReplicaLag int +} + +// Role reports determined role of node in cluster +func (n NodeInfo) Role() NodeRole { + return n.ClusterRole +} + +// Latency reports time spend on query execution from client's point of view. +// It can be used in LatencyNodePicker to determine node with fastest response time +func (n NodeInfo) Latency() time.Duration { + return n.NetworkLatency +} + +// ReplicationLag reports data replication delta on standby. +// It can be used in ReplicationNodePicker to determine node with most up-to-date data +func (n NodeInfo) ReplicationLag() int { + return n.ReplicaLag +} + +// NodeChecker is a function that can perform request to SQL node and retrieve various information +type NodeChecker func(context.Context, Querier) (NodeInfoProvider, error) + +// PostgreSQLChecker checks state on PostgreSQL node. +// It reports appropriate information for PostgreSQL nodes version 10 and higher +func PostgreSQLChecker(ctx context.Context, db Querier) (NodeInfoProvider, error) { + start := time.Now() + + var role NodeRole + var replicationLag *int + err := db. + QueryRowContext(ctx, ` + SELECT + ((pg_is_in_recovery())::int + 1) AS role, + pg_last_wal_receive_lsn() - pg_last_wal_replay_lsn() AS replication_lag + ; + `). + Scan(&role, &replicationLag) + if err != nil { + return nil, err + } + + latency := time.Since(start) + + // determine proper replication lag value + // by default we assume that replication is not started - hence maximum int value + // see: https://www.postgresql.org/docs/current/functions-admin.html#FUNCTIONS-RECOVERY-CONTROL + lag := math.MaxInt + if replicationLag != nil { + // use reported non-null replication lag + lag = *replicationLag + } + if role == NodeRolePrimary { + // primary node has no replication lag + lag = 0 + } + + return NodeInfo{ + ClusterRole: role, + NetworkLatency: latency, + ReplicaLag: lag, + }, nil +} + +// MySQLChecker checks state of MySQL node. +// ATTENTION: database user must have REPLICATION CLIENT privilege to perform underlying query. +func MySQLChecker(ctx context.Context, db Querier) (NodeInfoProvider, error) { + start := time.Now() + + rows, err := db.QueryContext(ctx, "SHOW SLAVE STATUS") + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + + latency := time.Since(start) + + // only standby MySQL server will return rows for `SHOW SLAVE STATUS` query + isStandby := rows.Next() + // TODO: check SECONDS_BEHIND_MASTER row for "replication lag" + + if err := rows.Err(); err != nil { + return nil, err + } + + role := NodeRoleStandby + lag := math.MaxInt + if !isStandby { + role = NodeRolePrimary + lag = 0 + } + + return NodeInfo{ + ClusterRole: role, + NetworkLatency: latency, + ReplicaLag: lag, + }, nil +} + +// MSSQLChecker checks state of MSSQL node +func MSSQLChecker(ctx context.Context, db Querier) (NodeInfoProvider, error) { + start := time.Now() + + var isPrimary bool + err := db. + QueryRowContext(ctx, ` + SELECT + IIF(count(database_guid) = 0, 'TRUE', 'FALSE') AS STATUS + FROM sys.database_recovery_status + WHERE database_guid IS NULL + `). + Scan(&isPrimary) + if err != nil { + return nil, err + } + + latency := time.Since(start) + role := NodeRoleStandby + // TODO: proper replication lag calculation + lag := math.MaxInt + if isPrimary { + role = NodeRolePrimary + lag = 0 + } + + return NodeInfo{ + ClusterRole: role, + NetworkLatency: latency, + ReplicaLag: lag, + }, nil +} diff --git a/node_discoverer.go b/node_discoverer.go new file mode 100644 index 0000000..60c259d --- /dev/null +++ b/node_discoverer.go @@ -0,0 +1,47 @@ +/* + Copyright 2020 YANDEX LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package hasql + +import ( + "context" + "database/sql" +) + +// NodeDiscoverer represents a provider of cluster nodes list. +// NodeDiscoverer must node check nodes liveness or role, just return all nodes registered in cluster +type NodeDiscoverer[T Querier] interface { + // DiscoverNodes returns list of nodes registered in cluster + DiscoverNodes(context.Context) ([]*Node[T], error) +} + +// StaticNodeDiscoverer implements NodeDiscoverer +var _ NodeDiscoverer[*sql.DB] = (*StaticNodeDiscoverer[*sql.DB])(nil) + +// StaticNodeDiscoverer returns always returns list of provided nodes +type StaticNodeDiscoverer[T Querier] struct { + nodes []*Node[T] +} + +// NewStaticNodeDiscoverer returns new staticNodeDiscoverer instance +func NewStaticNodeDiscoverer[T Querier](nodes ...*Node[T]) StaticNodeDiscoverer[T] { + return StaticNodeDiscoverer[T]{nodes: nodes} +} + +// DiscoverNodes returns static list of nodes from StaticNodeDiscoverer +func (s StaticNodeDiscoverer[T]) DiscoverNodes(_ context.Context) ([]*Node[T], error) { + return s.nodes, nil +} diff --git a/sqlx/node_test.go b/node_discoverer_test.go similarity index 50% rename from sqlx/node_test.go rename to node_discoverer_test.go index afecbb1..78373a5 100644 --- a/sqlx/node_test.go +++ b/node_discoverer_test.go @@ -17,35 +17,32 @@ package hasql import ( - "errors" + "context" + "database/sql" "testing" - "github.com/jmoiron/sqlx" "github.com/stretchr/testify/assert" ) -func Test_uncheckedSQLxNode(t *testing.T) { - assert.Nil(t, uncheckedSQLxNode(nil)) +func TestNewStaticNodeDiscoverer(t *testing.T) { + node1 := NewNode("shimba", new(sql.DB)) + node2 := NewNode("boomba", new(sql.DB)) - expected := NewNode("foo", &sqlx.DB{}) - assert.Equal(t, expected, uncheckedSQLxNode(expected)) -} + d := NewStaticNodeDiscoverer(node1, node2) + expected := StaticNodeDiscoverer[*sql.DB]{ + nodes: []*Node[*sql.DB]{node1, node2}, + } -func Test_checkedSQLxNode(t *testing.T) { - node, err := checkedSQLxNode(nil, errors.New("err")) - assert.Error(t, err) - assert.Nil(t, node) + assert.Equal(t, expected, d) +} - node, err = checkedSQLxNode(NewNode("foo", &sqlx.DB{}), errors.New("err")) - assert.Error(t, err) - assert.Nil(t, node) +func TestStaticNodeDiscoverer_DiscoverNodes(t *testing.T) { + node1 := NewNode("shimba", new(sql.DB)) + node2 := NewNode("boomba", new(sql.DB)) - node, err = checkedSQLxNode(nil, nil) - assert.NoError(t, err) - assert.Nil(t, node) + d := NewStaticNodeDiscoverer(node1, node2) - expected := NewNode("foo", &sqlx.DB{}) - node, err = checkedSQLxNode(expected, nil) + discovered, err := d.DiscoverNodes(context.Background()) assert.NoError(t, err) - assert.Equal(t, expected, node) + assert.Equal(t, []*Node[*sql.DB]{node1, node2}, discovered) } diff --git a/node_picker.go b/node_picker.go new file mode 100644 index 0000000..5c920f9 --- /dev/null +++ b/node_picker.go @@ -0,0 +1,128 @@ +/* + Copyright 2020 YANDEX LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package hasql + +import ( + "database/sql" + "math/rand/v2" + "sync/atomic" + "time" +) + +// NodePicker decides which node must be used from given set. +// It also provides a comparer to be used to pre-sort nodes for better performance +type NodePicker[T Querier] interface { + // PickNode returns a single node from given set + PickNode(nodes []CheckedNode[T]) CheckedNode[T] + // CompareNodes is a comparison function to be used to sort checked nodes + CompareNodes(a, b CheckedNode[T]) int +} + +// RandomNodePicker implements NodePicker +var _ NodePicker[*sql.DB] = (*RandomNodePicker[*sql.DB])(nil) + +// RandomNodePicker returns random node on each call and does not sort checked nodes +type RandomNodePicker[T Querier] struct{} + +// PickNode returns random node from picker +func (*RandomNodePicker[T]) PickNode(nodes []CheckedNode[T]) CheckedNode[T] { + return nodes[rand.IntN(len(nodes))] +} + +// CompareNodes always treats nodes as equal, effectively not changing nodes order +func (*RandomNodePicker[T]) CompareNodes(_, _ CheckedNode[T]) int { + return 0 +} + +// RoundRobinNodePicker implements NodePicker +var _ NodePicker[*sql.DB] = (*RoundRobinNodePicker[*sql.DB])(nil) + +// RoundRobinNodePicker returns next node based on Round Robin algorithm and tries to preserve nodes order across checks +type RoundRobinNodePicker[T Querier] struct { + idx uint32 +} + +// PickNode returns next node in Round-Robin sequence +func (r *RoundRobinNodePicker[T]) PickNode(nodes []CheckedNode[T]) CheckedNode[T] { + n := atomic.AddUint32(&r.idx, 1) + return nodes[(int(n)-1)%len(nodes)] +} + +// CompareNodes performs lexicographical comparison of two nodes +func (r *RoundRobinNodePicker[T]) CompareNodes(a, b CheckedNode[T]) int { + aName, bName := a.Node.String(), b.Node.String() + if aName < bName { + return -1 + } + if aName > bName { + return 1 + } + return 0 +} + +// LatencyNodePicker implements NodePicker +var _ NodePicker[*sql.DB] = (*LatencyNodePicker[*sql.DB])(nil) + +// LatencyNodePicker returns node with least latency and sorts checked nodes by reported latency ascending. +// WARNING: This picker requires that NodeInfoProvider can report node's network latency otherwise code will panic! +type LatencyNodePicker[T Querier] struct{} + +// PickNode returns node with least network latency +func (*LatencyNodePicker[T]) PickNode(nodes []CheckedNode[T]) CheckedNode[T] { + return nodes[0] +} + +// CompareNodes performs nodes comparison based on reported network latency +func (*LatencyNodePicker[T]) CompareNodes(a, b CheckedNode[T]) int { + aLatency := a.Info.(interface{ Latency() time.Duration }).Latency() + bLatency := b.Info.(interface{ Latency() time.Duration }).Latency() + + if aLatency < bLatency { + return -1 + } + if aLatency > bLatency { + return 1 + } + return 0 +} + +// ReplicationNodePicker implements NodePicker +var _ NodePicker[*sql.DB] = (*ReplicationNodePicker[*sql.DB])(nil) + +// ReplicationNodePicker returns node with smallest replication lag and sorts checked nodes by reported replication lag ascending. +// Note that replication lag reported by checkers can vastly differ from the real situation on standby server. +// WARNING: This picker requires that NodeInfoProvider can report node's replication lag otherwise code will panic! +type ReplicationNodePicker[T Querier] struct{} + +// PickNode returns node with lowest replication lag value +func (*ReplicationNodePicker[T]) PickNode(nodes []CheckedNode[T]) CheckedNode[T] { + return nodes[0] +} + +// CompareNodes performs nodes comparison based on reported replication lag +func (*ReplicationNodePicker[T]) CompareNodes(a, b CheckedNode[T]) int { + aLag := a.Info.(interface{ ReplicationLag() int }).ReplicationLag() + bLag := b.Info.(interface{ ReplicationLag() int }).ReplicationLag() + + if aLag < bLag { + return -1 + } + if aLag > bLag { + return 1 + } + return 0 +} diff --git a/node_picker_test.go b/node_picker_test.go new file mode 100644 index 0000000..44e7907 --- /dev/null +++ b/node_picker_test.go @@ -0,0 +1,273 @@ +/* + Copyright 2020 YANDEX LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package hasql + +import ( + "database/sql" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestRandomNodePicker(t *testing.T) { + t.Run("pick_node", func(t *testing.T) { + nodes := []CheckedNode[*sql.DB]{ + { + Node: NewNode("shimba", (*sql.DB)(nil)), + }, + { + Node: NewNode("boomba", (*sql.DB)(nil)), + }, + { + Node: NewNode("looken", (*sql.DB)(nil)), + }, + } + + np := new(RandomNodePicker[*sql.DB]) + + pickedNodes := make(map[string]struct{}) + for range 100 { + pickedNodes[np.PickNode(nodes).Node.String()] = struct{}{} + } + expectedNodes := map[string]struct{}{"boomba": {}, "looken": {}, "shimba": {}} + + assert.Equal(t, expectedNodes, pickedNodes) + }) + + t.Run("compare_nodes", func(t *testing.T) { + a := CheckedNode[*sql.DB]{ + Node: NewNode("shimba", (*sql.DB)(nil)), + Info: NodeInfo{ + ClusterRole: NodeRolePrimary, + NetworkLatency: 10 * time.Millisecond, + ReplicaLag: 1, + }, + } + + b := CheckedNode[*sql.DB]{ + Node: NewNode("boomba", (*sql.DB)(nil)), + Info: NodeInfo{ + ClusterRole: NodeRoleStandby, + NetworkLatency: 20 * time.Millisecond, + ReplicaLag: 2, + }, + } + + np := new(RandomNodePicker[*sql.DB]) + + for _, nodeA := range []CheckedNode[*sql.DB]{a, b} { + for _, nodeB := range []CheckedNode[*sql.DB]{a, b} { + assert.Equal(t, 0, np.CompareNodes(nodeA, nodeB)) + } + } + }) +} + +func TestRoundRobinNodePicker(t *testing.T) { + t.Run("pick_node", func(t *testing.T) { + nodes := []CheckedNode[*sql.DB]{ + { + Node: NewNode("shimba", (*sql.DB)(nil)), + }, + { + Node: NewNode("boomba", (*sql.DB)(nil)), + }, + { + Node: NewNode("looken", (*sql.DB)(nil)), + }, + { + Node: NewNode("tooken", (*sql.DB)(nil)), + }, + { + Node: NewNode("chicken", (*sql.DB)(nil)), + }, + { + Node: NewNode("cooken", (*sql.DB)(nil)), + }, + } + + np := new(RoundRobinNodePicker[*sql.DB]) + + var pickedNodes []string + for range len(nodes) * 3 { + pickedNodes = append(pickedNodes, np.PickNode(nodes).Node.String()) + } + + expectedNodes := []string{ + "shimba", "boomba", "looken", "tooken", "chicken", "cooken", + "shimba", "boomba", "looken", "tooken", "chicken", "cooken", + "shimba", "boomba", "looken", "tooken", "chicken", "cooken", + } + assert.Equal(t, expectedNodes, pickedNodes) + }) + + t.Run("compare_nodes", func(t *testing.T) { + a := CheckedNode[*sql.DB]{ + Node: NewNode("shimba", (*sql.DB)(nil)), + Info: NodeInfo{ + ClusterRole: NodeRolePrimary, + NetworkLatency: 10 * time.Millisecond, + ReplicaLag: 1, + }, + } + + b := CheckedNode[*sql.DB]{ + Node: NewNode("boomba", (*sql.DB)(nil)), + Info: NodeInfo{ + ClusterRole: NodeRoleStandby, + NetworkLatency: 20 * time.Millisecond, + ReplicaLag: 2, + }, + } + + np := new(RoundRobinNodePicker[*sql.DB]) + + assert.Equal(t, 1, np.CompareNodes(a, b)) + assert.Equal(t, -1, np.CompareNodes(b, a)) + assert.Equal(t, 0, np.CompareNodes(a, a)) + assert.Equal(t, 0, np.CompareNodes(b, b)) + }) +} + +func TestLatencyNodePicker(t *testing.T) { + t.Run("pick_node", func(t *testing.T) { + nodes := []CheckedNode[*sql.DB]{ + { + Node: NewNode("shimba", (*sql.DB)(nil)), + }, + { + Node: NewNode("boomba", (*sql.DB)(nil)), + }, + { + Node: NewNode("looken", (*sql.DB)(nil)), + }, + { + Node: NewNode("tooken", (*sql.DB)(nil)), + }, + { + Node: NewNode("chicken", (*sql.DB)(nil)), + }, + { + Node: NewNode("cooken", (*sql.DB)(nil)), + }, + } + + np := new(LatencyNodePicker[*sql.DB]) + + pickedNodes := make(map[string]struct{}) + for range 100 { + pickedNodes[np.PickNode(nodes).Node.String()] = struct{}{} + } + + expectedNodes := map[string]struct{}{ + "shimba": {}, + } + assert.Equal(t, expectedNodes, pickedNodes) + }) + + t.Run("compare_nodes", func(t *testing.T) { + a := CheckedNode[*sql.DB]{ + Node: NewNode("shimba", (*sql.DB)(nil)), + Info: NodeInfo{ + ClusterRole: NodeRolePrimary, + NetworkLatency: 10 * time.Millisecond, + ReplicaLag: 1, + }, + } + + b := CheckedNode[*sql.DB]{ + Node: NewNode("boomba", (*sql.DB)(nil)), + Info: NodeInfo{ + ClusterRole: NodeRoleStandby, + NetworkLatency: 20 * time.Millisecond, + ReplicaLag: 2, + }, + } + + np := new(LatencyNodePicker[*sql.DB]) + + assert.Equal(t, -1, np.CompareNodes(a, b)) + assert.Equal(t, 1, np.CompareNodes(b, a)) + assert.Equal(t, 0, np.CompareNodes(a, a)) + assert.Equal(t, 0, np.CompareNodes(b, b)) + }) +} + +func TestReplicationNodePicker(t *testing.T) { + t.Run("pick_node", func(t *testing.T) { + nodes := []CheckedNode[*sql.DB]{ + { + Node: NewNode("shimba", (*sql.DB)(nil)), + }, + { + Node: NewNode("boomba", (*sql.DB)(nil)), + }, + { + Node: NewNode("looken", (*sql.DB)(nil)), + }, + { + Node: NewNode("tooken", (*sql.DB)(nil)), + }, + { + Node: NewNode("chicken", (*sql.DB)(nil)), + }, + { + Node: NewNode("cooken", (*sql.DB)(nil)), + }, + } + + np := new(ReplicationNodePicker[*sql.DB]) + + pickedNodes := make(map[string]struct{}) + for range 100 { + pickedNodes[np.PickNode(nodes).Node.String()] = struct{}{} + } + + expectedNodes := map[string]struct{}{ + "shimba": {}, + } + assert.Equal(t, expectedNodes, pickedNodes) + }) + + t.Run("compare_nodes", func(t *testing.T) { + a := CheckedNode[*sql.DB]{ + Node: NewNode("shimba", (*sql.DB)(nil)), + Info: NodeInfo{ + ClusterRole: NodeRolePrimary, + NetworkLatency: 10 * time.Millisecond, + ReplicaLag: 1, + }, + } + + b := CheckedNode[*sql.DB]{ + Node: NewNode("boomba", (*sql.DB)(nil)), + Info: NodeInfo{ + ClusterRole: NodeRoleStandby, + NetworkLatency: 20 * time.Millisecond, + ReplicaLag: 2, + }, + } + + np := new(ReplicationNodePicker[*sql.DB]) + + assert.Equal(t, -1, np.CompareNodes(a, b)) + assert.Equal(t, 1, np.CompareNodes(b, a)) + assert.Equal(t, 0, np.CompareNodes(a, a)) + assert.Equal(t, 0, np.CompareNodes(b, b)) + }) +} diff --git a/node_pickers.go b/node_pickers.go deleted file mode 100644 index 3d15689..0000000 --- a/node_pickers.go +++ /dev/null @@ -1,45 +0,0 @@ -/* - Copyright 2020 YANDEX LLC - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ - -package hasql - -import ( - "math/rand" - "sync/atomic" -) - -// PickNodeRandom returns random node from nodes set -func PickNodeRandom() NodePicker { - return func(nodes []Node) Node { - return nodes[rand.Intn(len(nodes))] - } -} - -// PickNodeRoundRobin returns next node based on Round Robin algorithm -func PickNodeRoundRobin() NodePicker { - var nodeIdx uint32 - return func(nodes []Node) Node { - n := atomic.AddUint32(&nodeIdx, 1) - return nodes[(int(n)-1)%len(nodes)] - } -} - -// PickNodeClosest returns node with least latency -func PickNodeClosest() NodePicker { - return func(nodes []Node) Node { - return nodes[0] - } -} diff --git a/node_pickers_test.go b/node_pickers_test.go deleted file mode 100644 index af99aa5..0000000 --- a/node_pickers_test.go +++ /dev/null @@ -1,78 +0,0 @@ -/* - Copyright 2020 YANDEX LLC - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ - -package hasql - -import ( - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestRandom(t *testing.T) { - nodes := []Node{ - NewNode("shimba", nil), - NewNode("boomba", nil), - NewNode("looken", nil), - } - - rr := PickNodeRandom() - - pickedNodes := make(map[string]struct{}) - for i := 0; i < 100; i++ { - pickedNodes[rr(nodes).Addr()] = struct{}{} - } - expectedNodes := map[string]struct{}{"boomba": {}, "looken": {}, "shimba": {}} - - assert.Equal(t, expectedNodes, pickedNodes) -} - -func TestPickNodeRoundRobin(t *testing.T) { - nodes := []Node{ - NewNode("shimba", nil), - NewNode("boomba", nil), - NewNode("looken", nil), - NewNode("tooken", nil), - NewNode("chicken", nil), - NewNode("cooken", nil), - } - iterCount := len(nodes) * 3 - - rr := PickNodeRoundRobin() - - var pickedNodes []string - for i := 0; i < iterCount; i++ { - pickedNodes = append(pickedNodes, rr(nodes).Addr()) - } - - expectedNodes := []string{ - "shimba", "boomba", "looken", "tooken", "chicken", "cooken", - "shimba", "boomba", "looken", "tooken", "chicken", "cooken", - "shimba", "boomba", "looken", "tooken", "chicken", "cooken", - } - assert.Equal(t, expectedNodes, pickedNodes) -} - -func TestClosest(t *testing.T) { - nodes := []Node{ - NewNode("shimba", nil), - NewNode("boomba", nil), - NewNode("looken", nil), - } - - rr := PickNodeClosest() - assert.Equal(t, nodes[0], rr(nodes)) -} diff --git a/notify.go b/notify.go new file mode 100644 index 0000000..d913951 --- /dev/null +++ b/notify.go @@ -0,0 +1,62 @@ +/* + Copyright 2020 YANDEX LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package hasql + +// updateSubscriber represents a waiter for newly checked node event +type updateSubscriber[T Querier] struct { + ch chan *Node[T] + criterion NodeStateCriterion +} + +// addUpdateSubscriber adds new dubscriber to notification pool +func (cl *Cluster[T]) addUpdateSubscriber(criterion NodeStateCriterion) <-chan *Node[T] { + // buffered channel is essential + // read WaitForNode function for more information + ch := make(chan *Node[T], 1) + cl.subscribersMu.Lock() + defer cl.subscribersMu.Unlock() + cl.subscribers = append(cl.subscribers, updateSubscriber[T]{ch: ch, criterion: criterion}) + return ch +} + +// notifyUpdateSubscribers sends appropriate nodes to registered subsribers. +// This function uses newly checked nodes to avoid race conditions +func (cl *Cluster[T]) notifyUpdateSubscribers(nodes CheckedNodes[T]) { + cl.subscribersMu.Lock() + defer cl.subscribersMu.Unlock() + + if len(cl.subscribers) == 0 { + return + } + + var nodelessWaiters []updateSubscriber[T] + // Notify all waiters + for _, subscriber := range cl.subscribers { + node := pickNodeByCriterion(nodes, cl.picker, subscriber.criterion) + if node == nil { + // Put waiter back + nodelessWaiters = append(nodelessWaiters, subscriber) + continue + } + + // We won't block here, read addUpdateWaiter function for more information + subscriber.ch <- node + // No need to close channel since we write only once and forget it so does the 'client' + } + + cl.subscribers = nodelessWaiters +} diff --git a/checkers/mysql.go b/sql.go similarity index 56% rename from checkers/mysql.go rename to sql.go index ed3c7bd..c74a3ad 100644 --- a/checkers/mysql.go +++ b/sql.go @@ -14,21 +14,18 @@ limitations under the License. */ -package checkers +package hasql import ( "context" "database/sql" ) -// MySQL checks whether MySQL server is primary or not. -// ATTENTION: database user must have REPLICATION CLIENT privilege to perform underlying query. -func MySQL(ctx context.Context, db *sql.DB) (bool, error) { - rows, err := db.QueryContext(ctx, "SHOW SLAVE STATUS") - if err != nil { - return false, err - } - defer func() { _ = rows.Close() }() - hasRows := rows.Next() - return !hasRows, rows.Err() +// Querier describes abstract base SQL client such as database/sql.DB. +// Most of database/sql compatible third-party libraries already implement it +type Querier interface { + // QueryRowContext executes a query that is expected to return at most one row + QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row + // QueryContext executes a query that returns rows, typically a SELECT + QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) } diff --git a/sqlx/cluster.go b/sqlx/cluster.go deleted file mode 100644 index 7862e30..0000000 --- a/sqlx/cluster.go +++ /dev/null @@ -1,105 +0,0 @@ -/* - Copyright 2020 YANDEX LLC - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ - -package hasql - -import ( - "context" - - "golang.yandex/hasql" -) - -// Cluster consists of number of 'nodes' of a single SQL database. -// Background goroutine periodically checks nodes and updates their status. -type Cluster struct { - *hasql.Cluster -} - -// NewCluster constructs cluster object representing a single 'cluster' of SQL database. -// Close function must be called when cluster is not needed anymore. -func NewCluster(nodes []Node, checker NodeChecker, opts ...ClusterOption) (*Cluster, error) { - sqlNodes := make([]hasql.Node, 0, len(nodes)) - for _, n := range nodes { - sqlNodes = append(sqlNodes, n) - } - - cl, err := hasql.NewCluster(sqlNodes, checker, opts...) - if err != nil { - return nil, err - } - - return &Cluster{Cluster: cl}, nil -} - -// WaitForAlive node to appear or until context is canceled -func (cl *Cluster) WaitForAlive(ctx context.Context) (Node, error) { - return checkedSQLxNode(cl.Cluster.WaitForAlive(ctx)) -} - -// WaitForPrimary node to appear or until context is canceled -func (cl *Cluster) WaitForPrimary(ctx context.Context) (Node, error) { - return checkedSQLxNode(cl.Cluster.WaitForPrimary(ctx)) -} - -// WaitForStandby node to appear or until context is canceled -func (cl *Cluster) WaitForStandby(ctx context.Context) (Node, error) { - return checkedSQLxNode(cl.Cluster.WaitForStandby(ctx)) -} - -// WaitForPrimaryPreferred node to appear or until context is canceled -func (cl *Cluster) WaitForPrimaryPreferred(ctx context.Context) (Node, error) { - return checkedSQLxNode(cl.Cluster.WaitForPrimaryPreferred(ctx)) -} - -// WaitForStandbyPreferred node to appear or until context is canceled -func (cl *Cluster) WaitForStandbyPreferred(ctx context.Context) (Node, error) { - return checkedSQLxNode(cl.Cluster.WaitForStandbyPreferred(ctx)) -} - -// WaitForNode with specified status to appear or until context is canceled -func (cl *Cluster) WaitForNode(ctx context.Context, criteria NodeStateCriteria) (Node, error) { - return checkedSQLxNode(cl.Cluster.WaitForNode(ctx, criteria)) -} - -// Alive returns node that is considered alive -func (cl *Cluster) Alive() Node { - return uncheckedSQLxNode(cl.Cluster.Alive()) -} - -// Primary returns first available node that is considered alive and is primary (able to execute write operations) -func (cl *Cluster) Primary() Node { - return uncheckedSQLxNode(cl.Cluster.Primary()) -} - -// Standby returns node that is considered alive and is standby (unable to execute write operations) -func (cl *Cluster) Standby() Node { - return uncheckedSQLxNode(cl.Cluster.Standby()) -} - -// PrimaryPreferred returns primary node if possible, standby otherwise -func (cl *Cluster) PrimaryPreferred() Node { - return uncheckedSQLxNode(cl.Cluster.PrimaryPreferred()) -} - -// StandbyPreferred returns standby node if possible, primary otherwise -func (cl *Cluster) StandbyPreferred() Node { - return uncheckedSQLxNode(cl.Cluster.StandbyPreferred()) -} - -// Node returns cluster node with specified status. -func (cl *Cluster) Node(criteria NodeStateCriteria) Node { - return uncheckedSQLxNode(cl.Cluster.Node(criteria)) -} diff --git a/sqlx/cluster_test.go b/sqlx/cluster_test.go deleted file mode 100644 index 5aaac7c..0000000 --- a/sqlx/cluster_test.go +++ /dev/null @@ -1,56 +0,0 @@ -/* - Copyright 2020 YANDEX LLC - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ - -package hasql - -import ( - "context" - "database/sql" - "testing" - - "github.com/DATA-DOG/go-sqlmock" - "github.com/jmoiron/sqlx" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestNewCluster(t *testing.T) { - checker := func(_ context.Context, _ *sql.DB) (bool, error) { return false, nil } - - t.Run("Works", func(t *testing.T) { - db, mock, err := sqlmock.New(sqlmock.MonitorPingsOption(true)) - require.NoError(t, err) - require.NotNil(t, db) - - mock.ExpectClose() - defer func() { assert.NoError(t, mock.ExpectationsWereMet()) }() - - node := NewNode("fake.addr", sqlx.NewDb(db, "sqlmock")) - cl, err := NewCluster([]Node{node}, checker) - require.NoError(t, err) - require.NotNil(t, cl) - defer func() { require.NoError(t, cl.Close()) }() - - require.Len(t, cl.Nodes(), 1) - require.Equal(t, node, cl.Nodes()[0]) - }) - - t.Run("Fails", func(t *testing.T) { - cl, err := NewCluster(nil, checker) - require.Error(t, err) - require.Nil(t, cl) - }) -} diff --git a/sqlx/forward.go b/sqlx/forward.go deleted file mode 100644 index cdb7663..0000000 --- a/sqlx/forward.go +++ /dev/null @@ -1,68 +0,0 @@ -/* - Copyright 2020 YANDEX LLC - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ - -package hasql - -import "golang.yandex/hasql" - -type ( - // ClusterOption is a functional option type for Cluster constructor - ClusterOption = hasql.ClusterOption - // NodeStateCriteria for choosing a node - NodeStateCriteria = hasql.NodeStateCriteria - // NodeChecker is a signature for functions that check if specific node is alive and is primary. - // Returns true for primary and false if not. If error is returned, node is considered dead. - // Check function can be used to perform a query returning single boolean value that signals - // if node is primary or not. - NodeChecker = hasql.NodeChecker - // NodePicker is a signature for functions that determine how to pick single node from set of nodes. - // Nodes passed to the picker function are sorted according to latency (from lowest to greatest). - NodePicker = hasql.NodePicker - // AliveNodes of Cluster - AliveNodes = hasql.AliveNodes - // Tracer is a set of hooks to run at various stages of background nodes status update. - // Any particular hook may be nil. Functions may be called concurrently from different goroutines. - Tracer = hasql.Tracer -) - -var ( - // Alive for choosing any alive node - Alive = hasql.Alive - // Primary for choosing primary node - Primary = hasql.Primary - // Standby for choosing standby node - Standby = hasql.Standby - // PreferPrimary for choosing primary or any alive node - PreferPrimary = hasql.PreferPrimary - // PreferStandby for choosing standby or any alive node - PreferStandby = hasql.PreferStandby - - // WithUpdateInterval sets interval between cluster node updates - WithUpdateInterval = hasql.WithUpdateInterval - // WithUpdateTimeout sets ping timeout for update of each node in cluster - WithUpdateTimeout = hasql.WithUpdateTimeout - // WithNodePicker sets algorithm for node selection (e.g. random, round robin etc) - WithNodePicker = hasql.WithNodePicker - // WithTracer sets tracer for actions happening in the background - WithTracer = hasql.WithTracer - - // PickNodeRandom returns random node from nodes set - PickNodeRandom = hasql.PickNodeRandom - // PickNodeRoundRobin returns next node based on Round Robin algorithm - PickNodeRoundRobin = hasql.PickNodeRoundRobin - // PickNodeClosest returns node with least latency - PickNodeClosest = hasql.PickNodeClosest -) diff --git a/sqlx/node.go b/sqlx/node.go deleted file mode 100644 index eeaacd2..0000000 --- a/sqlx/node.go +++ /dev/null @@ -1,83 +0,0 @@ -/* - Copyright 2020 YANDEX LLC - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ - -package hasql - -import ( - "database/sql" - - "github.com/jmoiron/sqlx" - - "golang.yandex/hasql" -) - -// Node of single cluster -type Node interface { - hasql.Node - - DBx() *sqlx.DB -} - -type sqlxNode struct { - addr string - dbx *sqlx.DB -} - -var _ Node = &sqlxNode{} - -// NewNode constructs node from sqlx.DB -func NewNode(addr string, db *sqlx.DB) Node { - return &sqlxNode{ - addr: addr, - dbx: db, - } -} - -// Addr returns node's address -func (n *sqlxNode) Addr() string { - return n.addr -} - -// DB returns node's database/sql DB -func (n *sqlxNode) DB() *sql.DB { - return n.dbx.DB -} - -// DBx returns node's sqlx.DB -func (n *sqlxNode) DBx() *sqlx.DB { - return n.dbx -} - -// String implements Stringer -func (n *sqlxNode) String() string { - return n.addr -} - -func uncheckedSQLxNode(node hasql.Node) Node { - if node == nil { - return nil - } - - return node.(*sqlxNode) -} - -func checkedSQLxNode(node hasql.Node, err error) (Node, error) { - if err != nil { - return nil, err - } - - return uncheckedSQLxNode(node), nil -} diff --git a/trace.go b/trace.go index 23ab70f..c2555c0 100644 --- a/trace.go +++ b/trace.go @@ -16,17 +16,17 @@ package hasql -// Tracer is a set of hooks to run at various stages of background nodes status update. +// Tracer is a set of hooks to be called at various stages of background nodes status update. // Any particular hook may be nil. Functions may be called concurrently from different goroutines. -type Tracer struct { +type Tracer[T Querier] struct { // UpdateNodes is called when before updating nodes status. UpdateNodes func() - // UpdatedNodes is called after all nodes are updated. The nodes is a list of currently alive nodes. - UpdatedNodes func(nodes AliveNodes) + // NodesUpdated is called after all nodes are updated. The nodes is a list of currently alive nodes. + NodesUpdated func(nodes CheckedNodes[T]) // NodeDead is called when it is determined that specified node is dead. - NodeDead func(node Node, err error) + NodeDead func(err error) // NodeAlive is called when it is determined that specified node is alive. - NodeAlive func(node Node) - // NotifiedWaiters is called when all callers of 'WaitFor*' functions have been notified. - NotifiedWaiters func() + NodeAlive func(node CheckedNode[T]) + // WaitersNotified is called when callers of 'WaitForNode' function have been notified. + WaitersNotified func() }