diff --git a/.env.example b/.env.example index 49ffe3b..4fe3d20 100644 --- a/.env.example +++ b/.env.example @@ -1,3 +1 @@ PINECONE_API_KEY="" -TEST_PODS_INDEX_NAME="" -TEST_SERVERLESS_INDEX_NAME="" diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index c059c95..ea5f5d3 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -17,8 +17,6 @@ jobs: run: | go get ./pinecone - name: Run tests - run: go test ./pinecone + run: go test -count=1 -v ./pinecone env: - TEST_PODS_INDEX_NAME: ${{ secrets.TEST_PODS_INDEX_NAME }} - TEST_SERVERLESS_INDEX_NAME: ${{ secrets.TEST_SERVERLESS_INDEX_NAME }} PINECONE_API_KEY: ${{ secrets.API_KEY }} diff --git a/README.md b/README.md index 523e369..4a928d9 100644 --- a/README.md +++ b/README.md @@ -10,9 +10,11 @@ Official Pinecone Go Client ## Documentation + To see the latest documentation on `main`, visit https://pkg.go.dev/github.com/pinecone-io/go-pinecone@main/pinecone. -To see the latest versioned-release's documentation, visit https://pkg.go.dev/github.com/pinecone-io/go-pinecone/pinecone. +To see the latest versioned-release's documentation, +visit https://pkg.go.dev/github.com/pinecone-io/go-pinecone/pinecone. ## Features @@ -100,9 +102,11 @@ Then, execute `just bootstrap` to install the necessary Go packages ### .env Setup -To avoid race conditions or having to wait for index creation, the tests require a project with at least one pod index -and one serverless index. Copy the api key and index names to a `.env` file. See `.env.example` for a template. +An easy way to keep track of necessary environment variables is to create a `.env` file in the root of the project. +This project comes with a sample `.env` file (`.env.sample`) that you can copy and modify. At the very least, you +will need to include the `PINECONE_API_KEY` variable in your `.env` file for the tests to run locally. +```shell ### API Definitions submodule The API Definitions are in a private submodule. To checkout or update the submodules execute in the root of the project: @@ -111,7 +115,8 @@ The API Definitions are in a private submodule. To checkout or update the submod git submodule update --init --recursive ``` -For working with submodules, see the [Git Submodules](https://git-scm.com/book/en/v2/Git-Tools-Submodules) documentation. +For working with submodules, see the [Git Submodules](https://git-scm.com/book/en/v2/Git-Tools-Submodules) +documentation. ### Just commands diff --git a/justfile b/justfile index eadc3f2..0bdc7e4 100644 --- a/justfile +++ b/justfile @@ -6,12 +6,6 @@ test: source .env set +o allexport go test -count=1 -v ./pinecone -test-integration: - #!/usr/bin/env bash - set -o allexport - source .env - set +o allexport - go test -v -run Integration ./pinecone test-unit: #!/usr/bin/env bash set -o allexport diff --git a/pinecone/client_test.go b/pinecone/client_test.go index faec69a..594c075 100644 --- a/pinecone/client_test.go +++ b/pinecone/client_test.go @@ -11,60 +11,20 @@ import ( "strings" "testing" - "github.com/pinecone-io/go-pinecone/internal/gen" - "github.com/pinecone-io/go-pinecone/internal/provider" - "github.com/google/go-cmp/cmp" + "github.com/pinecone-io/go-pinecone/internal/gen" "github.com/pinecone-io/go-pinecone/internal/gen/control" + "github.com/pinecone-io/go-pinecone/internal/provider" "github.com/google/uuid" "github.com/pinecone-io/go-pinecone/internal/utils" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/stretchr/testify/suite" ) // Integration tests: -type ClientTestsIntegration struct { - suite.Suite - client Client - clientSourceTag Client - sourceTag string - podIndex string - serverlessIndex string -} - -func TestIntegrationClient(t *testing.T) { - suite.Run(t, new(ClientTestsIntegration)) -} - -func (ts *ClientTestsIntegration) SetupSuite() { - apiKey := os.Getenv("PINECONE_API_KEY") - require.NotEmpty(ts.T(), apiKey, "PINECONE_API_KEY env variable not set") - - ts.podIndex = os.Getenv("TEST_PODS_INDEX_NAME") - require.NotEmpty(ts.T(), ts.podIndex, "TEST_PODS_INDEX_NAME env variable not set") - - ts.serverlessIndex = os.Getenv("TEST_SERVERLESS_INDEX_NAME") - require.NotEmpty(ts.T(), ts.serverlessIndex, "TEST_SERVERLESS_INDEX_NAME env variable not set") - - client, err := NewClient(NewClientParams{}) - require.NoError(ts.T(), err) - - ts.client = *client - - ts.sourceTag = "test_source_tag" - clientSourceTag, err := NewClient(NewClientParams{ApiKey: apiKey, SourceTag: ts.sourceTag}) - require.NoError(ts.T(), err) - ts.clientSourceTag = *clientSourceTag - - // this will clean up the project deleting all indexes and collections that are - // named a UUID. Generally not needed as all tests are cleaning up after themselves - // Left here as a convenience during active development. - //deleteUUIDNamedResources(context.Background(), &ts.client) -} - -func (ts *ClientTestsIntegration) TestNewClientParamsSet() { +func (ts *IntegrationTests) TestNewClientParamsSet() { apiKey := "test-api-key" client, err := NewClient(NewClientParams{ApiKey: apiKey}) @@ -77,7 +37,7 @@ func (ts *ClientTestsIntegration) TestNewClientParamsSet() { require.Equal(ts.T(), 3, len(client.restClient.RequestEditors), "Expected client to have correct number of request editors") } -func (ts *ClientTestsIntegration) TestNewClientParamsSetSourceTag() { +func (ts *IntegrationTests) TestNewClientParamsSetSourceTag() { apiKey := "test-api-key" sourceTag := "test-source-tag" client, err := NewClient(NewClientParams{ @@ -93,7 +53,7 @@ func (ts *ClientTestsIntegration) TestNewClientParamsSetSourceTag() { require.Equal(ts.T(), 3, len(client.restClient.RequestEditors), "Expected client to have %s request editors, but got %s", 2, len(client.restClient.RequestEditors)) } -func (ts *ClientTestsIntegration) TestNewClientParamsSetHeaders() { +func (ts *IntegrationTests) TestNewClientParamsSetHeaders() { apiKey := "test-api-key" headers := map[string]string{"test-header": "test-ptr"} client, err := NewClient(NewClientParams{ApiKey: apiKey, Headers: headers}) @@ -106,7 +66,7 @@ func (ts *ClientTestsIntegration) TestNewClientParamsSetHeaders() { require.Equal(ts.T(), 4, len(client.restClient.RequestEditors), "Expected client to have %s request editors, but got %s", 3, len(client.restClient.RequestEditors)) } -func (ts *ClientTestsIntegration) TestNewClientParamsNoApiKeyNoAuthorizationHeader() { +func (ts *IntegrationTests) TestNewClientParamsNoApiKeyNoAuthorizationHeader() { apiKey := os.Getenv("PINECONE_API_KEY") os.Unsetenv("PINECONE_API_KEY") @@ -121,7 +81,7 @@ func (ts *ClientTestsIntegration) TestNewClientParamsNoApiKeyNoAuthorizationHead os.Setenv("PINECONE_API_KEY", apiKey) } -func (ts *ClientTestsIntegration) TestHeadersAppliedToRequests() { +func (ts *IntegrationTests) TestHeadersAppliedToRequests() { apiKey := "test-api-key" headers := map[string]string{"test-header": "123456"} @@ -140,7 +100,7 @@ func (ts *ClientTestsIntegration) TestHeadersAppliedToRequests() { assert.Equal(ts.T(), "123456", testHeaderValue, "Expected request to have header ptr '123456', but got '%s'", testHeaderValue) } -func (ts *ClientTestsIntegration) TestAdditionalHeadersAppliedToRequest() { +func (ts *IntegrationTests) TestAdditionalHeadersAppliedToRequest() { os.Setenv("PINECONE_ADDITIONAL_HEADERS", `{"test-header": "environment-header"}`) apiKey := "test-api-key" @@ -162,7 +122,7 @@ func (ts *ClientTestsIntegration) TestAdditionalHeadersAppliedToRequest() { os.Unsetenv("PINECONE_ADDITIONAL_HEADERS") } -func (ts *ClientTestsIntegration) TestHeadersOverrideAdditionalHeaders() { +func (ts *IntegrationTests) TestHeadersOverrideAdditionalHeaders() { os.Setenv("PINECONE_ADDITIONAL_HEADERS", `{"test-header": "environment-header"}`) apiKey := "test-api-key" @@ -185,7 +145,7 @@ func (ts *ClientTestsIntegration) TestHeadersOverrideAdditionalHeaders() { os.Unsetenv("PINECONE_ADDITIONAL_HEADERS") } -func (ts *ClientTestsIntegration) TestControllerHostOverride() { +func (ts *IntegrationTests) TestControllerHostOverride() { apiKey := "test-api-key" httpClient := utils.CreateMockClient(`{"indexes": []}`) client, err := NewClient(NewClientParams{ApiKey: apiKey, Host: "https://test-controller-host.io", RestClient: httpClient}) @@ -200,7 +160,7 @@ func (ts *ClientTestsIntegration) TestControllerHostOverride() { assert.Equal(ts.T(), "test-controller-host.io", mockTransport.Req.Host, "Expected request to be made to 'test-controller-host.io', but got '%s'", mockTransport.Req.URL.Host) } -func (ts *ClientTestsIntegration) TestControllerHostOverrideFromEnv() { +func (ts *IntegrationTests) TestControllerHostOverrideFromEnv() { os.Setenv("PINECONE_CONTROLLER_HOST", "https://env-controller-host.io") apiKey := "test-api-key" @@ -219,7 +179,7 @@ func (ts *ClientTestsIntegration) TestControllerHostOverrideFromEnv() { os.Unsetenv("PINECONE_CONTROLLER_HOST") } -func (ts *ClientTestsIntegration) TestControllerHostNormalization() { +func (ts *IntegrationTests) TestControllerHostNormalization() { tests := []struct { name string host string @@ -264,22 +224,22 @@ func (ts *ClientTestsIntegration) TestControllerHostNormalization() { } } -func (ts *ClientTestsIntegration) TestListIndexes() { +func (ts *IntegrationTests) TestListIndexes() { indexes, err := ts.client.ListIndexes(context.Background()) require.NoError(ts.T(), err) require.Greater(ts.T(), len(indexes), 0, "Expected at least one index to exist") } -func (ts *ClientTestsIntegration) TestListIndexesSourceTag() { +func (ts *IntegrationTests) TestListIndexesSourceTag() { indexes, err := ts.clientSourceTag.ListIndexes(context.Background()) require.NoError(ts.T(), err) require.Greater(ts.T(), len(indexes), 0, "Expected at least one index to exist") } -func (ts *ClientTestsIntegration) TestCreatePodIndex() { +func (ts *IntegrationTests) TestCreatePodIndex() { name := uuid.New().String() - defer func(ts *ClientTestsIntegration, name string) { + defer func(ts *IntegrationTests, name string) { err := ts.deleteIndex(name) require.NoError(ts.T(), err) }(ts, name) @@ -295,7 +255,7 @@ func (ts *ClientTestsIntegration) TestCreatePodIndex() { require.Equal(ts.T(), name, idx.Name, "Index name does not match") } -func (ts *ClientTestsIntegration) TestCreatePodIndexInvalidDimension() { +func (ts *IntegrationTests) TestCreatePodIndexInvalidDimension() { name := uuid.New().String() _, err := ts.client.CreatePodIndex(context.Background(), &CreatePodIndexRequest{ @@ -309,7 +269,7 @@ func (ts *ClientTestsIntegration) TestCreatePodIndexInvalidDimension() { require.Equal(ts.T(), reflect.TypeOf(err), reflect.TypeOf(&PineconeError{}), "Expected error to be of type PineconeError") } -func (ts *ClientTestsIntegration) TestCreateServerlessIndexInvalidDimension() { +func (ts *IntegrationTests) TestCreateServerlessIndexInvalidDimension() { name := uuid.New().String() _, err := ts.client.CreateServerlessIndex(context.Background(), &CreateServerlessIndexRequest{ @@ -323,10 +283,10 @@ func (ts *ClientTestsIntegration) TestCreateServerlessIndexInvalidDimension() { require.Equal(ts.T(), reflect.TypeOf(err), reflect.TypeOf(&PineconeError{}), "Expected error to be of type PineconeError") } -func (ts *ClientTestsIntegration) TestCreateServerlessIndex() { +func (ts *IntegrationTests) TestCreateServerlessIndex() { name := uuid.New().String() - defer func(ts *ClientTestsIntegration, name string) { + defer func(ts *IntegrationTests, name string) { err := ts.deleteIndex(name) require.NoError(ts.T(), err) }(ts, name) @@ -342,37 +302,52 @@ func (ts *ClientTestsIntegration) TestCreateServerlessIndex() { require.Equal(ts.T(), name, idx.Name, "Index name does not match") } -func (ts *ClientTestsIntegration) TestDescribeServerlessIndex() { - index, err := ts.client.DescribeIndex(context.Background(), ts.serverlessIndex) +func (ts *IntegrationTests) TestDescribeServerlessIndex() { + if ts.indexType == "pods" { + ts.T().Skip("No serverless index to test") + } + index, err := ts.client.DescribeIndex(context.Background(), ts.idxName) require.NoError(ts.T(), err) - require.Equal(ts.T(), ts.serverlessIndex, index.Name, "Index name does not match") + require.Equal(ts.T(), ts.idxName, index.Name, "Index name does not match") } -func (ts *ClientTestsIntegration) TestDescribeNonExistentIndex() { +func (ts *IntegrationTests) TestDescribeNonExistentIndex() { _, err := ts.client.DescribeIndex(context.Background(), "non-existent-index") require.Error(ts.T(), err) require.Equal(ts.T(), reflect.TypeOf(err), reflect.TypeOf(&PineconeError{}), "Expected error to be of type PineconeError") } -func (ts *ClientTestsIntegration) TestDescribeServerlessIndexSourceTag() { - index, err := ts.clientSourceTag.DescribeIndex(context.Background(), ts.serverlessIndex) +func (ts *IntegrationTests) TestDescribeServerlessIndexSourceTag() { + if ts.indexType == "pods" { + ts.T().Skip("No serverless index to test") + } + index, err := ts.clientSourceTag.DescribeIndex(context.Background(), ts.idxName) require.NoError(ts.T(), err) - require.Equal(ts.T(), ts.serverlessIndex, index.Name, "Index name does not match") + require.Equal(ts.T(), ts.idxName, index.Name, "Index name does not match") } -func (ts *ClientTestsIntegration) TestDescribePodIndex() { - index, err := ts.client.DescribeIndex(context.Background(), ts.podIndex) +func (ts *IntegrationTests) TestDescribePodIndex() { + if ts.indexType == "serverless" { + ts.T().Skip("No pod index to test") + } + index, err := ts.client.DescribeIndex(context.Background(), ts.idxName) require.NoError(ts.T(), err) - require.Equal(ts.T(), ts.podIndex, index.Name, "Index name does not match") + require.Equal(ts.T(), ts.idxName, index.Name, "Index name does not match") } -func (ts *ClientTestsIntegration) TestDescribePodIndexSourceTag() { - index, err := ts.clientSourceTag.DescribeIndex(context.Background(), ts.podIndex) +func (ts *IntegrationTests) TestDescribePodIndexSourceTag() { + if ts.indexType == "serverless" { + ts.T().Skip("No pod index to test") + } + index, err := ts.clientSourceTag.DescribeIndex(context.Background(), ts.idxName) require.NoError(ts.T(), err) - require.Equal(ts.T(), ts.podIndex, index.Name, "Index name does not match") + require.Equal(ts.T(), ts.idxName, index.Name, "Index name does not match") } -func (ts *ClientTestsIntegration) TestListCollections() { +func (ts *IntegrationTests) TestListCollections() { + if ts.indexType == "serverless" { + ts.T().Skip("No pod index to test") + } ctx := context.Background() var collectionNames []string @@ -381,7 +356,7 @@ func (ts *ClientTestsIntegration) TestListCollections() { collectionNames = append(collectionNames, collectionName) } - defer func(ts *ClientTestsIntegration, collectionNames []string) { + defer func(ts *IntegrationTests, collectionNames []string) { for _, name := range collectionNames { err := ts.client.DeleteCollection(ctx, name) require.NoError(ts.T(), err, "Error deleting collection") @@ -391,7 +366,7 @@ func (ts *ClientTestsIntegration) TestListCollections() { for _, name := range collectionNames { _, err := ts.client.CreateCollection(ctx, &CreateCollectionRequest{ Name: name, - Source: ts.podIndex, + Source: ts.idxName, }) require.NoError(ts.T(), err, "Error creating collection") } @@ -414,18 +389,21 @@ func (ts *ClientTestsIntegration) TestListCollections() { require.Equal(ts.T(), len(collectionNames), found, "Not all created collections were listed") } -func (ts *ClientTestsIntegration) TestDescribeCollection() { +func (ts *IntegrationTests) TestDescribeCollection() { + if ts.indexType == "serverless" { + ts.T().Skip("No pod index to test") + } ctx := context.Background() collectionName := uuid.New().String() defer func(client *Client, ctx context.Context, collectionName string) { err := client.DeleteCollection(ctx, collectionName) require.NoError(ts.T(), err) - }(&ts.client, ctx, collectionName) + }(ts.client, ctx, collectionName) _, err := ts.client.CreateCollection(ctx, &CreateCollectionRequest{ Name: collectionName, - Source: ts.podIndex, + Source: ts.idxName, }) require.NoError(ts.T(), err) @@ -434,9 +412,12 @@ func (ts *ClientTestsIntegration) TestDescribeCollection() { require.Equal(ts.T(), collectionName, collection.Name, "Collection name does not match") } -func (ts *ClientTestsIntegration) TestCreateCollection() { +func (ts *IntegrationTests) TestCreateCollection() { + if ts.indexType == "serverless" { + ts.T().Skip("No pod index to test") + } name := uuid.New().String() - sourceIndex := ts.podIndex + sourceIndex := ts.idxName defer func() { err := ts.client.DeleteCollection(context.Background(), name) @@ -451,11 +432,14 @@ func (ts *ClientTestsIntegration) TestCreateCollection() { require.Equal(ts.T(), name, collection.Name, "Collection name does not match") } -func (ts *ClientTestsIntegration) TestDeleteCollection() { +func (ts *IntegrationTests) TestDeleteCollection() { + if ts.indexType == "serverless" { + ts.T().Skip("No pod index to test") + } collectionName := uuid.New().String() _, err := ts.client.CreateCollection(context.Background(), &CreateCollectionRequest{ Name: collectionName, - Source: ts.podIndex, + Source: ts.idxName, }) require.NoError(ts.T(), err) @@ -463,10 +447,10 @@ func (ts *ClientTestsIntegration) TestDeleteCollection() { require.NoError(ts.T(), err) } -func (ts *ClientTestsIntegration) TestConfigureIndexIllegalScaleDown() { +func (ts *IntegrationTests) TestConfigureIndexIllegalScaleDown() { name := uuid.New().String() - defer func(ts *ClientTestsIntegration, name string) { + defer func(ts *IntegrationTests, name string) { err := ts.deleteIndex(name) require.NoError(ts.T(), err) }(ts, name) @@ -488,10 +472,10 @@ func (ts *ClientTestsIntegration) TestConfigureIndexIllegalScaleDown() { require.ErrorContainsf(ts.T(), err, "Cannot scale down", err.Error()) } -func (ts *ClientTestsIntegration) TestConfigureIndexScaleUpNoPods() { +func (ts *IntegrationTests) TestConfigureIndexScaleUpNoPods() { name := uuid.New().String() - defer func(ts *ClientTestsIntegration, name string) { + defer func(ts *IntegrationTests, name string) { err := ts.deleteIndex(name) require.NoError(ts.T(), err) }(ts, name) @@ -512,10 +496,10 @@ func (ts *ClientTestsIntegration) TestConfigureIndexScaleUpNoPods() { require.NoError(ts.T(), err) } -func (ts *ClientTestsIntegration) TestConfigureIndexScaleUpNoReplicas() { +func (ts *IntegrationTests) TestConfigureIndexScaleUpNoReplicas() { name := uuid.New().String() - defer func(ts *ClientTestsIntegration, name string) { + defer func(ts *IntegrationTests, name string) { err := ts.deleteIndex(name) require.NoError(ts.T(), err) }(ts, name) @@ -536,10 +520,10 @@ func (ts *ClientTestsIntegration) TestConfigureIndexScaleUpNoReplicas() { require.NoError(ts.T(), err) } -func (ts *ClientTestsIntegration) TestConfigureIndexIllegalNoPodsOrReplicas() { +func (ts *IntegrationTests) TestConfigureIndexIllegalNoPodsOrReplicas() { name := uuid.New().String() - defer func(ts *ClientTestsIntegration, name string) { + defer func(ts *IntegrationTests, name string) { err := ts.deleteIndex(name) require.NoError(ts.T(), err) }(ts, name) @@ -559,10 +543,10 @@ func (ts *ClientTestsIntegration) TestConfigureIndexIllegalNoPodsOrReplicas() { require.ErrorContainsf(ts.T(), err, "must specify either podType or replicas", err.Error()) } -func (ts *ClientTestsIntegration) TestConfigureIndexHitPodLimit() { +func (ts *IntegrationTests) TestConfigureIndexHitPodLimit() { name := uuid.New().String() - defer func(ts *ClientTestsIntegration, name string) { + defer func(ts *IntegrationTests, name string) { err := ts.deleteIndex(name) require.NoError(ts.T(), err) }(ts, name) @@ -583,11 +567,11 @@ func (ts *ClientTestsIntegration) TestConfigureIndexHitPodLimit() { require.ErrorContainsf(ts.T(), err, "You've reached the max pods allowed", err.Error()) } -func (ts *ClientTestsIntegration) deleteIndex(name string) error { +func (ts *IntegrationTests) deleteIndex(name string) error { return ts.client.DeleteIndex(context.Background(), name) } -func (ts *ClientTestsIntegration) TestExtractAuthHeader() { +func (ts *IntegrationTests) TestExtractAuthHeader() { globalApiKey := os.Getenv("PINECONE_API_KEY") os.Unsetenv("PINECONE_API_KEY") @@ -631,7 +615,7 @@ func (ts *ClientTestsIntegration) TestExtractAuthHeader() { os.Setenv("PINECONE_API_KEY", globalApiKey) } -func (ts *ClientTestsIntegration) TestApiKeyPassedToIndexConnection() { +func (ts *IntegrationTests) TestApiKeyPassedToIndexConnection() { apiKey := "test-api-key" client, err := NewClient(NewClientParams{ApiKey: apiKey}) @@ -1244,40 +1228,6 @@ func isValidUUID(u string) bool { return err == nil } -func deleteUUIDNamedResources(ctx context.Context, c *Client) error { - // Delete UUID-named indexes - indexes, err := c.ListIndexes(ctx) - if err != nil { - return err - } - - for _, index := range indexes { - if isValidUUID(index.Name) { - err := c.DeleteIndex(ctx, index.Name) - if err != nil { - return err - } - } - } - - // Delete UUID-named collections - collections, err := c.ListCollections(ctx) - if err != nil { - return err - } - - for _, collection := range collections { - if isValidUUID(collection.Name) { - err := c.DeleteCollection(ctx, collection.Name) - if err != nil { - return err - } - } - } - - return nil -} - func mockResponse(body string, statusCode int) *http.Response { return &http.Response{ Status: http.StatusText(statusCode), diff --git a/pinecone/index_connection_test.go b/pinecone/index_connection_test.go index bf0d25b..872c696 100644 --- a/pinecone/index_connection_test.go +++ b/pinecone/index_connection_test.go @@ -4,114 +4,22 @@ import ( "context" "encoding/json" "fmt" - "os" + "log" "testing" + "time" "github.com/pinecone-io/go-pinecone/internal/gen/data" + "github.com/pinecone-io/go-pinecone/internal/utils" + "google.golang.org/grpc" "google.golang.org/grpc/metadata" + "google.golang.org/protobuf/types/known/structpb" - "github.com/google/uuid" - "github.com/pinecone-io/go-pinecone/internal/utils" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/stretchr/testify/suite" - "google.golang.org/grpc" - "google.golang.org/protobuf/types/known/structpb" ) -// Integration tests: -type IndexConnectionTestsIntegration struct { - suite.Suite - host string - dimension int32 - apiKey string - indexType string - idxConn *IndexConnection - sourceTag string - idxConnSourceTag *IndexConnection - vectorIds []string -} - -func TestIntegrationIndexConnection(t *testing.T) { - apiKey := os.Getenv("PINECONE_API_KEY") - assert.NotEmptyf(t, apiKey, "PINECONE_API_KEY env variable not set") - - client, err := NewClient(NewClientParams{ApiKey: apiKey}) - if err != nil { - t.FailNow() - } - - podIndexName := os.Getenv("TEST_PODS_INDEX_NAME") - assert.NotEmptyf(t, podIndexName, "TEST_PODS_INDEX_NAME env variable not set") - - podIdx, err := client.DescribeIndex(context.Background(), podIndexName) - if err != nil { - t.FailNow() - } - - podTestSuite := new(IndexConnectionTestsIntegration) - podTestSuite.indexType = "pod" - podTestSuite.host = podIdx.Host - podTestSuite.dimension = podIdx.Dimension - podTestSuite.apiKey = apiKey - - serverlessIndexName := os.Getenv("TEST_SERVERLESS_INDEX_NAME") - assert.NotEmptyf(t, serverlessIndexName, "TEST_SERVERLESS_INDEX_NAME env variable not set") - - serverlessIdx, err := client.DescribeIndex(context.Background(), serverlessIndexName) - if err != nil { - t.FailNow() - } - - serverlessTestSuite := new(IndexConnectionTestsIntegration) - serverlessTestSuite.indexType = "serverless" - serverlessTestSuite.host = serverlessIdx.Host - serverlessTestSuite.dimension = serverlessIdx.Dimension - serverlessTestSuite.apiKey = apiKey - - suite.Run(t, podTestSuite) - suite.Run(t, serverlessTestSuite) -} - -func (ts *IndexConnectionTestsIntegration) SetupSuite() { - assert.NotEmptyf(ts.T(), ts.host, "HOST env variable not set") - assert.NotEmptyf(ts.T(), ts.apiKey, "API_KEY env variable not set") - additionalMetadata := map[string]string{"api-key": ts.apiKey} - - namespace, err := uuid.NewV7() - assert.NoError(ts.T(), err) - - idxConn, err := newIndexConnection(newIndexParameters{ - additionalMetadata: additionalMetadata, - host: ts.host, - namespace: namespace.String(), - sourceTag: ""}) - assert.NoError(ts.T(), err) - ts.idxConn = idxConn - - ts.sourceTag = "test_source_tag" - idxConnSourceTag, err := newIndexConnection(newIndexParameters{ - additionalMetadata: additionalMetadata, - host: ts.host, - namespace: namespace.String(), - sourceTag: ts.sourceTag}) - assert.NoError(ts.T(), err) - ts.idxConnSourceTag = idxConnSourceTag - - ts.loadData() -} - -func (ts *IndexConnectionTestsIntegration) TearDownSuite() { - ts.truncateData() - - err := ts.idxConn.Close() - assert.NoError(ts.T(), err) - - err = ts.idxConnSourceTag.Close() - assert.NoError(ts.T(), err) -} - -func (ts *IndexConnectionTestsIntegration) TestNewIndexConnection() { +// Integration tests +func (ts *IntegrationTests) TestNewIndexConnection() { apiKey := "test-api-key" namespace := "" sourceTag := "" @@ -131,7 +39,7 @@ func (ts *IndexConnectionTestsIntegration) TestNewIndexConnection() { require.NotNil(ts.T(), idxConn.grpcConn, "Expected idxConn to have non-nil grpcConn") } -func (ts *IndexConnectionTestsIntegration) TestNewIndexConnectionNamespace() { +func (ts *IntegrationTests) TestNewIndexConnectionNamespace() { apiKey := "test-api-key" namespace := "test-namespace" sourceTag := "test-source-tag" @@ -151,21 +59,21 @@ func (ts *IndexConnectionTestsIntegration) TestNewIndexConnectionNamespace() { require.NotNil(ts.T(), idxConn.grpcConn, "Expected idxConn to have non-nil grpcConn") } -func (ts *IndexConnectionTestsIntegration) TestFetchVectors() { +func (ts *IntegrationTests) TestFetchVectors() { ctx := context.Background() res, err := ts.idxConn.FetchVectors(ctx, ts.vectorIds) assert.NoError(ts.T(), err) assert.NotNil(ts.T(), res) } -func (ts *IndexConnectionTestsIntegration) TestFetchVectorsSourceTag() { +func (ts *IntegrationTests) TestFetchVectorsSourceTag() { ctx := context.Background() res, err := ts.idxConnSourceTag.FetchVectors(ctx, ts.vectorIds) assert.NoError(ts.T(), err) assert.NotNil(ts.T(), res) } -func (ts *IndexConnectionTestsIntegration) TestQueryByVector() { +func (ts *IntegrationTests) TestQueryByVector() { vec := make([]float32, ts.dimension) for i := range vec { vec[i] = 0.01 @@ -182,7 +90,7 @@ func (ts *IndexConnectionTestsIntegration) TestQueryByVector() { assert.NotNil(ts.T(), res) } -func (ts *IndexConnectionTestsIntegration) TestQueryByVectorSourceTag() { +func (ts *IntegrationTests) TestQueryByVectorSourceTag() { vec := make([]float32, ts.dimension) for i := range vec { vec[i] = 0.01 @@ -199,7 +107,7 @@ func (ts *IndexConnectionTestsIntegration) TestQueryByVectorSourceTag() { assert.NotNil(ts.T(), res) } -func (ts *IndexConnectionTestsIntegration) TestQueryById() { +func (ts *IntegrationTests) TestQueryById() { req := &QueryByVectorIdRequest{ VectorId: ts.vectorIds[0], TopK: 5, @@ -211,7 +119,7 @@ func (ts *IndexConnectionTestsIntegration) TestQueryById() { assert.NotNil(ts.T(), res) } -func (ts *IndexConnectionTestsIntegration) TestQueryByIdSourceTag() { +func (ts *IntegrationTests) TestQueryByIdSourceTag() { req := &QueryByVectorIdRequest{ VectorId: ts.vectorIds[0], TopK: 5, @@ -223,15 +131,18 @@ func (ts *IndexConnectionTestsIntegration) TestQueryByIdSourceTag() { assert.NotNil(ts.T(), res) } -func (ts *IndexConnectionTestsIntegration) TestDeleteVectorsById() { +func (ts *IntegrationTests) TestDeleteVectorsById() { ctx := context.Background() err := ts.idxConn.DeleteVectorsById(ctx, ts.vectorIds) assert.NoError(ts.T(), err) - ts.loadData() //reload deleted data + _, err = ts.idxConn.UpsertVectors(ctx, createVectorsForUpsert()) + if err != nil { + log.Fatalf("Failed to upsert vectors in TestDeleteVectorsById test. Error: %v", err) + } } -func (ts *IndexConnectionTestsIntegration) TestDeleteVectorsByFilter() { +func (ts *IntegrationTests) TestDeleteVectorsByFilter() { metadataFilter := map[string]interface{}{ "genre": "classical", } @@ -250,32 +161,39 @@ func (ts *IndexConnectionTestsIntegration) TestDeleteVectorsByFilter() { assert.NoError(ts.T(), err) } - ts.loadData() //reload deleted data + _, err = ts.idxConn.UpsertVectors(ctx, createVectorsForUpsert()) + if err != nil { + log.Fatalf("Failed to upsert vectors in TestDeleteVectorsById test. Error: %v", err) + } } -func (ts *IndexConnectionTestsIntegration) TestDeleteAllVectorsInNamespace() { +func (ts *IntegrationTests) TestDeleteAllVectorsInNamespace() { ctx := context.Background() err := ts.idxConn.DeleteAllVectorsInNamespace(ctx) assert.NoError(ts.T(), err) - ts.loadData() //reload deleted data + _, err = ts.idxConn.UpsertVectors(ctx, createVectorsForUpsert()) + if err != nil { + log.Fatalf("Failed to upsert vectors in TestDeleteVectorsById test. Error: %v", err) + } + } -func (ts *IndexConnectionTestsIntegration) TestDescribeIndexStats() { +func (ts *IntegrationTests) TestDescribeIndexStats() { ctx := context.Background() res, err := ts.idxConn.DescribeIndexStats(ctx) assert.NoError(ts.T(), err) assert.NotNil(ts.T(), res) } -func (ts *IndexConnectionTestsIntegration) TestDescribeIndexStatsFiltered() { +func (ts *IntegrationTests) TestDescribeIndexStatsFiltered() { ctx := context.Background() res, err := ts.idxConn.DescribeIndexStatsFiltered(ctx, &MetadataFilter{}) assert.NoError(ts.T(), err) assert.NotNil(ts.T(), res) } -func (ts *IndexConnectionTestsIntegration) TestListVectors() { +func (ts *IntegrationTests) TestListVectors() { ts.T().Skip() req := &ListVectorsRequest{} @@ -285,63 +203,7 @@ func (ts *IndexConnectionTestsIntegration) TestListVectors() { assert.NotNil(ts.T(), res) } -func (ts *IndexConnectionTestsIntegration) loadData() { - vals := []float32{0.01, 0.02, 0.03, 0.04, 0.05} - vectors := make([]*Vector, len(vals)) - ts.vectorIds = make([]string, len(vals)) - - for i, val := range vals { - vec := make([]float32, ts.dimension) - for i := range vec { - vec[i] = val - } - - id := fmt.Sprintf("vec-%d", i+1) - ts.vectorIds[i] = id - - vectors[i] = &Vector{ - Id: id, - Values: vec, - } - } - - ctx := context.Background() - _, err := ts.idxConn.UpsertVectors(ctx, vectors) - assert.NoError(ts.T(), err) -} - -func (ts *IndexConnectionTestsIntegration) loadDataSourceTag() { - vals := []float32{0.01, 0.02, 0.03, 0.04, 0.05} - vectors := make([]*Vector, len(vals)) - ts.vectorIds = make([]string, len(vals)) - - for i, val := range vals { - vec := make([]float32, ts.dimension) - for i := range vec { - vec[i] = val - } - - id := fmt.Sprintf("vec-%d", i+1) - ts.vectorIds[i] = id - - vectors[i] = &Vector{ - Id: id, - Values: vec, - } - } - - ctx := context.Background() - _, err := ts.idxConnSourceTag.UpsertVectors(ctx, vectors) - assert.NoError(ts.T(), err) -} - -func (ts *IndexConnectionTestsIntegration) truncateData() { - ctx := context.Background() - err := ts.idxConn.DeleteAllVectorsInNamespace(ctx) - assert.NoError(ts.T(), err) -} - -func (ts *IndexConnectionTestsIntegration) TestMetadataAppliedToRequests() { +func (ts *IntegrationTests) TestMetadataAppliedToRequests() { apiKey := "test-api-key" namespace := "test-namespace" sourceTag := "test-source-tag" @@ -373,6 +235,87 @@ func (ts *IndexConnectionTestsIntegration) TestMetadataAppliedToRequests() { require.NotNil(ts.T(), stats) } +func (ts *IntegrationTests) TestUpdateVectorValues() { + ctx := context.Background() + + expectedVals := []float32{7.2, 7.2, 7.2, 7.2, 7.2} + err := ts.idxConn.UpdateVector(ctx, &UpdateVectorRequest{ + Id: ts.vectorIds[0], + Values: expectedVals, + }) + assert.NoError(ts.T(), err) + + time.Sleep(5 * time.Second) + + vector, err := ts.idxConn.FetchVectors(ctx, []string{ts.vectorIds[0]}) + if err != nil { + ts.FailNow(fmt.Sprintf("Failed to fetch vector: %v", err)) + } + actualVals := vector.Vectors[ts.vectorIds[0]].Values + + assert.ElementsMatch(ts.T(), expectedVals, actualVals, "Values do not match") +} + +func (ts *IntegrationTests) TestUpdateVectorMetadata() { + ctx := context.Background() + + expectedMetadata := map[string]interface{}{ + "genre": "death-metal", + } + expectedMetadataMap, err := structpb.NewStruct(expectedMetadata) + + err = ts.idxConn.UpdateVector(ctx, &UpdateVectorRequest{ + Id: ts.vectorIds[0], + Metadata: expectedMetadataMap, + }) + assert.NoError(ts.T(), err) + + time.Sleep(5 * time.Second) + + vector, err := ts.idxConn.FetchVectors(ctx, []string{ts.vectorIds[0]}) + if err != nil { + ts.FailNow(fmt.Sprintf("Failed to fetch vector: %v", err)) + } + + expectedGenre := expectedMetadataMap.Fields["genre"].GetStringValue() + actualGenre := vector.Vectors[ts.vectorIds[0]].Metadata.Fields["genre"].GetStringValue() + + assert.Equal(ts.T(), expectedGenre, actualGenre, "Metadata does not match") +} + +func (ts *IntegrationTests) TestUpdateVectorSparseValues() error { + ctx := context.Background() + + dims := int(ts.dimension) + indices := generateUint32Array(dims) + vals := generateFloat32Array(dims) + expectedSparseValues := SparseValues{ + Indices: indices, + Values: vals, + } + + fmt.Printf("Updating sparse values in host \"%s\"...\n", ts.host) + err := ts.idxConn.UpdateVector(ctx, &UpdateVectorRequest{ + Id: ts.vectorIds[0], + SparseValues: &expectedSparseValues, + }) + require.NoError(ts.T(), err) + + // Wait for updates to propagate + time.Sleep(5 * time.Second) + + // Fetch updated vector and verify sparse values + vector, err := ts.idxConn.FetchVectors(ctx, []string{ts.vectorIds[0]}) + if err != nil { + ts.FailNow(fmt.Sprintf("Failed to fetch vector: %v", err)) + } + actualSparseValues := vector.Vectors[ts.vectorIds[0]].SparseValues.Values + + assert.ElementsMatch(ts.T(), expectedSparseValues.Values, actualSparseValues, "Sparse values do not match") + + return nil +} + // Unit tests: func TestMarshalFetchVectorsResponseUnit(t *testing.T) { tests := []struct { @@ -1073,3 +1016,20 @@ func TestToPaginationToken(t *testing.T) { }) } } + +// Helper funcs +func generateFloat32Array(n int) []float32 { + array := make([]float32, n) + for i := 0; i < n; i++ { + array[i] = float32(i) + } + return array +} + +func generateUint32Array(n int) []uint32 { + array := make([]uint32, n) + for i := 0; i < n; i++ { + array[i] = uint32(i) + } + return array +} diff --git a/pinecone/suite_runner_test.go b/pinecone/suite_runner_test.go new file mode 100644 index 0000000..6e02dd2 --- /dev/null +++ b/pinecone/suite_runner_test.go @@ -0,0 +1,57 @@ +// This file is used to run all the test suites in the package pinecone +package pinecone + +import ( + "os" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" +) + +func RunSuites(t *testing.T) { + apiKey, present := os.LookupEnv("PINECONE_API_KEY") + assert.True(t, present, "PINECONE_API_KEY env variable not set") + + client, err := NewClient(NewClientParams{ApiKey: apiKey}) + require.NotNil(t, client, "Client should not be nil after creation") + require.NoError(t, err) + + sourceTag := "test_source_tag" + clientSourceTag, err := NewClient(NewClientParams{ApiKey: apiKey, SourceTag: sourceTag}) + require.NoError(t, err) + + serverlessIdx := BuildServerlessTestIndex(client, "serverless-"+GenerateTestIndexName()) + podIdx := BuildPodTestIndex(client, "pods-"+GenerateTestIndexName()) + + podTestSuite := &IntegrationTests{ + apiKey: apiKey, + indexType: "pods", + host: podIdx.Host, + dimension: podIdx.Dimension, + client: client, + clientSourceTag: *clientSourceTag, + sourceTag: sourceTag, + idxName: podIdx.Name, + } + + serverlessTestSuite := &IntegrationTests{ + host: serverlessIdx.Host, + dimension: serverlessIdx.Dimension, + apiKey: apiKey, + indexType: "serverless", + client: client, + clientSourceTag: *clientSourceTag, + sourceTag: sourceTag, + idxName: serverlessIdx.Name, + } + + suite.Run(t, podTestSuite) + suite.Run(t, serverlessTestSuite) + +} + +func TestRunSuites(t *testing.T) { + RunSuites(t) +} diff --git a/pinecone/test_suite.go b/pinecone/test_suite.go new file mode 100644 index 0000000..9fbcf96 --- /dev/null +++ b/pinecone/test_suite.go @@ -0,0 +1,205 @@ +package pinecone + +import ( + "context" + "fmt" + "log" + "time" + + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "google.golang.org/protobuf/types/known/structpb" + + "github.com/google/uuid" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" +) + +type IntegrationTests struct { + suite.Suite + apiKey string + client *Client + host string + dimension int32 + indexType string + vectorIds []string + idxName string + idxConn *IndexConnection + sourceTag string + clientSourceTag Client + idxConnSourceTag *IndexConnection +} + +func (ts *IntegrationTests) SetupSuite() { + ctx := context.Background() + + additionalMetadata := map[string]string{"api-key": ts.apiKey} + + namespace, err := uuid.NewUUID() + require.NoError(ts.T(), err) + + idxConn, err := newIndexConnection(newIndexParameters{ + additionalMetadata: additionalMetadata, + host: ts.host, + namespace: namespace.String(), + sourceTag: ""}) + require.NoError(ts.T(), err) + require.NotNil(ts.T(), idxConn, "Failed to create idxConn") + + ts.idxConn = idxConn + + // Deterministically create vectors + vectors := createVectorsForUpsert() + + // Set vector IDs + vectorIds := make([]string, len(vectors)) + for i, v := range vectors { + vectorIds[i] = v.Id + } + ts.vectorIds = vectorIds + + // Upsert vectors + err = upsertVectors(ts, ctx, vectors) + if err != nil { + log.Fatalf("Failed to upsert vectors in SetupSuite: %v", err) + } + + ts.sourceTag = "test_source_tag" + idxConnSourceTag, err := newIndexConnection(newIndexParameters{ + additionalMetadata: additionalMetadata, + host: ts.host, + namespace: namespace.String(), + sourceTag: ts.sourceTag}) + require.NoError(ts.T(), err) + ts.idxConnSourceTag = idxConnSourceTag + + fmt.Printf("\n %s set up suite completed successfully\n", ts.indexType) +} + +func (ts *IntegrationTests) TearDownSuite() { + ctx := context.Background() + + // Delete test indexes + err := ts.client.DeleteIndex(ctx, ts.idxName) + + err = ts.idxConn.Close() + require.NoError(ts.T(), err) + + err = ts.idxConnSourceTag.Close() + require.NoError(ts.T(), err) + fmt.Printf("\n %s setup suite torn down successfully\n", ts.indexType) +} + +// Helper funcs +func GenerateTestIndexName() string { + return fmt.Sprintf("index-%d", time.Now().UnixMilli()) +} + +func upsertVectors(ts *IntegrationTests, ctx context.Context, vectors []*Vector) error { + maxRetries := 12 + delay := 12 * time.Second + fmt.Printf("Attempting to upsert vectors into host \"%s\"...\n", ts.host) + for i := 0; i < maxRetries; i++ { + ready, err := GetIndexStatus(ts, ctx) + if err != nil { + fmt.Printf("Error getting index ready: %v\n", err) + return err + } + if ready { + upsertVectors, err := ts.idxConn.UpsertVectors(ctx, vectors) + require.NoError(ts.T(), err) + fmt.Printf("Upserted vectors: %v into host: %s\n", upsertVectors, ts.host) + break + } else { + time.Sleep(delay) + fmt.Printf("Host \"%s\" not ready for upserting yet, retrying... (%d/%d)\n", ts.host, i, maxRetries) + } + } + return nil +} + +func GetIndexStatus(ts *IntegrationTests, ctx context.Context) (bool, error) { + var desc *Index + var err error + maxRetries := 12 + delay := 12 * time.Second + for i := 0; i < maxRetries; i++ { + desc, err = ts.client.DescribeIndex(ctx, ts.idxName) + if err == nil { + break + } + if status.Code(err) == codes.Unknown { + fmt.Printf("Index \"%s\" not found, retrying... (%d/%d)\n", ts.idxName, i+1, maxRetries) + time.Sleep(delay) + } else { + fmt.Printf("Status code = %v\n", status.Code(err)) + return false, err + } + } + if err != nil { + return false, fmt.Errorf("failed to describe index \"%s\" after retries: %v", err, ts.idxName) + } + return desc.Status.Ready, nil +} + +func createVectorsForUpsert() []*Vector { + vectors := make([]*Vector, 5) + for i := 0; i < 5; i++ { + vectors[i] = &Vector{ + Id: fmt.Sprintf("vector-%d", i+1), + Values: []float32{float32(i), float32(i) + 0.1, float32(i) + 0.2, float32(i) + 0.3, float32(i) + 0.4}, + SparseValues: &SparseValues{ + Indices: []uint32{0, 1, 2, 3, 4}, + Values: []float32{float32(i), float32(i) + 0.1, float32(i) + 0.2, float32(i) + 0.3, float32(i) + 0.4}, + }, + Metadata: &structpb.Struct{ + Fields: map[string]*structpb.Value{ + "genre": {Kind: &structpb.Value_StringValue{StringValue: "classical"}}, + }, + }, + } + } + return vectors +} + +func BuildServerlessTestIndex(in *Client, idxName string) *Index { + ctx := context.Background() + + fmt.Printf("Creating Serverless index: %s\n", idxName) + serverlessIdx, err := in.CreateServerlessIndex(ctx, &CreateServerlessIndexRequest{ + Name: idxName, + Dimension: int32(setDimensionsForTestIndexes()), + Metric: Cosine, + Region: "us-east-1", + Cloud: "aws", + }) + if err != nil { + log.Fatalf("Failed to create Serverless index \"%s\" in integration test: %v", err, idxName) + } else { + fmt.Printf("Successfully created a new Serverless index: %s!\n", idxName) + } + return serverlessIdx +} + +func BuildPodTestIndex(in *Client, name string) *Index { + ctx := context.Background() + + fmt.Printf("Creating pod index: %s\n", name) + podIdx, err := in.CreatePodIndex(ctx, &CreatePodIndexRequest{ + Name: name, + Dimension: int32(setDimensionsForTestIndexes()), + Metric: Cosine, + Environment: "us-east-1-aws", + PodType: "p1", + }) + if err != nil { + log.Fatalf("Failed to create pod index in buildPodTestIndex test: %v", err) + } else { + fmt.Printf("Successfully created a new pod index: %s!\n", name) + } + return podIdx +} + +func setDimensionsForTestIndexes() uint32 { + return uint32(5) +}