Skip to content

Commit

Permalink
fix: Improve Retry Logic to Only Retry on Server-Side HTTP Errors (#1390
Browse files Browse the repository at this point in the history
)

**Changes Implemented**

Fixes #861 

Selective Retrying in osv.go:

**Before:** The retry logic did not differentiate between server-side
and client-side HTTP errors, potentially leading to unnecessary retries
on HTTP 4xx responses.
**After:** Updated the retry mechanism to only retry when the response
status code is in the 500 range (HTTP 5xx). This prevents the system
from retrying requests that are likely to fail due to client-side
issues, thereby optimizing performance and reducing redundant network
calls.

**osv_test.go:**
Verified that the updated retry logic correctly differentiates between
HTTP 5xx and HTTP 4xx responses.
Ensured that retries are only attempted for HTTP 5xx errors by running
and passing the TestRetryOn5xx test case.


![image](https://github.com/user-attachments/assets/925de25c-b3fc-4daf-9571-b0d9da535f41)

---------

Co-authored-by: Rex P <[email protected]>
  • Loading branch information
VishalGawade1 and another-rex authored Dec 23, 2024
1 parent 9d28c7f commit 98f4319
Show file tree
Hide file tree
Showing 2 changed files with 140 additions and 26 deletions.
55 changes: 29 additions & 26 deletions pkg/osv/osv.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,21 +146,6 @@ func chunkBy[T any](items []T, chunkSize int) [][]T {
return append(chunks, items)
}

// checkResponseError checks if the response has an error.
func checkResponseError(resp *http.Response) error {
if resp.StatusCode == http.StatusOK {
return nil
}

respBuf, err := io.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("failed to read error response from server: %w", err)
}
defer resp.Body.Close()

return fmt.Errorf("server response error: %s", string(respBuf))
}

// MakeRequest sends a batched query to osv.dev
func MakeRequest(request BatchedQuery) (*BatchedResponse, error) {
return MakeRequestWithClient(request, http.DefaultClient)
Expand Down Expand Up @@ -319,10 +304,9 @@ func HydrateWithClient(resp *BatchedResponse, client *http.Client) (*HydratedBat
return &hydrated, nil
}

// makeRetryRequest will return an error on both network errors, and if the response is not 200
// makeRetryRequest executes HTTP requests with exponential backoff retry logic
func makeRetryRequest(action func() (*http.Response, error)) (*http.Response, error) {
var resp *http.Response
var err error
var lastErr error

for i := range maxRetryAttempts {
// rand is initialized with a random number (since go1.20), and is also safe to use concurrently
Expand All @@ -331,17 +315,36 @@ func makeRetryRequest(action func() (*http.Response, error)) (*http.Response, er
jitterAmount := (rand.Float64() * float64(jitterMultiplier) * float64(i))
time.Sleep(time.Duration(i*i)*time.Second + time.Duration(jitterAmount*1000)*time.Millisecond)

resp, err = action()
if err == nil {
// Check the response for HTTP errors
err = checkResponseError(resp)
if err == nil {
break
}
resp, err := action()
if err != nil {
lastErr = fmt.Errorf("attempt %d: request failed: %w", i+1, err)
continue
}

if resp.StatusCode >= 200 && resp.StatusCode < 300 {
return resp, nil
}

body, err := io.ReadAll(resp.Body)
resp.Body.Close()
if err != nil {
lastErr = fmt.Errorf("attempt %d: failed to read response: %w", i+1, err)
continue
}

if resp.StatusCode == http.StatusTooManyRequests {
lastErr = fmt.Errorf("attempt %d: too many requests: status=%d body=%s", i+1, resp.StatusCode, body)
continue
}

if resp.StatusCode >= 400 && resp.StatusCode < 500 {
return nil, fmt.Errorf("client error: status=%d body=%s", resp.StatusCode, body)
}

lastErr = fmt.Errorf("server error: status=%d body=%s", resp.StatusCode, body)
}

return resp, err
return nil, fmt.Errorf("max retries exceeded: %w", lastErr)
}

func MakeDetermineVersionRequest(name string, hashes []DetermineVersionHash) (*DetermineVersionResponse, error) {
Expand Down
111 changes: 111 additions & 0 deletions pkg/osv/osv_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
package osv

import (
"fmt"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"

"github.com/google/osv-scanner/internal/testutility"
)

func TestMakeRetryRequest(t *testing.T) {
t.Parallel()
testutility.Skip(t, "This test takes a long time (14+ seconds)")

tests := []struct {
name string
statusCodes []int
expectedError string
wantAttempts int
}{
{
name: "success on first attempt",
statusCodes: []int{http.StatusOK},
wantAttempts: 1,
},
{
name: "client error no retry",
statusCodes: []int{http.StatusBadRequest},
expectedError: "client error: status=400",
wantAttempts: 1,
},
{
name: "server error then success",
statusCodes: []int{http.StatusInternalServerError, http.StatusOK},
wantAttempts: 2,
},
{
name: "max retries on server error",
statusCodes: []int{http.StatusInternalServerError, http.StatusInternalServerError, http.StatusInternalServerError, http.StatusInternalServerError},
expectedError: "max retries exceeded",
wantAttempts: 4,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()

attempts := 0
idx := 0

server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
attempts++
status := tt.statusCodes[idx]
if idx < len(tt.statusCodes)-1 {
idx++
}

w.WriteHeader(status)
message := fmt.Sprintf("response-%d", attempts)
_, _ = w.Write([]byte(message))
}))
defer server.Close()

client := &http.Client{Timeout: time.Second}

resp, err := makeRetryRequest(func() (*http.Response, error) {
//nolint:noctx
return client.Get(server.URL)
})

if attempts != tt.wantAttempts {
t.Errorf("got %d attempts, want %d", attempts, tt.wantAttempts)
}

if tt.expectedError != "" {
if err == nil {
t.Fatalf("expected error containing %q, got nil", tt.expectedError)
}
if !strings.Contains(err.Error(), tt.expectedError) {
t.Errorf("expected error containing %q, got %q", tt.expectedError, err)
}

return
}

if err != nil {
t.Fatalf("unexpected error: %v", err)
}

if resp == nil {
t.Fatal("expected non-nil response")
}
defer resp.Body.Close()

body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("failed to read response body: %v", err)
}

expectedBody := fmt.Sprintf("response-%d", attempts)
if string(body) != expectedBody {
t.Errorf("got body %q, want %q", string(body), expectedBody)
}
})
}
}

0 comments on commit 98f4319

Please sign in to comment.