diff --git a/docs/features/wait/tls.md b/docs/features/wait/tls.md new file mode 100644 index 0000000000..29a6027d8d --- /dev/null +++ b/docs/features/wait/tls.md @@ -0,0 +1,19 @@ +# TLS Strategy + +TLS Strategy waits for one or more files to exist in the container and uses them +and other details to construct a `tls.Config` which can be used to create secure +connections. + +It supports: + +- x509 PEM Certificate loaded from a certificate / key file pair. +- Root Certificate Authorities aka RootCAs loaded from PEM encoded files. +- Server name. +- Startup timeout to be used in seconds, default is 60 seconds. +- Poll interval to be used in milliseconds, default is 100 milliseconds. + +## Waiting for certificate pair to exist and construct a tls.Config + + +[Waiting for certificate pair to exist and construct a tls.Config](../../../wait/tls_test.go) inside_block:waitForTLSCert + diff --git a/modules/cockroachdb/cockroachdb.go b/modules/cockroachdb/cockroachdb.go index c9d53e1bb5..e3948fe95d 100644 --- a/modules/cockroachdb/cockroachdb.go +++ b/modules/cockroachdb/cockroachdb.go @@ -4,11 +4,9 @@ import ( "bytes" "context" "crypto/tls" - "crypto/x509" _ "embed" "errors" "fmt" - "io" "net" "net/url" @@ -79,70 +77,10 @@ type CockroachDBContainer struct { // options represents the options for the CockroachDBContainer type. type options struct { - // Settings. - database string - user string - password string - insecure bool - - // Client certificate. - clientCert []byte - clientKey []byte - certPool *x509.CertPool - tlsConfig *tls.Config -} - -// WaitUntilReady implements the [wait.Strategy] interface. -// If TLS is enabled, it waits for the CA, client cert and key for the configured user to be -// available in the container and uses them to setup the TLS config, otherwise it does nothing. -// -// This is defined on the options as it needs to know the customised values to operate correctly. -func (o *options) WaitUntilReady(ctx context.Context, target wait.StrategyTarget) error { - if o.insecure { - return nil - } - - return wait.ForAll( - wait.ForFile(fileCACert).WithMatcher(func(r io.Reader) error { - buf, err := io.ReadAll(r) - if err != nil { - return fmt.Errorf("read CA cert: %w", err) - } - - if !o.certPool.AppendCertsFromPEM(buf) { - return errors.New("invalid CA cert") - } - - return nil - }), - wait.ForFile(certsDir+"/client."+o.user+".crt").WithMatcher(func(r io.Reader) error { - var err error - if o.clientCert, err = io.ReadAll(r); err != nil { - return fmt.Errorf("read client cert: %w", err) - } - - return nil - }), - wait.ForFile(certsDir+"/client."+o.user+".key").WithMatcher(func(r io.Reader) error { - var err error - if o.clientKey, err = io.ReadAll(r); err != nil { - return fmt.Errorf("read client key: %w", err) - } - - cert, err := tls.X509KeyPair(o.clientCert, o.clientKey) - if err != nil { - return fmt.Errorf("x509 key pair: %w", err) - } - - o.tlsConfig = &tls.Config{ - RootCAs: o.certPool, - Certificates: []tls.Certificate{cert}, - ServerName: "127.0.0.1", - } - - return nil - }), - ).WaitUntilReady(ctx, target) + database string + user string + password string + tlsStrategy *wait.TLSStrategy } // MustConnectionString returns a connection string to open a new connection to CockroachDB @@ -191,11 +129,12 @@ func (c *CockroachDBContainer) ConnectionConfig(ctx context.Context) (*pgx.ConnC // Deprecated: use [CockroachDBContainer.ConnectionConfig] or // [CockroachDBContainer.ConnectionConfig] instead. func (c *CockroachDBContainer) TLSConfig() (*tls.Config, error) { - if c.tlsConfig == nil { - return nil, ErrTLSNotEnabled + if cfg := c.tlsStrategy.TLSConfig(); cfg != nil { + return cfg, nil + } - return c.tlsConfig, nil + return nil, ErrTLSNotEnabled } // connString returns a connection string for the given host, port and options. @@ -218,7 +157,8 @@ func (c *CockroachDBContainer) connConfig(host string, port nat.Port) (*pgx.Conn } sslMode := "disable" - if c.tlsConfig != nil { + tlsConfig := c.tlsStrategy.TLSConfig() + if tlsConfig != nil { sslMode = "verify-full" } params := url.Values{ @@ -238,22 +178,46 @@ func (c *CockroachDBContainer) connConfig(host string, port nat.Port) (*pgx.Conn return nil, fmt.Errorf("parse config: %w", err) } - cfg.TLSConfig = c.tlsConfig + cfg.TLSConfig = tlsConfig return cfg, nil } // setOptions sets the CockroachDBContainer options from a request. -func (c *CockroachDBContainer) setOptions(req *testcontainers.GenericContainerRequest) { +func (c *CockroachDBContainer) setOptions(req *testcontainers.GenericContainerRequest) error { c.database = req.Env[envDatabase] c.user = req.Env[envUser] c.password = req.Env[envPassword] + + var insecure bool for _, arg := range req.Cmd { if arg == insecureFlag { - c.insecure = true + insecure = true break } } + + if err := wait.Walk(&req.WaitingFor, func(strategy wait.Strategy) error { + if cert, ok := strategy.(*wait.TLSStrategy); ok { + if insecure { + // If insecure mode is enabled, the certificate strategy is removed. + return errors.Join(wait.VisitRemove, wait.VisitStop) + } + + // Update the client certificate files to match the user which may have changed. + cert.WithCert(certsDir+"/client."+c.user+".crt", certsDir+"/client."+c.user+".key") + + c.tlsStrategy = cert + + // Stop the walk as the certificate strategy has been found. + return wait.VisitStop + } + return nil + }); err != nil { + return fmt.Errorf("walk strategies: %w", err) + } + + return nil } // Deprecated: use Run instead. @@ -281,7 +245,6 @@ func Run(ctx context.Context, img string, opts ...testcontainers.ContainerCustom database: defaultDatabase, user: defaultUser, password: defaultPassword, - certPool: x509.NewCertPool(), }, } req := testcontainers.GenericContainerRequest{ @@ -308,7 +271,10 @@ func Run(ctx context.Context, img string, opts ...testcontainers.ContainerCustom WaitingFor: wait.ForAll( wait.ForFile(cockroachDir+"/init_success"), wait.ForHTTP("/health").WithPort(defaultAdminPort), - ctr, // Wait for the TLS files to be available if needed. + wait.ForTLSCert( + certsDir+"/client."+defaultUser+".crt", + certsDir+"/client."+defaultUser+".key", + ).WithRootCAs(fileCACert).WithServerName("127.0.0.1"), wait.ForSQL(defaultSQLPort, "pgx/v5", func(host string, port nat.Port) string { connStr, err := ctr.connString(host, port) if err != nil { diff --git a/modules/cockroachdb/cockroachdb_test.go b/modules/cockroachdb/cockroachdb_test.go index 6f8e6725a7..e3a7bb1f12 100644 --- a/modules/cockroachdb/cockroachdb_test.go +++ b/modules/cockroachdb/cockroachdb_test.go @@ -30,19 +30,26 @@ func TestRun_WithAllOptions(t *testing.T) { ) } -func TestRun_WithInsecureAndPassword(t *testing.T) { - _, err := cockroachdb.Run(context.Background(), testImage, - cockroachdb.WithPassword("testPassword"), - cockroachdb.WithInsecure(), - ) - require.Error(t, err) +func TestRun_WithInsecure(t *testing.T) { + t.Run("valid", func(t *testing.T) { + testContainer(t, cockroachdb.WithInsecure()) + }) - // Check order does not matter. - _, err = cockroachdb.Run(context.Background(), testImage, - cockroachdb.WithInsecure(), - cockroachdb.WithPassword("testPassword"), - ) - require.Error(t, err) + t.Run("invalid-password-insecure", func(t *testing.T) { + _, err := cockroachdb.Run(context.Background(), testImage, + cockroachdb.WithPassword("testPassword"), + cockroachdb.WithInsecure(), + ) + require.Error(t, err) + }) + + t.Run("invalid-insecure-password", func(t *testing.T) { + _, err := cockroachdb.Run(context.Background(), testImage, + cockroachdb.WithInsecure(), + cockroachdb.WithPassword("testPassword"), + ) + require.Error(t, err) + }) } // testContainer runs a CockroachDB container and validates its functionality. diff --git a/modules/cockroachdb/options.go b/modules/cockroachdb/options.go index 5d3cc14c01..09bc4b52b3 100644 --- a/modules/cockroachdb/options.go +++ b/modules/cockroachdb/options.go @@ -113,7 +113,7 @@ func WithInitScripts(scripts ...string) testcontainers.CustomizeRequestOption { } } -// WithInsecure enables insecure mode and disables TLS. +// WithInsecure enables insecure mode which disables TLS. func WithInsecure() testcontainers.CustomizeRequestOption { return func(req *testcontainers.GenericContainerRequest) error { if req.Env[envPassword] != "" { diff --git a/wait/file_test.go b/wait/file_test.go index 22133ba349..20bcc13a01 100644 --- a/wait/file_test.go +++ b/wait/file_test.go @@ -20,7 +20,7 @@ import ( const testFilename = "/tmp/file" -var anyContext = mock.AnythingOfType("*context.timerCtx") +var anyContext = mock.MatchedBy(func(_ context.Context) bool { return true }) // newRunningTarget creates a new mockStrategyTarget that is running. func newRunningTarget() *mockStrategyTarget { diff --git a/wait/http_test.go b/wait/http_test.go index 32479bddd4..73e32d44d7 100644 --- a/wait/http_test.go +++ b/wait/http_test.go @@ -5,6 +5,7 @@ import ( "context" "crypto/tls" "crypto/x509" + _ "embed" "fmt" "io" "log" @@ -23,6 +24,9 @@ import ( "github.com/testcontainers/testcontainers-go/wait" ) +//go:embed testdata/root.pem +var caBytes []byte + // https://github.com/testcontainers/testcontainers-go/issues/183 func ExampleHTTPStrategy() { // waitForHTTPWithDefaultPort { @@ -80,7 +84,7 @@ func ExampleHTTPStrategy_WithHeaders() { tlsconfig := &tls.Config{RootCAs: certpool, ServerName: "testcontainer.go.test"} req := testcontainers.ContainerRequest{ FromDockerfile: testcontainers.FromDockerfile{ - Context: "testdata", + Context: "testdata/http", }, ExposedPorts: []string{"6443/tcp"}, WaitingFor: wait.ForHTTP("/headers"). @@ -227,20 +231,13 @@ func ExampleHTTPStrategy_WithBasicAuth() { } func TestHTTPStrategyWaitUntilReady(t *testing.T) { - workdir, err := os.Getwd() - require.NoError(t, err) - - capath := filepath.Join(workdir, "testdata", "root.pem") - cafile, err := os.ReadFile(capath) - require.NoError(t, err) - certpool := x509.NewCertPool() - require.Truef(t, certpool.AppendCertsFromPEM(cafile), "the ca file isn't valid") + require.Truef(t, certpool.AppendCertsFromPEM(caBytes), "the ca file isn't valid") tlsconfig := &tls.Config{RootCAs: certpool, ServerName: "testcontainer.go.test"} dockerReq := testcontainers.ContainerRequest{ FromDockerfile: testcontainers.FromDockerfile{ - Context: filepath.Join(workdir, "testdata"), + Context: "testdata/http", }, ExposedPorts: []string{"6443/tcp"}, WaitingFor: wait.NewHTTPStrategy("/auth-ping").WithTLS(true, tlsconfig). @@ -288,20 +285,13 @@ func TestHTTPStrategyWaitUntilReady(t *testing.T) { } func TestHTTPStrategyWaitUntilReadyWithQueryString(t *testing.T) { - workdir, err := os.Getwd() - require.NoError(t, err) - - capath := filepath.Join(workdir, "testdata", "root.pem") - cafile, err := os.ReadFile(capath) - require.NoError(t, err) - certpool := x509.NewCertPool() - require.Truef(t, certpool.AppendCertsFromPEM(cafile), "the ca file isn't valid") + require.Truef(t, certpool.AppendCertsFromPEM(caBytes), "the ca file isn't valid") tlsconfig := &tls.Config{RootCAs: certpool, ServerName: "testcontainer.go.test"} dockerReq := testcontainers.ContainerRequest{ FromDockerfile: testcontainers.FromDockerfile{ - Context: filepath.Join(workdir, "testdata"), + Context: "testdata/http", }, ExposedPorts: []string{"6443/tcp"}, @@ -348,22 +338,15 @@ func TestHTTPStrategyWaitUntilReadyWithQueryString(t *testing.T) { } func TestHTTPStrategyWaitUntilReadyNoBasicAuth(t *testing.T) { - workdir, err := os.Getwd() - require.NoError(t, err) - - capath := filepath.Join(workdir, "testdata", "root.pem") - cafile, err := os.ReadFile(capath) - require.NoError(t, err) - certpool := x509.NewCertPool() - require.Truef(t, certpool.AppendCertsFromPEM(cafile), "the ca file isn't valid") + require.Truef(t, certpool.AppendCertsFromPEM(caBytes), "the ca file isn't valid") // waitForHTTPStatusCode { tlsconfig := &tls.Config{RootCAs: certpool, ServerName: "testcontainer.go.test"} var i int dockerReq := testcontainers.ContainerRequest{ FromDockerfile: testcontainers.FromDockerfile{ - Context: filepath.Join(workdir, "testdata"), + Context: "testdata/http", }, ExposedPorts: []string{"6443/tcp"}, WaitingFor: wait.NewHTTPStrategy("/ping").WithTLS(true, tlsconfig). diff --git a/wait/testdata/tls.pem b/wait/testdata/cert.crt similarity index 100% rename from wait/testdata/tls.pem rename to wait/testdata/cert.crt diff --git a/wait/testdata/tls-key.pem b/wait/testdata/cert.key similarity index 100% rename from wait/testdata/tls-key.pem rename to wait/testdata/cert.key diff --git a/wait/testdata/Dockerfile b/wait/testdata/http/Dockerfile similarity index 100% rename from wait/testdata/Dockerfile rename to wait/testdata/http/Dockerfile diff --git a/wait/testdata/go.mod b/wait/testdata/http/go.mod similarity index 100% rename from wait/testdata/go.mod rename to wait/testdata/http/go.mod diff --git a/wait/testdata/main.go b/wait/testdata/http/main.go similarity index 100% rename from wait/testdata/main.go rename to wait/testdata/http/main.go diff --git a/wait/testdata/http/tls-key.pem b/wait/testdata/http/tls-key.pem new file mode 100644 index 0000000000..00789d2371 --- /dev/null +++ b/wait/testdata/http/tls-key.pem @@ -0,0 +1,5 @@ +-----BEGIN EC PRIVATE KEY----- +MHcCAQEEIM8HuDwcZyVqBBy2C6db6zNb/dAJ69bq5ejAEz7qGOIQoAoGCCqGSM49 +AwEHoUQDQgAEBL2ioRmfTc70WT0vyx+amSQOGbMeoMRAfF2qaPzpzOqpKTk0aLOG +0735iy9Fz16PX4vqnLMiM/ZupugAhB//yA== +-----END EC PRIVATE KEY----- diff --git a/wait/testdata/http/tls.pem b/wait/testdata/http/tls.pem new file mode 100644 index 0000000000..46348b7900 --- /dev/null +++ b/wait/testdata/http/tls.pem @@ -0,0 +1,12 @@ +-----BEGIN CERTIFICATE----- +MIIBxTCCAWugAwIBAgIUWBLNpiF1o4r+5ZXwawzPOfBM1F8wCgYIKoZIzj0EAwIw +ADAeFw0yMDA4MTkxMzM4MDBaFw0zMDA4MTcxMzM4MDBaMBkxFzAVBgNVBAMTDnRl +c3Rjb250YWluZXJzMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEBL2ioRmfTc70 +WT0vyx+amSQOGbMeoMRAfF2qaPzpzOqpKTk0aLOG0735iy9Fz16PX4vqnLMiM/Zu +pugAhB//yKOBqTCBpjAOBgNVHQ8BAf8EBAMCBaAwEwYDVR0lBAwwCgYIKwYBBQUH +AwEwDAYDVR0TAQH/BAIwADAdBgNVHQ4EFgQUTMdz5PIZ+Gix4jYUzRIHfByrW+Yw +HwYDVR0jBBgwFoAUFdfV6PSYUlHs+lSQNouRwSfR2ZgwMQYDVR0RBCowKIIVdGVz +dGNvbnRhaW5lci5nby50ZXN0gglsb2NhbGhvc3SHBH8AAAEwCgYIKoZIzj0EAwID +SAAwRQIhAJznPNumi2Plf0GsP9DpC+8WukT/jUhnhcDWCfZ6Ini2AiBLhnhFebZX +XWfSsdSNxIo20OWvy6z3wqdybZtRUfdU+g== +-----END CERTIFICATE----- diff --git a/wait/tls.go b/wait/tls.go new file mode 100644 index 0000000000..ab904b271e --- /dev/null +++ b/wait/tls.go @@ -0,0 +1,167 @@ +package wait + +import ( + "context" + "crypto/tls" + "crypto/x509" + "fmt" + "io" + "time" +) + +// Validate we implement interface. +var _ Strategy = (*TLSStrategy)(nil) + +// TLSStrategy is a strategy for handling TLS. +type TLSStrategy struct { + // General Settings. + timeout *time.Duration + pollInterval time.Duration + + // Custom Settings. + certFiles *x509KeyPair + rootFiles []string + + // State. + tlsConfig *tls.Config +} + +// x509KeyPair is a pair of certificate and key files. +type x509KeyPair struct { + certPEMFile string + keyPEMFile string +} + +// ForTLSCert returns a CertStrategy that will add a Certificate to the [tls.Config] +// constructed from PEM formatted certificate key file pair in the container. +func ForTLSCert(certPEMFile, keyPEMFile string) *TLSStrategy { + return &TLSStrategy{ + certFiles: &x509KeyPair{ + certPEMFile: certPEMFile, + keyPEMFile: keyPEMFile, + }, + tlsConfig: &tls.Config{}, + pollInterval: defaultPollInterval(), + } +} + +// ForTLSRootCAs returns a CertStrategy that sets the root CAs for the [tls.Config] +// using the given PEM formatted files from the container. +func ForTLSRootCAs(pemFiles ...string) *TLSStrategy { + return &TLSStrategy{ + rootFiles: pemFiles, + tlsConfig: &tls.Config{}, + pollInterval: defaultPollInterval(), + } +} + +// WithRootCAs sets the root CAs for the [tls.Config] using the given files from +// the container. +func (ws *TLSStrategy) WithRootCAs(files ...string) *TLSStrategy { + ws.rootFiles = files + return ws +} + +// WithCert sets the [tls.Config] Certificates using the given files from the container. +func (ws *TLSStrategy) WithCert(certPEMFile, keyPEMFile string) *TLSStrategy { + ws.certFiles = &x509KeyPair{ + certPEMFile: certPEMFile, + keyPEMFile: keyPEMFile, + } + return ws +} + +// WithServerName sets the server for the [tls.Config]. +func (ws *TLSStrategy) WithServerName(serverName string) *TLSStrategy { + ws.tlsConfig.ServerName = serverName + return ws +} + +// WithStartupTimeout can be used to change the default startup timeout. +func (ws *TLSStrategy) WithStartupTimeout(startupTimeout time.Duration) *TLSStrategy { + ws.timeout = &startupTimeout + return ws +} + +// WithPollInterval can be used to override the default polling interval of 100 milliseconds. +func (ws *TLSStrategy) WithPollInterval(pollInterval time.Duration) *TLSStrategy { + ws.pollInterval = pollInterval + return ws +} + +// TLSConfig returns the TLS config once the strategy is ready. +// If the strategy is nil, it returns nil. +func (ws *TLSStrategy) TLSConfig() *tls.Config { + if ws == nil { + return nil + } + + return ws.tlsConfig +} + +// WaitUntilReady implements the [Strategy] interface. +// It waits for the CA, client cert and key files to be available in the container and +// uses them to setup the TLS config. +func (ws *TLSStrategy) WaitUntilReady(ctx context.Context, target StrategyTarget) error { + size := len(ws.rootFiles) + if ws.certFiles != nil { + size += 2 + } + strategies := make([]Strategy, 0, size) + for _, file := range ws.rootFiles { + strategies = append(strategies, + ForFile(file).WithMatcher(func(r io.Reader) error { + buf, err := io.ReadAll(r) + if err != nil { + return fmt.Errorf("read CA cert file %q: %w", file, err) + } + + if ws.tlsConfig.RootCAs == nil { + ws.tlsConfig.RootCAs = x509.NewCertPool() + } + + if !ws.tlsConfig.RootCAs.AppendCertsFromPEM(buf) { + return fmt.Errorf("invalid CA cert file %q", file) + } + + return nil + }).WithPollInterval(ws.pollInterval), + ) + } + + if ws.certFiles != nil { + var certPEMBlock []byte + strategies = append(strategies, + ForFile(ws.certFiles.certPEMFile).WithMatcher(func(r io.Reader) error { + var err error + if certPEMBlock, err = io.ReadAll(r); err != nil { + return fmt.Errorf("read certificate cert %q: %w", ws.certFiles.certPEMFile, err) + } + + return nil + }).WithPollInterval(ws.pollInterval), + ForFile(ws.certFiles.keyPEMFile).WithMatcher(func(r io.Reader) error { + keyPEMBlock, err := io.ReadAll(r) + if err != nil { + return fmt.Errorf("read certificate key %q: %w", ws.certFiles.keyPEMFile, err) + } + + cert, err := tls.X509KeyPair(certPEMBlock, keyPEMBlock) + if err != nil { + return fmt.Errorf("x509 key pair %q %q: %w", ws.certFiles.certPEMFile, ws.certFiles.keyPEMFile, err) + } + + ws.tlsConfig.Certificates = []tls.Certificate{cert} + + return nil + }).WithPollInterval(ws.pollInterval), + ) + } + + strategy := ForAll(strategies...) + if ws.timeout != nil { + strategy.WithStartupTimeout(*ws.timeout) + } + + return strategy.WaitUntilReady(ctx, target) +} diff --git a/wait/tls_test.go b/wait/tls_test.go new file mode 100644 index 0000000000..28d3b4af03 --- /dev/null +++ b/wait/tls_test.go @@ -0,0 +1,146 @@ +package wait_test + +import ( + "bytes" + "context" + "crypto/tls" + "crypto/x509" + _ "embed" + "errors" + "fmt" + "io" + "log" + "testing" + "time" + + "github.com/docker/docker/errdefs" + "github.com/stretchr/testify/require" + + "github.com/testcontainers/testcontainers-go" + "github.com/testcontainers/testcontainers-go/wait" +) + +const ( + serverName = "127.0.0.1" + caFilename = "/tmp/ca.pem" + clientCertFilename = "/tmp/cert.crt" + clientKeyFilename = "/tmp/cert.key" +) + +var ( + //go:embed testdata/cert.crt + certBytes []byte + + //go:embed testdata/cert.key + keyBytes []byte +) + +// testForTLSCert creates a new CertStrategy for testing. +func testForTLSCert() *wait.TLSStrategy { + return wait.ForTLSCert(clientCertFilename, clientKeyFilename). + WithRootCAs(caFilename). + WithServerName(serverName). + WithStartupTimeout(time.Millisecond * 50). + WithPollInterval(time.Millisecond) +} + +func TestForCert(t *testing.T) { + errNotFound := errdefs.NotFound(errors.New("file not found")) + ctx := context.Background() + + t.Run("ca-not-found", func(t *testing.T) { + target := newRunningTarget() + target.EXPECT().CopyFileFromContainer(anyContext, caFilename).Return(nil, errNotFound) + err := testForTLSCert().WaitUntilReady(ctx, target) + require.EqualError(t, err, context.DeadlineExceeded.Error()) + }) + + t.Run("cert-not-found", func(t *testing.T) { + target := newRunningTarget() + caFile := io.NopCloser(bytes.NewBuffer(caBytes)) + target.EXPECT().CopyFileFromContainer(anyContext, caFilename).Return(caFile, nil) + target.EXPECT().CopyFileFromContainer(anyContext, clientCertFilename).Return(nil, errNotFound) + err := testForTLSCert().WaitUntilReady(ctx, target) + require.EqualError(t, err, context.DeadlineExceeded.Error()) + }) + + t.Run("key-not-found", func(t *testing.T) { + target := newRunningTarget() + caFile := io.NopCloser(bytes.NewBuffer(caBytes)) + certFile := io.NopCloser(bytes.NewBuffer(certBytes)) + target.EXPECT().CopyFileFromContainer(anyContext, caFilename).Return(caFile, nil) + target.EXPECT().CopyFileFromContainer(anyContext, clientCertFilename).Return(certFile, nil) + target.EXPECT().CopyFileFromContainer(anyContext, clientKeyFilename).Return(nil, errNotFound) + err := testForTLSCert().WaitUntilReady(ctx, target) + require.EqualError(t, err, context.DeadlineExceeded.Error()) + }) + + t.Run("valid", func(t *testing.T) { + target := newRunningTarget() + caFile := io.NopCloser(bytes.NewBuffer(caBytes)) + certFile := io.NopCloser(bytes.NewBuffer(certBytes)) + keyFile := io.NopCloser(bytes.NewBuffer(keyBytes)) + target.EXPECT().CopyFileFromContainer(anyContext, caFilename).Return(caFile, nil) + target.EXPECT().CopyFileFromContainer(anyContext, clientCertFilename).Return(certFile, nil) + target.EXPECT().CopyFileFromContainer(anyContext, clientKeyFilename).Return(keyFile, nil) + + certStrategy := testForTLSCert() + err := certStrategy.WaitUntilReady(ctx, target) + require.NoError(t, err) + + pool := x509.NewCertPool() + require.True(t, pool.AppendCertsFromPEM(caBytes)) + cert, err := tls.X509KeyPair(certBytes, keyBytes) + require.NoError(t, err) + got := certStrategy.TLSConfig() + require.Equal(t, serverName, got.ServerName) + require.Equal(t, []tls.Certificate{cert}, got.Certificates) + require.True(t, pool.Equal(got.RootCAs)) + }) +} + +func ExampleForTLSCert() { + ctx := context.Background() + + // waitForTLSCert { + forCert := wait.ForTLSCert("/app/tls.pem", "/app/tls-key.pem"). + WithServerName("testcontainer.go.test") + req := testcontainers.ContainerRequest{ + FromDockerfile: testcontainers.FromDockerfile{ + Context: "testdata/http", + }, + WaitingFor: forCert, + } + // } + + c, err := testcontainers.GenericContainer(ctx, testcontainers.GenericContainerRequest{ + ContainerRequest: req, + Started: true, + }) + defer func() { + if err := testcontainers.TerminateContainer(c); err != nil { + log.Printf("failed to terminate container: %s", err) + } + }() + if err != nil { + log.Printf("failed to start container: %s", err) + return + } + + state, err := c.State(ctx) + if err != nil { + log.Printf("failed to get container state: %s", err) + return + } + + fmt.Println(state.Running) + + config := forCert.TLSConfig() + fmt.Println(config.ServerName) + fmt.Println(len(config.Certificates)) + + // Output: + // true + // testcontainer.go.test + // 1 +} diff --git a/wait/walk.go b/wait/walk.go new file mode 100644 index 0000000000..2571edcff5 --- /dev/null +++ b/wait/walk.go @@ -0,0 +1,73 @@ +package wait + +import ( + "errors" +) + +var ( + // VisitStop is used as a return value from [VisitFunc] to stop the walk. + // It is not returned as an error by any function. + VisitStop = errors.New("stop the walk") + + // VisitRemove is used as a return value from [VisitFunc] to have the current node removed. + // It is not returned as an error by any function. + VisitRemove = errors.New("remove this strategy") +) + +// VisitFunc is a function that visits a strategy node. +// If it returns [VisitStop], the walk stops. +// If it returns [VisitRemove], the current node is removed. +type VisitFunc func(root Strategy) error + +// Walk walks the strategies tree and calls the visit function for each node. +func Walk(root *Strategy, visit VisitFunc) error { + if root == nil { + return errors.New("root strategy is nil") + } + + if err := walk(root, visit); err != nil { + if errors.Is(err, VisitRemove) || errors.Is(err, VisitStop) { + return nil + } + return err + } + + return nil +} + +// walk walks the strategies tree and calls the visit function for each node. +// It returns an error if the visit function returns an error. +func walk(root *Strategy, visit VisitFunc) error { + if *root == nil { + // No strategy. + return nil + } + + if err := visit(*root); err != nil { + if errors.Is(err, VisitRemove) { + *root = nil + } + + return err + } + + if s, ok := (*root).(*MultiStrategy); ok { + var i int + for range s.Strategies { + if err := walk(&s.Strategies[i], visit); err != nil { + if errors.Is(err, VisitRemove) { + s.Strategies = append(s.Strategies[:i], s.Strategies[i+1:]...) + if errors.Is(err, VisitStop) { + return VisitStop + } + continue + } + + return err + } + i++ + } + } + + return nil +} diff --git a/wait/walk_test.go b/wait/walk_test.go new file mode 100644 index 0000000000..d81cbedbed --- /dev/null +++ b/wait/walk_test.go @@ -0,0 +1,87 @@ +package wait_test + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/testcontainers/testcontainers-go" + "github.com/testcontainers/testcontainers-go/wait" +) + +func TestWalk(t *testing.T) { + req := testcontainers.ContainerRequest{ + WaitingFor: wait.ForAll( + wait.ForFile("/tmp/file"), + wait.ForHTTP("/health"), + wait.ForAll( + wait.ForFile("/tmp/other"), + ), + ), + } + + t.Run("walk", func(t *testing.T) { + var count int + err := wait.Walk(&req.WaitingFor, func(_ wait.Strategy) error { + count++ + return nil + }) + require.NoError(t, err) + require.Equal(t, 5, count) + }) + + t.Run("stop", func(t *testing.T) { + var count int + err := wait.Walk(&req.WaitingFor, func(_ wait.Strategy) error { + count++ + return wait.VisitStop + }) + require.NoError(t, err) + require.Equal(t, 1, count) + }) + + t.Run("remove", func(t *testing.T) { + var count, matched int + err := wait.Walk(&req.WaitingFor, func(s wait.Strategy) error { + count++ + if _, ok := s.(*wait.FileStrategy); ok { + matched++ + return wait.VisitRemove + } + + return nil + }) + require.NoError(t, err) + require.Equal(t, 5, count) + require.Equal(t, 2, matched) + + count = 0 + err = wait.Walk(&req.WaitingFor, func(s wait.Strategy) error { + count++ + if _, ok := s.(*wait.FileStrategy); ok { + matched++ + } + return nil + }) + require.NoError(t, err) + require.Equal(t, 3, count) + }) + + t.Run("remove-stop", func(t *testing.T) { + req := testcontainers.ContainerRequest{ + WaitingFor: wait.ForAll( + wait.ForFile("/tmp/file"), + wait.ForHTTP("/health"), + ), + } + var count int + err := wait.Walk(&req.WaitingFor, func(_ wait.Strategy) error { + count++ + return errors.Join(wait.VisitRemove, wait.VisitStop) + }) + require.NoError(t, err) + require.Equal(t, 1, count) + require.Nil(t, req.WaitingFor) + }) +}