diff --git a/internal/provider/header.go b/internal/provider/header.go new file mode 100644 index 0000000..7c880e7 --- /dev/null +++ b/internal/provider/header.go @@ -0,0 +1,20 @@ +package provider + +import ( + "context" + "net/http" +) + +type customHeader struct { + name string + value string +} + +func NewHeaderProvider(name string, value string) *customHeader { + return &customHeader{name: name, value: value} +} + +func (h *customHeader) Intercept(ctx context.Context, req *http.Request) error { + req.Header.Set(h.name, h.value) + return nil +} diff --git a/internal/provider/header_test.go b/internal/provider/header_test.go new file mode 100644 index 0000000..2802afb --- /dev/null +++ b/internal/provider/header_test.go @@ -0,0 +1,32 @@ +package provider + +import ( + "context" + "net/http" + "testing" +) + +func TestCustomHeaderIntercept(t *testing.T) { + expectedName := "X-Custom-Header" + expectedValue := "Custom-Value" + header := NewHeaderProvider(expectedName, expectedValue) + + req, err := http.NewRequest("GET", "https://example.com", nil) + if err != nil { + t.Fatalf("Failed to create HTTP request: %v", err) + } + + ctx := context.Background() + + // Call the Intercept method (method being tested) + err = header.Intercept(ctx, req) + if err != nil { + t.Errorf("Intercept failed: %v", err) + } + + // Verify that the custom header is set correctly + if req.Header.Get(expectedName) != expectedValue { + t.Errorf("Expected header '%s' to have value '%s', got '%s'", expectedName, expectedValue, + req.Header.Get(expectedName)) + } +} diff --git a/internal/useragent/useragent.go b/internal/useragent/useragent.go new file mode 100644 index 0000000..e52abec --- /dev/null +++ b/internal/useragent/useragent.go @@ -0,0 +1,52 @@ +package useragent + +import ( + "fmt" + "strings" +) + +func getPackageVersion() string { + // update at release time + return "v0.5.0-pre" +} + +func BuildUserAgent(sourceTag string) string { + return buildUserAgent("go-client", sourceTag) +} + +func BuildUserAgentGRPC(sourceTag string) string { + return buildUserAgent("go-client[grpc]", sourceTag) +} + +func buildUserAgent(appName string, sourceTag string) string { + appVersion := getPackageVersion() + + sourceTagInfo := "" + if sourceTag != "" { + sourceTagInfo = buildSourceTagField(sourceTag) + } + userAgent := fmt.Sprintf("%s/%s%s", appName, appVersion, sourceTagInfo) + return userAgent +} + +func buildSourceTagField(userAgent string) string { + // Lowercase + userAgent = strings.ToLower(userAgent) + + // Limit charset to [a-z0-9_ ] + var strBldr strings.Builder + for _, char := range userAgent { + if (char >= 'a' && char <= 'z') || (char >= '0' && char <= '9') || char == '_' || char == ' ' { + strBldr.WriteRune(char) + } + } + userAgent = strBldr.String() + + // Trim left/right whitespace + userAgent = strings.TrimSpace(userAgent) + + // Condense multiple spaces to one, and replace with underscore + userAgent = strings.Join(strings.Fields(userAgent), "_") + + return fmt.Sprintf("; source_tag=%s;", userAgent) +} diff --git a/internal/useragent/useragent_test.go b/internal/useragent/useragent_test.go new file mode 100644 index 0000000..16bf7f0 --- /dev/null +++ b/internal/useragent/useragent_test.go @@ -0,0 +1,81 @@ +package useragent + +import ( + "fmt" + "strings" + "testing" +) + +func TestBuildUserAgentNoSourceTag(t *testing.T) { + sourceTag := "" + expectedStartWith := fmt.Sprintf("go-client/%s", getPackageVersion()) + result := BuildUserAgent(sourceTag) + if !strings.HasPrefix(result, expectedStartWith) { + t.Errorf("BuildUserAgent(): expected user-agent to start with %s, but got %s", expectedStartWith, result) + } + if strings.Contains(result, "source_tag") { + t.Errorf("BuildUserAgent(): expected user-agent to not contain 'source_tag', but got %s", result) + } +} + +func TestBuildUserAgentWithSourceTag(t *testing.T) { + sourceTag := "my_source_tag" + expectedStartWith := fmt.Sprintf("go-client/%s", getPackageVersion()) + result := BuildUserAgent(sourceTag) + if !strings.HasPrefix(result, expectedStartWith) { + t.Errorf("BuildUserAgent(): expected user-agent to start with %s, but got %s", expectedStartWith, result) + } + if !strings.Contains(result, "source_tag=my_source_tag") { + t.Errorf("BuildUserAgent(): expected user-agent to contain 'source_tag=my_source_tag', but got %s", result) + } +} + +func TestBuildUserAgentGRPCNoSourceTag(t *testing.T) { + sourceTag := "" + expectedStartWith := fmt.Sprintf("go-client[grpc]/%s", getPackageVersion()) + result := BuildUserAgentGRPC(sourceTag) + if !strings.HasPrefix(result, expectedStartWith) { + t.Errorf("BuildUserAgent(): expected user-agent to start with %s, but got %s", expectedStartWith, result) + } + if strings.Contains(result, "source_tag") { + t.Errorf("BuildUserAgent(): expected user-agent to not contain 'source_tag', but got %s", result) + } +} + +func TestBuildUserAgentGRPCWithSourceTag(t *testing.T) { + sourceTag := "my_source_tag" + expectedStartWith := fmt.Sprintf("go-client[grpc]/%s", getPackageVersion()) + result := BuildUserAgentGRPC(sourceTag) + if !strings.HasPrefix(result, expectedStartWith) { + t.Errorf("BuildUserAgent(): expected user-agent to start with %s, but got %s", expectedStartWith, result) + } + if !strings.Contains(result, "source_tag=my_source_tag") { + t.Errorf("BuildUserAgent(): expected user-agent to contain 'source_tag=my_source_tag', but got %s", result) + } +} + +func TestBuildUserAgentSourceTagIsNormalized(t *testing.T) { + sourceTag := "my source tag!!!!" + result := BuildUserAgent(sourceTag) + if !strings.Contains(result, "source_tag=my_source_tag") { + t.Errorf("BuildUserAgent(\"%s\"): expected user-agent to contain 'source_tag=my_source_tag', but got %s", sourceTag, result) + } + + sourceTag = "My Source Tag" + result = BuildUserAgent(sourceTag) + if !strings.Contains(result, "source_tag=my_source_tag") { + t.Errorf("BuildUserAgent(\"%s\"): expected user-agent to contain 'source_tag=my_source_tag', but got %s", sourceTag, result) + } + + sourceTag = " My Source Tag 123 " + result = BuildUserAgent(sourceTag) + if !strings.Contains(result, "source_tag=my_source_tag") { + t.Errorf("BuildUserAgent(\"%s\"): expected user-agent to contain 'source_tag=my_source_tag_123', but got %s", sourceTag, result) + } + + sourceTag = " My Source Tag 123 #### !! " + result = BuildUserAgent(sourceTag) + if !strings.Contains(result, "source_tag=my_source_tag") { + t.Errorf("BuildUserAgent(\"%s\"): expected user-agent to contain 'source_tag=my_source_tag_123', but got %s", sourceTag, result) + } +} diff --git a/pinecone/client.go b/pinecone/client.go index 09a6ce8..d701480 100644 --- a/pinecone/client.go +++ b/pinecone/client.go @@ -6,6 +6,8 @@ import ( "fmt" "github.com/deepmap/oapi-codegen/v2/pkg/securityprovider" "github.com/pinecone-io/go-pinecone/internal/gen/control" + "github.com/pinecone-io/go-pinecone/internal/provider" + "github.com/pinecone-io/go-pinecone/internal/useragent" "io" "net/http" ) @@ -13,10 +15,12 @@ import ( type Client struct { apiKey string restClient *control.Client + sourceTag string } type NewClientParams struct { - ApiKey string + ApiKey string + SourceTag string // optional } func NewClient(in NewClientParams) (*Client, error) { @@ -25,12 +29,17 @@ func NewClient(in NewClientParams) (*Client, error) { return nil, err } - client, err := control.NewClient("https://api.pinecone.io", control.WithRequestEditorFn(apiKeyProvider.Intercept)) + userAgentProvider := provider.NewHeaderProvider("User-Agent", useragent.BuildUserAgent(in.SourceTag)) + + client, err := control.NewClient("https://api.pinecone.io", + control.WithRequestEditorFn(apiKeyProvider.Intercept), + control.WithRequestEditorFn(userAgentProvider.Intercept), + ) if err != nil { return nil, err } - c := Client{apiKey: in.ApiKey, restClient: client} + c := Client{apiKey: in.ApiKey, restClient: client, sourceTag: in.SourceTag} return &c, nil } @@ -39,7 +48,7 @@ func (c *Client) Index(host string) (*IndexConnection, error) { } func (c *Client) IndexWithNamespace(host string, namespace string) (*IndexConnection, error) { - idx, err := newIndexConnection(c.apiKey, host, namespace) + idx, err := newIndexConnection(c.apiKey, host, namespace, c.sourceTag) if err != nil { return nil, err } diff --git a/pinecone/client_test.go b/pinecone/client_test.go index a458f90..446d2d3 100644 --- a/pinecone/client_test.go +++ b/pinecone/client_test.go @@ -2,16 +2,20 @@ package pinecone import ( "context" + "fmt" + "os" + "testing" + "github.com/google/uuid" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" - "os" - "testing" ) type ClientTests struct { suite.Suite client Client + clientSourceTag Client + sourceTag string podIndex string serverlessIndex string } @@ -36,6 +40,13 @@ func (ts *ClientTests) SetupSuite() { } ts.client = *client + ts.sourceTag = "test_source_tag" + clientSourceTag, err := NewClient(NewClientParams{ApiKey: apiKey, SourceTag: ts.sourceTag}) + if err != nil { + ts.FailNow(err.Error()) + } + ts.clientSourceTag = *clientSourceTag + // this will clean up the project deleting all indexes and collections that are // named a UUID. Generally not needed as all tests are cleaning up after themselves // Left here as a convenience during active development. @@ -43,12 +54,53 @@ func (ts *ClientTests) SetupSuite() { } +func (ts *ClientTests) TestNewClientParamsSet() { + apiKey := "test-api-key" + client, err := NewClient(NewClientParams{ApiKey: apiKey}) + if err != nil { + ts.FailNow(err.Error()) + } + if client.apiKey != apiKey { + ts.FailNow(fmt.Sprintf("Expected client to have apiKey '%s', but got '%s'", apiKey, client.apiKey)) + } + if client.sourceTag != "" { + ts.FailNow(fmt.Sprintf("Expected client to have empty sourceTag, but got '%s'", client.sourceTag)) + } + if len(client.restClient.RequestEditors) != 2 { + ts.FailNow("Expected 2 request editors on client") + } +} + +func (ts *ClientTests) TestNewClientParamsSetSourceTag() { + apiKey := "test-api-key" + sourceTag := "test-source-tag" + client, err := NewClient(NewClientParams{ApiKey: apiKey, SourceTag: sourceTag}) + if err != nil { + ts.FailNow(err.Error()) + } + if client.apiKey != apiKey { + ts.FailNow(fmt.Sprintf("Expected client to have apiKey '%s', but got '%s'", apiKey, client.apiKey)) + } + if client.sourceTag != sourceTag { + ts.FailNow(fmt.Sprintf("Expected client to have sourceTag '%s', but got '%s'", sourceTag, client.sourceTag)) + } + if len(client.restClient.RequestEditors) != 2 { + ts.FailNow("Expected 2 request editors on client") + } +} + func (ts *ClientTests) TestListIndexes() { indexes, err := ts.client.ListIndexes(context.Background()) require.NoError(ts.T(), err) require.Greater(ts.T(), len(indexes), 0, "Expected at least one index to exist") } +func (ts *ClientTests) TestListIndexesSourceTag() { + indexes, err := ts.clientSourceTag.ListIndexes(context.Background()) + require.NoError(ts.T(), err) + require.Greater(ts.T(), len(indexes), 0, "Expected at least one index to exist") +} + func (ts *ClientTests) TestCreatePodIndex() { name := uuid.New().String() @@ -93,12 +145,24 @@ func (ts *ClientTests) TestDescribeServerlessIndex() { require.Equal(ts.T(), ts.serverlessIndex, index.Name, "Index name does not match") } +func (ts *ClientTests) TestDescribeServerlessIndexSourceTag() { + index, err := ts.clientSourceTag.DescribeIndex(context.Background(), ts.serverlessIndex) + require.NoError(ts.T(), err) + require.Equal(ts.T(), ts.serverlessIndex, index.Name, "Index name does not match") +} + func (ts *ClientTests) TestDescribePodIndex() { index, err := ts.client.DescribeIndex(context.Background(), ts.podIndex) require.NoError(ts.T(), err) require.Equal(ts.T(), ts.podIndex, index.Name, "Index name does not match") } +func (ts *ClientTests) TestDescribePodIndexSourceTag() { + index, err := ts.clientSourceTag.DescribeIndex(context.Background(), ts.podIndex) + require.NoError(ts.T(), err) + require.Equal(ts.T(), ts.podIndex, index.Name, "Index name does not match") +} + func (ts *ClientTests) TestListCollections() { ctx := context.Background() diff --git a/pinecone/index_connection.go b/pinecone/index_connection.go index 6d1a3ab..59fc5c8 100644 --- a/pinecone/index_connection.go +++ b/pinecone/index_connection.go @@ -5,6 +5,7 @@ import ( "crypto/tls" "fmt" "github.com/pinecone-io/go-pinecone/internal/gen/data" + "github.com/pinecone-io/go-pinecone/internal/useragent" "google.golang.org/grpc" "google.golang.org/grpc/credentials" "google.golang.org/grpc/metadata" @@ -18,7 +19,7 @@ type IndexConnection struct { grpcConn *grpc.ClientConn } -func newIndexConnection(apiKey string, host string, namespace string) (*IndexConnection, error) { +func newIndexConnection(apiKey string, host string, namespace string, sourceTag string) (*IndexConnection, error) { config := &tls.Config{} target := fmt.Sprintf("%s:443", host) conn, err := grpc.Dial( @@ -26,6 +27,7 @@ func newIndexConnection(apiKey string, host string, namespace string) (*IndexCon grpc.WithTransportCredentials(credentials.NewTLS(config)), grpc.WithAuthority(target), grpc.WithBlock(), + grpc.WithUserAgent(useragent.BuildUserAgentGRPC(sourceTag)), ) if err != nil { diff --git a/pinecone/index_connection_test.go b/pinecone/index_connection_test.go index 6c1a2df..7449160 100644 --- a/pinecone/index_connection_test.go +++ b/pinecone/index_connection_test.go @@ -12,11 +12,13 @@ import ( type IndexConnectionTests struct { suite.Suite - host string - dimension int32 - apiKey string - idxConn *IndexConnection - vectorIds []string + host string + dimension int32 + apiKey string + idxConn *IndexConnection + sourceTag string + idxConnSourceTag *IndexConnection + vectorIds []string } // Runs the test suite with `go test` @@ -59,10 +61,15 @@ func (ts *IndexConnectionTests) SetupSuite() { namespace, err := uuid.NewV7() assert.NoError(ts.T(), err) - idxConn, err := newIndexConnection(ts.apiKey, ts.host, namespace.String()) + idxConn, err := newIndexConnection(ts.apiKey, ts.host, namespace.String(), "") assert.NoError(ts.T(), err) ts.idxConn = idxConn + ts.sourceTag = "test_source_tag" + idxConnSourceTag, err := newIndexConnection(ts.apiKey, ts.host, namespace.String(), ts.sourceTag) + assert.NoError(ts.T(), err) + ts.idxConnSourceTag = idxConnSourceTag + ts.loadData() } @@ -73,6 +80,48 @@ func (ts *IndexConnectionTests) TearDownSuite() { assert.NoError(ts.T(), err) } +func (ts *IndexConnectionTests) TestNewIndexConnection() { + apiKey := "test-api-key" + namespace := "" + sourceTag := "" + idxConn, err := newIndexConnection(apiKey, ts.host, namespace, sourceTag) + assert.NoError(ts.T(), err) + + if idxConn.apiKey != apiKey { + ts.FailNow(fmt.Sprintf("Expected idxConn to have apiKey '%s', but got '%s'", apiKey, idxConn.apiKey)) + } + if idxConn.Namespace != "" { + ts.FailNow(fmt.Sprintf("Expected idxConn to have empty namespace, but got '%s'", idxConn.Namespace)) + } + if idxConn.dataClient == nil { + ts.FailNow("Expected idxConn to have non-nil dataClient") + } + if idxConn.grpcConn == nil { + ts.FailNow("Expected idxConn to have non-nil grpcConn") + } +} + +func (ts *IndexConnectionTests) TestNewIndexConnectionNamespace() { + apiKey := "test-api-key" + namespace := "test-namespace" + sourceTag := "test-source-tag" + idxConn, err := newIndexConnection(apiKey, ts.host, namespace, sourceTag) + assert.NoError(ts.T(), err) + + if idxConn.apiKey != apiKey { + ts.FailNow(fmt.Sprintf("Expected idxConn to have apiKey '%s', but got '%s'", apiKey, idxConn.apiKey)) + } + if idxConn.Namespace != namespace { + ts.FailNow(fmt.Sprintf("Expected idxConn to have namespace '%s', but got '%s'", namespace, idxConn.Namespace)) + } + if idxConn.dataClient == nil { + ts.FailNow("Expected idxConn to have non-nil dataClient") + } + if idxConn.grpcConn == nil { + ts.FailNow("Expected idxConn to have non-nil grpcConn") + } +} + func (ts *IndexConnectionTests) TestFetchVectors() { ctx := context.Background() res, err := ts.idxConn.FetchVectors(&ctx, ts.vectorIds) @@ -80,6 +129,13 @@ func (ts *IndexConnectionTests) TestFetchVectors() { assert.NotNil(ts.T(), res) } +func (ts *IndexConnectionTests) TestFetchVectorsSourceTag() { + ctx := context.Background() + res, err := ts.idxConnSourceTag.FetchVectors(&ctx, ts.vectorIds) + assert.NoError(ts.T(), err) + assert.NotNil(ts.T(), res) +} + func (ts *IndexConnectionTests) TestQueryByVector() { vec := make([]float32, ts.dimension) for i := range vec { @@ -97,6 +153,23 @@ func (ts *IndexConnectionTests) TestQueryByVector() { assert.NotNil(ts.T(), res) } +func (ts *IndexConnectionTests) TestQueryByVectorSourceTag() { + vec := make([]float32, ts.dimension) + for i := range vec { + vec[i] = 0.01 + } + + req := &QueryByVectorValuesRequest{ + Vector: vec, + TopK: 5, + } + + ctx := context.Background() + res, err := ts.idxConnSourceTag.QueryByVectorValues(&ctx, req) + assert.NoError(ts.T(), err) + assert.NotNil(ts.T(), res) +} + func (ts *IndexConnectionTests) TestQueryById() { req := &QueryByVectorIdRequest{ VectorId: ts.vectorIds[0], @@ -109,6 +182,18 @@ func (ts *IndexConnectionTests) TestQueryById() { assert.NotNil(ts.T(), res) } +func (ts *IndexConnectionTests) TestQueryByIdSourceTag() { + req := &QueryByVectorIdRequest{ + VectorId: ts.vectorIds[0], + TopK: 5, + } + + ctx := context.Background() + res, err := ts.idxConnSourceTag.QueryByVectorId(&ctx, req) + assert.NoError(ts.T(), err) + assert.NotNil(ts.T(), res) +} + func (ts *IndexConnectionTests) TestDeleteVectorsById() { ctx := context.Background() err := ts.idxConn.DeleteVectorsById(&ctx, ts.vectorIds) @@ -182,6 +267,31 @@ func (ts *IndexConnectionTests) loadData() { assert.NoError(ts.T(), err) } +func (ts *IndexConnectionTests) loadDataSourceTag() { + vals := []float32{0.01, 0.02, 0.03, 0.04, 0.05} + vectors := make([]*Vector, len(vals)) + ts.vectorIds = make([]string, len(vals)) + + for i, val := range vals { + vec := make([]float32, ts.dimension) + for i := range vec { + vec[i] = val + } + + id := fmt.Sprintf("vec-%d", i+1) + ts.vectorIds[i] = id + + vectors[i] = &Vector{ + Id: id, + Values: vec, + } + } + + ctx := context.Background() + _, err := ts.idxConnSourceTag.UpsertVectors(&ctx, vectors) + assert.NoError(ts.T(), err) +} + func (ts *IndexConnectionTests) truncateData() { ctx := context.Background() err := ts.idxConn.DeleteAllVectorsInNamespace(&ctx)