diff --git a/internal/mysql/node.go b/internal/mysql/node.go index 3149276..e74c8cd 100644 --- a/internal/mysql/node.go +++ b/internal/mysql/node.go @@ -14,6 +14,8 @@ import ( "regexp" "strconv" "strings" + "sync" + "sync/atomic" "syscall" "time" @@ -35,6 +37,9 @@ type Node struct { version *Version host string uuid uuid.UUID + + done atomic.Uint32 + mu sync.Mutex } var ( @@ -50,30 +55,49 @@ const ( // NewNode returns new Node func NewNode(config *config.Config, logger *log.Logger, host string) (*Node, error) { - addr := util.JoinHostPort(host, config.MySQL.Port) - dsn := fmt.Sprintf("%s:%s@tcp(%s)/mysql?autocommit=1", config.MySQL.User, config.MySQL.Password, addr) - if config.MySQL.SslCA != "" { - dsn += "&tls=custom" - } - db, err := sqlx.Open("mysql", dsn) - if err != nil { - return nil, err - } - // Unsafe option allow us to use queries containing fields missing in structs - // eg. when we running "SHOW SLAVE STATUS", but need only few columns - db = db.Unsafe() - db.SetMaxIdleConns(1) - db.SetMaxOpenConns(3) - db.SetConnMaxLifetime(3 * config.TickInterval) return &Node{ config: config, logger: logger, - db: db, + db: nil, host: host, version: nil, + + done: atomic.Uint32{}, + mu: sync.Mutex{}, }, nil } +// Lazy initialization of db connection +func (n *Node) GetDB() (*sqlx.DB, error) { + n.mu.Lock() + defer n.mu.Unlock() + var err error + + // First initialization + if n.done.Load() == 0 { + defer n.done.Store(1) + addr := util.JoinHostPort(n.host, n.config.MySQL.Port) + dsn := fmt.Sprintf("%s:%s@tcp(%s)/mysql?autocommit=1", n.config.MySQL.User, n.config.MySQL.Password, addr) + if n.config.MySQL.SslCA != "" { + dsn += "&tls=custom" + } + n.db, err = sqlx.Open("mysql", dsn) + if err != nil { + return nil, err + } + + // Unsafe option allow us to use queries containing fields missing in structs + // eg. when we running "SHOW SLAVE STATUS", but need only few columns + n.db = n.db.Unsafe() + n.db.SetMaxIdleConns(1) + n.db.SetMaxOpenConns(3) + n.db.SetConnMaxLifetime(3 * n.config.TickInterval) + } + + // Return old value + return n.db, nil +} + // RegisterTLSConfig loads and register CA file for TLS encryption func RegisterTLSConfig(config *config.Config) error { if config.MySQL.SslCA != "" { @@ -107,7 +131,13 @@ func (n *Node) String() string { // Close closes underlying SQL connection func (n *Node) Close() error { - return n.db.Close() + n.mu.Lock() + defer n.mu.Unlock() + if n.done.Load() != 0 { + return n.db.Close() + } else { + return nil + } } func (n *Node) getCommand(name string) string { @@ -185,7 +215,11 @@ func (n *Node) queryRowWithTimeout(queryName string, arg interface{}, result int query := n.getQuery(queryName) ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() - rows, err := n.db.NamedQueryContext(ctx, query, arg) + db, err := n.GetDB() + if err != nil { + return err + } + rows, err := db.NamedQueryContext(ctx, query, arg) if err == nil { defer func() { _ = rows.Close() }() if rows.Next() { @@ -228,7 +262,11 @@ func (n *Node) processQuery(queryName string, arg interface{}, rowsProcessor fun defer cancel() query := n.getQuery(queryName) - rows, err := n.db.NamedQueryContext(ctx, query, arg) + db, err := n.GetDB() + if err != nil { + return err + } + rows, err := db.NamedQueryContext(ctx, query, arg) n.traceQuery(query, arg, rows, err) if err != nil { return err @@ -249,12 +287,15 @@ func (n *Node) execWithTimeout(queryName string, arg map[string]interface{}, tim defer cancel() // avoid connection leak on long lock timeouts lockTimeout := int64(math.Floor(0.8 * float64(timeout/time.Second))) - if _, err := n.db.ExecContext(ctx, n.getQuery(querySetLockTimeout), lockTimeout); err != nil { + db, err := n.GetDB() + if err != nil { + return err + } + if _, err := db.ExecContext(ctx, n.getQuery(querySetLockTimeout), lockTimeout); err != nil { n.traceQuery(query, arg, nil, err) return err } - - _, err := n.db.NamedExecContext(ctx, query, arg) + _, err = db.NamedExecContext(ctx, query, arg) n.traceQuery(query, arg, nil, err) return err } @@ -275,7 +316,11 @@ func (n *Node) getRunningQueryIDs(excludeUsers []string, timeout time.Duration) n.traceQuery(bquery, args, nil, err) return nil, err } - rows, err := n.db.QueryxContext(ctx, bquery, args...) + db, err := n.GetDB() + if err != nil { + return nil, err + } + rows, err := db.QueryxContext(ctx, bquery, args...) if err != nil { n.traceQuery(bquery, args, nil, err) return nil, err @@ -334,7 +379,11 @@ func (n *Node) execMogrifyWithTimeout(queryName string, arg map[string]interface query = Mogrify(query, arg) ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() - _, err := n.db.ExecContext(ctx, query) + db, err := n.GetDB() + if err != nil { + return err + } + _, err = db.ExecContext(ctx, query) n.traceQuery(query, nil, nil, err) return err } @@ -348,7 +397,11 @@ func (n *Node) queryRowMogrifyWithTimeout(queryName string, arg map[string]inter query = Mogrify(query, arg) ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() - rows, err := n.db.NamedQueryContext(ctx, query, arg) + db, err := n.GetDB() + if err != nil { + return err + } + rows, err := db.NamedQueryContext(ctx, query, arg) if err == nil { defer func() { _ = rows.Close() }() if rows.Next() {