Skip to content

Commit

Permalink
Adding retries to httpPost
Browse files Browse the repository at this point in the history
  • Loading branch information
David Wertenteil committed Feb 26, 2024
1 parent 25556bc commit cef973e
Show file tree
Hide file tree
Showing 6 changed files with 260 additions and 12 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/pr.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ jobs:

uses: actions/setup-go@v4
with:
go-version: "1.20"
go-version: "1.21"

- name: Get dependencies
run: go mod download
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/release.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ jobs:
- name: Set up Go
uses: actions/setup-go@v4
with:
go-version: 1.19
go-version: 1.21

- name: Test & coverage
id: unit-test
Expand Down
7 changes: 5 additions & 2 deletions go.mod
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
module github.com/armosec/utils-go

go 1.19
go 1.21

require github.com/stretchr/testify v1.8.4
require (
github.com/cenkalti/backoff v2.2.1+incompatible
github.com/stretchr/testify v1.8.4
)

require (
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
github.com/cenkalti/backoff v2.2.1+incompatible h1:tNowT99t7UNflLxfYYSlKYsBpXdEet03Pg2g16Swow4=
github.com/cenkalti/backoff v2.2.1+incompatible/go.mod h1:90ReRw6GdpyfrHakVjL/QHaoyV4aDUVVkXQJJJ3NXXM=
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
Expand Down
64 changes: 56 additions & 8 deletions httputils/httphelpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,17 @@ import (
"net/http"
"strings"
"sync"
"time"

"github.com/cenkalti/backoff"
)

type IHttpClient interface {
Do(req *http.Request) (*http.Response, error)
}

type ShouldNotRetryFunc func(resp *http.Response) bool

// JSONDecoder returns JSON decoder for given string
func JSONDecoder(origin string) *json.Decoder {
dec := json.NewDecoder(strings.NewReader(origin))
Expand Down Expand Up @@ -65,21 +70,64 @@ func HttpGetWithContext(ctx context.Context, httpClient IHttpClient, fullURL str
}

func HttpPost(httpClient IHttpClient, fullURL string, headers map[string]string, body []byte) (*http.Response, error) {
return HttpPostWithContext(context.Background(), httpClient, fullURL, headers, body)
return HttpPostWithContext(context.Background(), httpClient, fullURL, headers, body, -1, func(resp *http.Response) bool {
return true
})
}

func HttpPostWithContext(ctx context.Context, httpClient IHttpClient, fullURL string, headers map[string]string, body []byte) (*http.Response, error) {
req, err := http.NewRequestWithContext(ctx, "POST", fullURL, bytes.NewReader(body))
if err != nil {
return nil, err
func HttpPostWithRetry(httpClient IHttpClient, fullURL string, headers map[string]string, body []byte, maxElapsedTime time.Duration) (*http.Response, error) {
return HttpPostWithContext(context.Background(), httpClient, fullURL, headers, body, maxElapsedTime, defaultShouldRetry)
}

func HttpPostWithContext(ctx context.Context, httpClient IHttpClient, fullURL string, headers map[string]string, body []byte, maxElapsedTime time.Duration, shouldRetry func(resp *http.Response) bool) (*http.Response, error) {
var resp *http.Response
var err error

operation := func() error {
req, err := http.NewRequestWithContext(ctx, "POST", fullURL, bytes.NewReader(body))
if err != nil {
return backoff.Permanent(err)
}
setHeaders(req, headers)

resp, err = httpClient.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()

// If the status code is not 200, we will retry
if resp.StatusCode != http.StatusOK {
if shouldRetry(resp) {
return fmt.Errorf("received status code: %d", resp.StatusCode)
}
return backoff.Permanent(err)
}

return nil
}
setHeaders(req, headers)

return httpClient.Do(req)
// Create a new exponential backoff policy
expBackOff := backoff.NewExponentialBackOff()
expBackOff.MaxElapsedTime = maxElapsedTime // Set the maximum elapsed time

// Run the operation with the exponential backoff policy
if err = backoff.Retry(operation, expBackOff); err != nil {
return resp, err
}

return resp, nil
}
func defaultShouldRetry(resp *http.Response) bool {
// If received codes 401/403/404/500 should return false
return resp.StatusCode != http.StatusUnauthorized &&
resp.StatusCode != http.StatusForbidden &&
resp.StatusCode != http.StatusNotFound &&
resp.StatusCode != http.StatusInternalServerError
}

func setHeaders(req *http.Request, headers map[string]string) {
if headers != nil && len(headers) > 0 {
if len(headers) > 0 {
for k, v := range headers {
req.Header.Set(k, v)
}
Expand Down
195 changes: 195 additions & 0 deletions httputils/httphelpers_test.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
package httputils

import (
"bytes"
"fmt"
"net/http"
"reflect"
"sync"
"testing"
"time"

"github.com/stretchr/testify/assert"
)
Expand Down Expand Up @@ -296,3 +298,196 @@ func TestSplit2Chunks(t *testing.T) {
}

}

type mockHttpClient struct {
doFunc func(req *http.Request) (*http.Response, error)
}

func (m *mockHttpClient) Do(req *http.Request) (*http.Response, error) {
return m.doFunc(req)
}

func TestHttpPostWithContext(t *testing.T) {
defaultMaxTime := 5 * time.Second

t.Run("Successful request", func(t *testing.T) {
expectedURL := "http://example.com"
expectedBody := []byte("test body")
headers := map[string]string{
"Content-Type": "application/json",
}

expectedResponse := &http.Response{
StatusCode: http.StatusOK,
Body: http.NoBody,
}

httpClient := &mockHttpClient{
doFunc: func(req *http.Request) (*http.Response, error) {
assert.Equal(t, expectedURL, req.URL.String())
assert.Equal(t, "POST", req.Method)
assert.Equal(t, expectedBody, readRequestBody(req))

return expectedResponse, nil
},
}

resp, err := HttpPost(httpClient, expectedURL, headers, expectedBody)

assert.NoError(t, err)
assert.Equal(t, expectedResponse.StatusCode, resp.StatusCode)
})

t.Run("Permanent error", func(t *testing.T) {
expectedURL := "http://example.com"
expectedBody := []byte("test body")
expectedHeaders := map[string]string{
"Content-Type": "application/json",
}
expectedError := fmt.Errorf("permanent error")

httpClient := &mockHttpClient{
doFunc: func(req *http.Request) (*http.Response, error) {
return nil, expectedError
},
}

resp, err := HttpPost(httpClient, expectedURL, expectedHeaders, expectedBody)

assert.Equal(t, expectedError, err)
assert.Nil(t, resp)
})

t.Run("Non-retryable error", func(t *testing.T) {
expectedURL := "http://example.com"
expectedBody := []byte("test body")
expectedHeaders := map[string]string{
"Content-Type": "application/json",
}
expectedResponse := &http.Response{
StatusCode: http.StatusInternalServerError,
Body: http.NoBody,
}

httpClient := &mockHttpClient{
doFunc: func(req *http.Request) (*http.Response, error) {
return expectedResponse, nil
},
}

resp, err := HttpPostWithRetry(httpClient, expectedURL, expectedHeaders, expectedBody, defaultMaxTime)

assert.Equal(t, nil, err)
assert.Equal(t, expectedResponse.StatusCode, resp.StatusCode)
})

t.Run("Retryable error", func(t *testing.T) {
expectedURL := "http://example.com"
expectedBody := []byte("test body")
expectedHeaders := map[string]string{
"Content-Type": "application/json",
}
expectedResponse := &http.Response{
StatusCode: http.StatusBadGateway,
Body: http.NoBody,
}
expectedError := fmt.Errorf("received status code: %d", expectedResponse.StatusCode)

httpClient := &mockHttpClient{
doFunc: func(req *http.Request) (*http.Response, error) {
return expectedResponse, nil
},
}

resp, err := HttpPostWithRetry(httpClient, expectedURL, expectedHeaders, expectedBody, defaultMaxTime)

assert.Equal(t, expectedError, err)
assert.Equal(t, expectedResponse.StatusCode, resp.StatusCode)
})

t.Run("Retryable error with successful retry", func(t *testing.T) {
expectedURL := "http://example.com"
expectedBody := []byte("test body")
expectedHeaders := map[string]string{
"Content-Type": "application/json",
}
expectedResponse := &http.Response{
StatusCode: http.StatusBadGateway,
Body: http.NoBody,
}

retryCount := 0
httpClient := &mockHttpClient{
doFunc: func(req *http.Request) (*http.Response, error) {
retryCount++
if retryCount == 1 {
return expectedResponse, nil
}
expectedResponse.StatusCode = http.StatusOK
return expectedResponse, nil
},
}

resp, err := HttpPostWithRetry(httpClient, expectedURL, expectedHeaders, expectedBody, defaultMaxTime)

assert.NoError(t, err)
assert.Equal(t, expectedResponse.StatusCode, resp.StatusCode)
assert.Equal(t, 2, retryCount)
})
}

func readRequestBody(req *http.Request) []byte {
buf := new(bytes.Buffer)
buf.ReadFrom(req.Body)
return buf.Bytes()
}
func TestDefaultShouldRetry(t *testing.T) {
tests := []struct {
name string
response *http.Response
want bool
}{
{
name: "StatusUnauthorized",
response: &http.Response{
StatusCode: http.StatusUnauthorized,
},
want: false,
},
{
name: "StatusForbidden",
response: &http.Response{
StatusCode: http.StatusForbidden,
},
want: false,
},
{
name: "StatusNotFound",
response: &http.Response{
StatusCode: http.StatusNotFound,
},
want: false,
},
{
name: "StatusInternalServerError",
response: &http.Response{
StatusCode: http.StatusInternalServerError,
},
want: false,
},
{
name: "OtherStatusCodes",
response: &http.Response{
StatusCode: http.StatusOK,
},
want: true,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := defaultShouldRetry(tt.response)
assert.Equal(t, tt.want, got)
})
}
}

0 comments on commit cef973e

Please sign in to comment.