diff --git a/modules/cockroachdb/certs.go b/modules/cockroachdb/certs.go index afa12fcd1a..61280b4db9 100644 --- a/modules/cockroachdb/certs.go +++ b/modules/cockroachdb/certs.go @@ -1,8 +1,10 @@ package cockroachdb import ( + "crypto/tls" "crypto/x509" "errors" + "fmt" "net" "time" @@ -65,3 +67,20 @@ func NewTLSConfig() (*TLSConfig, error) { ClientKey: clientCert.KeyBytes, }, nil } + +// tlsConfig returns a [tls.Config] for options. +func (c *TLSConfig) tlsConfig() (*tls.Config, error) { + keyPair, err := tls.X509KeyPair(c.ClientCert, c.ClientKey) + if err != nil { + return nil, fmt.Errorf("x509 key pair: %w", err) + } + + certPool := x509.NewCertPool() + certPool.AddCert(c.CACert) + + return &tls.Config{ + RootCAs: certPool, + Certificates: []tls.Certificate{keyPair}, + ServerName: "localhost", + }, nil +} diff --git a/modules/cockroachdb/cockroachdb.go b/modules/cockroachdb/cockroachdb.go index 884d8f076f..98c5ebf149 100644 --- a/modules/cockroachdb/cockroachdb.go +++ b/modules/cockroachdb/cockroachdb.go @@ -3,23 +3,20 @@ package cockroachdb import ( "context" "crypto/tls" - "crypto/x509" "database/sql" "encoding/pem" "errors" "fmt" - "net" - "net/url" "path/filepath" "github.com/docker/go-connections/nat" "github.com/jackc/pgx/v5" - "github.com/jackc/pgx/v5/stdlib" "github.com/testcontainers/testcontainers-go" "github.com/testcontainers/testcontainers-go/wait" ) +// ErrTLSNotEnabled is returned when trying to get a TLS config from a container that does not have TLS enabled. var ErrTLSNotEnabled = errors.New("tls not enabled") const ( @@ -40,7 +37,9 @@ type CockroachDBContainer struct { opts options } -// MustConnectionString panics if the address cannot be determined. +// MustConnectionString returns a connection string to open a new connection to CockroachDB +// as described by [CockroachDBContainer.ConnectionString]. +// It panics if an error occurs. func (c *CockroachDBContainer) MustConnectionString(ctx context.Context) string { addr, err := c.ConnectionString(ctx) if err != nil { @@ -49,27 +48,33 @@ func (c *CockroachDBContainer) MustConnectionString(ctx context.Context) string return addr } -// ConnectionString returns the dial address to open a new connection to CockroachDB. +// ConnectionString returns a connection string to open a new connection to CockroachDB. +// The returned string is suitable for use by [sql.Open] but is not be compatible with +// [pgx.ParseConfig], so if you want to call [pgx.ConnectConfig] use the +// [CockroachDBContainer.ConnectionConfig] method instead. func (c *CockroachDBContainer) ConnectionString(ctx context.Context) (string, error) { - port, err := c.MappedPort(ctx, defaultSQLPort) - if err != nil { - return "", err - } - - host, err := c.Host(ctx) - if err != nil { - return "", err - } + return c.opts.containerConnString(ctx, c.Container) +} - return connString(c.opts, host, port), nil +// ConnectionConfig returns a [pgx.ConnConfig] for the CockroachDB container. +// This can be passed to [pgx.ConnectConfig] to open a new connection. +func (c *CockroachDBContainer) ConnectionConfig(ctx context.Context) (*pgx.ConnConfig, error) { + return c.opts.containerConnConfig(ctx, c.Container) } // TLSConfig returns config necessary to connect to CockroachDB over TLS. +// +// Deprecated: use [CockroachDBContainer.ConnectionConfig] or +// [CockroachDBContainer.ConnectionConfig] instead. func (c *CockroachDBContainer) TLSConfig() (*tls.Config, error) { - return connTLS(c.opts) + if c.opts.TLS == nil { + return nil, ErrTLSNotEnabled + } + + return c.opts.TLS.tlsConfig() } -// Deprecated: use Run instead +// Deprecated: use Run instead. // RunContainer creates an instance of the CockroachDB container type func RunContainer(ctx context.Context, opts ...testcontainers.ContainerCustomizer) (*CockroachDBContainer, error) { return Run(ctx, "cockroachdb/cockroach:latest-v23.1", opts...) @@ -178,29 +183,12 @@ func addEnvs(req *testcontainers.GenericContainerRequest, opts options) error { } func addWaitingFor(req *testcontainers.GenericContainerRequest, opts options) error { - var tlsConfig *tls.Config - if opts.TLS != nil { - cfg, err := connTLS(opts) - if err != nil { - return err - } - tlsConfig = cfg - } - sqlWait := wait.ForSQL(defaultSQLPort, "pgx/v5", func(host string, port nat.Port) string { - connStr := connString(opts, host, port) - if tlsConfig == nil { - return connStr - } - - // register TLS config with pgx driver - connCfg, err := pgx.ParseConfig(connStr) + connStr, err := opts.connString(host, port) if err != nil { panic(err) } - connCfg.TLSConfig = tlsConfig - - return stdlib.RegisterConnConfig(connCfg) + return connStr }) defaultStrategy := wait.ForAll( wait.ForHTTP("/health").WithPort(defaultAdminPort), @@ -246,17 +234,12 @@ func runStatements(ctx context.Context, container testcontainers.Container, opts return nil } - port, err := container.MappedPort(ctx, defaultSQLPort) - if err != nil { - return fmt.Errorf("mapped port: %w", err) - } - - host, err := container.Host(ctx) + connStr, err := opts.containerConnString(ctx, container) if err != nil { - return fmt.Errorf("host: %w", err) + return fmt.Errorf("connection string: %w", err) } - db, err := sql.Open("pgx/v5", connString(opts, host, port)) + db, err := sql.Open("pgx/v5", connStr) if err != nil { return fmt.Errorf("sql.Open: %w", err) } @@ -275,48 +258,3 @@ func runStatements(ctx context.Context, container testcontainers.Container, opts return nil } - -func connString(opts options, host string, port nat.Port) string { - user := url.User(opts.User) - if opts.Password != "" { - user = url.UserPassword(opts.User, opts.Password) - } - - sslMode := "disable" - if opts.TLS != nil { - sslMode = "verify-full" - } - params := url.Values{ - "sslmode": []string{sslMode}, - } - - u := url.URL{ - Scheme: "postgres", - User: user, - Host: net.JoinHostPort(host, port.Port()), - Path: opts.Database, - RawQuery: params.Encode(), - } - - return u.String() -} - -func connTLS(opts options) (*tls.Config, error) { - if opts.TLS == nil { - return nil, ErrTLSNotEnabled - } - - keyPair, err := tls.X509KeyPair(opts.TLS.ClientCert, opts.TLS.ClientKey) - if err != nil { - return nil, err - } - - certPool := x509.NewCertPool() - certPool.AddCert(opts.TLS.CACert) - - return &tls.Config{ - RootCAs: certPool, - Certificates: []tls.Certificate{keyPair}, - ServerName: "localhost", - }, nil -} diff --git a/modules/cockroachdb/cockroachdb_test.go b/modules/cockroachdb/cockroachdb_test.go index 45df7909bb..d3d05d1d86 100644 --- a/modules/cockroachdb/cockroachdb_test.go +++ b/modules/cockroachdb/cockroachdb_test.go @@ -2,9 +2,6 @@ package cockroachdb_test import ( "context" - "errors" - "net/url" - "strings" "testing" "time" @@ -18,14 +15,11 @@ import ( ) func TestCockroach_Insecure(t *testing.T) { - suite.Run(t, &AuthNSuite{ - url: "postgres://root@localhost:xxxxx/defaultdb?sslmode=disable", - }) + suite.Run(t, &AuthNSuite{}) } func TestCockroach_NotRoot(t *testing.T) { suite.Run(t, &AuthNSuite{ - url: "postgres://test@localhost:xxxxx/defaultdb?sslmode=disable", opts: []testcontainers.ContainerCustomizer{ cockroachdb.WithUser("test"), // Do not run the default statements as the user used on this test is @@ -37,7 +31,6 @@ func TestCockroach_NotRoot(t *testing.T) { func TestCockroach_Password(t *testing.T) { suite.Run(t, &AuthNSuite{ - url: "postgres://foo:bar@localhost:xxxxx/defaultdb?sslmode=disable", opts: []testcontainers.ContainerCustomizer{ cockroachdb.WithUser("foo"), cockroachdb.WithPassword("bar"), @@ -53,19 +46,26 @@ func TestCockroach_TLS(t *testing.T) { require.NoError(t, err) suite.Run(t, &AuthNSuite{ - url: "postgres://root@localhost:xxxxx/defaultdb?sslmode=verify-full", opts: []testcontainers.ContainerCustomizer{ cockroachdb.WithTLS(tlsCfg), - // Do not run the default statements as the user used on this test is - // lacking the needed MODIFYCLUSTERSETTING privilege to run them. - cockroachdb.WithStatements(), }, }) } +func TestTLS(t *testing.T) { + tlsCfg, err := cockroachdb.NewTLSConfig() + require.NoError(t, err) + + ctx := context.Background() + + ctr, err := cockroachdb.Run(ctx, "cockroachdb/cockroach:latest-v23.1", cockroachdb.WithTLS(tlsCfg)) + testcontainers.CleanupContainer(t, ctr) + require.NoError(t, err) + require.NotNil(t, ctr) +} + type AuthNSuite struct { suite.Suite - url string opts []testcontainers.ContainerCustomizer } @@ -75,11 +75,6 @@ func (suite *AuthNSuite) TestConnectionString() { ctr, err := cockroachdb.Run(ctx, "cockroachdb/cockroach:latest-v23.1", suite.opts...) testcontainers.CleanupContainer(suite.T(), ctr) suite.Require().NoError(err) - - connStr, err := removePort(ctr.MustConnectionString(ctx)) - suite.Require().NoError(err) - - suite.Equal(suite.url, connStr) } func (suite *AuthNSuite) TestPing() { @@ -203,29 +198,10 @@ func (suite *AuthNSuite) TestWithWaitStrategyAndDeadline() { } func conn(ctx context.Context, container *cockroachdb.CockroachDBContainer) (*pgx.Conn, error) { - cfg, err := pgx.ParseConfig(container.MustConnectionString(ctx)) + cfg, err := container.ConnectionConfig(ctx) if err != nil { return nil, err } - tlsCfg, err := container.TLSConfig() - switch { - case err != nil: - if !errors.Is(err, cockroachdb.ErrTLSNotEnabled) { - return nil, err - } - default: - // apply TLS config - cfg.TLSConfig = tlsCfg - } - return pgx.ConnectConfig(ctx, cfg) } - -func removePort(s string) (string, error) { - u, err := url.Parse(s) - if err != nil { - return "", err - } - return strings.Replace(s, ":"+u.Port(), ":xxxxx", 1), nil -} diff --git a/modules/cockroachdb/examples_test.go b/modules/cockroachdb/examples_test.go index 9a8fb12881..4cc14e8b7b 100644 --- a/modules/cockroachdb/examples_test.go +++ b/modules/cockroachdb/examples_test.go @@ -5,7 +5,8 @@ import ( "database/sql" "fmt" "log" - "net/url" + + "github.com/jackc/pgx/v5" "github.com/testcontainers/testcontainers-go" "github.com/testcontainers/testcontainers-go/modules/cockroachdb" @@ -34,25 +35,34 @@ func ExampleRun() { } fmt.Println(state.Running) - addr, err := cockroachdbContainer.ConnectionString(ctx) + cfg, err := cockroachdbContainer.ConnectionConfig(ctx) if err != nil { log.Printf("failed to get connection string: %s", err) return } - u, err := url.Parse(addr) + + conn, err := pgx.ConnectConfig(ctx, cfg) if err != nil { - log.Printf("failed to parse connection string: %s", err) + log.Printf("failed to connect: %s", err) + return + } + + defer func() { + if err := conn.Close(ctx); err != nil { + log.Printf("failed to close connection: %s", err) + } + }() + + if err = conn.Ping(ctx); err != nil { + log.Printf("failed to ping: %s", err) return } - u.Host = fmt.Sprintf("%s:%s", u.Hostname(), "xxx") - fmt.Println(u.String()) // Output: // true - // postgres://root@localhost:xxx/defaultdb?sslmode=disable } -func ExampleRun_withRecommendedSettings() { +func ExampleRun_withCustomStatements() { ctx := context.Background() cockroachdbContainer, err := cockroachdb.Run(ctx, "cockroachdb/cockroach:latest-v23.1", cockroachdb.WithStatements(cockroachdb.DefaultStatements...)) diff --git a/modules/cockroachdb/options.go b/modules/cockroachdb/options.go index eba101834e..81a1843f71 100644 --- a/modules/cockroachdb/options.go +++ b/modules/cockroachdb/options.go @@ -1,6 +1,17 @@ package cockroachdb -import "github.com/testcontainers/testcontainers-go" +import ( + "context" + "fmt" + "net" + "net/url" + + "github.com/docker/go-connections/nat" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/stdlib" + + "github.com/testcontainers/testcontainers-go" +) type options struct { Database string @@ -11,6 +22,81 @@ type options struct { Statements []string } +// containerConnConfig returns the [pgx.ConnConfig] for the given container and options. +func (opts options) containerConnConfig(ctx context.Context, container testcontainers.Container) (*pgx.ConnConfig, error) { + port, err := container.MappedPort(ctx, defaultSQLPort) + if err != nil { + return nil, fmt.Errorf("mapped port: %w", err) + } + + host, err := container.Host(ctx) + if err != nil { + return nil, fmt.Errorf("host: %w", err) + } + + return opts.connConfig(host, port) +} + +// containerConnString returns the connection string for the given container and options. +func (opts options) containerConnString(ctx context.Context, container testcontainers.Container) (string, error) { + cfg, err := opts.containerConnConfig(ctx, container) + if err != nil { + return "", fmt.Errorf("container connection config: %w", err) + } + + return stdlib.RegisterConnConfig(cfg), nil +} + +// connString returns a connection string for the given host, port and options. +func (opts options) connString(host string, port nat.Port) (string, error) { + cfg, err := opts.connConfig(host, port) + if err != nil { + return "", fmt.Errorf("connection config: %w", err) + } + + return stdlib.RegisterConnConfig(cfg), nil +} + +// connConfig returns a [pgx.ConnConfig] for the given host, port and options. +func (opts options) connConfig(host string, port nat.Port) (*pgx.ConnConfig, error) { + user := url.User(opts.User) + if opts.Password != "" { + user = url.UserPassword(opts.User, opts.Password) + } + + sslMode := "disable" + if opts.TLS != nil { + sslMode = "require" // We can't use "verify-full" as it might be a self signed cert. + } + params := url.Values{ + "sslmode": []string{sslMode}, + } + + u := url.URL{ + Scheme: "postgres", + User: user, + Host: net.JoinHostPort(host, port.Port()), + Path: opts.Database, + RawQuery: params.Encode(), + } + + cfg, err := pgx.ParseConfig(u.String()) + if err != nil { + return nil, fmt.Errorf("parse config: %w", err) + } + + if opts.TLS != nil { + tlsCfg, err := opts.TLS.tlsConfig() + if err != nil { + return nil, fmt.Errorf("tls config: %w", err) + } + + cfg.TLSConfig = tlsCfg + } + + return cfg, nil +} + func defaultOptions() options { return options{ User: defaultUser, @@ -85,8 +171,8 @@ var DefaultStatements = []string{ } // WithStatements sets the statements to run on the CockroachDB cluster once the container is ready. -// This, in combination with DefaultStatements, can be used to configure the cluster with the settings -// recommended by Cockroach Labs. +// By default, the container will run the statements in [DefaultStatements] as recommended by +// Cockroach Labs however that is not always possible due to the user not having the required privileges. func WithStatements(statements ...string) Option { return func(o *options) { o.Statements = statements