Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: allow to specify read-only replica for SELECTs #1085

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 75 additions & 8 deletions db.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"reflect"
"strings"
"sync/atomic"
"time"

"github.com/uptrace/bun/dialect/feature"
"github.com/uptrace/bun/internal"
Expand All @@ -32,32 +33,55 @@ func WithDiscardUnknownColumns() DBOption {
}
}

type DB struct {
*sql.DB
func WithReadOnlyReplica(replica *sql.DB) DBOption {
return func(db *DB) {
db.replicas = append(db.replicas, replica)
}
}

dialect schema.Dialect
type DB struct {
// Must be a pointer so we copy the state, not the state fields.
*noCopyState

queryHooks []QueryHook

fmter schema.Formatter
flags internal.Flag

stats DBStats
}

// noCopyState contains DB fields that must not be copied on clone(),
// for example, it is forbidden to copy atomic.Pointer.
type noCopyState struct {
*sql.DB
dialect schema.Dialect

replicas []*sql.DB
healthyReplicas atomic.Pointer[[]*sql.DB]
nextReplica atomic.Int64

flags internal.Flag
closed atomic.Bool
}

func NewDB(sqldb *sql.DB, dialect schema.Dialect, opts ...DBOption) *DB {
dialect.Init(sqldb)

db := &DB{
DB: sqldb,
dialect: dialect,
fmter: schema.NewFormatter(dialect),
noCopyState: &noCopyState{
DB: sqldb,
dialect: dialect,
},
fmter: schema.NewFormatter(dialect),
}

for _, opt := range opts {
opt(db)
}

if len(db.replicas) > 0 {
go db.monitorReplicas()
}

return db
}

Expand All @@ -69,6 +93,11 @@ func (db *DB) String() string {
return b.String()
}

func (db *DB) Close() error {
db.closed.Store(true)
return db.DB.Close()
}

func (db *DB) DBStats() DBStats {
return DBStats{
Queries: atomic.LoadUint32(&db.stats.Queries),
Expand Down Expand Up @@ -232,6 +261,44 @@ func (db *DB) HasFeature(feat feature.Feature) bool {
return db.dialect.Features().Has(feat)
}

// healthyReplica returns a random healthy replica.
func (db *DB) healthyReplica() *sql.DB {
replicas := db.loadHealthyReplicas()
if len(replicas) == 0 {
return db.DB
}
if len(replicas) == 1 {
return replicas[0]
}
i := db.nextReplica.Add(1)
return replicas[int(i)%len(replicas)]
}

func (db *DB) loadHealthyReplicas() []*sql.DB {
if ptr := db.healthyReplicas.Load(); ptr != nil {
return *ptr
}
return nil
}

func (db *DB) monitorReplicas() {
for !db.closed.Load() {
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()

healthy := make([]*sql.DB, 0, len(db.replicas))

for _, replica := range db.replicas {
if err := replica.PingContext(ctx); err == nil {
healthy = append(healthy, replica)
}
}

db.healthyReplicas.Store(&healthy)
time.Sleep(5 * time.Second)
}
}

//------------------------------------------------------------------------------

func (db *DB) Exec(query string, args ...interface{}) (sql.Result, error) {
Expand Down
2 changes: 0 additions & 2 deletions internal/dbtest/docker-compose.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
version: '3.9'

services:
mysql8:
image: mysql:8.0
Expand Down
58 changes: 40 additions & 18 deletions query_base.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ const (

type withQuery struct {
name string
query schema.QueryAppender
query Query
recursive bool
}

Expand Down Expand Up @@ -114,8 +114,27 @@ func (q *baseQuery) DB() *DB {
return q.db
}

func (q *baseQuery) GetConn() IConn {
return q.conn
func (q *baseQuery) resolveConn(query Query) IConn {
if q.conn != nil {
return q.conn
}
if len(q.db.replicas) == 0 || !isReadOnlyQuery(query) {
return q.db.DB
}
return q.db.healthyReplica()
}

func isReadOnlyQuery(query Query) bool {
sel, ok := query.(*SelectQuery)
if !ok {
return false
}
for _, el := range sel.with {
if !isReadOnlyQuery(el.query) {
return false
}
}
return true
}

func (q *baseQuery) GetModel() Model {
Expand All @@ -128,10 +147,8 @@ func (q *baseQuery) GetTableName() string {
}

for _, wq := range q.with {
if v, ok := wq.query.(Query); ok {
if model := v.GetModel(); model != nil {
return v.GetTableName()
}
if model := wq.query.GetModel(); model != nil {
return wq.query.GetTableName()
}
}

Expand Down Expand Up @@ -249,7 +266,7 @@ func (q *baseQuery) isSoftDelete() bool {

//------------------------------------------------------------------------------

func (q *baseQuery) addWith(name string, query schema.QueryAppender, recursive bool) {
func (q *baseQuery) addWith(name string, query Query, recursive bool) {
q.with = append(q.with, withQuery{
name: name,
query: query,
Expand Down Expand Up @@ -565,28 +582,33 @@ func (q *baseQuery) scan(
hasDest bool,
) (sql.Result, error) {
ctx, event := q.db.beforeQuery(ctx, iquery, query, nil, query, q.model)
res, err := q._scan(ctx, iquery, query, model, hasDest)
q.db.afterQuery(ctx, event, res, err)
return res, err
}

rows, err := q.conn.QueryContext(ctx, query)
func (q *baseQuery) _scan(
ctx context.Context,
iquery Query,
query string,
model Model,
hasDest bool,
) (sql.Result, error) {
rows, err := q.resolveConn(iquery).QueryContext(ctx, query)
if err != nil {
q.db.afterQuery(ctx, event, nil, err)
return nil, err
}
defer rows.Close()

numRow, err := model.ScanRows(ctx, rows)
if err != nil {
q.db.afterQuery(ctx, event, nil, err)
return nil, err
}

if numRow == 0 && hasDest && isSingleRowModel(model) {
err = sql.ErrNoRows
return nil, sql.ErrNoRows
}

res := driver.RowsAffected(numRow)
q.db.afterQuery(ctx, event, res, err)

return res, err
return driver.RowsAffected(numRow), nil
}

func (q *baseQuery) exec(
Expand All @@ -595,7 +617,7 @@ func (q *baseQuery) exec(
query string,
) (sql.Result, error) {
ctx, event := q.db.beforeQuery(ctx, iquery, query, nil, query, q.model)
res, err := q.conn.ExecContext(ctx, query)
res, err := q.resolveConn(iquery).ExecContext(ctx, query)
q.db.afterQuery(ctx, event, res, err)
return res, err
}
Expand Down
3 changes: 1 addition & 2 deletions query_column_add.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@ var _ Query = (*AddColumnQuery)(nil)
func NewAddColumnQuery(db *DB) *AddColumnQuery {
q := &AddColumnQuery{
baseQuery: baseQuery{
db: db,
conn: db.DB,
db: db,
},
}
return q
Expand Down
3 changes: 1 addition & 2 deletions query_column_drop.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@ var _ Query = (*DropColumnQuery)(nil)
func NewDropColumnQuery(db *DB) *DropColumnQuery {
q := &DropColumnQuery{
baseQuery: baseQuery{
db: db,
conn: db.DB,
db: db,
},
}
return q
Expand Down
7 changes: 3 additions & 4 deletions query_delete.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@ func NewDeleteQuery(db *DB) *DeleteQuery {
q := &DeleteQuery{
whereBaseQuery: whereBaseQuery{
baseQuery: baseQuery{
db: db,
conn: db.DB,
db: db,
},
},
}
Expand Down Expand Up @@ -56,12 +55,12 @@ func (q *DeleteQuery) Apply(fns ...func(*DeleteQuery) *DeleteQuery) *DeleteQuery
return q
}

func (q *DeleteQuery) With(name string, query schema.QueryAppender) *DeleteQuery {
func (q *DeleteQuery) With(name string, query Query) *DeleteQuery {
q.addWith(name, query, false)
return q
}

func (q *DeleteQuery) WithRecursive(name string, query schema.QueryAppender) *DeleteQuery {
func (q *DeleteQuery) WithRecursive(name string, query Query) *DeleteQuery {
q.addWith(name, query, true)
return q
}
Expand Down
3 changes: 1 addition & 2 deletions query_index_create.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,7 @@ func NewCreateIndexQuery(db *DB) *CreateIndexQuery {
q := &CreateIndexQuery{
whereBaseQuery: whereBaseQuery{
baseQuery: baseQuery{
db: db,
conn: db.DB,
db: db,
},
},
}
Expand Down
3 changes: 1 addition & 2 deletions query_index_drop.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@ var _ Query = (*DropIndexQuery)(nil)
func NewDropIndexQuery(db *DB) *DropIndexQuery {
q := &DropIndexQuery{
baseQuery: baseQuery{
db: db,
conn: db.DB,
db: db,
},
}
return q
Expand Down
7 changes: 3 additions & 4 deletions query_insert.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,7 @@ func NewInsertQuery(db *DB) *InsertQuery {
q := &InsertQuery{
whereBaseQuery: whereBaseQuery{
baseQuery: baseQuery{
db: db,
conn: db.DB,
db: db,
},
},
}
Expand Down Expand Up @@ -63,12 +62,12 @@ func (q *InsertQuery) Apply(fns ...func(*InsertQuery) *InsertQuery) *InsertQuery
return q
}

func (q *InsertQuery) With(name string, query schema.QueryAppender) *InsertQuery {
func (q *InsertQuery) With(name string, query Query) *InsertQuery {
q.addWith(name, query, false)
return q
}

func (q *InsertQuery) WithRecursive(name string, query schema.QueryAppender) *InsertQuery {
func (q *InsertQuery) WithRecursive(name string, query Query) *InsertQuery {
q.addWith(name, query, true)
return q
}
Expand Down
7 changes: 3 additions & 4 deletions query_merge.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@ var _ Query = (*MergeQuery)(nil)
func NewMergeQuery(db *DB) *MergeQuery {
q := &MergeQuery{
baseQuery: baseQuery{
db: db,
conn: db.DB,
db: db,
},
}
if q.db.dialect.Name() != dialect.MSSQL && q.db.dialect.Name() != dialect.PG {
Expand Down Expand Up @@ -60,12 +59,12 @@ func (q *MergeQuery) Apply(fns ...func(*MergeQuery) *MergeQuery) *MergeQuery {
return q
}

func (q *MergeQuery) With(name string, query schema.QueryAppender) *MergeQuery {
func (q *MergeQuery) With(name string, query Query) *MergeQuery {
q.addWith(name, query, false)
return q
}

func (q *MergeQuery) WithRecursive(name string, query schema.QueryAppender) *MergeQuery {
func (q *MergeQuery) WithRecursive(name string, query Query) *MergeQuery {
q.addWith(name, query, true)
return q
}
Expand Down
15 changes: 1 addition & 14 deletions query_raw.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,23 +14,10 @@ type RawQuery struct {
args []interface{}
}

// Deprecated: Use NewRaw instead. When add it to IDB, it conflicts with the sql.Conn#Raw
func (db *DB) Raw(query string, args ...interface{}) *RawQuery {
return &RawQuery{
baseQuery: baseQuery{
db: db,
conn: db.DB,
},
query: query,
args: args,
}
}

func NewRawQuery(db *DB, query string, args ...interface{}) *RawQuery {
return &RawQuery{
baseQuery: baseQuery{
db: db,
conn: db.DB,
db: db,
},
query: query,
args: args,
Expand Down
Loading
Loading