diff --git a/wait/http.go b/wait/http.go index e778eb4296..ff1031b041 100644 --- a/wait/http.go +++ b/wait/http.go @@ -5,6 +5,7 @@ import ( "crypto/tls" "errors" "fmt" + "io" "net" "net/http" "strconv" @@ -24,8 +25,12 @@ type HTTPStrategy struct { Port nat.Port Path string StatusCodeMatcher func(status int) bool + ResponseMatcher func(body io.Reader) bool UseTLS bool AllowInsecure bool + TLSConfig *tls.Config // TLS config for HTTPS + Method string // http method + Body io.Reader // http request body } // NewHTTPStrategy constructs a HTTP strategy waiting on port 80 and status code 200 @@ -35,9 +40,12 @@ func NewHTTPStrategy(path string) *HTTPStrategy { Port: "80/tcp", Path: path, StatusCodeMatcher: defaultStatusCodeMatcher, + ResponseMatcher: func(body io.Reader) bool { return true }, UseTLS: false, + TLSConfig: nil, + Method: http.MethodGet, + Body: nil, } - } func defaultStatusCodeMatcher(status int) bool { @@ -63,8 +71,16 @@ func (ws *HTTPStrategy) WithStatusCodeMatcher(statusCodeMatcher func(status int) return ws } -func (ws *HTTPStrategy) WithTLS(useTLS bool) *HTTPStrategy { +func (ws *HTTPStrategy) WithResponseMatcher(matcher func(body io.Reader) bool) *HTTPStrategy { + ws.ResponseMatcher = matcher + return ws +} + +func (ws *HTTPStrategy) WithTLS(useTLS bool, tlsconf ...*tls.Config) *HTTPStrategy { ws.UseTLS = useTLS + if useTLS && len(tlsconf) > 0 { + ws.TLSConfig = tlsconf[0] + } return ws } @@ -73,6 +89,16 @@ func (ws *HTTPStrategy) WithAllowInsecure(allowInsecure bool) *HTTPStrategy { return ws } +func (ws *HTTPStrategy) WithMethod(method string) *HTTPStrategy { + ws.Method = method + return ws +} + +func (ws *HTTPStrategy) WithBody(reqdata io.Reader) *HTTPStrategy { + ws.Body = reqdata + return ws +} + // ForHTTP is a convenience method similar to Wait.java // https://github.com/testcontainers/testcontainers-java/blob/1d85a3834bd937f80aad3a4cec249c027f31aeb4/core/src/main/java/org/testcontainers/containers/wait/strategy/Wait.java func ForHTTP(path string) *HTTPStrategy { @@ -99,49 +125,73 @@ func (ws *HTTPStrategy) WaitUntilReady(ctx context.Context, target StrategyTarge return errors.New("Cannot use HTTP client on non-TCP ports") } - portNumber := port.Int() - portString := strconv.Itoa(portNumber) + switch ws.Method { + case http.MethodGet, http.MethodHead, http.MethodPost, + http.MethodPut, http.MethodPatch, http.MethodDelete, + http.MethodConnect, http.MethodOptions, http.MethodTrace: + default: + if ws.Method != "" { + return fmt.Errorf("invalid http method %q", ws.Method) + } + ws.Method = http.MethodGet + } - address := net.JoinHostPort(ipAddress, portString) + tripper := &http.Transport{ + Proxy: http.ProxyFromEnvironment, + DialContext: (&net.Dialer{ + Timeout: time.Second, + KeepAlive: 30 * time.Second, + DualStack: true, + }).DialContext, + ForceAttemptHTTP2: true, + MaxIdleConns: 100, + IdleConnTimeout: 90 * time.Second, + TLSHandshakeTimeout: 10 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + TLSClientConfig: ws.TLSConfig, + } var proto string if ws.UseTLS { proto = "https" + if ws.AllowInsecure { + if ws.TLSConfig == nil { + tripper.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} + } else { + ws.TLSConfig.InsecureSkipVerify = true + } + } } else { proto = "http" } - url := fmt.Sprintf("%s://%s%s", proto, address, ws.Path) + client := http.Client{Transport: tripper, Timeout: time.Second} + address := net.JoinHostPort(ipAddress, strconv.Itoa(port.Int())) + endpoint := fmt.Sprintf("%s://%s%s", proto, address, ws.Path) - tripper := http.DefaultTransport - - if ws.AllowInsecure { - tripper.(*http.Transport).TLSClientConfig = &tls.Config{InsecureSkipVerify: true} - } - - client := http.Client{Timeout: ws.startupTimeout, Transport: tripper} - req, err := http.NewRequest("GET", url, nil) - if err != nil { - return err - } - - req = req.WithContext(ctx) - -Retry: for { select { case <-ctx.Done(): - break Retry - default: + return ctx.Err() + case <-time.After(time.Second / 10): + req, err := http.NewRequestWithContext(ctx, ws.Method, endpoint, ws.Body) + if err != nil { + return err + } resp, err := client.Do(req) - if err != nil || !ws.StatusCodeMatcher(resp.StatusCode) { - time.Sleep(100 * time.Millisecond) + if err != nil { continue } - - break Retry + if ws.StatusCodeMatcher != nil && !ws.StatusCodeMatcher(resp.StatusCode) { + continue + } + if ws.ResponseMatcher != nil && !ws.ResponseMatcher(resp.Body) { + continue + } + if err := resp.Body.Close(); err != nil { + continue + } + return nil } } - - return nil } diff --git a/wait/http_test.go b/wait/http_test.go index 05a734eca7..a3a34431e1 100644 --- a/wait/http_test.go +++ b/wait/http_test.go @@ -1,7 +1,18 @@ package wait_test import ( + "bytes" "context" + "crypto/tls" + "crypto/x509" + "fmt" + "io" + "io/ioutil" + "net" + "net/http" + "os" + "testing" + "time" "github.com/testcontainers/testcontainers-go" "github.com/testcontainers/testcontainers-go/wait" @@ -25,7 +36,88 @@ func ExampleHTTPStrategy() { panic(err) } - defer gogs.Terminate(ctx) - + defer gogs.Terminate(ctx) // nolint: errcheck // Here you have a running container + +} + +func TestHTTPStrategyWaitUntilReady(t *testing.T) { + workdir, err := os.Getwd() + if err != nil { + t.Error(err) + return + } + + capath := workdir + "/testdata/root.pem" + cafile, err := ioutil.ReadFile(capath) + if err != nil { + t.Errorf("can't load ca file: %v", err) + return + } + + certpool := x509.NewCertPool() + if !certpool.AppendCertsFromPEM(cafile) { + t.Errorf("the ca file isn't valid") + return + } + + tlsconfig := &tls.Config{RootCAs: certpool, ServerName: "testcontainer.go.test"} + dockerReq := testcontainers.ContainerRequest{ + FromDockerfile: testcontainers.FromDockerfile{ + Context: workdir + "/testdata", + }, + ExposedPorts: []string{"6443/tcp"}, + WaitingFor: wait.NewHTTPStrategy("/ping").WithTLS(true, tlsconfig). + WithStartupTimeout(time.Second * 10).WithPort("6443/tcp"). + WithResponseMatcher(func(body io.Reader) bool { + data, _ := ioutil.ReadAll(body) + return bytes.Equal(data, []byte("pong")) + }). + WithMethod(http.MethodPost).WithBody(bytes.NewReader([]byte("ping"))), + } + + container, err := testcontainers.GenericContainer(context.Background(), + testcontainers.GenericContainerRequest{ContainerRequest: dockerReq, Started: true}) + if err != nil { + t.Error(err) + return + } + defer container.Terminate(context.Background()) // nolint: errcheck + + host, err := container.Host(context.Background()) + if err != nil { + t.Error(err) + return + } + port, err := container.MappedPort(context.Background(), "6443/tcp") + if err != nil { + t.Error(err) + return + } + client := http.Client{ + Transport: &http.Transport{ + TLSClientConfig: tlsconfig, + Proxy: http.ProxyFromEnvironment, + DialContext: (&net.Dialer{ + Timeout: time.Second, + KeepAlive: 30 * time.Second, + DualStack: true, + }).DialContext, + ForceAttemptHTTP2: true, + MaxIdleConns: 100, + IdleConnTimeout: 90 * time.Second, + TLSHandshakeTimeout: 10 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + }, + } + resp, err := client.Get(fmt.Sprintf("https://%s:%s", host, port.Port())) + if err != nil { + t.Error(err) + return + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + t.Errorf("status code isn't ok: %s", resp.Status) + return + } } diff --git a/wait/testdata/Dockerfile b/wait/testdata/Dockerfile new file mode 100644 index 0000000000..d89165a2e0 --- /dev/null +++ b/wait/testdata/Dockerfile @@ -0,0 +1,12 @@ +FROM golang:1.15-alpine as builder +WORKDIR /app +COPY . . +RUN mkdir -p dist +RUN go build -o ./dist/server main.go + +FROM alpine +WORKDIR /app +COPY --from=builder /app/tls.pem /app/tls-key.pem ./ +COPY --from=builder /app/dist/server . +EXPOSE 6443 +CMD ["/app/server"] diff --git a/wait/testdata/go.mod b/wait/testdata/go.mod new file mode 100644 index 0000000000..9557d2288d --- /dev/null +++ b/wait/testdata/go.mod @@ -0,0 +1,3 @@ +module httptest + +go 1.15 diff --git a/wait/testdata/main.go b/wait/testdata/main.go new file mode 100644 index 0000000000..a8bdfc2ee4 --- /dev/null +++ b/wait/testdata/main.go @@ -0,0 +1,47 @@ +package main + +import ( + "bytes" + "context" + "io/ioutil" + "log" + "net/http" + "os" + "os/signal" + "syscall" + "time" +) + +func main() { + mux := http.NewServeMux() + mux.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) { + w.WriteHeader(http.StatusOK) + }) + + mux.HandleFunc("/ping", func(w http.ResponseWriter, req *http.Request) { + data, _ := ioutil.ReadAll(req.Body) + if bytes.Equal(data, []byte("ping")) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("pong")) + } else { + w.WriteHeader(http.StatusBadRequest) + } + }) + + server := http.Server{Addr: ":6443", Handler: mux} + go func() { + log.Println("serving...") + if err := server.ListenAndServeTLS("tls.pem", "tls-key.pem"); err != nil && err != http.ErrServerClosed { + log.Fatal(err) + } + }() + + stopsig := make(chan os.Signal, 1) + signal.Notify(stopsig, syscall.SIGINT, syscall.SIGTERM) + <-stopsig + + log.Println("stopping...") + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + _ = server.Shutdown(ctx) +} diff --git a/wait/testdata/root.pem b/wait/testdata/root.pem new file mode 100644 index 0000000000..c31286daaf --- /dev/null +++ b/wait/testdata/root.pem @@ -0,0 +1,10 @@ +-----BEGIN CERTIFICATE----- +MIIBVTCB/aADAgECAghLWuRKnTb4BjAKBggqhkjOPQQDAjAAMB4XDTIwMDgxOTEz +MzUwOFoXDTMwMDgxNzEzNDAwOFowADBZMBMGByqGSM49AgEGCCqGSM49AwEHA0IA +BP39G8oZK7JvdcJzSuEzoqe3KsWS7/4C7UKhdoGHkEuHED+I456v3O8x0gUTjqIv +I9FmW3cq/eMoraPzzk3u7vajYTBfMA4GA1UdDwEB/wQEAwIBpjAdBgNVHSUEFjAU +BggrBgEFBQcDAQYIKwYBBQUHAwIwDwYDVR0TAQH/BAUwAwEB/zAdBgNVHQ4EFgQU +FdfV6PSYUlHs+lSQNouRwSfR2ZgwCgYIKoZIzj0EAwIDRwAwRAIgDAFtSEaFGuvP +wJZhQv7zjIhCGzYzsZ8KSKUJ3YvdL/4CIBbgDFzEeQWFWUMFPeMaVVrmBmsflPIg +cnC4yG76skGg +-----END CERTIFICATE----- diff --git a/wait/testdata/tls-key.pem b/wait/testdata/tls-key.pem new file mode 100644 index 0000000000..00789d2371 --- /dev/null +++ b/wait/testdata/tls-key.pem @@ -0,0 +1,5 @@ +-----BEGIN EC PRIVATE KEY----- +MHcCAQEEIM8HuDwcZyVqBBy2C6db6zNb/dAJ69bq5ejAEz7qGOIQoAoGCCqGSM49 +AwEHoUQDQgAEBL2ioRmfTc70WT0vyx+amSQOGbMeoMRAfF2qaPzpzOqpKTk0aLOG +0735iy9Fz16PX4vqnLMiM/ZupugAhB//yA== +-----END EC PRIVATE KEY----- diff --git a/wait/testdata/tls.pem b/wait/testdata/tls.pem new file mode 100644 index 0000000000..46348b7900 --- /dev/null +++ b/wait/testdata/tls.pem @@ -0,0 +1,12 @@ +-----BEGIN CERTIFICATE----- +MIIBxTCCAWugAwIBAgIUWBLNpiF1o4r+5ZXwawzPOfBM1F8wCgYIKoZIzj0EAwIw +ADAeFw0yMDA4MTkxMzM4MDBaFw0zMDA4MTcxMzM4MDBaMBkxFzAVBgNVBAMTDnRl +c3Rjb250YWluZXJzMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEBL2ioRmfTc70 +WT0vyx+amSQOGbMeoMRAfF2qaPzpzOqpKTk0aLOG0735iy9Fz16PX4vqnLMiM/Zu +pugAhB//yKOBqTCBpjAOBgNVHQ8BAf8EBAMCBaAwEwYDVR0lBAwwCgYIKwYBBQUH +AwEwDAYDVR0TAQH/BAIwADAdBgNVHQ4EFgQUTMdz5PIZ+Gix4jYUzRIHfByrW+Yw +HwYDVR0jBBgwFoAUFdfV6PSYUlHs+lSQNouRwSfR2ZgwMQYDVR0RBCowKIIVdGVz +dGNvbnRhaW5lci5nby50ZXN0gglsb2NhbGhvc3SHBH8AAAEwCgYIKoZIzj0EAwID +SAAwRQIhAJznPNumi2Plf0GsP9DpC+8WukT/jUhnhcDWCfZ6Ini2AiBLhnhFebZX +XWfSsdSNxIo20OWvy6z3wqdybZtRUfdU+g== +-----END CERTIFICATE-----