Skip to content

Commit

Permalink
[Ready for Review] Lazy initialization for nodes (#164)
Browse files Browse the repository at this point in the history
* make lazy initialization of db connaction for every node and new method Node.GetDB() instead of using n.db

* implement own lazy initialization

* logs: remove unnecessary fmt.Print

* add mutex to node.close function
  • Loading branch information
WithSoull authored Mar 4, 2025
1 parent 4c0cf72 commit dcdb179
Showing 1 changed file with 78 additions and 25 deletions.
103 changes: 78 additions & 25 deletions internal/mysql/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ import (
"regexp"
"strconv"
"strings"
"sync"
"sync/atomic"
"syscall"
"time"

Expand All @@ -35,6 +37,9 @@ type Node struct {
version *Version
host string
uuid uuid.UUID

done atomic.Uint32
mu sync.Mutex
}

var (
Expand All @@ -50,30 +55,49 @@ const (

// NewNode returns new Node
func NewNode(config *config.Config, logger *log.Logger, host string) (*Node, error) {
addr := util.JoinHostPort(host, config.MySQL.Port)
dsn := fmt.Sprintf("%s:%s@tcp(%s)/mysql?autocommit=1", config.MySQL.User, config.MySQL.Password, addr)
if config.MySQL.SslCA != "" {
dsn += "&tls=custom"
}
db, err := sqlx.Open("mysql", dsn)
if err != nil {
return nil, err
}
// Unsafe option allow us to use queries containing fields missing in structs
// eg. when we running "SHOW SLAVE STATUS", but need only few columns
db = db.Unsafe()
db.SetMaxIdleConns(1)
db.SetMaxOpenConns(3)
db.SetConnMaxLifetime(3 * config.TickInterval)
return &Node{
config: config,
logger: logger,
db: db,
db: nil,
host: host,
version: nil,

done: atomic.Uint32{},
mu: sync.Mutex{},
}, nil
}

// Lazy initialization of db connection
func (n *Node) GetDB() (*sqlx.DB, error) {
n.mu.Lock()
defer n.mu.Unlock()
var err error

// First initialization
if n.done.Load() == 0 {
defer n.done.Store(1)
addr := util.JoinHostPort(n.host, n.config.MySQL.Port)
dsn := fmt.Sprintf("%s:%s@tcp(%s)/mysql?autocommit=1", n.config.MySQL.User, n.config.MySQL.Password, addr)
if n.config.MySQL.SslCA != "" {
dsn += "&tls=custom"
}
n.db, err = sqlx.Open("mysql", dsn)
if err != nil {
return nil, err
}

// Unsafe option allow us to use queries containing fields missing in structs
// eg. when we running "SHOW SLAVE STATUS", but need only few columns
n.db = n.db.Unsafe()
n.db.SetMaxIdleConns(1)
n.db.SetMaxOpenConns(3)
n.db.SetConnMaxLifetime(3 * n.config.TickInterval)
}

// Return old value
return n.db, nil
}

// RegisterTLSConfig loads and register CA file for TLS encryption
func RegisterTLSConfig(config *config.Config) error {
if config.MySQL.SslCA != "" {
Expand Down Expand Up @@ -107,7 +131,13 @@ func (n *Node) String() string {

// Close closes underlying SQL connection
func (n *Node) Close() error {
return n.db.Close()
n.mu.Lock()
defer n.mu.Unlock()
if n.done.Load() != 0 {
return n.db.Close()
} else {
return nil
}
}

func (n *Node) getCommand(name string) string {
Expand Down Expand Up @@ -185,7 +215,11 @@ func (n *Node) queryRowWithTimeout(queryName string, arg interface{}, result int
query := n.getQuery(queryName)
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
rows, err := n.db.NamedQueryContext(ctx, query, arg)
db, err := n.GetDB()
if err != nil {
return err
}
rows, err := db.NamedQueryContext(ctx, query, arg)
if err == nil {
defer func() { _ = rows.Close() }()
if rows.Next() {
Expand Down Expand Up @@ -228,7 +262,11 @@ func (n *Node) processQuery(queryName string, arg interface{}, rowsProcessor fun
defer cancel()

query := n.getQuery(queryName)
rows, err := n.db.NamedQueryContext(ctx, query, arg)
db, err := n.GetDB()
if err != nil {
return err
}
rows, err := db.NamedQueryContext(ctx, query, arg)
n.traceQuery(query, arg, rows, err)
if err != nil {
return err
Expand All @@ -249,12 +287,15 @@ func (n *Node) execWithTimeout(queryName string, arg map[string]interface{}, tim
defer cancel()
// avoid connection leak on long lock timeouts
lockTimeout := int64(math.Floor(0.8 * float64(timeout/time.Second)))
if _, err := n.db.ExecContext(ctx, n.getQuery(querySetLockTimeout), lockTimeout); err != nil {
db, err := n.GetDB()
if err != nil {
return err
}
if _, err := db.ExecContext(ctx, n.getQuery(querySetLockTimeout), lockTimeout); err != nil {
n.traceQuery(query, arg, nil, err)
return err
}

_, err := n.db.NamedExecContext(ctx, query, arg)
_, err = db.NamedExecContext(ctx, query, arg)
n.traceQuery(query, arg, nil, err)
return err
}
Expand All @@ -275,7 +316,11 @@ func (n *Node) getRunningQueryIDs(excludeUsers []string, timeout time.Duration)
n.traceQuery(bquery, args, nil, err)
return nil, err
}
rows, err := n.db.QueryxContext(ctx, bquery, args...)
db, err := n.GetDB()
if err != nil {
return nil, err
}
rows, err := db.QueryxContext(ctx, bquery, args...)
if err != nil {
n.traceQuery(bquery, args, nil, err)
return nil, err
Expand Down Expand Up @@ -334,7 +379,11 @@ func (n *Node) execMogrifyWithTimeout(queryName string, arg map[string]interface
query = Mogrify(query, arg)
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
_, err := n.db.ExecContext(ctx, query)
db, err := n.GetDB()
if err != nil {
return err
}
_, err = db.ExecContext(ctx, query)
n.traceQuery(query, nil, nil, err)
return err
}
Expand All @@ -348,7 +397,11 @@ func (n *Node) queryRowMogrifyWithTimeout(queryName string, arg map[string]inter
query = Mogrify(query, arg)
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
rows, err := n.db.NamedQueryContext(ctx, query, arg)
db, err := n.GetDB()
if err != nil {
return err
}
rows, err := db.NamedQueryContext(ctx, query, arg)
if err == nil {
defer func() { _ = rows.Close() }()
if rows.Next() {
Expand Down

0 comments on commit dcdb179

Please sign in to comment.