diff --git a/database/db.go b/database/db.go index 312b16ec..a33cd56a 100644 --- a/database/db.go +++ b/database/db.go @@ -18,6 +18,7 @@ import ( "github.com/lib/pq" "github.com/pkg/errors" "go.uber.org/zap" + "go.uber.org/zap/zapcore" "golang.org/x/sync/errgroup" "golang.org/x/sync/semaphore" "net" @@ -108,6 +109,7 @@ func NewDbFromConfig(c *Config, logger *logging.Logger, connectorCallbacks Retry if utils.IsUnixAddr(c.Host) { config.Net = "unix" config.Addr = c.Host + addr = "(" + config.Addr + ")" } else { config.Net = "tcp" port := c.Port @@ -115,6 +117,7 @@ func NewDbFromConfig(c *Config, logger *logging.Logger, connectorCallbacks Retry port = 3306 } config.Addr = net.JoinHostPort(c.Host, fmt.Sprint(port)) + addr = config.Addr } config.DBName = c.Database @@ -150,7 +153,6 @@ func NewDbFromConfig(c *Config, logger *logging.Logger, connectorCallbacks Retry return unsafeSetSessionVariableIfExists(ctx, conn, "wsrep_sync_wait", fmt.Sprint(c.Options.WsrepSyncWait)) } - addr = config.Addr db = sqlx.NewDb(sql.OpenDB(NewConnector(connector, logger, connectorCallbacks)), MySQL) case "pgsql": uri := &url.URL{ @@ -208,12 +210,23 @@ func NewDbFromConfig(c *Config, logger *logging.Logger, connectorCallbacks Retry return nil, errors.Wrap(err, "can't open pgsql database") } - addr = utils.JoinHostPort(c.Host, port) + if utils.IsUnixAddr(c.Host) { + // https://www.postgresql.org/docs/17/runtime-config-connection.html#GUC-UNIX-SOCKET-DIRECTORIES + addr = fmt.Sprintf("(%s/.s.PGSQL.%d)", strings.TrimRight(c.Host, "/"), port) + } else { + addr = utils.JoinHostPort(c.Host, port) + } db = sqlx.NewDb(sql.OpenDB(NewConnector(connector, logger, connectorCallbacks)), PostgreSQL) default: return nil, unknownDbType(c.Type) } + if c.TlsOptions.Enable { + addr = fmt.Sprintf("%s+tls://%s@%s/%s", c.Type, c.User, addr, c.Database) + } else { + addr = fmt.Sprintf("%s://%s@%s/%s", c.Type, c.User, addr, c.Database) + } + db.SetMaxIdleConns(c.Options.MaxConnections / 3) db.SetMaxOpenConns(c.Options.MaxConnections) @@ -229,11 +242,22 @@ func NewDbFromConfig(c *Config, logger *logging.Logger, connectorCallbacks Retry }, nil } -// GetAddr returns the database host:port or Unix socket address. +// GetAddr returns a URI-like database connection string. +// +// It has the following syntax: +// +// type[+tls]://user@host[:port]/database func (db *DB) GetAddr() string { return db.addr } +// MarshalLogObject implements [zapcore.ObjectMarshaler], adding the database address [DB.GetAddr] to each log message. +func (db *DB) MarshalLogObject(encoder zapcore.ObjectEncoder) error { + encoder.AddString("database_address", db.GetAddr()) + + return nil +} + // BuildColumns returns all columns of the given struct. func (db *DB) BuildColumns(subject interface{}) []string { return slices.Clone(db.columnMap.Columns(subject)) diff --git a/database/db_test.go b/database/db_test.go new file mode 100644 index 00000000..1ab48a86 --- /dev/null +++ b/database/db_test.go @@ -0,0 +1,124 @@ +package database + +import ( + "github.com/icinga/icinga-go-library/config" + "github.com/icinga/icinga-go-library/logging" + "github.com/stretchr/testify/require" + "go.uber.org/zap/zaptest" + "testing" +) + +func TestNewDbFromConfig_GetAddr(t *testing.T) { + tests := []struct { + name string + conf *Config + addr string + }{ + { + name: "mysql-simple", + conf: &Config{ + Type: "mysql", + Host: "example.com", + Database: "db", + User: "user", + }, + addr: "mysql://user@example.com:3306/db", + }, + { + name: "mysql-custom-port", + conf: &Config{ + Type: "mysql", + Host: "example.com", + Port: 1234, + Database: "db", + User: "user", + }, + addr: "mysql://user@example.com:1234/db", + }, + { + name: "mysql-tls", + conf: &Config{ + Type: "mysql", + Host: "example.com", + Database: "db", + User: "user", + TlsOptions: config.TLS{Enable: true}, + }, + addr: "mysql+tls://user@example.com:3306/db", + }, + { + name: "mysql-unix-domain-socket", + conf: &Config{ + Type: "mysql", + Host: "/var/empty/mysql.sock", + Database: "db", + User: "user", + }, + addr: "mysql://user@(/var/empty/mysql.sock)/db", + }, + { + name: "pgsql-simple", + conf: &Config{ + Type: "pgsql", + Host: "example.com", + Database: "db", + User: "user", + }, + addr: "pgsql://user@example.com:5432/db", + }, + { + name: "pgsql-custom-port", + conf: &Config{ + Type: "pgsql", + Host: "example.com", + Port: 1234, + Database: "db", + User: "user", + }, + addr: "pgsql://user@example.com:1234/db", + }, + { + name: "pgsql-tls", + conf: &Config{ + Type: "pgsql", + Host: "example.com", + Database: "db", + User: "user", + TlsOptions: config.TLS{Enable: true}, + }, + addr: "pgsql+tls://user@example.com:5432/db", + }, + { + name: "pgsql-unix-domain-socket", + conf: &Config{ + Type: "pgsql", + Host: "/var/empty/pgsql", + Database: "db", + User: "user", + }, + addr: "pgsql://user@(/var/empty/pgsql/.s.PGSQL.5432)/db", + }, + { + name: "pgsql-unix-domain-socket-custom-port", + conf: &Config{ + Type: "pgsql", + Host: "/var/empty/pgsql", + Port: 1234, + Database: "db", + User: "user", + }, + addr: "pgsql://user@(/var/empty/pgsql/.s.PGSQL.1234)/db", + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + db, err := NewDbFromConfig( + test.conf, + logging.NewLogger(zaptest.NewLogger(t).Sugar(), 0), + RetryConnectorCallbacks{}) + require.NoError(t, err) + require.Equal(t, test.addr, db.GetAddr()) + }) + } +} diff --git a/redis/client.go b/redis/client.go index 8884fde0..f1f5b263 100644 --- a/redis/client.go +++ b/redis/client.go @@ -13,6 +13,7 @@ import ( "github.com/pkg/errors" "github.com/redis/go-redis/v9" "go.uber.org/zap" + "go.uber.org/zap/zapcore" "golang.org/x/sync/errgroup" "golang.org/x/sync/semaphore" "net" @@ -79,9 +80,37 @@ func NewClientFromConfig(c *Config, logger *logging.Logger) (*Client, error) { return NewClient(redis.NewClient(options), logger, &c.Options), nil } -// GetAddr returns the Redis host:port or Unix socket address. +// GetAddr returns a URI-like Redis connection string. +// +// It has the following syntax: +// +// redis[+tls]://user@host[:port]/database func (c *Client) GetAddr() string { - return c.Client.Options().Addr + description := "redis" + if c.Client.Options().TLSConfig != nil { + description += "+tls" + } + description += "://" + if username := c.Client.Options().Username; username != "" { + description += username + "@" + } + if utils.IsUnixAddr(c.Client.Options().Addr) { + description += "(" + c.Client.Options().Addr + ")" + } else { + description += c.Client.Options().Addr + } + if db := c.Client.Options().DB; db != 0 { + description += fmt.Sprintf("/%d", db) + } + + return description +} + +// MarshalLogObject implements [zapcore.ObjectMarshaler], adding the redis address [Client.GetAddr] to each log message. +func (c *Client) MarshalLogObject(encoder zapcore.ObjectEncoder) error { + encoder.AddString("redis_address", c.GetAddr()) + + return nil } // HPair defines Redis hashes field-value pairs. diff --git a/redis/client_test.go b/redis/client_test.go new file mode 100644 index 00000000..48ff9af9 --- /dev/null +++ b/redis/client_test.go @@ -0,0 +1,87 @@ +package redis + +import ( + "github.com/icinga/icinga-go-library/config" + "github.com/icinga/icinga-go-library/logging" + "github.com/stretchr/testify/require" + "go.uber.org/zap/zaptest" + "testing" +) + +func TestNewClientFromConfig_GetAddr(t *testing.T) { + tests := []struct { + name string + conf *Config + addr string + }{ + { + name: "redis-simple", + conf: &Config{ + Host: "example.com", + }, + addr: "redis://example.com:6379", + }, + { + name: "redis-custom-port", + conf: &Config{ + Host: "example.com", + Port: 6380, + }, + addr: "redis://example.com:6380", + }, + { + name: "redis-acl", + conf: &Config{ + Host: "example.com", + Username: "user", + Password: "pass", + }, + addr: "redis://user@example.com:6379", + }, + { + name: "redis-custom-database", + conf: &Config{ + Host: "example.com", + Database: 23, + }, + addr: "redis://example.com:6379/23", + }, + { + name: "redis-tls", + conf: &Config{ + Host: "example.com", + TlsOptions: config.TLS{Enable: true}, + }, + addr: "redis+tls://example.com:6379", + }, + { + name: "redis-with-everything", + conf: &Config{ + Host: "example.com", + Port: 6380, + Username: "user", + Password: "pass", + Database: 23, + TlsOptions: config.TLS{Enable: true}, + }, + addr: "redis+tls://user@example.com:6380/23", + }, + { + name: "redis-unix-domain-socket", + conf: &Config{ + Host: "/var/empty/redis.sock", + }, + addr: "redis://(/var/empty/redis.sock)", + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + redis, err := NewClientFromConfig( + test.conf, + logging.NewLogger(zaptest.NewLogger(t).Sugar(), 0)) + require.NoError(t, err) + require.Equal(t, test.addr, redis.GetAddr()) + }) + } +}