-
Notifications
You must be signed in to change notification settings - Fork 371
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
fix: Improve Retry Logic to Only Retry on Server-Side HTTP Errors (#1390
) **Changes Implemented** Fixes #861 Selective Retrying in osv.go: **Before:** The retry logic did not differentiate between server-side and client-side HTTP errors, potentially leading to unnecessary retries on HTTP 4xx responses. **After:** Updated the retry mechanism to only retry when the response status code is in the 500 range (HTTP 5xx). This prevents the system from retrying requests that are likely to fail due to client-side issues, thereby optimizing performance and reducing redundant network calls. **osv_test.go:** Verified that the updated retry logic correctly differentiates between HTTP 5xx and HTTP 4xx responses. Ensured that retries are only attempted for HTTP 5xx errors by running and passing the TestRetryOn5xx test case. ![image](https://github.com/user-attachments/assets/925de25c-b3fc-4daf-9571-b0d9da535f41) --------- Co-authored-by: Rex P <[email protected]>
- Loading branch information
1 parent
9d28c7f
commit 98f4319
Showing
2 changed files
with
140 additions
and
26 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
package osv | ||
|
||
import ( | ||
"fmt" | ||
"io" | ||
"net/http" | ||
"net/http/httptest" | ||
"strings" | ||
"testing" | ||
"time" | ||
|
||
"github.com/google/osv-scanner/internal/testutility" | ||
) | ||
|
||
func TestMakeRetryRequest(t *testing.T) { | ||
t.Parallel() | ||
testutility.Skip(t, "This test takes a long time (14+ seconds)") | ||
|
||
tests := []struct { | ||
name string | ||
statusCodes []int | ||
expectedError string | ||
wantAttempts int | ||
}{ | ||
{ | ||
name: "success on first attempt", | ||
statusCodes: []int{http.StatusOK}, | ||
wantAttempts: 1, | ||
}, | ||
{ | ||
name: "client error no retry", | ||
statusCodes: []int{http.StatusBadRequest}, | ||
expectedError: "client error: status=400", | ||
wantAttempts: 1, | ||
}, | ||
{ | ||
name: "server error then success", | ||
statusCodes: []int{http.StatusInternalServerError, http.StatusOK}, | ||
wantAttempts: 2, | ||
}, | ||
{ | ||
name: "max retries on server error", | ||
statusCodes: []int{http.StatusInternalServerError, http.StatusInternalServerError, http.StatusInternalServerError, http.StatusInternalServerError}, | ||
expectedError: "max retries exceeded", | ||
wantAttempts: 4, | ||
}, | ||
} | ||
|
||
for _, tt := range tests { | ||
t.Run(tt.name, func(t *testing.T) { | ||
t.Parallel() | ||
|
||
attempts := 0 | ||
idx := 0 | ||
|
||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { | ||
attempts++ | ||
status := tt.statusCodes[idx] | ||
if idx < len(tt.statusCodes)-1 { | ||
idx++ | ||
} | ||
|
||
w.WriteHeader(status) | ||
message := fmt.Sprintf("response-%d", attempts) | ||
_, _ = w.Write([]byte(message)) | ||
})) | ||
defer server.Close() | ||
|
||
client := &http.Client{Timeout: time.Second} | ||
|
||
resp, err := makeRetryRequest(func() (*http.Response, error) { | ||
//nolint:noctx | ||
return client.Get(server.URL) | ||
}) | ||
|
||
if attempts != tt.wantAttempts { | ||
t.Errorf("got %d attempts, want %d", attempts, tt.wantAttempts) | ||
} | ||
|
||
if tt.expectedError != "" { | ||
if err == nil { | ||
t.Fatalf("expected error containing %q, got nil", tt.expectedError) | ||
} | ||
if !strings.Contains(err.Error(), tt.expectedError) { | ||
t.Errorf("expected error containing %q, got %q", tt.expectedError, err) | ||
} | ||
|
||
return | ||
} | ||
|
||
if err != nil { | ||
t.Fatalf("unexpected error: %v", err) | ||
} | ||
|
||
if resp == nil { | ||
t.Fatal("expected non-nil response") | ||
} | ||
defer resp.Body.Close() | ||
|
||
body, err := io.ReadAll(resp.Body) | ||
if err != nil { | ||
t.Fatalf("failed to read response body: %v", err) | ||
} | ||
|
||
expectedBody := fmt.Sprintf("response-%d", attempts) | ||
if string(body) != expectedBody { | ||
t.Errorf("got body %q, want %q", string(body), expectedBody) | ||
} | ||
}) | ||
} | ||
} |