diff --git a/pinecone/client.go b/pinecone/client.go index 2461344..a6573d4 100644 --- a/pinecone/client.go +++ b/pinecone/client.go @@ -107,24 +107,6 @@ func (c *Client) IndexWithAdditionalMetadata(host string, namespace string, addi return idx, nil } -func (c *Client) extractAuthHeader() map[string]string { - possibleAuthKeys := []string{ - "api-key", - "authorization", - "access_token", - } - - for key, value := range c.headers { - for _, checkKey := range possibleAuthKeys { - if strings.ToLower(key) == checkKey { - return map[string]string{key: value} - } - } - } - - return nil -} - func (c *Client) ListIndexes(ctx context.Context) ([]*Index, error) { res, err := c.restClient.ListIndexes(ctx) if err != nil { @@ -132,12 +114,17 @@ func (c *Client) ListIndexes(ctx context.Context) ([]*Index, error) { } defer res.Body.Close() + resBodyBytes, err := io.ReadAll(res.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + if res.StatusCode != http.StatusOK { - return nil, fmt.Errorf("unexpected status code: %d", res.StatusCode) + return nil, handleErrorResponseBody(resBodyBytes, res.StatusCode, "failed to list indexes: ") } var indexList control.IndexList - err = json.NewDecoder(res.Body).Decode(&indexList) + err = json.Unmarshal(resBodyBytes, &indexList) if err != nil { return nil, err } @@ -221,16 +208,16 @@ func (c *Client) CreatePodIndex(ctx context.Context, in *CreatePodIndexRequest) } defer res.Body.Close() + resBodyBytes, err := io.ReadAll(res.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + if res.StatusCode != http.StatusCreated { - var errResp control.ErrorResponse - err = json.NewDecoder(res.Body).Decode(&errResp) - if err != nil { - return nil, fmt.Errorf("failed to decode error response: %w", err) - } - return nil, fmt.Errorf("failed to create index: %s", errResp.Error.Message) + return nil, handleErrorResponseBody(resBodyBytes, res.StatusCode, "failed to create index: ") } - return decodeIndex(res.Body) + return decodeIndex(resBodyBytes) } type CreateServerlessIndexRequest struct { @@ -261,16 +248,16 @@ func (c *Client) CreateServerlessIndex(ctx context.Context, in *CreateServerless } defer res.Body.Close() + resBodyBytes, err := io.ReadAll(res.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + if res.StatusCode != http.StatusCreated { - var errResp control.ErrorResponse - err = json.NewDecoder(res.Body).Decode(&errResp) - if err != nil { - return nil, fmt.Errorf("failed to decode error response: %w", err) - } - return nil, fmt.Errorf("failed to create index: %s", errResp.Error.Message) + return nil, handleErrorResponseBody(resBodyBytes, res.StatusCode, "failed to create index: ") } - return decodeIndex(res.Body) + return decodeIndex(resBodyBytes) } func (c *Client) DescribeIndex(ctx context.Context, idxName string) (*Index, error) { @@ -280,16 +267,16 @@ func (c *Client) DescribeIndex(ctx context.Context, idxName string) (*Index, err } defer res.Body.Close() + resBodyBytes, err := io.ReadAll(res.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + if res.StatusCode != http.StatusOK { - var errResp control.ErrorResponse - err = json.NewDecoder(res.Body).Decode(&errResp) - if err != nil { - return nil, fmt.Errorf("failed to decode error response: %w", err) - } - return nil, fmt.Errorf("failed to describe idx: %s", errResp.Error.Message) + return nil, handleErrorResponseBody(resBodyBytes, res.StatusCode, "failed to describe index: ") } - return decodeIndex(res.Body) + return decodeIndex(resBodyBytes) } func (c *Client) DeleteIndex(ctx context.Context, idxName string) error { @@ -299,13 +286,13 @@ func (c *Client) DeleteIndex(ctx context.Context, idxName string) error { } defer res.Body.Close() + resBodyBytes, err := io.ReadAll(res.Body) + if err != nil { + return fmt.Errorf("failed to read response body: %w", err) + } + if res.StatusCode != http.StatusAccepted { - var errResp control.ErrorResponse - err = json.NewDecoder(res.Body).Decode(&errResp) - if err != nil { - return fmt.Errorf("failed to decode error response: %w", err) - } - return fmt.Errorf("failed to delete index: %s", errResp.Error.Message) + return handleErrorResponseBody(resBodyBytes, res.StatusCode, "failed to delete index: ") } return nil @@ -318,12 +305,17 @@ func (c *Client) ListCollections(ctx context.Context) ([]*Collection, error) { } defer res.Body.Close() + resBodyBytes, err := io.ReadAll(res.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + if res.StatusCode != http.StatusOK { - return nil, fmt.Errorf("unexpected status code: %d", res.StatusCode) + return nil, handleErrorResponseBody(resBodyBytes, res.StatusCode, "failed to list collections: ") } var collectionsResponse control.CollectionList - if err := json.NewDecoder(res.Body).Decode(&collectionsResponse); err != nil { + if err := json.Unmarshal(resBodyBytes, &collectionsResponse); err != nil { return nil, err } @@ -342,11 +334,16 @@ func (c *Client) DescribeCollection(ctx context.Context, collectionName string) } defer res.Body.Close() + resBodyBytes, err := io.ReadAll(res.Body) + if err != nil { + return nil, err + } + if res.StatusCode != http.StatusOK { - return nil, fmt.Errorf("unexpected status code: %d", res.StatusCode) + return nil, handleErrorResponseBody(resBodyBytes, res.StatusCode, "failed to describe collection: ") } - return decodeCollection(res.Body) + return decodeCollection(resBodyBytes) } type CreateCollectionRequest struct { @@ -366,17 +363,16 @@ func (c *Client) CreateCollection(ctx context.Context, in *CreateCollectionReque } defer res.Body.Close() - if res.StatusCode != http.StatusCreated { - var errorResponse control.ErrorResponse - err = json.NewDecoder(res.Body).Decode(&errorResponse) - if err != nil { - return nil, err - } + resBodyBytes, err := io.ReadAll(res.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } - return nil, fmt.Errorf("failed to create collection: %s", errorResponse.Error.Message) + if res.StatusCode != http.StatusCreated { + return nil, handleErrorResponseBody(resBodyBytes, res.StatusCode, "failed to create collection: ") } - return decodeCollection(res.Body) + return decodeCollection(resBodyBytes) } func (c *Client) DeleteCollection(ctx context.Context, collectionName string) error { @@ -386,14 +382,31 @@ func (c *Client) DeleteCollection(ctx context.Context, collectionName string) er } defer res.Body.Close() - // Check for successful response, consider successful HTTP codes like 200 or 204 as successful deletion + resBodyBytes, err := io.ReadAll(res.Body) + if err != nil { + return fmt.Errorf("failed to read response body: %w", err) + } + if res.StatusCode != http.StatusAccepted { - var errResp control.ErrorResponse - err = json.NewDecoder(res.Body).Decode(&errResp) - if err != nil { - return fmt.Errorf("failed to decode error response: %w", err) + return handleErrorResponseBody(resBodyBytes, res.StatusCode, "failed to delete collection: ") + } + + return nil +} + +func (c *Client) extractAuthHeader() map[string]string { + possibleAuthKeys := []string{ + "api-key", + "authorization", + "access_token", + } + + for key, value := range c.headers { + for _, checkKey := range possibleAuthKeys { + if strings.ToLower(key) == checkKey { + return map[string]string{key: value} + } } - return fmt.Errorf("failed to delete collection '%s': %s", collectionName, errResp.Error.Message) } return nil @@ -438,9 +451,9 @@ func toIndex(idx *control.IndexModel) *Index { } } -func decodeIndex(resBody io.ReadCloser) (*Index, error) { +func decodeIndex(resBodyBytes []byte) (*Index, error) { var idx control.IndexModel - err := json.NewDecoder(resBody).Decode(&idx) + err := json.Unmarshal(resBodyBytes, &idx) if err != nil { return nil, fmt.Errorf("failed to decode idx response: %w", err) } @@ -463,9 +476,9 @@ func toCollection(cm *control.CollectionModel) *Collection { } } -func decodeCollection(resBody io.ReadCloser) (*Collection, error) { +func decodeCollection(resBodyBytes []byte) (*Collection, error) { var collectionModel control.CollectionModel - err := json.NewDecoder(resBody).Decode(&collectionModel) + err := json.Unmarshal(resBodyBytes, &collectionModel) if err != nil { return nil, fmt.Errorf("failed to decode collection response: %w", err) } @@ -473,6 +486,65 @@ func decodeCollection(resBody io.ReadCloser) (*Collection, error) { return toCollection(&collectionModel), nil } +func decodeErrorResponse(resBodyBytes []byte) (*control.ErrorResponse, error) { + var errorResponse control.ErrorResponse + err := json.Unmarshal(resBodyBytes, &errorResponse) + if err != nil { + return nil, fmt.Errorf("failed to decode error response: %w", err) + } + + if errorResponse.Status == 0 { + return nil, fmt.Errorf("unable to parse ErrorResponse: %v", string(resBodyBytes)) + } + + return &errorResponse, nil +} + +type errorResponseMap struct { + StatusCode int `json:"status_code"` + Body string `json:"body,omitempty"` + ErrorCode string `json:"error_code,omitempty"` + Message string `json:"message,omitempty"` + Details string `json:"details,omitempty"` +} + +func handleErrorResponseBody(resBodyBytes []byte, statusCode int, errMsgPrefix string) error { + var errMap errorResponseMap + + errMap.StatusCode = statusCode + + // try and decode ErrorResponse + if json.Valid(resBodyBytes) { + errorResponse, err := decodeErrorResponse(resBodyBytes) + if err == nil { + errMap.Message = errorResponse.Error.Message + errMap.ErrorCode = string(errorResponse.Error.Code) + + if errorResponse.Error.Details != nil { + errMap.Details = fmt.Sprintf("%+v", errorResponse.Error.Details) + } + } + } + + errMap.Body = string(resBodyBytes) + + if errMap.Message != "" { + errMap.Message = errMsgPrefix + errMap.Message + } + + return formatError(errMap) +} + +func formatError(errMap errorResponseMap) error { + jsonString, err := json.Marshal(errMap) + if err != nil { + return err + } + baseError := fmt.Errorf(string(jsonString)) + + return &PineconeError{Code: errMap.StatusCode, Msg: baseError} +} + func buildClientBaseOptions(in NewClientBaseParams) []control.ClientOption { clientOptions := []control.ClientOption{} @@ -546,3 +618,8 @@ func minOne(x int32) int32 { } return x } + +func PrettifyStruct(obj interface{}) string { + bytes, _ := json.MarshalIndent(obj, "", " ") + return string(bytes) +} diff --git a/pinecone/client_test.go b/pinecone/client_test.go index c5533d1..2f72c9c 100644 --- a/pinecone/client_test.go +++ b/pinecone/client_test.go @@ -3,7 +3,9 @@ package pinecone import ( "context" "fmt" + "net/http" "os" + "reflect" "strings" "testing" @@ -27,6 +29,47 @@ func TestClient(t *testing.T) { suite.Run(t, new(ClientTests)) } +func TestHandleErrorResponseBody(t *testing.T) { + tests := []struct { + name string + responseBody string + statusCode int + prefix string + errorOutput string + }{ + { + name: "test ErrorResponse body", + responseBody: `{"error": { "code": "INVALID_ARGUMENT", "message": "test error message"}, "status": 400}`, + statusCode: http.StatusBadRequest, + errorOutput: `{"status_code":400,"body":"{\"error\": { \"code\": \"INVALID_ARGUMENT\", \"message\": \"test error message\"}, \"status\": 400}","error_code":"INVALID_ARGUMENT","message":"test error message"}`, + }, { + name: "test JSON body", + responseBody: `{"message": "test error message", "extraCode": 665}`, + statusCode: http.StatusBadRequest, + errorOutput: `{"status_code":400,"body":"{\"message\": \"test error message\", \"extraCode\": 665}"}`, + }, { + name: "test string body", + responseBody: `test error message`, + statusCode: http.StatusBadRequest, + errorOutput: `{"status_code":400,"body":"test error message"}`, + }, { + name: "Test error response with empty response", + responseBody: `{}`, + statusCode: http.StatusBadRequest, + prefix: "test prefix", + errorOutput: `{"status_code":400,"body":"{}"}`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := handleErrorResponseBody([]byte(tt.responseBody), tt.statusCode, tt.prefix) + assert.Equal(t, err.Error(), tt.errorOutput, "Expected error to be '%s', but got '%s'", tt.errorOutput, err.Error()) + + }) + } +} + func (ts *ClientTests) SetupSuite() { apiKey := os.Getenv("PINECONE_API_KEY") require.NotEmpty(ts.T(), apiKey, "PINECONE_API_KEY env variable not set") @@ -51,7 +94,6 @@ func (ts *ClientTests) SetupSuite() { // named a UUID. Generally not needed as all tests are cleaning up after themselves // Left here as a convenience during active development. //deleteUUIDNamedResources(context.Background(), &ts.client) - } func (ts *ClientTests) TestNewClientParamsSet() { @@ -285,6 +327,34 @@ func (ts *ClientTests) TestCreatePodIndex() { require.Equal(ts.T(), name, idx.Name, "Index name does not match") } +func (ts *ClientTests) TestCreatePodIndexInvalidDimension() { + name := uuid.New().String() + + _, err := ts.client.CreatePodIndex(context.Background(), &CreatePodIndexRequest{ + Name: name, + Dimension: -1, + Metric: Cosine, + Environment: "us-east1-gcp", + PodType: "p1.x1", + }) + require.Error(ts.T(), err) + require.Equal(ts.T(), reflect.TypeOf(err), reflect.TypeOf(&PineconeError{}), "Expected error to be of type PineconeError") +} + +func (ts *ClientTests) TestCreateServerlessIndexInvalidDimension() { + name := uuid.New().String() + + _, err := ts.client.CreateServerlessIndex(context.Background(), &CreateServerlessIndexRequest{ + Name: name, + Dimension: -1, + Metric: Cosine, + Cloud: Aws, + Region: "us-west-2", + }) + require.Error(ts.T(), err) + require.Equal(ts.T(), reflect.TypeOf(err), reflect.TypeOf(&PineconeError{}), "Expected error to be of type PineconeError") +} + func (ts *ClientTests) TestCreateServerlessIndex() { name := uuid.New().String() @@ -310,6 +380,12 @@ func (ts *ClientTests) TestDescribeServerlessIndex() { require.Equal(ts.T(), ts.serverlessIndex, index.Name, "Index name does not match") } +func (ts *ClientTests) TestDescribeNonExistentIndex() { + _, err := ts.client.DescribeIndex(context.Background(), "non-existent-index") + require.Error(ts.T(), err) + require.Equal(ts.T(), reflect.TypeOf(err), reflect.TypeOf(&PineconeError{}), "Expected error to be of type PineconeError") +} + func (ts *ClientTests) TestDescribeServerlessIndexSourceTag() { index, err := ts.clientSourceTag.DescribeIndex(context.Background(), ts.serverlessIndex) require.NoError(ts.T(), err) diff --git a/pinecone/errors.go b/pinecone/errors.go new file mode 100644 index 0000000..980dd0c --- /dev/null +++ b/pinecone/errors.go @@ -0,0 +1,12 @@ +package pinecone + +import "fmt" + +type PineconeError struct { + Code int + Msg error +} + +func (pe *PineconeError) Error() string { + return fmt.Sprintf("%+v", pe.Msg) +}