From 7d6ed315813dcca69e2e087a75818a007e3a880e Mon Sep 17 00:00:00 2001 From: Silas Smith Date: Mon, 25 Mar 2024 17:39:24 -0700 Subject: [PATCH 01/10] First pass at functional impl for source_tag --- internal/util/user_agent.go | 44 ++++++++++++++++++++++++++++++++++++ pinecone/client.go | 28 ++++++++++++++++++++--- pinecone/index_connection.go | 4 +++- 3 files changed, 72 insertions(+), 4 deletions(-) create mode 100644 internal/util/user_agent.go diff --git a/internal/util/user_agent.go b/internal/util/user_agent.go new file mode 100644 index 0000000..058ff89 --- /dev/null +++ b/internal/util/user_agent.go @@ -0,0 +1,44 @@ +package util + +import ( + "fmt" + "regexp" + "strings" +) + +func BuildUserAgent(sourceTag string) string { + return buildUserAgent("go-client", sourceTag) +} + +func BuildUserAgentGRPC(sourceTag string) string { + return buildUserAgent("go-client[grpc]", sourceTag) +} + +func buildUserAgent(appName string, sourceTag string) string { + // need to set to actual current version + appVersion := "0.0.1" + + sourceTagInfo := "" + if sourceTag != "" { + sourceTagInfo = buildSourceTagField(sourceTag) + } + userAgent := fmt.Sprintf("%s/%s%s", appName, appVersion, sourceTagInfo) + return userAgent +} + +func buildSourceTagField(userAgent string) string { + // Lowercase + userAgent = strings.ToLower(userAgent) + + // Limit charset to [a-z0-9_ ] + re := regexp.MustCompile(`[^a-z0-9_ ]`) + userAgent = re.ReplaceAllString(userAgent, "") + + // Trim left/right whitespace + userAgent = strings.TrimSpace(userAgent) + + // Condense multiple spaces to one, and replace with underscore + userAgent = strings.Join(strings.Fields(userAgent), "_") + + return fmt.Sprintf("; source_tag=%s;", userAgent) +} diff --git a/pinecone/client.go b/pinecone/client.go index 09a6ce8..36f419d 100644 --- a/pinecone/client.go +++ b/pinecone/client.go @@ -6,6 +6,7 @@ import ( "fmt" "github.com/deepmap/oapi-codegen/v2/pkg/securityprovider" "github.com/pinecone-io/go-pinecone/internal/gen/control" + "github.com/pinecone-io/go-pinecone/internal/util" "io" "net/http" ) @@ -13,10 +14,26 @@ import ( type Client struct { apiKey string restClient *control.Client + sourceTag string +} + +type CustomHeader struct { + userAgent string +} + +func NewUserAgentProvider(userAgent string) *CustomHeader { + return &CustomHeader{userAgent: userAgent} +} + +func (s *CustomHeader) Intercept(ctx context.Context, req *http.Request) error { + req.Header.Set("User-Agent", s.userAgent) + return nil } type NewClientParams struct { ApiKey string + // optional fields + SourceTag string } func NewClient(in NewClientParams) (*Client, error) { @@ -25,12 +42,17 @@ func NewClient(in NewClientParams) (*Client, error) { return nil, err } - client, err := control.NewClient("https://api.pinecone.io", control.WithRequestEditorFn(apiKeyProvider.Intercept)) + userAgentProvider := NewUserAgentProvider(util.BuildUserAgent(in.SourceTag)) + + client, err := control.NewClient("https://api.pinecone.io", + control.WithRequestEditorFn(apiKeyProvider.Intercept), + control.WithRequestEditorFn(userAgentProvider.Intercept), + ) if err != nil { return nil, err } - c := Client{apiKey: in.ApiKey, restClient: client} + c := Client{apiKey: in.ApiKey, restClient: client, sourceTag: in.SourceTag} return &c, nil } @@ -39,7 +61,7 @@ func (c *Client) Index(host string) (*IndexConnection, error) { } func (c *Client) IndexWithNamespace(host string, namespace string) (*IndexConnection, error) { - idx, err := newIndexConnection(c.apiKey, host, namespace) + idx, err := newIndexConnection(c.apiKey, host, namespace, c.sourceTag) if err != nil { return nil, err } diff --git a/pinecone/index_connection.go b/pinecone/index_connection.go index 6d1a3ab..8472969 100644 --- a/pinecone/index_connection.go +++ b/pinecone/index_connection.go @@ -5,6 +5,7 @@ import ( "crypto/tls" "fmt" "github.com/pinecone-io/go-pinecone/internal/gen/data" + "github.com/pinecone-io/go-pinecone/internal/util" "google.golang.org/grpc" "google.golang.org/grpc/credentials" "google.golang.org/grpc/metadata" @@ -18,7 +19,7 @@ type IndexConnection struct { grpcConn *grpc.ClientConn } -func newIndexConnection(apiKey string, host string, namespace string) (*IndexConnection, error) { +func newIndexConnection(apiKey string, host string, namespace string, sourceTag string) (*IndexConnection, error) { config := &tls.Config{} target := fmt.Sprintf("%s:443", host) conn, err := grpc.Dial( @@ -26,6 +27,7 @@ func newIndexConnection(apiKey string, host string, namespace string) (*IndexCon grpc.WithTransportCredentials(credentials.NewTLS(config)), grpc.WithAuthority(target), grpc.WithBlock(), + grpc.WithUserAgent(util.BuildUserAgentGRPC(sourceTag)), ) if err != nil { From 6211bd9ee4862172a118ee3b48c9029db136b3aa Mon Sep 17 00:00:00 2001 From: Silas Smith Date: Tue, 26 Mar 2024 16:32:07 -0700 Subject: [PATCH 02/10] Move user-agent header provider to its own package Also makes the provider generic to any header --- internal/provider/header.go | 20 +++++++++++++++++++ .../user_agent.go => useragent/useragent.go} | 2 +- pinecone/client.go | 20 ++++--------------- pinecone/index_connection.go | 4 ++-- 4 files changed, 27 insertions(+), 19 deletions(-) create mode 100644 internal/provider/header.go rename internal/{util/user_agent.go => useragent/useragent.go} (98%) diff --git a/internal/provider/header.go b/internal/provider/header.go new file mode 100644 index 0000000..7c880e7 --- /dev/null +++ b/internal/provider/header.go @@ -0,0 +1,20 @@ +package provider + +import ( + "context" + "net/http" +) + +type customHeader struct { + name string + value string +} + +func NewHeaderProvider(name string, value string) *customHeader { + return &customHeader{name: name, value: value} +} + +func (h *customHeader) Intercept(ctx context.Context, req *http.Request) error { + req.Header.Set(h.name, h.value) + return nil +} diff --git a/internal/util/user_agent.go b/internal/useragent/useragent.go similarity index 98% rename from internal/util/user_agent.go rename to internal/useragent/useragent.go index 058ff89..08771ab 100644 --- a/internal/util/user_agent.go +++ b/internal/useragent/useragent.go @@ -1,4 +1,4 @@ -package util +package useragent import ( "fmt" diff --git a/pinecone/client.go b/pinecone/client.go index 36f419d..9123298 100644 --- a/pinecone/client.go +++ b/pinecone/client.go @@ -6,7 +6,8 @@ import ( "fmt" "github.com/deepmap/oapi-codegen/v2/pkg/securityprovider" "github.com/pinecone-io/go-pinecone/internal/gen/control" - "github.com/pinecone-io/go-pinecone/internal/util" + "github.com/pinecone-io/go-pinecone/internal/provider" + "github.com/pinecone-io/go-pinecone/internal/useragent" "io" "net/http" ) @@ -17,23 +18,10 @@ type Client struct { sourceTag string } -type CustomHeader struct { - userAgent string -} - -func NewUserAgentProvider(userAgent string) *CustomHeader { - return &CustomHeader{userAgent: userAgent} -} - -func (s *CustomHeader) Intercept(ctx context.Context, req *http.Request) error { - req.Header.Set("User-Agent", s.userAgent) - return nil -} - type NewClientParams struct { ApiKey string // optional fields - SourceTag string + SourceTag string } func NewClient(in NewClientParams) (*Client, error) { @@ -42,7 +30,7 @@ func NewClient(in NewClientParams) (*Client, error) { return nil, err } - userAgentProvider := NewUserAgentProvider(util.BuildUserAgent(in.SourceTag)) + userAgentProvider := provider.NewHeaderProvider("User-Agent", useragent.BuildUserAgent(in.SourceTag)) client, err := control.NewClient("https://api.pinecone.io", control.WithRequestEditorFn(apiKeyProvider.Intercept), diff --git a/pinecone/index_connection.go b/pinecone/index_connection.go index 8472969..59fc5c8 100644 --- a/pinecone/index_connection.go +++ b/pinecone/index_connection.go @@ -5,7 +5,7 @@ import ( "crypto/tls" "fmt" "github.com/pinecone-io/go-pinecone/internal/gen/data" - "github.com/pinecone-io/go-pinecone/internal/util" + "github.com/pinecone-io/go-pinecone/internal/useragent" "google.golang.org/grpc" "google.golang.org/grpc/credentials" "google.golang.org/grpc/metadata" @@ -27,7 +27,7 @@ func newIndexConnection(apiKey string, host string, namespace string, sourceTag grpc.WithTransportCredentials(credentials.NewTLS(config)), grpc.WithAuthority(target), grpc.WithBlock(), - grpc.WithUserAgent(util.BuildUserAgentGRPC(sourceTag)), + grpc.WithUserAgent(useragent.BuildUserAgentGRPC(sourceTag)), ) if err != nil { From 715e5099c3e145817fb12258f073a68a97d11345 Mon Sep 17 00:00:00 2001 From: Silas Smith Date: Tue, 26 Mar 2024 16:51:54 -0700 Subject: [PATCH 03/10] Avoid regexp for simple string matching --- internal/useragent/useragent.go | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/internal/useragent/useragent.go b/internal/useragent/useragent.go index 08771ab..1d86ab5 100644 --- a/internal/useragent/useragent.go +++ b/internal/useragent/useragent.go @@ -2,7 +2,6 @@ package useragent import ( "fmt" - "regexp" "strings" ) @@ -31,8 +30,13 @@ func buildSourceTagField(userAgent string) string { userAgent = strings.ToLower(userAgent) // Limit charset to [a-z0-9_ ] - re := regexp.MustCompile(`[^a-z0-9_ ]`) - userAgent = re.ReplaceAllString(userAgent, "") + var strBldr strings.Builder + for _, char := range userAgent { + if (char >= 'a' && char <= 'z') || (char >= '0' && char <= '9') || char == '_' || char == ' ' { + strBldr.WriteRune(char) + } + } + userAgent = strBldr.String() // Trim left/right whitespace userAgent = strings.TrimSpace(userAgent) From 145b9fee0890cb8f17325e8b0e982f1ec057549e Mon Sep 17 00:00:00 2001 From: Silas Smith Date: Tue, 26 Mar 2024 16:55:03 -0700 Subject: [PATCH 04/10] Fix existing test --- pinecone/index_connection_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pinecone/index_connection_test.go b/pinecone/index_connection_test.go index 6c1a2df..0bc5a4b 100644 --- a/pinecone/index_connection_test.go +++ b/pinecone/index_connection_test.go @@ -59,7 +59,7 @@ func (ts *IndexConnectionTests) SetupSuite() { namespace, err := uuid.NewV7() assert.NoError(ts.T(), err) - idxConn, err := newIndexConnection(ts.apiKey, ts.host, namespace.String()) + idxConn, err := newIndexConnection(ts.apiKey, ts.host, namespace.String(), "") assert.NoError(ts.T(), err) ts.idxConn = idxConn From 491e12ddee6fd2a5c3dd9a87911120892586bfe1 Mon Sep 17 00:00:00 2001 From: Silas Smith Date: Tue, 26 Mar 2024 17:25:14 -0700 Subject: [PATCH 05/10] Add tests for building useragent --- internal/useragent/useragent_test.go | 80 ++++++++++++++++++++++++++++ 1 file changed, 80 insertions(+) create mode 100644 internal/useragent/useragent_test.go diff --git a/internal/useragent/useragent_test.go b/internal/useragent/useragent_test.go new file mode 100644 index 0000000..c70246a --- /dev/null +++ b/internal/useragent/useragent_test.go @@ -0,0 +1,80 @@ +package useragent + +import ( + "strings" + "testing" +) + +func TestBuildUserAgentNoSourceTag(t *testing.T) { + sourceTag := "" + expectedStartWith := "go-client/" + result := BuildUserAgent(sourceTag) + if !strings.HasPrefix(result, expectedStartWith) { + t.Errorf("BuildUserAgent(): expected user-agent to start with %s, but got %s", expectedStartWith, result) + } + if strings.Contains(result, "source_tag") { + t.Errorf("BuildUserAgent(): expected user-agent to not contain 'source_tag', but got %s", result) + } +} + +func TestBuildUserAgentWithSourceTag(t *testing.T) { + sourceTag := "my_source_tag" + expectedStartWith := "go-client/" + result := BuildUserAgent(sourceTag) + if !strings.HasPrefix(result, expectedStartWith) { + t.Errorf("BuildUserAgent(): expected user-agent to start with %s, but got %s", expectedStartWith, result) + } + if !strings.Contains(result, "source_tag=my_source_tag") { + t.Errorf("BuildUserAgent(): expected user-agent to contain 'source_tag=my_source_tag', but got %s", result) + } +} + +func TestBuildUserAgentGRPCNoSourceTag(t *testing.T) { + sourceTag := "" + expectedStartWith := "go-client[grpc]/" + result := BuildUserAgentGRPC(sourceTag) + if !strings.HasPrefix(result, expectedStartWith) { + t.Errorf("BuildUserAgent(): expected user-agent to start with %s, but got %s", expectedStartWith, result) + } + if strings.Contains(result, "source_tag") { + t.Errorf("BuildUserAgent(): expected user-agent to not contain 'source_tag', but got %s", result) + } +} + +func TestBuildUserAgentGRPCWithSourceTag(t *testing.T) { + sourceTag := "my_source_tag" + expectedStartWith := "go-client[grpc]/" + result := BuildUserAgentGRPC(sourceTag) + if !strings.HasPrefix(result, expectedStartWith) { + t.Errorf("BuildUserAgent(): expected user-agent to start with %s, but got %s", expectedStartWith, result) + } + if !strings.Contains(result, "source_tag=my_source_tag") { + t.Errorf("BuildUserAgent(): expected user-agent to contain 'source_tag=my_source_tag', but got %s", result) + } +} + +func TestBuildUserAgentSourceTagIsNormalized(t *testing.T) { + sourceTag := "my source tag!!!!" + result := BuildUserAgent(sourceTag) + if !strings.Contains(result, "source_tag=my_source_tag") { + t.Errorf("BuildUserAgent(\"%s\"): expected user-agent to contain 'source_tag=my_source_tag', but got %s", sourceTag, result) + } + + sourceTag = "My Source Tag" + result = BuildUserAgent(sourceTag) + if !strings.Contains(result, "source_tag=my_source_tag") { + t.Errorf("BuildUserAgent(\"%s\"): expected user-agent to contain 'source_tag=my_source_tag', but got %s", sourceTag, result) + } + + sourceTag = " My Source Tag 123 " + result = BuildUserAgent(sourceTag) + if !strings.Contains(result, "source_tag=my_source_tag") { + t.Errorf("BuildUserAgent(\"%s\"): expected user-agent to contain 'source_tag=my_source_tag_123', but got %s", sourceTag, result) + } + + sourceTag = " My Source Tag 123 #### !! " + result = BuildUserAgent(sourceTag) + if !strings.Contains(result, "source_tag=my_source_tag") { + t.Errorf("BuildUserAgent(\"%s\"): expected user-agent to contain 'source_tag=my_source_tag_123', but got %s", sourceTag, result) + } +} From 0150639065d86d772ec6aecacdf925d6bbf93783 Mon Sep 17 00:00:00 2001 From: Silas Smith Date: Wed, 27 Mar 2024 06:32:44 -0700 Subject: [PATCH 06/10] Add unit test for header provider --- internal/provider/header_test.go | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) create mode 100644 internal/provider/header_test.go diff --git a/internal/provider/header_test.go b/internal/provider/header_test.go new file mode 100644 index 0000000..2802afb --- /dev/null +++ b/internal/provider/header_test.go @@ -0,0 +1,32 @@ +package provider + +import ( + "context" + "net/http" + "testing" +) + +func TestCustomHeaderIntercept(t *testing.T) { + expectedName := "X-Custom-Header" + expectedValue := "Custom-Value" + header := NewHeaderProvider(expectedName, expectedValue) + + req, err := http.NewRequest("GET", "https://example.com", nil) + if err != nil { + t.Fatalf("Failed to create HTTP request: %v", err) + } + + ctx := context.Background() + + // Call the Intercept method (method being tested) + err = header.Intercept(ctx, req) + if err != nil { + t.Errorf("Intercept failed: %v", err) + } + + // Verify that the custom header is set correctly + if req.Header.Get(expectedName) != expectedValue { + t.Errorf("Expected header '%s' to have value '%s', got '%s'", expectedName, expectedValue, + req.Header.Get(expectedName)) + } +} From 8a5e4681e3ddf4c28879b744e3c81f1f3c7aabe0 Mon Sep 17 00:00:00 2001 From: Silas Smith Date: Wed, 27 Mar 2024 08:39:12 -0700 Subject: [PATCH 07/10] Set version in useragent --- internal/useragent/useragent.go | 8 ++++++-- internal/useragent/useragent_test.go | 9 +++++---- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/internal/useragent/useragent.go b/internal/useragent/useragent.go index 1d86ab5..e52abec 100644 --- a/internal/useragent/useragent.go +++ b/internal/useragent/useragent.go @@ -5,6 +5,11 @@ import ( "strings" ) +func getPackageVersion() string { + // update at release time + return "v0.5.0-pre" +} + func BuildUserAgent(sourceTag string) string { return buildUserAgent("go-client", sourceTag) } @@ -14,8 +19,7 @@ func BuildUserAgentGRPC(sourceTag string) string { } func buildUserAgent(appName string, sourceTag string) string { - // need to set to actual current version - appVersion := "0.0.1" + appVersion := getPackageVersion() sourceTagInfo := "" if sourceTag != "" { diff --git a/internal/useragent/useragent_test.go b/internal/useragent/useragent_test.go index c70246a..16bf7f0 100644 --- a/internal/useragent/useragent_test.go +++ b/internal/useragent/useragent_test.go @@ -1,13 +1,14 @@ package useragent import ( + "fmt" "strings" "testing" ) func TestBuildUserAgentNoSourceTag(t *testing.T) { sourceTag := "" - expectedStartWith := "go-client/" + expectedStartWith := fmt.Sprintf("go-client/%s", getPackageVersion()) result := BuildUserAgent(sourceTag) if !strings.HasPrefix(result, expectedStartWith) { t.Errorf("BuildUserAgent(): expected user-agent to start with %s, but got %s", expectedStartWith, result) @@ -19,7 +20,7 @@ func TestBuildUserAgentNoSourceTag(t *testing.T) { func TestBuildUserAgentWithSourceTag(t *testing.T) { sourceTag := "my_source_tag" - expectedStartWith := "go-client/" + expectedStartWith := fmt.Sprintf("go-client/%s", getPackageVersion()) result := BuildUserAgent(sourceTag) if !strings.HasPrefix(result, expectedStartWith) { t.Errorf("BuildUserAgent(): expected user-agent to start with %s, but got %s", expectedStartWith, result) @@ -31,7 +32,7 @@ func TestBuildUserAgentWithSourceTag(t *testing.T) { func TestBuildUserAgentGRPCNoSourceTag(t *testing.T) { sourceTag := "" - expectedStartWith := "go-client[grpc]/" + expectedStartWith := fmt.Sprintf("go-client[grpc]/%s", getPackageVersion()) result := BuildUserAgentGRPC(sourceTag) if !strings.HasPrefix(result, expectedStartWith) { t.Errorf("BuildUserAgent(): expected user-agent to start with %s, but got %s", expectedStartWith, result) @@ -43,7 +44,7 @@ func TestBuildUserAgentGRPCNoSourceTag(t *testing.T) { func TestBuildUserAgentGRPCWithSourceTag(t *testing.T) { sourceTag := "my_source_tag" - expectedStartWith := "go-client[grpc]/" + expectedStartWith := fmt.Sprintf("go-client[grpc]/%s", getPackageVersion()) result := BuildUserAgentGRPC(sourceTag) if !strings.HasPrefix(result, expectedStartWith) { t.Errorf("BuildUserAgent(): expected user-agent to start with %s, but got %s", expectedStartWith, result) From b0ceea9d3324491ade5d5d09dfca80799408ebc0 Mon Sep 17 00:00:00 2001 From: Silas Smith Date: Wed, 27 Mar 2024 09:19:54 -0700 Subject: [PATCH 08/10] Formatting --- pinecone/client.go | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/pinecone/client.go b/pinecone/client.go index 9123298..d701480 100644 --- a/pinecone/client.go +++ b/pinecone/client.go @@ -19,9 +19,8 @@ type Client struct { } type NewClientParams struct { - ApiKey string - // optional fields - SourceTag string + ApiKey string + SourceTag string // optional } func NewClient(in NewClientParams) (*Client, error) { From a8ce2fda4a84c8aa9d3574f0fd2d7eabdae686aa Mon Sep 17 00:00:00 2001 From: Silas Smith Date: Wed, 27 Mar 2024 09:52:32 -0700 Subject: [PATCH 09/10] Add client/index tests when including a source_tag --- pinecone/client_test.go | 27 +++++++++++ pinecone/index_connection_test.go | 78 +++++++++++++++++++++++++++++-- 2 files changed, 100 insertions(+), 5 deletions(-) diff --git a/pinecone/client_test.go b/pinecone/client_test.go index a458f90..941967c 100644 --- a/pinecone/client_test.go +++ b/pinecone/client_test.go @@ -12,6 +12,8 @@ import ( type ClientTests struct { suite.Suite client Client + clientSourceTag Client + sourceTag string podIndex string serverlessIndex string } @@ -36,6 +38,13 @@ func (ts *ClientTests) SetupSuite() { } ts.client = *client + ts.sourceTag = "test_source_tag" + clientSourceTag, err := NewClient(NewClientParams{ApiKey: apiKey, SourceTag: ts.sourceTag}) + if err != nil { + ts.FailNow(err.Error()) + } + ts.clientSourceTag = *clientSourceTag + // this will clean up the project deleting all indexes and collections that are // named a UUID. Generally not needed as all tests are cleaning up after themselves // Left here as a convenience during active development. @@ -49,6 +58,12 @@ func (ts *ClientTests) TestListIndexes() { require.Greater(ts.T(), len(indexes), 0, "Expected at least one index to exist") } +func (ts *ClientTests) TestListIndexesSourceTag() { + indexes, err := ts.clientSourceTag.ListIndexes(context.Background()) + require.NoError(ts.T(), err) + require.Greater(ts.T(), len(indexes), 0, "Expected at least one index to exist") +} + func (ts *ClientTests) TestCreatePodIndex() { name := uuid.New().String() @@ -93,12 +108,24 @@ func (ts *ClientTests) TestDescribeServerlessIndex() { require.Equal(ts.T(), ts.serverlessIndex, index.Name, "Index name does not match") } +func (ts *ClientTests) TestDescribeServerlessIndexSourceTag() { + index, err := ts.clientSourceTag.DescribeIndex(context.Background(), ts.serverlessIndex) + require.NoError(ts.T(), err) + require.Equal(ts.T(), ts.serverlessIndex, index.Name, "Index name does not match") +} + func (ts *ClientTests) TestDescribePodIndex() { index, err := ts.client.DescribeIndex(context.Background(), ts.podIndex) require.NoError(ts.T(), err) require.Equal(ts.T(), ts.podIndex, index.Name, "Index name does not match") } +func (ts *ClientTests) TestDescribePodIndexSourceTag() { + index, err := ts.clientSourceTag.DescribeIndex(context.Background(), ts.podIndex) + require.NoError(ts.T(), err) + require.Equal(ts.T(), ts.podIndex, index.Name, "Index name does not match") +} + func (ts *ClientTests) TestListCollections() { ctx := context.Background() diff --git a/pinecone/index_connection_test.go b/pinecone/index_connection_test.go index 0bc5a4b..cce6bf1 100644 --- a/pinecone/index_connection_test.go +++ b/pinecone/index_connection_test.go @@ -12,11 +12,13 @@ import ( type IndexConnectionTests struct { suite.Suite - host string - dimension int32 - apiKey string - idxConn *IndexConnection - vectorIds []string + host string + dimension int32 + apiKey string + idxConn *IndexConnection + sourceTag string + idxConnSourceTag *IndexConnection + vectorIds []string } // Runs the test suite with `go test` @@ -63,6 +65,11 @@ func (ts *IndexConnectionTests) SetupSuite() { assert.NoError(ts.T(), err) ts.idxConn = idxConn + ts.sourceTag = "test_source_tag" + idxConnSourceTag, err := newIndexConnection(ts.apiKey, ts.host, namespace.String(), ts.sourceTag) + assert.NoError(ts.T(), err) + ts.idxConnSourceTag = idxConnSourceTag + ts.loadData() } @@ -80,6 +87,13 @@ func (ts *IndexConnectionTests) TestFetchVectors() { assert.NotNil(ts.T(), res) } +func (ts *IndexConnectionTests) TestFetchVectorsSourceTag() { + ctx := context.Background() + res, err := ts.idxConnSourceTag.FetchVectors(&ctx, ts.vectorIds) + assert.NoError(ts.T(), err) + assert.NotNil(ts.T(), res) +} + func (ts *IndexConnectionTests) TestQueryByVector() { vec := make([]float32, ts.dimension) for i := range vec { @@ -97,6 +111,23 @@ func (ts *IndexConnectionTests) TestQueryByVector() { assert.NotNil(ts.T(), res) } +func (ts *IndexConnectionTests) TestQueryByVectorSourceTag() { + vec := make([]float32, ts.dimension) + for i := range vec { + vec[i] = 0.01 + } + + req := &QueryByVectorValuesRequest{ + Vector: vec, + TopK: 5, + } + + ctx := context.Background() + res, err := ts.idxConnSourceTag.QueryByVectorValues(&ctx, req) + assert.NoError(ts.T(), err) + assert.NotNil(ts.T(), res) +} + func (ts *IndexConnectionTests) TestQueryById() { req := &QueryByVectorIdRequest{ VectorId: ts.vectorIds[0], @@ -109,6 +140,18 @@ func (ts *IndexConnectionTests) TestQueryById() { assert.NotNil(ts.T(), res) } +func (ts *IndexConnectionTests) TestQueryByIdSourceTag() { + req := &QueryByVectorIdRequest{ + VectorId: ts.vectorIds[0], + TopK: 5, + } + + ctx := context.Background() + res, err := ts.idxConnSourceTag.QueryByVectorId(&ctx, req) + assert.NoError(ts.T(), err) + assert.NotNil(ts.T(), res) +} + func (ts *IndexConnectionTests) TestDeleteVectorsById() { ctx := context.Background() err := ts.idxConn.DeleteVectorsById(&ctx, ts.vectorIds) @@ -182,6 +225,31 @@ func (ts *IndexConnectionTests) loadData() { assert.NoError(ts.T(), err) } +func (ts *IndexConnectionTests) loadDataSourceTag() { + vals := []float32{0.01, 0.02, 0.03, 0.04, 0.05} + vectors := make([]*Vector, len(vals)) + ts.vectorIds = make([]string, len(vals)) + + for i, val := range vals { + vec := make([]float32, ts.dimension) + for i := range vec { + vec[i] = val + } + + id := fmt.Sprintf("vec-%d", i+1) + ts.vectorIds[i] = id + + vectors[i] = &Vector{ + Id: id, + Values: vec, + } + } + + ctx := context.Background() + _, err := ts.idxConnSourceTag.UpsertVectors(&ctx, vectors) + assert.NoError(ts.T(), err) +} + func (ts *IndexConnectionTests) truncateData() { ctx := context.Background() err := ts.idxConn.DeleteAllVectorsInNamespace(&ctx) From b075e9b047e3ca862161542843f6de09728fece8 Mon Sep 17 00:00:00 2001 From: Silas Smith Date: Wed, 27 Mar 2024 11:09:13 -0700 Subject: [PATCH 10/10] Add tests for client/idxConnection creation --- pinecone/client_test.go | 41 ++++++++++++++++++++++++++++-- pinecone/index_connection_test.go | 42 +++++++++++++++++++++++++++++++ 2 files changed, 81 insertions(+), 2 deletions(-) diff --git a/pinecone/client_test.go b/pinecone/client_test.go index 941967c..446d2d3 100644 --- a/pinecone/client_test.go +++ b/pinecone/client_test.go @@ -2,11 +2,13 @@ package pinecone import ( "context" + "fmt" + "os" + "testing" + "github.com/google/uuid" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" - "os" - "testing" ) type ClientTests struct { @@ -52,6 +54,41 @@ func (ts *ClientTests) SetupSuite() { } +func (ts *ClientTests) TestNewClientParamsSet() { + apiKey := "test-api-key" + client, err := NewClient(NewClientParams{ApiKey: apiKey}) + if err != nil { + ts.FailNow(err.Error()) + } + if client.apiKey != apiKey { + ts.FailNow(fmt.Sprintf("Expected client to have apiKey '%s', but got '%s'", apiKey, client.apiKey)) + } + if client.sourceTag != "" { + ts.FailNow(fmt.Sprintf("Expected client to have empty sourceTag, but got '%s'", client.sourceTag)) + } + if len(client.restClient.RequestEditors) != 2 { + ts.FailNow("Expected 2 request editors on client") + } +} + +func (ts *ClientTests) TestNewClientParamsSetSourceTag() { + apiKey := "test-api-key" + sourceTag := "test-source-tag" + client, err := NewClient(NewClientParams{ApiKey: apiKey, SourceTag: sourceTag}) + if err != nil { + ts.FailNow(err.Error()) + } + if client.apiKey != apiKey { + ts.FailNow(fmt.Sprintf("Expected client to have apiKey '%s', but got '%s'", apiKey, client.apiKey)) + } + if client.sourceTag != sourceTag { + ts.FailNow(fmt.Sprintf("Expected client to have sourceTag '%s', but got '%s'", sourceTag, client.sourceTag)) + } + if len(client.restClient.RequestEditors) != 2 { + ts.FailNow("Expected 2 request editors on client") + } +} + func (ts *ClientTests) TestListIndexes() { indexes, err := ts.client.ListIndexes(context.Background()) require.NoError(ts.T(), err) diff --git a/pinecone/index_connection_test.go b/pinecone/index_connection_test.go index cce6bf1..7449160 100644 --- a/pinecone/index_connection_test.go +++ b/pinecone/index_connection_test.go @@ -80,6 +80,48 @@ func (ts *IndexConnectionTests) TearDownSuite() { assert.NoError(ts.T(), err) } +func (ts *IndexConnectionTests) TestNewIndexConnection() { + apiKey := "test-api-key" + namespace := "" + sourceTag := "" + idxConn, err := newIndexConnection(apiKey, ts.host, namespace, sourceTag) + assert.NoError(ts.T(), err) + + if idxConn.apiKey != apiKey { + ts.FailNow(fmt.Sprintf("Expected idxConn to have apiKey '%s', but got '%s'", apiKey, idxConn.apiKey)) + } + if idxConn.Namespace != "" { + ts.FailNow(fmt.Sprintf("Expected idxConn to have empty namespace, but got '%s'", idxConn.Namespace)) + } + if idxConn.dataClient == nil { + ts.FailNow("Expected idxConn to have non-nil dataClient") + } + if idxConn.grpcConn == nil { + ts.FailNow("Expected idxConn to have non-nil grpcConn") + } +} + +func (ts *IndexConnectionTests) TestNewIndexConnectionNamespace() { + apiKey := "test-api-key" + namespace := "test-namespace" + sourceTag := "test-source-tag" + idxConn, err := newIndexConnection(apiKey, ts.host, namespace, sourceTag) + assert.NoError(ts.T(), err) + + if idxConn.apiKey != apiKey { + ts.FailNow(fmt.Sprintf("Expected idxConn to have apiKey '%s', but got '%s'", apiKey, idxConn.apiKey)) + } + if idxConn.Namespace != namespace { + ts.FailNow(fmt.Sprintf("Expected idxConn to have namespace '%s', but got '%s'", namespace, idxConn.Namespace)) + } + if idxConn.dataClient == nil { + ts.FailNow("Expected idxConn to have non-nil dataClient") + } + if idxConn.grpcConn == nil { + ts.FailNow("Expected idxConn to have non-nil grpcConn") + } +} + func (ts *IndexConnectionTests) TestFetchVectors() { ctx := context.Background() res, err := ts.idxConn.FetchVectors(&ctx, ts.vectorIds)