From d173edc4c6cc5670ac83465a435e1c8534f89ca4 Mon Sep 17 00:00:00 2001 From: Audrey Sage Lorberfeld Date: Fri, 2 Aug 2024 11:05:05 -0700 Subject: [PATCH] Add error handling for missing required fields in passed structs (#55) ## Problem Currently, users structs that are missing fields that are actually required on the backend. ## Solution - Explicitly note required and optional fields in doc comments - Add error handling for when structs are passed that are missing required fields - Add unit tests that confirm an error is thrown when structs are missing required fields ## Type of Change - [x] Bug fix (non-breaking change which fixes an issue) - [ ] New feature (non-breaking change which adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected) - [ ] This change requires a documentation update - [ ] Infrastructure change (CI configs, etc) - [ ] Non-code change (docs, etc) - [ ] None of the above: (explain here) ## Test Plan CI passes --- - To see the specific tasks where the Asana app for GitHub is being used, see below: - https://app.asana.com/0/0/1207649770516337 --- pinecone/client.go | 67 ++++++++++++++++++------------- pinecone/client_test.go | 32 ++++++++++++--- pinecone/index_connection.go | 36 +++++++++-------- pinecone/index_connection_test.go | 8 ++++ 4 files changed, 94 insertions(+), 49 deletions(-) diff --git a/pinecone/client.go b/pinecone/client.go index 15bfabb..a1e1b74 100644 --- a/pinecone/client.go +++ b/pinecone/client.go @@ -79,7 +79,7 @@ type Client struct { // - ApiKey: (Required) The API key used to authenticate with the Pinecone control plane API. // This value must be passed by the user unless it is set as an environment variable ("PINECONE_API_KEY"). // - Headers: An optional map of additional HTTP headers to include in each API request to the control plane. -// - Host: The host URL of the Pinecone control plane API. If not provided, +// - Host: (Optional) The host URL of the Pinecone control plane API. If not provided, // the default value is "https://api.pinecone.io". // - RestClient: An optional HTTP client to use for communication with the control plane API. // - SourceTag: An optional string used to help Pinecone attribute API activity. @@ -99,7 +99,7 @@ type NewClientParams struct { // Fields: // - Headers: An optional map of additional HTTP headers to include in each API request to the control plane. // "Authorization" and "X-Project-Id" headers are required if authenticating using a JWT. -// - Host: The host URL of the Pinecone control plane API. If not provided, +// - Host: (Optional) The host URL of the Pinecone control plane API. If not provided, // the default value is "https://api.pinecone.io". // - RestClient: An optional *http.Client object to use for communication with the control plane API. // - SourceTag: An optional string used to help Pinecone attribute API activity. @@ -115,10 +115,10 @@ type NewClientBaseParams struct { // NewIndexConnParams holds the parameters for creating an IndexConnection to a Pinecone index. // // Fields: -// - Host: The host URL of the Pinecone index. This is required. To find your host url use the DescribeIndex or ListIndexes methods. +// - Host: (Required) The host URL of the Pinecone index. To find your host url use the DescribeIndex or ListIndexes methods. // Alternatively, the host is displayed in the Pinecone web console. // - Namespace: Optional index namespace to use for operations. If not provided, the default namespace of "" will be used. -// - AdditionalMetdata: Optional additional metdata to be sent with each RPC request. +// - AdditionalMetadata: Optional additional metadata to be sent with each RPC request. // // See Client.Index for code example. type NewIndexConnParams struct { @@ -155,7 +155,7 @@ type NewIndexConnParams struct { // } func NewClient(in NewClientParams) (*Client, error) { osApiKey := os.Getenv("PINECONE_API_KEY") - hasApiKey := (valueOrFallback(in.ApiKey, osApiKey) != "") + hasApiKey := valueOrFallback(in.ApiKey, osApiKey) != "" if !hasApiKey { return nil, fmt.Errorf("no API key provided, please pass an API key for authorization through NewClientParams or set the PINECONE_API_KEY environment variable") @@ -184,7 +184,7 @@ func NewClient(in NewClientParams) (*Client, error) { // Notes: // - It is important to handle the error returned by this function to ensure that the // control plane client has been created successfully before attempting to make API calls. -// - A Pinecone API key is not requried when using NewClientBase. +// - A Pinecone API key is not required when using NewClientBase. // // Returns a pointer to an initialized Client instance or an error. // @@ -280,6 +280,10 @@ func (c *Client) Index(in NewIndexConnParams, dialOpts ...grpc.DialOption) (*Ind in.AdditionalMetadata = make(map[string]string) } + if in.Host == "" { + return nil, fmt.Errorf("field Host is required to create an IndexConnection. Find your Host from calling DescribeIndex or via the Pinecone console") + } + // add api version header if not provided if _, ok := in.AdditionalMetadata["X-Pinecone-Api-Version"]; !ok { in.AdditionalMetadata["X-Pinecone-Api-Version"] = gen.PineconeApiVersion @@ -365,21 +369,21 @@ func (c *Client) ListIndexes(ctx context.Context) ([]*Index, error) { // CreatePodIndexRequest holds the parameters for creating a new pods-based Index. // // Fields: -// - Name: The name of the Index. Resource name must be 1-45 characters long, +// - Name: (Required) The name of the Index. Resource name must be 1-45 characters long, // start and end with an alphanumeric character, // and consist only of lower case alphanumeric characters or '-'. -// - Dimension: The [dimensionality] of the vectors to be inserted in the Index. -// - Metric: The distance metric to be used for [similarity] search. You can use +// - Dimension: (Required) The [dimensionality] of the vectors to be inserted in the Index. +// - Metric: (Required) The distance metric to be used for [similarity] search. You can use // 'euclidean', 'cosine', or 'dotproduct'. -// - Environment: The [cloud environment] where the Index will be hosted. -// - PodType: The [type of pod] to use for the Index. One of `s1`, `p1`, or `p2` appended with `.` and +// - Environment: (Required) The [cloud environment] where the Index will be hosted. +// - PodType: (Required) The [type of pod] to use for the Index. One of `s1`, `p1`, or `p2` appended with `.` and // one of `x1`, `x2`, `x4`, or `x8`. -// - Shards: The number of shards to use for the Index (defaults to 1). +// - Shards: (Optional) The number of shards to use for the Index (defaults to 1). // Shards split your data across multiple pods, so you can fit more data into an Index. -// - Replicas: The number of [replicas] to use for the Index (defaults to 1). Replicas duplicate your Index. +// - Replicas: (Optional) The number of [replicas] to use for the Index (defaults to 1). Replicas duplicate your Index. // They provide higher availability and throughput. Replicas can be scaled up or down as your needs change. -// - SourceCollection: The name of the Collection to be used as the source for the Index. -// - MetadataConfig: The [metadata configuration] for the behavior of Pinecone's internal metadata Index. By +// - SourceCollection: (Optional) The name of the Collection to be used as the source for the Index. +// - MetadataConfig: (Optional) The [metadata configuration] for the behavior of Pinecone's internal metadata Index. By // default, all metadata is indexed; when `metadata_config` is present, // only specified metadata fields are indexed. These configurations are // only valid for use with pod-based Indexes. @@ -504,6 +508,10 @@ func (req CreatePodIndexRequest) TotalCount() int { // fmt.Printf("Successfully created pod index: %s", idx.Name) // } func (c *Client) CreatePodIndex(ctx context.Context, in *CreatePodIndexRequest) (*Index, error) { + if in.Name == "" || in.Dimension == 0 || in.Metric == "" || in.Environment == "" || in.PodType == "" { + return nil, fmt.Errorf("fields Name, Dimension, Metric, Environment, and Podtype must be included in CreatePodIndexRequest") + } + deletionProtection := pointerOrNil(control.DeletionProtection(in.DeletionProtection)) metric := pointerOrNil(control.CreateIndexRequestMetric(in.Metric)) @@ -549,14 +557,14 @@ func (c *Client) CreatePodIndex(ctx context.Context, in *CreatePodIndexRequest) // CreateServerlessIndexRequest holds the parameters for creating a new [Serverless] Index. // // Fields: -// - Name: The name of the Index. Resource name must be 1-45 characters long, +// - Name: (Required) The name of the Index. Resource name must be 1-45 characters long, // start and end with an alphanumeric character, // and consist only of lower case alphanumeric characters or '-'. -// - Dimension: The [dimensionality] of the vectors to be inserted in the Index. -// - Metric: The metric used to measure the [similarity] between vectors ('euclidean', 'cosine', or 'dotproduct'). -// - Cloud: The public [cloud provider] where you would like your Index hosted. +// - Dimension: (Required) The [dimensionality] of the vectors to be inserted in the Index. +// - Metric: (Required) The metric used to measure the [similarity] between vectors ('euclidean', 'cosine', or 'dotproduct'). +// - Cloud: (Required) The public [cloud provider] where you would like your Index hosted. // For serverless Indexes, you define only the cloud and region where the Index should be hosted. -// - Region: The [region] where you would like your Index to be created. +// - Region: (Required) The [region] where you would like your Index to be created. // - DeletionProtection: (Optional) determines whether [deletion protection] is "enabled" or "disabled" for the index. // When "enabled", the index cannot be deleted. Defaults to "disabled". // @@ -646,6 +654,10 @@ type CreateServerlessIndexRequest struct { // fmt.Printf("Successfully created serverless index: %s", idx.Name) // } func (c *Client) CreateServerlessIndex(ctx context.Context, in *CreateServerlessIndexRequest) (*Index, error) { + if in.Name == "" || in.Dimension == 0 || in.Metric == "" || in.Cloud == "" || in.Region == "" { + return nil, fmt.Errorf("fields Name, Dimension, Metric, Cloud, and Region must be included in CreateServerlessIndexRequest") + } + deletionProtection := pointerOrNil(control.DeletionProtection(in.DeletionProtection)) metric := pointerOrNil(control.CreateIndexRequestMetric(in.Metric)) @@ -770,7 +782,7 @@ func (c *Client) DeleteIndex(ctx context.Context, idxName string) error { // ConfigureIndexParams contains parameters for configuring an index. For both pod-based // and serverless indexes you can configure the DeletionProtection status for an index. // For pod-based indexes you can also configure the number of Replicas and the PodType. -// Each of the fields are optional, but at least one field must be set. +// Each of the fields is optional, but at least one field must be set. // See [scale a pods-based index] for more information. // // Fields: @@ -894,11 +906,6 @@ func (c *Client) ConfigureIndex(ctx context.Context, name string, in ConfigureIn return decodeIndex(res.Body) } -func PrettifyStruct(obj interface{}) string { - bytes, _ := json.MarshalIndent(obj, "", " ") - return string(bytes) -} - // ListCollections retrieves a list of all Collections in a Pinecone [project]. See Collection for more information. // // Parameters: @@ -1026,8 +1033,8 @@ func (c *Client) DescribeCollection(ctx context.Context, collectionName string) // CreateCollectionRequest holds the parameters for creating a new [Collection]. // // Fields: -// - Name: The name of the Collection. -// - Source: The name of the Index to be used as the source for the Collection. +// - Name: (Required) The name of the Collection. +// - Source: (Required) The name of the Index to be used as the source for the Collection. // // To create a new Collection, use the CreateCollection method on the Client object. // @@ -1104,6 +1111,10 @@ type CreateCollectionRequest struct { // // [Collection]: https://docs.pinecone.io/guides/indexes/understanding-collections func (c *Client) CreateCollection(ctx context.Context, in *CreateCollectionRequest) (*Collection, error) { + if in.Source == "" || in.Name == "" { + return nil, fmt.Errorf("fields Name and Source must be included in CreateCollectionRequest") + } + req := control.CreateCollectionRequest{ Name: in.Name, Source: in.Source, diff --git a/pinecone/client_test.go b/pinecone/client_test.go index d7218df..b0b649e 100644 --- a/pinecone/client_test.go +++ b/pinecone/client_test.go @@ -703,6 +703,33 @@ func (ts *IntegrationTests) TestApiKeyPassedToIndexConnection() { } // Unit tests: +func TestIndexConnectionMissingReqdFieldsUnit(t *testing.T) { + client := &Client{} + _, err := client.Index(NewIndexConnParams{}) + require.ErrorContainsf(t, err, "field Host is required", err.Error()) +} + +func TestCreatePodIndexMissingReqdFieldsUnit(t *testing.T) { + client := &Client{} + _, err := client.CreatePodIndex(context.Background(), &CreatePodIndexRequest{}) + require.Error(t, err) + require.ErrorContainsf(t, err, "fields Name, Dimension, Metric, Environment, and Podtype must be included in CreatePodIndexRequest", err.Error()) //_, err := ts.client.CreatePodIndex(context.Background(), &CreatePodIndexRequest{}) +} + +func TestCreateServerlessIndexMissingReqdFieldsUnit(t *testing.T) { + client := &Client{} + _, err := client.CreateServerlessIndex(context.Background(), &CreateServerlessIndexRequest{}) + require.Error(t, err) + require.ErrorContainsf(t, err, "fields Name, Dimension, Metric, Cloud, and Region must be included in CreateServerlessIndexRequest", err.Error()) +} + +func TestCreateCollectionMissingReqdFieldsUnit(t *testing.T) { + client := &Client{} + _, err := client.CreateCollection(context.Background(), &CreateCollectionRequest{}) + require.Error(t, err) + require.ErrorContains(t, err, "fields Name and Source must be included in CreateCollectionRequest") +} + func TestHandleErrorResponseBodyUnit(t *testing.T) { tests := []struct { name string @@ -1292,11 +1319,6 @@ func TestBuildClientBaseOptionsUnit(t *testing.T) { } // Helper functions: -func isValidUUID(u string) bool { - _, err := uuid.Parse(u) - return err == nil -} - func mockResponse(body string, statusCode int) *http.Response { return &http.Response{ Status: http.StatusText(statusCode), diff --git a/pinecone/index_connection.go b/pinecone/index_connection.go index b7888e8..296d6be 100644 --- a/pinecone/index_connection.go +++ b/pinecone/index_connection.go @@ -281,10 +281,10 @@ func (idx *IndexConnection) FetchVectors(ctx context.Context, ids []string) (*Fe // which is passed into the ListVectors method. // // Fields: -// - Prefix: The prefix by which to filter. If unspecified, +// - Prefix: (Optional) The prefix by which to filter. If unspecified, // an empty string will be used which will list all vector ids in the namespace -// - Limit: The maximum number of vectors to return. If unspecified, the server will use a default value. -// - PaginationToken: The token for paginating through results. +// - Limit: (Optional) The maximum number of vectors to return. If unspecified, the server will use a default value. +// - PaginationToken: (Optional) The token for paginating through results. type ListVectorsRequest struct { Prefix *string Limit *uint32 @@ -391,12 +391,12 @@ func (idx *IndexConnection) ListVectors(ctx context.Context, in *ListVectorsRequ // which is passed into the QueryByVectorValues method. // // Fields: -// - Vector: The query vector used to find similar vectors. -// - TopK: The number of vectors to return. -// - MetadataFilter: The filter to apply to your query. -// - IncludeValues: Whether to include the values of the vectors in the response. -// - IncludeMetadata: Whether to include the metadata associated with the vectors in the response. -// - SparseValues: The sparse values of the query vector, if applicable. +// - Vector: (Required) The query vector used to find similar vectors. +// - TopK: (Required) The number of vectors to return. +// - MetadataFilter: (Optional) The filter to apply to your query. +// - IncludeValues: (Optional) Whether to include the values of the vectors in the response. +// - IncludeMetadata: (Optional) Whether to include the metadata associated with the vectors in the response. +// - SparseValues: (Optional) The sparse values of the query vector, if applicable. type QueryByVectorValuesRequest struct { Vector []float32 TopK uint32 @@ -510,12 +510,12 @@ func (idx *IndexConnection) QueryByVectorValues(ctx context.Context, in *QueryBy // which is passed into the QueryByVectorId method. // // Fields: -// - VectorId: The unique ID of the vector used to find similar vectors. -// - TopK: The number of vectors to return. -// - MetadataFilter: The filter to apply to your query. -// - IncludeValues: Whether to include the values of the vectors in the response. -// - IncludeMetadata: Whether to include the metadata associated with the vectors in the response. -// - SparseValues: The sparse values of the query vector, if applicable. +// - VectorId: (Required) The unique ID of the vector used to find similar vectors. +// - TopK: (Required) The number of vectors to return. +// - MetadataFilter: (Optional) The filter to apply to your query. +// - IncludeValues: (Optional) Whether to include the values of the vectors in the response. +// - IncludeMetadata: (Optional) Whether to include the metadata associated with the vectors in the response. +// - SparseValues: (Optional) The sparse values of the query vector, if applicable. type QueryByVectorIdRequest struct { VectorId string TopK uint32 @@ -772,7 +772,7 @@ func (idx *IndexConnection) DeleteAllVectorsInNamespace(ctx context.Context) err // which is passed into the UpdateVector method. // // Fields: -// - Id: The unique ID of the vector to update. +// - Id: (Required) The unique ID of the vector to update. // - Values: The values with which you want to update the vector. // - SparseValues: The sparse values with which you want to update the vector. // - Metadata: The metadata with which you want to update the vector. @@ -830,6 +830,10 @@ type UpdateVectorRequest struct { // log.Fatalf("Failed to update vector with ID %s. Error: %s", id, err) // } func (idx *IndexConnection) UpdateVector(ctx context.Context, in *UpdateVectorRequest) error { + if in.Id == "" { + return fmt.Errorf("a vector ID plus at least one of Values, SparseValues, or Metadata must be provided to update a vector") + } + req := &data.UpdateRequest{ Id: in.Id, Values: in.Values, diff --git a/pinecone/index_connection_test.go b/pinecone/index_connection_test.go index c0a3ea4..2bb4e70 100644 --- a/pinecone/index_connection_test.go +++ b/pinecone/index_connection_test.go @@ -317,6 +317,14 @@ func (ts *IntegrationTests) TestUpdateVectorSparseValues() error { } // Unit tests: +func TestUpdateVectorMissingReqdFieldsUnit(t *testing.T) { + ctx := context.Background() + idxConn := &IndexConnection{} + err := idxConn.UpdateVector(ctx, &UpdateVectorRequest{}) + assert.Error(t, err) + assert.Contains(t, err.Error(), "a vector ID plus at least one of Values, SparseValues, or Metadata must be provided to update a vector") +} + func TestMarshalFetchVectorsResponseUnit(t *testing.T) { tests := []struct { name string