From cef973ec164ac4800e7e7e6566dec5509ed638d2 Mon Sep 17 00:00:00 2001 From: David Wertenteil Date: Mon, 26 Feb 2024 09:22:51 +0200 Subject: [PATCH] Adding retries to httpPost --- .github/workflows/pr.yaml | 2 +- .github/workflows/release.yaml | 2 +- go.mod | 7 +- go.sum | 2 + httputils/httphelpers.go | 64 +++++++++-- httputils/httphelpers_test.go | 195 +++++++++++++++++++++++++++++++++ 6 files changed, 260 insertions(+), 12 deletions(-) diff --git a/.github/workflows/pr.yaml b/.github/workflows/pr.yaml index 9a41ea6..b920182 100644 --- a/.github/workflows/pr.yaml +++ b/.github/workflows/pr.yaml @@ -22,7 +22,7 @@ jobs: uses: actions/setup-go@v4 with: - go-version: "1.20" + go-version: "1.21" - name: Get dependencies run: go mod download diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml index b2d2660..284240e 100644 --- a/.github/workflows/release.yaml +++ b/.github/workflows/release.yaml @@ -20,7 +20,7 @@ jobs: - name: Set up Go uses: actions/setup-go@v4 with: - go-version: 1.19 + go-version: 1.21 - name: Test & coverage id: unit-test diff --git a/go.mod b/go.mod index ba3dcfb..3158704 100644 --- a/go.mod +++ b/go.mod @@ -1,8 +1,11 @@ module github.com/armosec/utils-go -go 1.19 +go 1.21 -require github.com/stretchr/testify v1.8.4 +require ( + github.com/cenkalti/backoff v2.2.1+incompatible + github.com/stretchr/testify v1.8.4 +) require ( github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect diff --git a/go.sum b/go.sum index 82b748f..2059d19 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,5 @@ +github.com/cenkalti/backoff v2.2.1+incompatible h1:tNowT99t7UNflLxfYYSlKYsBpXdEet03Pg2g16Swow4= +github.com/cenkalti/backoff v2.2.1+incompatible/go.mod h1:90ReRw6GdpyfrHakVjL/QHaoyV4aDUVVkXQJJJ3NXXM= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= diff --git a/httputils/httphelpers.go b/httputils/httphelpers.go index 818f120..f5dcc61 100644 --- a/httputils/httphelpers.go +++ b/httputils/httphelpers.go @@ -9,12 +9,17 @@ import ( "net/http" "strings" "sync" + "time" + + "github.com/cenkalti/backoff" ) type IHttpClient interface { Do(req *http.Request) (*http.Response, error) } +type ShouldNotRetryFunc func(resp *http.Response) bool + // JSONDecoder returns JSON decoder for given string func JSONDecoder(origin string) *json.Decoder { dec := json.NewDecoder(strings.NewReader(origin)) @@ -65,21 +70,64 @@ func HttpGetWithContext(ctx context.Context, httpClient IHttpClient, fullURL str } func HttpPost(httpClient IHttpClient, fullURL string, headers map[string]string, body []byte) (*http.Response, error) { - return HttpPostWithContext(context.Background(), httpClient, fullURL, headers, body) + return HttpPostWithContext(context.Background(), httpClient, fullURL, headers, body, -1, func(resp *http.Response) bool { + return true + }) } -func HttpPostWithContext(ctx context.Context, httpClient IHttpClient, fullURL string, headers map[string]string, body []byte) (*http.Response, error) { - req, err := http.NewRequestWithContext(ctx, "POST", fullURL, bytes.NewReader(body)) - if err != nil { - return nil, err +func HttpPostWithRetry(httpClient IHttpClient, fullURL string, headers map[string]string, body []byte, maxElapsedTime time.Duration) (*http.Response, error) { + return HttpPostWithContext(context.Background(), httpClient, fullURL, headers, body, maxElapsedTime, defaultShouldRetry) +} + +func HttpPostWithContext(ctx context.Context, httpClient IHttpClient, fullURL string, headers map[string]string, body []byte, maxElapsedTime time.Duration, shouldRetry func(resp *http.Response) bool) (*http.Response, error) { + var resp *http.Response + var err error + + operation := func() error { + req, err := http.NewRequestWithContext(ctx, "POST", fullURL, bytes.NewReader(body)) + if err != nil { + return backoff.Permanent(err) + } + setHeaders(req, headers) + + resp, err = httpClient.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + + // If the status code is not 200, we will retry + if resp.StatusCode != http.StatusOK { + if shouldRetry(resp) { + return fmt.Errorf("received status code: %d", resp.StatusCode) + } + return backoff.Permanent(err) + } + + return nil } - setHeaders(req, headers) - return httpClient.Do(req) + // Create a new exponential backoff policy + expBackOff := backoff.NewExponentialBackOff() + expBackOff.MaxElapsedTime = maxElapsedTime // Set the maximum elapsed time + + // Run the operation with the exponential backoff policy + if err = backoff.Retry(operation, expBackOff); err != nil { + return resp, err + } + + return resp, nil +} +func defaultShouldRetry(resp *http.Response) bool { + // If received codes 401/403/404/500 should return false + return resp.StatusCode != http.StatusUnauthorized && + resp.StatusCode != http.StatusForbidden && + resp.StatusCode != http.StatusNotFound && + resp.StatusCode != http.StatusInternalServerError } func setHeaders(req *http.Request, headers map[string]string) { - if headers != nil && len(headers) > 0 { + if len(headers) > 0 { for k, v := range headers { req.Header.Set(k, v) } diff --git a/httputils/httphelpers_test.go b/httputils/httphelpers_test.go index 7141b73..e703dda 100644 --- a/httputils/httphelpers_test.go +++ b/httputils/httphelpers_test.go @@ -1,11 +1,13 @@ package httputils import ( + "bytes" "fmt" "net/http" "reflect" "sync" "testing" + "time" "github.com/stretchr/testify/assert" ) @@ -296,3 +298,196 @@ func TestSplit2Chunks(t *testing.T) { } } + +type mockHttpClient struct { + doFunc func(req *http.Request) (*http.Response, error) +} + +func (m *mockHttpClient) Do(req *http.Request) (*http.Response, error) { + return m.doFunc(req) +} + +func TestHttpPostWithContext(t *testing.T) { + defaultMaxTime := 5 * time.Second + + t.Run("Successful request", func(t *testing.T) { + expectedURL := "http://example.com" + expectedBody := []byte("test body") + headers := map[string]string{ + "Content-Type": "application/json", + } + + expectedResponse := &http.Response{ + StatusCode: http.StatusOK, + Body: http.NoBody, + } + + httpClient := &mockHttpClient{ + doFunc: func(req *http.Request) (*http.Response, error) { + assert.Equal(t, expectedURL, req.URL.String()) + assert.Equal(t, "POST", req.Method) + assert.Equal(t, expectedBody, readRequestBody(req)) + + return expectedResponse, nil + }, + } + + resp, err := HttpPost(httpClient, expectedURL, headers, expectedBody) + + assert.NoError(t, err) + assert.Equal(t, expectedResponse.StatusCode, resp.StatusCode) + }) + + t.Run("Permanent error", func(t *testing.T) { + expectedURL := "http://example.com" + expectedBody := []byte("test body") + expectedHeaders := map[string]string{ + "Content-Type": "application/json", + } + expectedError := fmt.Errorf("permanent error") + + httpClient := &mockHttpClient{ + doFunc: func(req *http.Request) (*http.Response, error) { + return nil, expectedError + }, + } + + resp, err := HttpPost(httpClient, expectedURL, expectedHeaders, expectedBody) + + assert.Equal(t, expectedError, err) + assert.Nil(t, resp) + }) + + t.Run("Non-retryable error", func(t *testing.T) { + expectedURL := "http://example.com" + expectedBody := []byte("test body") + expectedHeaders := map[string]string{ + "Content-Type": "application/json", + } + expectedResponse := &http.Response{ + StatusCode: http.StatusInternalServerError, + Body: http.NoBody, + } + + httpClient := &mockHttpClient{ + doFunc: func(req *http.Request) (*http.Response, error) { + return expectedResponse, nil + }, + } + + resp, err := HttpPostWithRetry(httpClient, expectedURL, expectedHeaders, expectedBody, defaultMaxTime) + + assert.Equal(t, nil, err) + assert.Equal(t, expectedResponse.StatusCode, resp.StatusCode) + }) + + t.Run("Retryable error", func(t *testing.T) { + expectedURL := "http://example.com" + expectedBody := []byte("test body") + expectedHeaders := map[string]string{ + "Content-Type": "application/json", + } + expectedResponse := &http.Response{ + StatusCode: http.StatusBadGateway, + Body: http.NoBody, + } + expectedError := fmt.Errorf("received status code: %d", expectedResponse.StatusCode) + + httpClient := &mockHttpClient{ + doFunc: func(req *http.Request) (*http.Response, error) { + return expectedResponse, nil + }, + } + + resp, err := HttpPostWithRetry(httpClient, expectedURL, expectedHeaders, expectedBody, defaultMaxTime) + + assert.Equal(t, expectedError, err) + assert.Equal(t, expectedResponse.StatusCode, resp.StatusCode) + }) + + t.Run("Retryable error with successful retry", func(t *testing.T) { + expectedURL := "http://example.com" + expectedBody := []byte("test body") + expectedHeaders := map[string]string{ + "Content-Type": "application/json", + } + expectedResponse := &http.Response{ + StatusCode: http.StatusBadGateway, + Body: http.NoBody, + } + + retryCount := 0 + httpClient := &mockHttpClient{ + doFunc: func(req *http.Request) (*http.Response, error) { + retryCount++ + if retryCount == 1 { + return expectedResponse, nil + } + expectedResponse.StatusCode = http.StatusOK + return expectedResponse, nil + }, + } + + resp, err := HttpPostWithRetry(httpClient, expectedURL, expectedHeaders, expectedBody, defaultMaxTime) + + assert.NoError(t, err) + assert.Equal(t, expectedResponse.StatusCode, resp.StatusCode) + assert.Equal(t, 2, retryCount) + }) +} + +func readRequestBody(req *http.Request) []byte { + buf := new(bytes.Buffer) + buf.ReadFrom(req.Body) + return buf.Bytes() +} +func TestDefaultShouldRetry(t *testing.T) { + tests := []struct { + name string + response *http.Response + want bool + }{ + { + name: "StatusUnauthorized", + response: &http.Response{ + StatusCode: http.StatusUnauthorized, + }, + want: false, + }, + { + name: "StatusForbidden", + response: &http.Response{ + StatusCode: http.StatusForbidden, + }, + want: false, + }, + { + name: "StatusNotFound", + response: &http.Response{ + StatusCode: http.StatusNotFound, + }, + want: false, + }, + { + name: "StatusInternalServerError", + response: &http.Response{ + StatusCode: http.StatusInternalServerError, + }, + want: false, + }, + { + name: "OtherStatusCodes", + response: &http.Response{ + StatusCode: http.StatusOK, + }, + want: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := defaultShouldRetry(tt.response) + assert.Equal(t, tt.want, got) + }) + } +}