Skip to content

Commit

Permalink
Add index_connection struct validations + integration tests
Browse files Browse the repository at this point in the history
  • Loading branch information
aulorbe committed Jul 31, 2024
1 parent cf95d9d commit 1b0734d
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 16 deletions.
44 changes: 28 additions & 16 deletions pinecone/index_connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import (
"fmt"
"log"

"github.com/pinecone-io/go-pinecone/internal/utils"

"github.com/pinecone-io/go-pinecone/internal/gen/data"
"github.com/pinecone-io/go-pinecone/internal/useragent"
"google.golang.org/grpc"
Expand Down Expand Up @@ -177,6 +179,11 @@ func (idx *IndexConnection) Close() error {
// log.Fatalf("Successfully upserted %d vector(s)!\n", count)
// }
func (idx *IndexConnection) UpsertVectors(ctx context.Context, in []*Vector) (uint32, error) {
requiredFields := []string{"Id", "Values"}
if err := utils.CheckMissingFields(in, requiredFields); err != nil {
log.Fatalln("vectors must have at least ID and Values fields in order to be upserted")
}

vectors := make([]*data.Vector, len(in))
for i, v := range in {
vectors[i] = vecToGrpc(v)
Expand Down Expand Up @@ -281,10 +288,10 @@ func (idx *IndexConnection) FetchVectors(ctx context.Context, ids []string) (*Fe
// which is passed into the ListVectors method.
//
// Fields:
// - Prefix: The prefix by which to filter. If unspecified,
// - Prefix: (Optional) The prefix by which to filter. If unspecified,
// an empty string will be used which will list all vector ids in the namespace
// - Limit: The maximum number of vectors to return. If unspecified, the server will use a default value.
// - PaginationToken: The token for paginating through results.
// - Limit: (Optional) The maximum number of vectors to return. If unspecified, the server will use a default value.
// - PaginationToken: (Optional) The token for paginating through results.
type ListVectorsRequest struct {
Prefix *string
Limit *uint32
Expand Down Expand Up @@ -391,12 +398,12 @@ func (idx *IndexConnection) ListVectors(ctx context.Context, in *ListVectorsRequ
// which is passed into the QueryByVectorValues method.
//
// Fields:
// - Vector: The query vector used to find similar vectors.
// - TopK: The number of vectors to return.
// - MetadataFilter: The filter to apply to your query.
// - IncludeValues: Whether to include the values of the vectors in the response.
// - IncludeMetadata: Whether to include the metadata associated with the vectors in the response.
// - SparseValues: The sparse values of the query vector, if applicable.
// - Vector: (Required) The query vector used to find similar vectors.
// - TopK: (Required) The number of vectors to return.
// - MetadataFilter: (Optional) The filter to apply to your query.
// - IncludeValues: (Optional) Whether to include the values of the vectors in the response.
// - IncludeMetadata: (Optional) Whether to include the metadata associated with the vectors in the response.
// - SparseValues: (Optional) The sparse values of the query vector, if applicable.
type QueryByVectorValuesRequest struct {
Vector []float32
TopK uint32
Expand Down Expand Up @@ -510,12 +517,12 @@ func (idx *IndexConnection) QueryByVectorValues(ctx context.Context, in *QueryBy
// which is passed into the QueryByVectorId method.
//
// Fields:
// - VectorId: The unique ID of the vector used to find similar vectors.
// - TopK: The number of vectors to return.
// - MetadataFilter: The filter to apply to your query.
// - IncludeValues: Whether to include the values of the vectors in the response.
// - IncludeMetadata: Whether to include the metadata associated with the vectors in the response.
// - SparseValues: The sparse values of the query vector, if applicable.
// - VectorId: (Required) The unique ID of the vector used to find similar vectors.
// - TopK: (Required) The number of vectors to return.
// - MetadataFilter: (Optional) The filter to apply to your query.
// - IncludeValues: (Optional) Whether to include the values of the vectors in the response.
// - IncludeMetadata: (Optional) Whether to include the metadata associated with the vectors in the response.
// - SparseValues: (Optional) The sparse values of the query vector, if applicable.
type QueryByVectorIdRequest struct {
VectorId string
TopK uint32
Expand Down Expand Up @@ -772,7 +779,7 @@ func (idx *IndexConnection) DeleteAllVectorsInNamespace(ctx context.Context) err
// which is passed into the UpdateVector method.
//
// Fields:
// - Id: The unique ID of the vector to update.
// - Id: (Required) The unique ID of the vector to update.
// - Values: The values with which you want to update the vector.
// - SparseValues: The sparse values with which you want to update the vector.
// - Metadata: The metadata with which you want to update the vector.
Expand Down Expand Up @@ -830,6 +837,11 @@ type UpdateVectorRequest struct {
// log.Fatalf("Failed to update vector with ID %s. Error: %s", id, err)
// }
func (idx *IndexConnection) UpdateVector(ctx context.Context, in *UpdateVectorRequest) error {
requiredFields := []string{"Id"}
if err := utils.CheckMissingFields(in, requiredFields); err != nil {
log.Fatalln("a vector ID plus at least one of Values, SparseValues, or Metadata must be provided to update a vector")
}

req := &data.UpdateRequest{
Id: in.Id,
Values: in.Values,
Expand Down
17 changes: 17 additions & 0 deletions pinecone/index_connection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,16 @@ func (ts *IntegrationTests) TestDescribeIndexStatsFiltered() {
assert.NotNil(ts.T(), res)
}

func (ts *IntegrationTests) TestUpsertVectorsMissingReqdFields() {
ctx := context.Background()
vectors := []*Vector{
{},
}
_, err := ts.idxConn.UpsertVectors(ctx, vectors)
assert.Error(ts.T(), err)
assert.Contains(ts.T(), err.Error(), "vectors must have at least ID and Values fields in order to be upserted")
}

func (ts *IntegrationTests) TestListVectors() {
ts.T().Skip()
req := &ListVectorsRequest{}
Expand Down Expand Up @@ -256,6 +266,13 @@ func (ts *IntegrationTests) TestUpdateVectorValues() {
assert.ElementsMatch(ts.T(), expectedVals, actualVals, "Values do not match")
}

func (ts *IntegrationTests) TestUpdateVectorMissingReqdFields() {
ctx := context.Background()
err := ts.idxConn.UpdateVector(ctx, &UpdateVectorRequest{})
assert.Error(ts.T(), err)
assert.Contains(ts.T(), err.Error(), "a vector ID plus at least one of Values, SparseValues, or Metadata must be provided to update a vector")
}

func (ts *IntegrationTests) TestUpdateVectorMetadata() {
ctx := context.Background()

Expand Down

0 comments on commit 1b0734d

Please sign in to comment.