diff --git a/docs/features/wait/introduction.md b/docs/features/wait/introduction.md index b9d2c160ac..feef9dc939 100644 --- a/docs/features/wait/introduction.md +++ b/docs/features/wait/introduction.md @@ -15,6 +15,7 @@ Below you can find a list of the available wait strategies that you can use: - [Log](./log.md) - [Multi](./multi.md) - [SQL](./sql.md) +- [TLS](./tls.md) ## Startup timeout and Poll interval diff --git a/docs/features/wait/tls.md b/docs/features/wait/tls.md new file mode 100644 index 0000000000..a98f78d84c --- /dev/null +++ b/docs/features/wait/tls.md @@ -0,0 +1,31 @@ +# TLS Strategy + +TLS Strategy waits for one or more files to exist in the container and uses them +and other details to construct a `tls.Config` which can be used to create secure +connections. + +It supports: + +- x509 PEM Certificate loaded from a certificate / key file pair. +- Root Certificate Authorities aka RootCAs loaded from PEM encoded files. +- Server name. +- Startup timeout to be used in seconds, default is 60 seconds. +- Poll interval to be used in milliseconds, default is 100 milliseconds. + +## Waiting for certificate pair + +The following snippets show how to configure a request to wait for certificate +pair to exist once started and then read the +[tls.Config](https://pkg.go.dev/crypto/tls#Config), alongside how to copy a test +certificate pair into a container image using a `Dockerfile`. + +It should be noted that copying certificate pairs into an images is only an +example which might be useful for testing with testcontainers-go and should not +be done with production images as that could expose your certificates if your +images become public. + + +[Wait for certificate](../../../wait/tls_test.go) inside_block:waitForTLSCert +[Read TLS Config](../../../wait/tls_test.go) inside_block:waitTLSConfig +[Dockerfile with certificate](../../../wait/testdata/http/Dockerfile) + diff --git a/mkdocs.yml b/mkdocs.yml index e4b57c4b75..47044423dc 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -64,6 +64,7 @@ nav: - Log: features/wait/log.md - Multi: features/wait/multi.md - SQL: features/wait/sql.md + - TLS: features/wait/tls.md - Walk: features/wait/walk.md - Modules: - modules/index.md diff --git a/wait/file_test.go b/wait/file_test.go index 22133ba349..20bcc13a01 100644 --- a/wait/file_test.go +++ b/wait/file_test.go @@ -20,7 +20,7 @@ import ( const testFilename = "/tmp/file" -var anyContext = mock.AnythingOfType("*context.timerCtx") +var anyContext = mock.MatchedBy(func(_ context.Context) bool { return true }) // newRunningTarget creates a new mockStrategyTarget that is running. func newRunningTarget() *mockStrategyTarget { diff --git a/wait/http_test.go b/wait/http_test.go index 32479bddd4..73e32d44d7 100644 --- a/wait/http_test.go +++ b/wait/http_test.go @@ -5,6 +5,7 @@ import ( "context" "crypto/tls" "crypto/x509" + _ "embed" "fmt" "io" "log" @@ -23,6 +24,9 @@ import ( "github.com/testcontainers/testcontainers-go/wait" ) +//go:embed testdata/root.pem +var caBytes []byte + // https://github.com/testcontainers/testcontainers-go/issues/183 func ExampleHTTPStrategy() { // waitForHTTPWithDefaultPort { @@ -80,7 +84,7 @@ func ExampleHTTPStrategy_WithHeaders() { tlsconfig := &tls.Config{RootCAs: certpool, ServerName: "testcontainer.go.test"} req := testcontainers.ContainerRequest{ FromDockerfile: testcontainers.FromDockerfile{ - Context: "testdata", + Context: "testdata/http", }, ExposedPorts: []string{"6443/tcp"}, WaitingFor: wait.ForHTTP("/headers"). @@ -227,20 +231,13 @@ func ExampleHTTPStrategy_WithBasicAuth() { } func TestHTTPStrategyWaitUntilReady(t *testing.T) { - workdir, err := os.Getwd() - require.NoError(t, err) - - capath := filepath.Join(workdir, "testdata", "root.pem") - cafile, err := os.ReadFile(capath) - require.NoError(t, err) - certpool := x509.NewCertPool() - require.Truef(t, certpool.AppendCertsFromPEM(cafile), "the ca file isn't valid") + require.Truef(t, certpool.AppendCertsFromPEM(caBytes), "the ca file isn't valid") tlsconfig := &tls.Config{RootCAs: certpool, ServerName: "testcontainer.go.test"} dockerReq := testcontainers.ContainerRequest{ FromDockerfile: testcontainers.FromDockerfile{ - Context: filepath.Join(workdir, "testdata"), + Context: "testdata/http", }, ExposedPorts: []string{"6443/tcp"}, WaitingFor: wait.NewHTTPStrategy("/auth-ping").WithTLS(true, tlsconfig). @@ -288,20 +285,13 @@ func TestHTTPStrategyWaitUntilReady(t *testing.T) { } func TestHTTPStrategyWaitUntilReadyWithQueryString(t *testing.T) { - workdir, err := os.Getwd() - require.NoError(t, err) - - capath := filepath.Join(workdir, "testdata", "root.pem") - cafile, err := os.ReadFile(capath) - require.NoError(t, err) - certpool := x509.NewCertPool() - require.Truef(t, certpool.AppendCertsFromPEM(cafile), "the ca file isn't valid") + require.Truef(t, certpool.AppendCertsFromPEM(caBytes), "the ca file isn't valid") tlsconfig := &tls.Config{RootCAs: certpool, ServerName: "testcontainer.go.test"} dockerReq := testcontainers.ContainerRequest{ FromDockerfile: testcontainers.FromDockerfile{ - Context: filepath.Join(workdir, "testdata"), + Context: "testdata/http", }, ExposedPorts: []string{"6443/tcp"}, @@ -348,22 +338,15 @@ func TestHTTPStrategyWaitUntilReadyWithQueryString(t *testing.T) { } func TestHTTPStrategyWaitUntilReadyNoBasicAuth(t *testing.T) { - workdir, err := os.Getwd() - require.NoError(t, err) - - capath := filepath.Join(workdir, "testdata", "root.pem") - cafile, err := os.ReadFile(capath) - require.NoError(t, err) - certpool := x509.NewCertPool() - require.Truef(t, certpool.AppendCertsFromPEM(cafile), "the ca file isn't valid") + require.Truef(t, certpool.AppendCertsFromPEM(caBytes), "the ca file isn't valid") // waitForHTTPStatusCode { tlsconfig := &tls.Config{RootCAs: certpool, ServerName: "testcontainer.go.test"} var i int dockerReq := testcontainers.ContainerRequest{ FromDockerfile: testcontainers.FromDockerfile{ - Context: filepath.Join(workdir, "testdata"), + Context: "testdata/http", }, ExposedPorts: []string{"6443/tcp"}, WaitingFor: wait.NewHTTPStrategy("/ping").WithTLS(true, tlsconfig). diff --git a/wait/testdata/tls.pem b/wait/testdata/cert.crt similarity index 100% rename from wait/testdata/tls.pem rename to wait/testdata/cert.crt diff --git a/wait/testdata/tls-key.pem b/wait/testdata/cert.key similarity index 100% rename from wait/testdata/tls-key.pem rename to wait/testdata/cert.key diff --git a/wait/testdata/Dockerfile b/wait/testdata/http/Dockerfile similarity index 100% rename from wait/testdata/Dockerfile rename to wait/testdata/http/Dockerfile diff --git a/wait/testdata/go.mod b/wait/testdata/http/go.mod similarity index 100% rename from wait/testdata/go.mod rename to wait/testdata/http/go.mod diff --git a/wait/testdata/main.go b/wait/testdata/http/main.go similarity index 100% rename from wait/testdata/main.go rename to wait/testdata/http/main.go diff --git a/wait/testdata/http/tls-key.pem b/wait/testdata/http/tls-key.pem new file mode 100644 index 0000000000..00789d2371 --- /dev/null +++ b/wait/testdata/http/tls-key.pem @@ -0,0 +1,5 @@ +-----BEGIN EC PRIVATE KEY----- +MHcCAQEEIM8HuDwcZyVqBBy2C6db6zNb/dAJ69bq5ejAEz7qGOIQoAoGCCqGSM49 +AwEHoUQDQgAEBL2ioRmfTc70WT0vyx+amSQOGbMeoMRAfF2qaPzpzOqpKTk0aLOG +0735iy9Fz16PX4vqnLMiM/ZupugAhB//yA== +-----END EC PRIVATE KEY----- diff --git a/wait/testdata/http/tls.pem b/wait/testdata/http/tls.pem new file mode 100644 index 0000000000..46348b7900 --- /dev/null +++ b/wait/testdata/http/tls.pem @@ -0,0 +1,12 @@ +-----BEGIN CERTIFICATE----- +MIIBxTCCAWugAwIBAgIUWBLNpiF1o4r+5ZXwawzPOfBM1F8wCgYIKoZIzj0EAwIw +ADAeFw0yMDA4MTkxMzM4MDBaFw0zMDA4MTcxMzM4MDBaMBkxFzAVBgNVBAMTDnRl +c3Rjb250YWluZXJzMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEBL2ioRmfTc70 +WT0vyx+amSQOGbMeoMRAfF2qaPzpzOqpKTk0aLOG0735iy9Fz16PX4vqnLMiM/Zu +pugAhB//yKOBqTCBpjAOBgNVHQ8BAf8EBAMCBaAwEwYDVR0lBAwwCgYIKwYBBQUH +AwEwDAYDVR0TAQH/BAIwADAdBgNVHQ4EFgQUTMdz5PIZ+Gix4jYUzRIHfByrW+Yw +HwYDVR0jBBgwFoAUFdfV6PSYUlHs+lSQNouRwSfR2ZgwMQYDVR0RBCowKIIVdGVz +dGNvbnRhaW5lci5nby50ZXN0gglsb2NhbGhvc3SHBH8AAAEwCgYIKoZIzj0EAwID +SAAwRQIhAJznPNumi2Plf0GsP9DpC+8WukT/jUhnhcDWCfZ6Ini2AiBLhnhFebZX +XWfSsdSNxIo20OWvy6z3wqdybZtRUfdU+g== +-----END CERTIFICATE----- diff --git a/wait/tls.go b/wait/tls.go new file mode 100644 index 0000000000..ab904b271e --- /dev/null +++ b/wait/tls.go @@ -0,0 +1,167 @@ +package wait + +import ( + "context" + "crypto/tls" + "crypto/x509" + "fmt" + "io" + "time" +) + +// Validate we implement interface. +var _ Strategy = (*TLSStrategy)(nil) + +// TLSStrategy is a strategy for handling TLS. +type TLSStrategy struct { + // General Settings. + timeout *time.Duration + pollInterval time.Duration + + // Custom Settings. + certFiles *x509KeyPair + rootFiles []string + + // State. + tlsConfig *tls.Config +} + +// x509KeyPair is a pair of certificate and key files. +type x509KeyPair struct { + certPEMFile string + keyPEMFile string +} + +// ForTLSCert returns a CertStrategy that will add a Certificate to the [tls.Config] +// constructed from PEM formatted certificate key file pair in the container. +func ForTLSCert(certPEMFile, keyPEMFile string) *TLSStrategy { + return &TLSStrategy{ + certFiles: &x509KeyPair{ + certPEMFile: certPEMFile, + keyPEMFile: keyPEMFile, + }, + tlsConfig: &tls.Config{}, + pollInterval: defaultPollInterval(), + } +} + +// ForTLSRootCAs returns a CertStrategy that sets the root CAs for the [tls.Config] +// using the given PEM formatted files from the container. +func ForTLSRootCAs(pemFiles ...string) *TLSStrategy { + return &TLSStrategy{ + rootFiles: pemFiles, + tlsConfig: &tls.Config{}, + pollInterval: defaultPollInterval(), + } +} + +// WithRootCAs sets the root CAs for the [tls.Config] using the given files from +// the container. +func (ws *TLSStrategy) WithRootCAs(files ...string) *TLSStrategy { + ws.rootFiles = files + return ws +} + +// WithCert sets the [tls.Config] Certificates using the given files from the container. +func (ws *TLSStrategy) WithCert(certPEMFile, keyPEMFile string) *TLSStrategy { + ws.certFiles = &x509KeyPair{ + certPEMFile: certPEMFile, + keyPEMFile: keyPEMFile, + } + return ws +} + +// WithServerName sets the server for the [tls.Config]. +func (ws *TLSStrategy) WithServerName(serverName string) *TLSStrategy { + ws.tlsConfig.ServerName = serverName + return ws +} + +// WithStartupTimeout can be used to change the default startup timeout. +func (ws *TLSStrategy) WithStartupTimeout(startupTimeout time.Duration) *TLSStrategy { + ws.timeout = &startupTimeout + return ws +} + +// WithPollInterval can be used to override the default polling interval of 100 milliseconds. +func (ws *TLSStrategy) WithPollInterval(pollInterval time.Duration) *TLSStrategy { + ws.pollInterval = pollInterval + return ws +} + +// TLSConfig returns the TLS config once the strategy is ready. +// If the strategy is nil, it returns nil. +func (ws *TLSStrategy) TLSConfig() *tls.Config { + if ws == nil { + return nil + } + + return ws.tlsConfig +} + +// WaitUntilReady implements the [Strategy] interface. +// It waits for the CA, client cert and key files to be available in the container and +// uses them to setup the TLS config. +func (ws *TLSStrategy) WaitUntilReady(ctx context.Context, target StrategyTarget) error { + size := len(ws.rootFiles) + if ws.certFiles != nil { + size += 2 + } + strategies := make([]Strategy, 0, size) + for _, file := range ws.rootFiles { + strategies = append(strategies, + ForFile(file).WithMatcher(func(r io.Reader) error { + buf, err := io.ReadAll(r) + if err != nil { + return fmt.Errorf("read CA cert file %q: %w", file, err) + } + + if ws.tlsConfig.RootCAs == nil { + ws.tlsConfig.RootCAs = x509.NewCertPool() + } + + if !ws.tlsConfig.RootCAs.AppendCertsFromPEM(buf) { + return fmt.Errorf("invalid CA cert file %q", file) + } + + return nil + }).WithPollInterval(ws.pollInterval), + ) + } + + if ws.certFiles != nil { + var certPEMBlock []byte + strategies = append(strategies, + ForFile(ws.certFiles.certPEMFile).WithMatcher(func(r io.Reader) error { + var err error + if certPEMBlock, err = io.ReadAll(r); err != nil { + return fmt.Errorf("read certificate cert %q: %w", ws.certFiles.certPEMFile, err) + } + + return nil + }).WithPollInterval(ws.pollInterval), + ForFile(ws.certFiles.keyPEMFile).WithMatcher(func(r io.Reader) error { + keyPEMBlock, err := io.ReadAll(r) + if err != nil { + return fmt.Errorf("read certificate key %q: %w", ws.certFiles.keyPEMFile, err) + } + + cert, err := tls.X509KeyPair(certPEMBlock, keyPEMBlock) + if err != nil { + return fmt.Errorf("x509 key pair %q %q: %w", ws.certFiles.certPEMFile, ws.certFiles.keyPEMFile, err) + } + + ws.tlsConfig.Certificates = []tls.Certificate{cert} + + return nil + }).WithPollInterval(ws.pollInterval), + ) + } + + strategy := ForAll(strategies...) + if ws.timeout != nil { + strategy.WithStartupTimeout(*ws.timeout) + } + + return strategy.WaitUntilReady(ctx, target) +} diff --git a/wait/tls_test.go b/wait/tls_test.go new file mode 100644 index 0000000000..babc17b3d0 --- /dev/null +++ b/wait/tls_test.go @@ -0,0 +1,150 @@ +package wait_test + +import ( + "bytes" + "context" + "crypto/tls" + "crypto/x509" + _ "embed" + "errors" + "fmt" + "io" + "log" + "testing" + "time" + + "github.com/docker/docker/errdefs" + "github.com/stretchr/testify/require" + + "github.com/testcontainers/testcontainers-go" + "github.com/testcontainers/testcontainers-go/wait" +) + +const ( + serverName = "127.0.0.1" + caFilename = "/tmp/ca.pem" + clientCertFilename = "/tmp/cert.crt" + clientKeyFilename = "/tmp/cert.key" +) + +var ( + //go:embed testdata/cert.crt + certBytes []byte + + //go:embed testdata/cert.key + keyBytes []byte +) + +// testForTLSCert creates a new CertStrategy for testing. +func testForTLSCert() *wait.TLSStrategy { + return wait.ForTLSCert(clientCertFilename, clientKeyFilename). + WithRootCAs(caFilename). + WithServerName(serverName). + WithStartupTimeout(time.Millisecond * 50). + WithPollInterval(time.Millisecond) +} + +func TestForCert(t *testing.T) { + errNotFound := errdefs.NotFound(errors.New("file not found")) + ctx := context.Background() + + t.Run("ca-not-found", func(t *testing.T) { + target := newRunningTarget() + target.EXPECT().CopyFileFromContainer(anyContext, caFilename).Return(nil, errNotFound) + err := testForTLSCert().WaitUntilReady(ctx, target) + require.EqualError(t, err, context.DeadlineExceeded.Error()) + }) + + t.Run("cert-not-found", func(t *testing.T) { + target := newRunningTarget() + caFile := io.NopCloser(bytes.NewBuffer(caBytes)) + target.EXPECT().CopyFileFromContainer(anyContext, caFilename).Return(caFile, nil) + target.EXPECT().CopyFileFromContainer(anyContext, clientCertFilename).Return(nil, errNotFound) + err := testForTLSCert().WaitUntilReady(ctx, target) + require.EqualError(t, err, context.DeadlineExceeded.Error()) + }) + + t.Run("key-not-found", func(t *testing.T) { + target := newRunningTarget() + caFile := io.NopCloser(bytes.NewBuffer(caBytes)) + certFile := io.NopCloser(bytes.NewBuffer(certBytes)) + target.EXPECT().CopyFileFromContainer(anyContext, caFilename).Return(caFile, nil) + target.EXPECT().CopyFileFromContainer(anyContext, clientCertFilename).Return(certFile, nil) + target.EXPECT().CopyFileFromContainer(anyContext, clientKeyFilename).Return(nil, errNotFound) + err := testForTLSCert().WaitUntilReady(ctx, target) + require.EqualError(t, err, context.DeadlineExceeded.Error()) + }) + + t.Run("valid", func(t *testing.T) { + target := newRunningTarget() + caFile := io.NopCloser(bytes.NewBuffer(caBytes)) + certFile := io.NopCloser(bytes.NewBuffer(certBytes)) + keyFile := io.NopCloser(bytes.NewBuffer(keyBytes)) + target.EXPECT().CopyFileFromContainer(anyContext, caFilename).Return(caFile, nil) + target.EXPECT().CopyFileFromContainer(anyContext, clientCertFilename).Return(certFile, nil) + target.EXPECT().CopyFileFromContainer(anyContext, clientKeyFilename).Return(keyFile, nil) + + certStrategy := testForTLSCert() + err := certStrategy.WaitUntilReady(ctx, target) + require.NoError(t, err) + + pool := x509.NewCertPool() + require.True(t, pool.AppendCertsFromPEM(caBytes)) + cert, err := tls.X509KeyPair(certBytes, keyBytes) + require.NoError(t, err) + got := certStrategy.TLSConfig() + require.Equal(t, serverName, got.ServerName) + require.Equal(t, []tls.Certificate{cert}, got.Certificates) + require.True(t, pool.Equal(got.RootCAs)) + }) +} + +func ExampleForTLSCert() { + ctx := context.Background() + + // waitForTLSCert { + // The file names passed to ForTLSCert are the paths where the files will + // be copied to in the container as detailed by the Dockerfile. + forCert := wait.ForTLSCert("/app/tls.pem", "/app/tls-key.pem"). + WithServerName("testcontainer.go.test") + req := testcontainers.ContainerRequest{ + FromDockerfile: testcontainers.FromDockerfile{ + Context: "testdata/http", + }, + WaitingFor: forCert, + } + // } + + c, err := testcontainers.GenericContainer(ctx, testcontainers.GenericContainerRequest{ + ContainerRequest: req, + Started: true, + }) + defer func() { + if err := testcontainers.TerminateContainer(c); err != nil { + log.Printf("failed to terminate container: %s", err) + } + }() + if err != nil { + log.Printf("failed to start container: %s", err) + return + } + + state, err := c.State(ctx) + if err != nil { + log.Printf("failed to get container state: %s", err) + return + } + + fmt.Println(state.Running) + + // waitTLSConfig { + config := forCert.TLSConfig() + // } + fmt.Println(config.ServerName) + fmt.Println(len(config.Certificates)) + + // Output: + // true + // testcontainer.go.test + // 1 +}