diff --git a/pinecone/client.go b/pinecone/client.go index a6573d4..73aefc5 100644 --- a/pinecone/client.go +++ b/pinecone/client.go @@ -114,17 +114,12 @@ func (c *Client) ListIndexes(ctx context.Context) ([]*Index, error) { } defer res.Body.Close() - resBodyBytes, err := io.ReadAll(res.Body) - if err != nil { - return nil, fmt.Errorf("failed to read response body: %w", err) - } - if res.StatusCode != http.StatusOK { - return nil, handleErrorResponseBody(resBodyBytes, res.StatusCode, "failed to list indexes: ") + return nil, handleErrorResponseBody(res, "failed to list indexes: ") } var indexList control.IndexList - err = json.Unmarshal(resBodyBytes, &indexList) + err = json.NewDecoder(res.Body).Decode(&indexList) if err != nil { return nil, err } @@ -208,16 +203,11 @@ func (c *Client) CreatePodIndex(ctx context.Context, in *CreatePodIndexRequest) } defer res.Body.Close() - resBodyBytes, err := io.ReadAll(res.Body) - if err != nil { - return nil, fmt.Errorf("failed to read response body: %w", err) - } - if res.StatusCode != http.StatusCreated { - return nil, handleErrorResponseBody(resBodyBytes, res.StatusCode, "failed to create index: ") + return nil, handleErrorResponseBody(res, "failed to create index: ") } - return decodeIndex(resBodyBytes) + return decodeIndex(res.Body) } type CreateServerlessIndexRequest struct { @@ -248,16 +238,11 @@ func (c *Client) CreateServerlessIndex(ctx context.Context, in *CreateServerless } defer res.Body.Close() - resBodyBytes, err := io.ReadAll(res.Body) - if err != nil { - return nil, fmt.Errorf("failed to read response body: %w", err) - } - if res.StatusCode != http.StatusCreated { - return nil, handleErrorResponseBody(resBodyBytes, res.StatusCode, "failed to create index: ") + return nil, handleErrorResponseBody(res, "failed to create index: ") } - return decodeIndex(resBodyBytes) + return decodeIndex(res.Body) } func (c *Client) DescribeIndex(ctx context.Context, idxName string) (*Index, error) { @@ -267,16 +252,11 @@ func (c *Client) DescribeIndex(ctx context.Context, idxName string) (*Index, err } defer res.Body.Close() - resBodyBytes, err := io.ReadAll(res.Body) - if err != nil { - return nil, fmt.Errorf("failed to read response body: %w", err) - } - if res.StatusCode != http.StatusOK { - return nil, handleErrorResponseBody(resBodyBytes, res.StatusCode, "failed to describe index: ") + return nil, handleErrorResponseBody(res, "failed to describe index: ") } - return decodeIndex(resBodyBytes) + return decodeIndex(res.Body) } func (c *Client) DeleteIndex(ctx context.Context, idxName string) error { @@ -286,13 +266,8 @@ func (c *Client) DeleteIndex(ctx context.Context, idxName string) error { } defer res.Body.Close() - resBodyBytes, err := io.ReadAll(res.Body) - if err != nil { - return fmt.Errorf("failed to read response body: %w", err) - } - if res.StatusCode != http.StatusAccepted { - return handleErrorResponseBody(resBodyBytes, res.StatusCode, "failed to delete index: ") + return handleErrorResponseBody(res, "failed to delete index: ") } return nil @@ -305,17 +280,12 @@ func (c *Client) ListCollections(ctx context.Context) ([]*Collection, error) { } defer res.Body.Close() - resBodyBytes, err := io.ReadAll(res.Body) - if err != nil { - return nil, fmt.Errorf("failed to read response body: %w", err) - } - if res.StatusCode != http.StatusOK { - return nil, handleErrorResponseBody(resBodyBytes, res.StatusCode, "failed to list collections: ") + return nil, handleErrorResponseBody(res, "failed to list collections: ") } var collectionsResponse control.CollectionList - if err := json.Unmarshal(resBodyBytes, &collectionsResponse); err != nil { + if err := json.NewDecoder(res.Body).Decode(&collectionsResponse); err != nil { return nil, err } @@ -334,16 +304,11 @@ func (c *Client) DescribeCollection(ctx context.Context, collectionName string) } defer res.Body.Close() - resBodyBytes, err := io.ReadAll(res.Body) - if err != nil { - return nil, err - } - if res.StatusCode != http.StatusOK { - return nil, handleErrorResponseBody(resBodyBytes, res.StatusCode, "failed to describe collection: ") + return nil, handleErrorResponseBody(res, "failed to describe collection: ") } - return decodeCollection(resBodyBytes) + return decodeCollection(res.Body) } type CreateCollectionRequest struct { @@ -363,16 +328,11 @@ func (c *Client) CreateCollection(ctx context.Context, in *CreateCollectionReque } defer res.Body.Close() - resBodyBytes, err := io.ReadAll(res.Body) - if err != nil { - return nil, fmt.Errorf("failed to read response body: %w", err) - } - if res.StatusCode != http.StatusCreated { - return nil, handleErrorResponseBody(resBodyBytes, res.StatusCode, "failed to create collection: ") + return nil, handleErrorResponseBody(res, "failed to create collection: ") } - return decodeCollection(resBodyBytes) + return decodeCollection(res.Body) } func (c *Client) DeleteCollection(ctx context.Context, collectionName string) error { @@ -382,13 +342,8 @@ func (c *Client) DeleteCollection(ctx context.Context, collectionName string) er } defer res.Body.Close() - resBodyBytes, err := io.ReadAll(res.Body) - if err != nil { - return fmt.Errorf("failed to read response body: %w", err) - } - if res.StatusCode != http.StatusAccepted { - return handleErrorResponseBody(resBodyBytes, res.StatusCode, "failed to delete collection: ") + return handleErrorResponseBody(res, "failed to delete collection: ") } return nil @@ -451,9 +406,9 @@ func toIndex(idx *control.IndexModel) *Index { } } -func decodeIndex(resBodyBytes []byte) (*Index, error) { +func decodeIndex(resBody io.ReadCloser) (*Index, error) { var idx control.IndexModel - err := json.Unmarshal(resBodyBytes, &idx) + err := json.NewDecoder(resBody).Decode(&idx) if err != nil { return nil, fmt.Errorf("failed to decode idx response: %w", err) } @@ -476,9 +431,9 @@ func toCollection(cm *control.CollectionModel) *Collection { } } -func decodeCollection(resBodyBytes []byte) (*Collection, error) { +func decodeCollection(resBody io.ReadCloser) (*Collection, error) { var collectionModel control.CollectionModel - err := json.Unmarshal(resBodyBytes, &collectionModel) + err := json.NewDecoder(resBody).Decode(&collectionModel) if err != nil { return nil, fmt.Errorf("failed to decode collection response: %w", err) } @@ -508,10 +463,14 @@ type errorResponseMap struct { Details string `json:"details,omitempty"` } -func handleErrorResponseBody(resBodyBytes []byte, statusCode int, errMsgPrefix string) error { - var errMap errorResponseMap +func handleErrorResponseBody(response *http.Response, errMsgPrefix string) error { + resBodyBytes, err := io.ReadAll(response.Body) + if err != nil { + return fmt.Errorf("failed to read response body: %w", err) + } - errMap.StatusCode = statusCode + var errMap errorResponseMap + errMap.StatusCode = response.StatusCode // try and decode ErrorResponse if json.Valid(resBodyBytes) { diff --git a/pinecone/client_test.go b/pinecone/client_test.go index 2f72c9c..f8b2c94 100644 --- a/pinecone/client_test.go +++ b/pinecone/client_test.go @@ -3,6 +3,7 @@ package pinecone import ( "context" "fmt" + "io" "net/http" "os" "reflect" @@ -32,29 +33,29 @@ func TestClient(t *testing.T) { func TestHandleErrorResponseBody(t *testing.T) { tests := []struct { name string - responseBody string + responseBody *http.Response statusCode int prefix string errorOutput string }{ { name: "test ErrorResponse body", - responseBody: `{"error": { "code": "INVALID_ARGUMENT", "message": "test error message"}, "status": 400}`, + responseBody: mockResponse(`{"error": { "code": "INVALID_ARGUMENT", "message": "test error message"}, "status": 400}`, http.StatusBadRequest), statusCode: http.StatusBadRequest, errorOutput: `{"status_code":400,"body":"{\"error\": { \"code\": \"INVALID_ARGUMENT\", \"message\": \"test error message\"}, \"status\": 400}","error_code":"INVALID_ARGUMENT","message":"test error message"}`, }, { name: "test JSON body", - responseBody: `{"message": "test error message", "extraCode": 665}`, + responseBody: mockResponse(`{"message": "test error message", "extraCode": 665}`, http.StatusBadRequest), statusCode: http.StatusBadRequest, errorOutput: `{"status_code":400,"body":"{\"message\": \"test error message\", \"extraCode\": 665}"}`, }, { name: "test string body", - responseBody: `test error message`, + responseBody: mockResponse(`test error message`, http.StatusBadRequest), statusCode: http.StatusBadRequest, errorOutput: `{"status_code":400,"body":"test error message"}`, }, { name: "Test error response with empty response", - responseBody: `{}`, + responseBody: mockResponse(`{}`, http.StatusBadRequest), statusCode: http.StatusBadRequest, prefix: "test prefix", errorOutput: `{"status_code":400,"body":"{}"}`, @@ -63,7 +64,7 @@ func TestHandleErrorResponseBody(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - err := handleErrorResponseBody([]byte(tt.responseBody), tt.statusCode, tt.prefix) + err := handleErrorResponseBody(tt.responseBody, tt.prefix) assert.Equal(t, err.Error(), tt.errorOutput, "Expected error to be '%s', but got '%s'", tt.errorOutput, err.Error()) }) @@ -538,3 +539,12 @@ func deleteUUIDNamedResources(ctx context.Context, c *Client) error { return nil } + +func mockResponse(body string, statusCode int) *http.Response { + return &http.Response{ + Status: http.StatusText(statusCode), + StatusCode: statusCode, + Body: io.NopCloser(strings.NewReader(body)), + Header: make(http.Header), + } +}