Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BED-4273 azurehound backoff retry #74

Merged
merged 18 commits into from
Apr 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 7 additions & 19 deletions client/rest/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ import (
"encoding/json"
"fmt"
"io"
"math"
"net/http"
"net/url"
"strconv"
Expand Down Expand Up @@ -219,23 +218,9 @@ func (s *restClient) Send(req *http.Request) (*http.Response, error) {
return s.send(req)
}

func copyBody(req *http.Request) ([]byte, error) {
var (
body []byte
err error
)
if req.Body != nil {
body, err = io.ReadAll(req.Body)
if body != nil {
req.Body = io.NopCloser(bytes.NewBuffer(body))
}
}
return body, err
}

func (s *restClient) send(req *http.Request) (*http.Response, error) {
// copy the bytes in case we need to retry the request
if body, err := copyBody(req); err != nil {
if body, err := CopyBody(req); err != nil {
return nil, err
} else {
var (
Expand All @@ -254,7 +239,11 @@ func (s *restClient) send(req *http.Request) (*http.Response, error) {

// Try the request
if res, err = s.http.Do(req); err != nil {
// client error
if IsClosedConnectionErr(err) {
fmt.Printf("remote host force closed connection while requesting %s; attempt %d/%d; trying again\n", req.URL, retry+1, maxRetries)
ExponentialBackoff(retry)
continue
}
return nil, err
} else if res.StatusCode < http.StatusOK || res.StatusCode >= http.StatusBadRequest {
// Error response code handling
Expand All @@ -270,8 +259,7 @@ func (s *restClient) send(req *http.Request) (*http.Response, error) {
}
} else if res.StatusCode >= http.StatusInternalServerError {
// Wait the time calculated by the 5 second exponential backoff
backoff := math.Pow(5, float64(retry+1))
time.Sleep(time.Second * time.Duration(backoff))
ExponentialBackoff(retry)
continue
} else {
// Not a status code that warrants a retry
Expand Down
69 changes: 69 additions & 0 deletions client/rest/client_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
// Copyright (C) 2024 Specter Ops, Inc.
//
// This file is part of AzureHound.
//
// AzureHound is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// AzureHound is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.

package rest

import (
"net/http"
"net/http/httptest"

"testing"

"github.com/bloodhoundad/azurehound/v2/client/config"
)

func TestClosedConnection(t *testing.T) {
var testServer *httptest.Server
attempt := 0
var mockHandler http.HandlerFunc = func(w http.ResponseWriter, r *http.Request) {
attempt++
testServer.CloseClientConnections()
}

testServer = httptest.NewServer(mockHandler)
defer testServer.Close()

defaultConfig := config.Config{
Username: "azurehound",
Password: "we_collect",
Authority: testServer.URL,
}

if client, err := NewRestClient(testServer.URL, defaultConfig); err != nil {
t.Fatalf("error initializing rest client %v", err)
} else {
requestCompleted := false

// make request in separate goroutine so its not blocking after we validated the retry
go func() {
client.Authenticate() // Authenticate() because it uses the internal client.send method.
// CloseClientConnections should block the request from completing, however if it completes then the test fails.
requestCompleted = true
}()

// block until attempt is > 2 or request succeeds
for attempt <= 2 {
if attempt > 1 || requestCompleted {
break
}
}

if requestCompleted {
t.Fatalf("expected an attempted retry but the request completed")
}
}
}
30 changes: 30 additions & 0 deletions client/rest/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,16 @@
package rest

import (
"bytes"
"crypto/sha1"
"crypto/x509"
"encoding/base64"
"encoding/json"
"encoding/pem"
"fmt"
"io"
"math"
"net/http"
"strings"
"time"

Expand Down Expand Up @@ -120,3 +123,30 @@ func x5t(certificate string) (string, error) {
return base64.StdEncoding.EncodeToString(checksum[:]), nil
}
}

func IsClosedConnectionErr(err error) bool {
var closedConnectionMsg = "An existing connection was forcibly closed by the remote host."
closedFromClient := strings.Contains(err.Error(), closedConnectionMsg)
// Mocking http.Do would require a larger refactor, so closedFromTestCase is used to cover testing only.
closedFromTestCase := strings.HasSuffix(err.Error(), ": EOF")
return closedFromClient || closedFromTestCase
}

func ExponentialBackoff(retry int) {
backoff := math.Pow(5, float64(retry+1))
time.Sleep(time.Second * time.Duration(backoff))
}

func CopyBody(req *http.Request) ([]byte, error) {
var (
body []byte
err error
)
if req.Body != nil {
body, err = io.ReadAll(req.Body)
if body != nil {
req.Body = io.NopCloser(bytes.NewBuffer(body))
}
}
return body, err
}
81 changes: 65 additions & 16 deletions cmd/start.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ import (
"errors"
"fmt"
"io"
"math"
"net/http"
"net/url"
"os"
Expand Down Expand Up @@ -215,16 +214,30 @@ func ingest(ctx context.Context, bheUrl url.URL, bheClient *http.Client, in <-ch
} else {
req.Header.Set("User-Agent", constants.UserAgent())
req.Header.Set("Accept", "application/json")
req.Header.Set("Prefer", "wait=60")
req.Header.Set("Content-Encoding", "gzip")
for retry := 0; retry < maxRetries; retry++ {
// No retries on regular err cases, only on HTTP 504 Gateway Timeout and HTTP 503 Service Unavailable
if response, err := bheClient.Do(req); err != nil {
if rest.IsClosedConnectionErr(err) {
// try again on force closed connection
log.Error(err, fmt.Sprintf("remote host force closed connection while requesting %s; attempt %d/%d; trying again", req.URL, retry+1, maxRetries))
rest.ExponentialBackoff(retry)

if retry == maxRetries-1 {
log.Error(ErrExceededRetryLimit, "")
hasErrors = true
}

continue
}
log.Error(err, unrecoverableErrMsg)
return true
} else if response.StatusCode == http.StatusGatewayTimeout || response.StatusCode == http.StatusServiceUnavailable {
backoff := math.Pow(5, float64(retry+1))
time.Sleep(time.Second * time.Duration(backoff))
} else if response.StatusCode == http.StatusGatewayTimeout || response.StatusCode == http.StatusServiceUnavailable || response.StatusCode == http.StatusBadGateway {
serverError := fmt.Errorf("received server error %d while requesting %v; attempt %d/%d; trying again", response.StatusCode, endpoint, retry+1, maxRetries)
log.Error(serverError, "")

rest.ExponentialBackoff(retry)

if retry == maxRetries-1 {
log.Error(ErrExceededRetryLimit, "")
hasErrors = true
Expand Down Expand Up @@ -256,19 +269,55 @@ func ingest(ctx context.Context, bheUrl url.URL, bheClient *http.Client, in <-ch

// TODO: create/use a proper bloodhound client
func do(bheClient *http.Client, req *http.Request) (*http.Response, error) {
if res, err := bheClient.Do(req); err != nil {
return nil, fmt.Errorf("failed to request %v: %w", req.URL, err)
} else if res.StatusCode < http.StatusOK || res.StatusCode >= http.StatusBadRequest {
var body json.RawMessage
defer res.Body.Close()
if err := json.NewDecoder(res.Body).Decode(&body); err != nil {
return nil, fmt.Errorf("received unexpected response code from %v: %s; failure reading response body", req.URL, res.Status)
} else {
return nil, fmt.Errorf("received unexpected response code from %v: %s %s", req.URL, res.Status, body)
}
var (
res *http.Response
maxRetries = 3
)

// copy the bytes in case we need to retry the request
if body, err := rest.CopyBody(req); err != nil {
return nil, err
} else {
return res, nil
for retry := 0; retry < maxRetries; retry++ {
// Reusing http.Request requires rewinding the request body
// back to a working state
if body != nil && retry > 0 {
req.Body = io.NopCloser(bytes.NewBuffer(body))
}

if res, err = bheClient.Do(req); err != nil {
if rest.IsClosedConnectionErr(err) {
// try again on force closed connections
log.Error(err, fmt.Sprintf("remote host force closed connection while requesting %s; attempt %d/%d; trying again", req.URL, retry+1, maxRetries))
rest.ExponentialBackoff(retry)
continue
}
// normal client error, dont attempt again
return nil, err
} else if res.StatusCode < http.StatusOK || res.StatusCode >= http.StatusBadRequest {
if res.StatusCode >= http.StatusInternalServerError {
// Internal server error, backoff and try again.
serverError := fmt.Errorf("received server error %d while requesting %v", res.StatusCode, req.URL)
log.Error(serverError, fmt.Sprintf("attempt %d/%d; trying again", retry+1, maxRetries))

rest.ExponentialBackoff(retry)
continue
}
// bad request we do not need to retry
var body json.RawMessage
defer res.Body.Close()
if err := json.NewDecoder(res.Body).Decode(&body); err != nil {
return nil, fmt.Errorf("received unexpected response code from %v: %s; failure reading response body", req.URL, res.Status)
} else {
return nil, fmt.Errorf("received unexpected response code from %v: %s %s", req.URL, res.Status, body)
}
} else {
return res, nil
}
}
}

return nil, fmt.Errorf("unable to complete request to url=%s; attempts=%d;", req.URL, maxRetries)
}

type basicResponse[T any] struct {
Expand Down
Loading