Skip to content

Commit

Permalink
Merge branch 'main' into Audrey/address-bug-bash-feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
aulorbe authored Aug 2, 2024
2 parents b40bc46 + d173edc commit 2cc389f
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 49 deletions.
67 changes: 39 additions & 28 deletions pinecone/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ type Client struct {
// - ApiKey: (Required) The API key used to authenticate with the Pinecone control plane API.
// This value must be passed by the user unless it is set as an environment variable ("PINECONE_API_KEY").
// - Headers: An optional map of additional HTTP headers to include in each API request to the control plane.
// - Host: The host URL of the Pinecone control plane API. If not provided,
// - Host: (Optional) The host URL of the Pinecone control plane API. If not provided,
// the default value is "https://api.pinecone.io".
// - RestClient: An optional HTTP client to use for communication with the control plane API.
// - SourceTag: An optional string used to help Pinecone attribute API activity.
Expand All @@ -99,7 +99,7 @@ type NewClientParams struct {
// Fields:
// - Headers: An optional map of additional HTTP headers to include in each API request to the control plane.
// "Authorization" and "X-Project-Id" headers are required if authenticating using a JWT.
// - Host: The host URL of the Pinecone control plane API. If not provided,
// - Host: (Optional) The host URL of the Pinecone control plane API. If not provided,
// the default value is "https://api.pinecone.io".
// - RestClient: An optional *http.Client object to use for communication with the control plane API.
// - SourceTag: An optional string used to help Pinecone attribute API activity.
Expand All @@ -115,10 +115,10 @@ type NewClientBaseParams struct {
// NewIndexConnParams holds the parameters for creating an IndexConnection to a Pinecone index.
//
// Fields:
// - Host: The host URL of the Pinecone index. This is required. To find your host url use the DescribeIndex or ListIndexes methods.
// - Host: (Required) The host URL of the Pinecone index. To find your host url use the DescribeIndex or ListIndexes methods.
// Alternatively, the host is displayed in the Pinecone web console.
// - Namespace: Optional index namespace to use for operations. If not provided, the default namespace of "" will be used.
// - AdditionalMetdata: Optional additional metdata to be sent with each RPC request.
// - AdditionalMetadata: Optional additional metadata to be sent with each RPC request.
//
// See Client.Index for code example.
type NewIndexConnParams struct {
Expand Down Expand Up @@ -155,7 +155,7 @@ type NewIndexConnParams struct {
// }
func NewClient(in NewClientParams) (*Client, error) {
osApiKey := os.Getenv("PINECONE_API_KEY")
hasApiKey := (valueOrFallback(in.ApiKey, osApiKey) != "")
hasApiKey := valueOrFallback(in.ApiKey, osApiKey) != ""

if !hasApiKey {
return nil, fmt.Errorf("no API key provided, please pass an API key for authorization through NewClientParams or set the PINECONE_API_KEY environment variable")
Expand Down Expand Up @@ -184,7 +184,7 @@ func NewClient(in NewClientParams) (*Client, error) {
// Notes:
// - It is important to handle the error returned by this function to ensure that the
// control plane client has been created successfully before attempting to make API calls.
// - A Pinecone API key is not requried when using NewClientBase.
// - A Pinecone API key is not required when using NewClientBase.
//
// Returns a pointer to an initialized Client instance or an error.
//
Expand Down Expand Up @@ -280,6 +280,10 @@ func (c *Client) Index(in NewIndexConnParams, dialOpts ...grpc.DialOption) (*Ind
in.AdditionalMetadata = make(map[string]string)
}

if in.Host == "" {
return nil, fmt.Errorf("field Host is required to create an IndexConnection. Find your Host from calling DescribeIndex or via the Pinecone console")
}

// add api version header if not provided
if _, ok := in.AdditionalMetadata["X-Pinecone-Api-Version"]; !ok {
in.AdditionalMetadata["X-Pinecone-Api-Version"] = gen.PineconeApiVersion
Expand Down Expand Up @@ -365,21 +369,21 @@ func (c *Client) ListIndexes(ctx context.Context) ([]*Index, error) {
// CreatePodIndexRequest holds the parameters for creating a new pods-based Index.
//
// Fields:
// - Name: The name of the Index. Resource name must be 1-45 characters long,
// - Name: (Required) The name of the Index. Resource name must be 1-45 characters long,
// start and end with an alphanumeric character,
// and consist only of lower case alphanumeric characters or '-'.
// - Dimension: The [dimensionality] of the vectors to be inserted in the Index.
// - Metric: The distance metric to be used for [similarity] search. You can use
// - Dimension: (Required) The [dimensionality] of the vectors to be inserted in the Index.
// - Metric: (Required) The distance metric to be used for [similarity] search. You can use
// 'euclidean', 'cosine', or 'dotproduct'.
// - Environment: The [cloud environment] where the Index will be hosted.
// - PodType: The [type of pod] to use for the Index. One of `s1`, `p1`, or `p2` appended with `.` and
// - Environment: (Required) The [cloud environment] where the Index will be hosted.
// - PodType: (Required) The [type of pod] to use for the Index. One of `s1`, `p1`, or `p2` appended with `.` and
// one of `x1`, `x2`, `x4`, or `x8`.
// - Shards: The number of shards to use for the Index (defaults to 1).
// - Shards: (Optional) The number of shards to use for the Index (defaults to 1).
// Shards split your data across multiple pods, so you can fit more data into an Index.
// - Replicas: The number of [replicas] to use for the Index (defaults to 1). Replicas duplicate your Index.
// - Replicas: (Optional) The number of [replicas] to use for the Index (defaults to 1). Replicas duplicate your Index.
// They provide higher availability and throughput. Replicas can be scaled up or down as your needs change.
// - SourceCollection: The name of the Collection to be used as the source for the Index.
// - MetadataConfig: The [metadata configuration] for the behavior of Pinecone's internal metadata Index. By
// - SourceCollection: (Optional) The name of the Collection to be used as the source for the Index.
// - MetadataConfig: (Optional) The [metadata configuration] for the behavior of Pinecone's internal metadata Index. By
// default, all metadata is indexed; when `metadata_config` is present,
// only specified metadata fields are indexed. These configurations are
// only valid for use with pod-based Indexes.
Expand Down Expand Up @@ -532,6 +536,10 @@ func (req CreatePodIndexRequest) TotalCount() int {
// fmt.Printf("Successfully created pod index: %s", idx.Name)
// }
func (c *Client) CreatePodIndex(ctx context.Context, in *CreatePodIndexRequest) (*Index, error) {
if in.Name == "" || in.Dimension == 0 || in.Metric == "" || in.Environment == "" || in.PodType == "" {
return nil, fmt.Errorf("fields Name, Dimension, Metric, Environment, and Podtype must be included in CreatePodIndexRequest")
}

deletionProtection := pointerOrNil(control.DeletionProtection(in.DeletionProtection))
metric := pointerOrNil(control.CreateIndexRequestMetric(in.Metric))

Expand Down Expand Up @@ -577,14 +585,14 @@ func (c *Client) CreatePodIndex(ctx context.Context, in *CreatePodIndexRequest)
// CreateServerlessIndexRequest holds the parameters for creating a new [Serverless] Index.
//
// Fields:
// - Name: The name of the Index. Resource name must be 1-45 characters long,
// - Name: (Required) The name of the Index. Resource name must be 1-45 characters long,
// start and end with an alphanumeric character,
// and consist only of lower case alphanumeric characters or '-'.
// - Dimension: The [dimensionality] of the vectors to be inserted in the Index.
// - Metric: The metric used to measure the [similarity] between vectors ('euclidean', 'cosine', or 'dotproduct').
// - Cloud: The public [cloud provider] where you would like your Index hosted.
// - Dimension: (Required) The [dimensionality] of the vectors to be inserted in the Index.
// - Metric: (Required) The metric used to measure the [similarity] between vectors ('euclidean', 'cosine', or 'dotproduct').
// - Cloud: (Required) The public [cloud provider] where you would like your Index hosted.
// For serverless Indexes, you define only the cloud and region where the Index should be hosted.
// - Region: The [region] where you would like your Index to be created.
// - Region: (Required) The [region] where you would like your Index to be created.
// - DeletionProtection: (Optional) determines whether [deletion protection] is "enabled" or "disabled" for the index.
// When "enabled", the index cannot be deleted. Defaults to "disabled".
//
Expand Down Expand Up @@ -702,6 +710,10 @@ type CreateServerlessIndexRequest struct {
// fmt.Printf("Successfully created serverless index: %s", idx.Name)
// }
func (c *Client) CreateServerlessIndex(ctx context.Context, in *CreateServerlessIndexRequest) (*Index, error) {
if in.Name == "" || in.Dimension == 0 || in.Metric == "" || in.Cloud == "" || in.Region == "" {
return nil, fmt.Errorf("fields Name, Dimension, Metric, Cloud, and Region must be included in CreateServerlessIndexRequest")
}

deletionProtection := pointerOrNil(control.DeletionProtection(in.DeletionProtection))
metric := pointerOrNil(control.CreateIndexRequestMetric(in.Metric))

Expand Down Expand Up @@ -833,7 +845,7 @@ func (c *Client) DeleteIndex(ctx context.Context, idxName string) error {
// ConfigureIndexParams contains parameters for configuring an index. For both pod-based
// and serverless indexes you can configure the DeletionProtection status for an index.
// For pod-based indexes you can also configure the number of Replicas and the PodType.
// Each of the fields are optional, but at least one field must be set.
// Each of the fields is optional, but at least one field must be set.
// See [scale a pods-based index] for more information.
//
// Fields:
Expand Down Expand Up @@ -957,11 +969,6 @@ func (c *Client) ConfigureIndex(ctx context.Context, name string, in ConfigureIn
return decodeIndex(res.Body)
}

func PrettifyStruct(obj interface{}) string {
bytes, _ := json.MarshalIndent(obj, "", " ")
return string(bytes)
}

// ListCollections retrieves a list of all Collections in a Pinecone [project]. See Collection for more information.
//
// Parameters:
Expand Down Expand Up @@ -1089,8 +1096,8 @@ func (c *Client) DescribeCollection(ctx context.Context, collectionName string)
// CreateCollectionRequest holds the parameters for creating a new [Collection].
//
// Fields:
// - Name: The name of the Collection.
// - Source: The name of the Index to be used as the source for the Collection.
// - Name: (Required) The name of the Collection.
// - Source: (Required) The name of the Index to be used as the source for the Collection.
//
// To create a new Collection, use the CreateCollection method on the Client object.
//
Expand Down Expand Up @@ -1167,6 +1174,10 @@ type CreateCollectionRequest struct {
//
// [Collection]: https://docs.pinecone.io/guides/indexes/understanding-collections
func (c *Client) CreateCollection(ctx context.Context, in *CreateCollectionRequest) (*Collection, error) {
if in.Source == "" || in.Name == "" {
return nil, fmt.Errorf("fields Name and Source must be included in CreateCollectionRequest")
}

req := control.CreateCollectionRequest{
Name: in.Name,
Source: in.Source,
Expand Down
32 changes: 27 additions & 5 deletions pinecone/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -703,6 +703,33 @@ func (ts *IntegrationTests) TestApiKeyPassedToIndexConnection() {
}

// Unit tests:
func TestIndexConnectionMissingReqdFieldsUnit(t *testing.T) {
client := &Client{}
_, err := client.Index(NewIndexConnParams{})
require.ErrorContainsf(t, err, "field Host is required", err.Error())
}

func TestCreatePodIndexMissingReqdFieldsUnit(t *testing.T) {
client := &Client{}
_, err := client.CreatePodIndex(context.Background(), &CreatePodIndexRequest{})
require.Error(t, err)
require.ErrorContainsf(t, err, "fields Name, Dimension, Metric, Environment, and Podtype must be included in CreatePodIndexRequest", err.Error()) //_, err := ts.client.CreatePodIndex(context.Background(), &CreatePodIndexRequest{})
}

func TestCreateServerlessIndexMissingReqdFieldsUnit(t *testing.T) {
client := &Client{}
_, err := client.CreateServerlessIndex(context.Background(), &CreateServerlessIndexRequest{})
require.Error(t, err)
require.ErrorContainsf(t, err, "fields Name, Dimension, Metric, Cloud, and Region must be included in CreateServerlessIndexRequest", err.Error())
}

func TestCreateCollectionMissingReqdFieldsUnit(t *testing.T) {
client := &Client{}
_, err := client.CreateCollection(context.Background(), &CreateCollectionRequest{})
require.Error(t, err)
require.ErrorContains(t, err, "fields Name and Source must be included in CreateCollectionRequest")
}

func TestHandleErrorResponseBodyUnit(t *testing.T) {
tests := []struct {
name string
Expand Down Expand Up @@ -1292,11 +1319,6 @@ func TestBuildClientBaseOptionsUnit(t *testing.T) {
}

// Helper functions:
func isValidUUID(u string) bool {
_, err := uuid.Parse(u)
return err == nil
}

func mockResponse(body string, statusCode int) *http.Response {
return &http.Response{
Status: http.StatusText(statusCode),
Expand Down
36 changes: 20 additions & 16 deletions pinecone/index_connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -281,10 +281,10 @@ func (idx *IndexConnection) FetchVectors(ctx context.Context, ids []string) (*Fe
// which is passed into the ListVectors method.
//
// Fields:
// - Prefix: The prefix by which to filter. If unspecified,
// - Prefix: (Optional) The prefix by which to filter. If unspecified,
// an empty string will be used which will list all vector ids in the namespace
// - Limit: The maximum number of vectors to return. If unspecified, the server will use a default value.
// - PaginationToken: The token for paginating through results.
// - Limit: (Optional) The maximum number of vectors to return. If unspecified, the server will use a default value.
// - PaginationToken: (Optional) The token for paginating through results.
type ListVectorsRequest struct {
Prefix *string
Limit *uint32
Expand Down Expand Up @@ -391,12 +391,12 @@ func (idx *IndexConnection) ListVectors(ctx context.Context, in *ListVectorsRequ
// which is passed into the QueryByVectorValues method.
//
// Fields:
// - Vector: The query vector used to find similar vectors.
// - TopK: The number of vectors to return.
// - MetadataFilter: The filter to apply to your query.
// - IncludeValues: Whether to include the values of the vectors in the response.
// - IncludeMetadata: Whether to include the metadata associated with the vectors in the response.
// - SparseValues: The sparse values of the query vector, if applicable.
// - Vector: (Required) The query vector used to find similar vectors.
// - TopK: (Required) The number of vectors to return.
// - MetadataFilter: (Optional) The filter to apply to your query.
// - IncludeValues: (Optional) Whether to include the values of the vectors in the response.
// - IncludeMetadata: (Optional) Whether to include the metadata associated with the vectors in the response.
// - SparseValues: (Optional) The sparse values of the query vector, if applicable.
type QueryByVectorValuesRequest struct {
Vector []float32
TopK uint32
Expand Down Expand Up @@ -510,12 +510,12 @@ func (idx *IndexConnection) QueryByVectorValues(ctx context.Context, in *QueryBy
// which is passed into the QueryByVectorId method.
//
// Fields:
// - VectorId: The unique ID of the vector used to find similar vectors.
// - TopK: The number of vectors to return.
// - MetadataFilter: The filter to apply to your query.
// - IncludeValues: Whether to include the values of the vectors in the response.
// - IncludeMetadata: Whether to include the metadata associated with the vectors in the response.
// - SparseValues: The sparse values of the query vector, if applicable.
// - VectorId: (Required) The unique ID of the vector used to find similar vectors.
// - TopK: (Required) The number of vectors to return.
// - MetadataFilter: (Optional) The filter to apply to your query.
// - IncludeValues: (Optional) Whether to include the values of the vectors in the response.
// - IncludeMetadata: (Optional) Whether to include the metadata associated with the vectors in the response.
// - SparseValues: (Optional) The sparse values of the query vector, if applicable.
type QueryByVectorIdRequest struct {
VectorId string
TopK uint32
Expand Down Expand Up @@ -772,7 +772,7 @@ func (idx *IndexConnection) DeleteAllVectorsInNamespace(ctx context.Context) err
// which is passed into the UpdateVector method.
//
// Fields:
// - Id: The unique ID of the vector to update.
// - Id: (Required) The unique ID of the vector to update.
// - Values: The values with which you want to update the vector.
// - SparseValues: The sparse values with which you want to update the vector.
// - Metadata: The metadata with which you want to update the vector.
Expand Down Expand Up @@ -830,6 +830,10 @@ type UpdateVectorRequest struct {
// log.Fatalf("Failed to update vector with ID %s. Error: %s", id, err)
// }
func (idx *IndexConnection) UpdateVector(ctx context.Context, in *UpdateVectorRequest) error {
if in.Id == "" {
return fmt.Errorf("a vector ID plus at least one of Values, SparseValues, or Metadata must be provided to update a vector")
}

req := &data.UpdateRequest{
Id: in.Id,
Values: in.Values,
Expand Down
8 changes: 8 additions & 0 deletions pinecone/index_connection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,14 @@ func (ts *IntegrationTests) TestUpdateVectorSparseValues() error {
}

// Unit tests:
func TestUpdateVectorMissingReqdFieldsUnit(t *testing.T) {
ctx := context.Background()
idxConn := &IndexConnection{}
err := idxConn.UpdateVector(ctx, &UpdateVectorRequest{})
assert.Error(t, err)
assert.Contains(t, err.Error(), "a vector ID plus at least one of Values, SparseValues, or Metadata must be provided to update a vector")
}

func TestMarshalFetchVectorsResponseUnit(t *testing.T) {
tests := []struct {
name string
Expand Down

0 comments on commit 2cc389f

Please sign in to comment.