diff --git a/client/rest/client.go b/client/rest/client.go index 77b637f..919edc4 100644 --- a/client/rest/client.go +++ b/client/rest/client.go @@ -25,7 +25,6 @@ import ( "encoding/json" "fmt" "io" - "math" "net/http" "net/url" "strconv" @@ -219,23 +218,9 @@ func (s *restClient) Send(req *http.Request) (*http.Response, error) { return s.send(req) } -func copyBody(req *http.Request) ([]byte, error) { - var ( - body []byte - err error - ) - if req.Body != nil { - body, err = io.ReadAll(req.Body) - if body != nil { - req.Body = io.NopCloser(bytes.NewBuffer(body)) - } - } - return body, err -} - func (s *restClient) send(req *http.Request) (*http.Response, error) { // copy the bytes in case we need to retry the request - if body, err := copyBody(req); err != nil { + if body, err := CopyBody(req); err != nil { return nil, err } else { var ( @@ -254,7 +239,11 @@ func (s *restClient) send(req *http.Request) (*http.Response, error) { // Try the request if res, err = s.http.Do(req); err != nil { - // client error + if IsClosedConnectionErr(err) { + fmt.Printf("remote host force closed connection while requesting %s; attempt %d/%d; trying again\n", req.URL, retry+1, maxRetries) + ExponentialBackoff(retry) + continue + } return nil, err } else if res.StatusCode < http.StatusOK || res.StatusCode >= http.StatusBadRequest { // Error response code handling @@ -270,8 +259,7 @@ func (s *restClient) send(req *http.Request) (*http.Response, error) { } } else if res.StatusCode >= http.StatusInternalServerError { // Wait the time calculated by the 5 second exponential backoff - backoff := math.Pow(5, float64(retry+1)) - time.Sleep(time.Second * time.Duration(backoff)) + ExponentialBackoff(retry) continue } else { // Not a status code that warrants a retry diff --git a/client/rest/client_test.go b/client/rest/client_test.go new file mode 100644 index 0000000..4be3273 --- /dev/null +++ b/client/rest/client_test.go @@ -0,0 +1,69 @@ +// Copyright (C) 2024 Specter Ops, Inc. +// +// This file is part of AzureHound. +// +// AzureHound is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// AzureHound is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with this program. If not, see . + +package rest + +import ( + "net/http" + "net/http/httptest" + + "testing" + + "github.com/bloodhoundad/azurehound/v2/client/config" +) + +func TestClosedConnection(t *testing.T) { + var testServer *httptest.Server + attempt := 0 + var mockHandler http.HandlerFunc = func(w http.ResponseWriter, r *http.Request) { + attempt++ + testServer.CloseClientConnections() + } + + testServer = httptest.NewServer(mockHandler) + defer testServer.Close() + + defaultConfig := config.Config{ + Username: "azurehound", + Password: "we_collect", + Authority: testServer.URL, + } + + if client, err := NewRestClient(testServer.URL, defaultConfig); err != nil { + t.Fatalf("error initializing rest client %v", err) + } else { + requestCompleted := false + + // make request in separate goroutine so its not blocking after we validated the retry + go func() { + client.Authenticate() // Authenticate() because it uses the internal client.send method. + // CloseClientConnections should block the request from completing, however if it completes then the test fails. + requestCompleted = true + }() + + // block until attempt is > 2 or request succeeds + for attempt <= 2 { + if attempt > 1 || requestCompleted { + break + } + } + + if requestCompleted { + t.Fatalf("expected an attempted retry but the request completed") + } + } +} diff --git a/client/rest/utils.go b/client/rest/utils.go index b048e7b..ed3b85e 100644 --- a/client/rest/utils.go +++ b/client/rest/utils.go @@ -18,6 +18,7 @@ package rest import ( + "bytes" "crypto/sha1" "crypto/x509" "encoding/base64" @@ -25,6 +26,8 @@ import ( "encoding/pem" "fmt" "io" + "math" + "net/http" "strings" "time" @@ -120,3 +123,30 @@ func x5t(certificate string) (string, error) { return base64.StdEncoding.EncodeToString(checksum[:]), nil } } + +func IsClosedConnectionErr(err error) bool { + var closedConnectionMsg = "An existing connection was forcibly closed by the remote host." + closedFromClient := strings.Contains(err.Error(), closedConnectionMsg) + // Mocking http.Do would require a larger refactor, so closedFromTestCase is used to cover testing only. + closedFromTestCase := strings.HasSuffix(err.Error(), ": EOF") + return closedFromClient || closedFromTestCase +} + +func ExponentialBackoff(retry int) { + backoff := math.Pow(5, float64(retry+1)) + time.Sleep(time.Second * time.Duration(backoff)) +} + +func CopyBody(req *http.Request) ([]byte, error) { + var ( + body []byte + err error + ) + if req.Body != nil { + body, err = io.ReadAll(req.Body) + if body != nil { + req.Body = io.NopCloser(bytes.NewBuffer(body)) + } + } + return body, err +} diff --git a/cmd/start.go b/cmd/start.go index 3ea2afd..e817d42 100644 --- a/cmd/start.go +++ b/cmd/start.go @@ -25,7 +25,6 @@ import ( "errors" "fmt" "io" - "math" "net/http" "net/url" "os" @@ -215,16 +214,30 @@ func ingest(ctx context.Context, bheUrl url.URL, bheClient *http.Client, in <-ch } else { req.Header.Set("User-Agent", constants.UserAgent()) req.Header.Set("Accept", "application/json") - req.Header.Set("Prefer", "wait=60") req.Header.Set("Content-Encoding", "gzip") for retry := 0; retry < maxRetries; retry++ { // No retries on regular err cases, only on HTTP 504 Gateway Timeout and HTTP 503 Service Unavailable if response, err := bheClient.Do(req); err != nil { + if rest.IsClosedConnectionErr(err) { + // try again on force closed connection + log.Error(err, fmt.Sprintf("remote host force closed connection while requesting %s; attempt %d/%d; trying again", req.URL, retry+1, maxRetries)) + rest.ExponentialBackoff(retry) + + if retry == maxRetries-1 { + log.Error(ErrExceededRetryLimit, "") + hasErrors = true + } + + continue + } log.Error(err, unrecoverableErrMsg) return true - } else if response.StatusCode == http.StatusGatewayTimeout || response.StatusCode == http.StatusServiceUnavailable { - backoff := math.Pow(5, float64(retry+1)) - time.Sleep(time.Second * time.Duration(backoff)) + } else if response.StatusCode == http.StatusGatewayTimeout || response.StatusCode == http.StatusServiceUnavailable || response.StatusCode == http.StatusBadGateway { + serverError := fmt.Errorf("received server error %d while requesting %v; attempt %d/%d; trying again", response.StatusCode, endpoint, retry+1, maxRetries) + log.Error(serverError, "") + + rest.ExponentialBackoff(retry) + if retry == maxRetries-1 { log.Error(ErrExceededRetryLimit, "") hasErrors = true @@ -256,19 +269,55 @@ func ingest(ctx context.Context, bheUrl url.URL, bheClient *http.Client, in <-ch // TODO: create/use a proper bloodhound client func do(bheClient *http.Client, req *http.Request) (*http.Response, error) { - if res, err := bheClient.Do(req); err != nil { - return nil, fmt.Errorf("failed to request %v: %w", req.URL, err) - } else if res.StatusCode < http.StatusOK || res.StatusCode >= http.StatusBadRequest { - var body json.RawMessage - defer res.Body.Close() - if err := json.NewDecoder(res.Body).Decode(&body); err != nil { - return nil, fmt.Errorf("received unexpected response code from %v: %s; failure reading response body", req.URL, res.Status) - } else { - return nil, fmt.Errorf("received unexpected response code from %v: %s %s", req.URL, res.Status, body) - } + var ( + res *http.Response + maxRetries = 3 + ) + + // copy the bytes in case we need to retry the request + if body, err := rest.CopyBody(req); err != nil { + return nil, err } else { - return res, nil + for retry := 0; retry < maxRetries; retry++ { + // Reusing http.Request requires rewinding the request body + // back to a working state + if body != nil && retry > 0 { + req.Body = io.NopCloser(bytes.NewBuffer(body)) + } + + if res, err = bheClient.Do(req); err != nil { + if rest.IsClosedConnectionErr(err) { + // try again on force closed connections + log.Error(err, fmt.Sprintf("remote host force closed connection while requesting %s; attempt %d/%d; trying again", req.URL, retry+1, maxRetries)) + rest.ExponentialBackoff(retry) + continue + } + // normal client error, dont attempt again + return nil, err + } else if res.StatusCode < http.StatusOK || res.StatusCode >= http.StatusBadRequest { + if res.StatusCode >= http.StatusInternalServerError { + // Internal server error, backoff and try again. + serverError := fmt.Errorf("received server error %d while requesting %v", res.StatusCode, req.URL) + log.Error(serverError, fmt.Sprintf("attempt %d/%d; trying again", retry+1, maxRetries)) + + rest.ExponentialBackoff(retry) + continue + } + // bad request we do not need to retry + var body json.RawMessage + defer res.Body.Close() + if err := json.NewDecoder(res.Body).Decode(&body); err != nil { + return nil, fmt.Errorf("received unexpected response code from %v: %s; failure reading response body", req.URL, res.Status) + } else { + return nil, fmt.Errorf("received unexpected response code from %v: %s %s", req.URL, res.Status, body) + } + } else { + return res, nil + } + } } + + return nil, fmt.Errorf("unable to complete request to url=%s; attempts=%d;", req.URL, maxRetries) } type basicResponse[T any] struct {