Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds optional source_tag to User-Agent #16

Merged
merged 10 commits into from
Mar 27, 2024
20 changes: 20 additions & 0 deletions internal/provider/header.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package provider

import (
"context"
"net/http"
)

type customHeader struct {
name string
value string
}

func NewHeaderProvider(name string, value string) *customHeader {
return &customHeader{name: name, value: value}
}

func (h *customHeader) Intercept(ctx context.Context, req *http.Request) error {
req.Header.Set(h.name, h.value)
return nil
}
32 changes: 32 additions & 0 deletions internal/provider/header_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package provider

import (
"context"
"net/http"
"testing"
)

func TestCustomHeaderIntercept(t *testing.T) {
expectedName := "X-Custom-Header"
expectedValue := "Custom-Value"
header := NewHeaderProvider(expectedName, expectedValue)

req, err := http.NewRequest("GET", "https://example.com", nil)
if err != nil {
t.Fatalf("Failed to create HTTP request: %v", err)
}

ctx := context.Background()

// Call the Intercept method (method being tested)
err = header.Intercept(ctx, req)
if err != nil {
t.Errorf("Intercept failed: %v", err)
}

// Verify that the custom header is set correctly
if req.Header.Get(expectedName) != expectedValue {
t.Errorf("Expected header '%s' to have value '%s', got '%s'", expectedName, expectedValue,
req.Header.Get(expectedName))
}
}
52 changes: 52 additions & 0 deletions internal/useragent/useragent.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
package useragent

import (
"fmt"
"strings"
)

func getPackageVersion() string {
// update at release time
return "v0.5.0-pre"
}
ssmith-pc marked this conversation as resolved.
Show resolved Hide resolved

func BuildUserAgent(sourceTag string) string {
return buildUserAgent("go-client", sourceTag)
}

func BuildUserAgentGRPC(sourceTag string) string {
return buildUserAgent("go-client[grpc]", sourceTag)
}

func buildUserAgent(appName string, sourceTag string) string {
appVersion := getPackageVersion()

sourceTagInfo := ""
if sourceTag != "" {
sourceTagInfo = buildSourceTagField(sourceTag)
}
userAgent := fmt.Sprintf("%s/%s%s", appName, appVersion, sourceTagInfo)
return userAgent
}

func buildSourceTagField(userAgent string) string {
// Lowercase
userAgent = strings.ToLower(userAgent)

// Limit charset to [a-z0-9_ ]
var strBldr strings.Builder
for _, char := range userAgent {
if (char >= 'a' && char <= 'z') || (char >= '0' && char <= '9') || char == '_' || char == ' ' {
strBldr.WriteRune(char)
}
}
userAgent = strBldr.String()

// Trim left/right whitespace
userAgent = strings.TrimSpace(userAgent)

// Condense multiple spaces to one, and replace with underscore
userAgent = strings.Join(strings.Fields(userAgent), "_")

return fmt.Sprintf("; source_tag=%s;", userAgent)
}
81 changes: 81 additions & 0 deletions internal/useragent/useragent_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
package useragent

import (
"fmt"
"strings"
"testing"
)

func TestBuildUserAgentNoSourceTag(t *testing.T) {
sourceTag := ""
expectedStartWith := fmt.Sprintf("go-client/%s", getPackageVersion())
result := BuildUserAgent(sourceTag)
if !strings.HasPrefix(result, expectedStartWith) {
t.Errorf("BuildUserAgent(): expected user-agent to start with %s, but got %s", expectedStartWith, result)
}
if strings.Contains(result, "source_tag") {
t.Errorf("BuildUserAgent(): expected user-agent to not contain 'source_tag', but got %s", result)
}
}

func TestBuildUserAgentWithSourceTag(t *testing.T) {
sourceTag := "my_source_tag"
expectedStartWith := fmt.Sprintf("go-client/%s", getPackageVersion())
result := BuildUserAgent(sourceTag)
if !strings.HasPrefix(result, expectedStartWith) {
t.Errorf("BuildUserAgent(): expected user-agent to start with %s, but got %s", expectedStartWith, result)
}
if !strings.Contains(result, "source_tag=my_source_tag") {
t.Errorf("BuildUserAgent(): expected user-agent to contain 'source_tag=my_source_tag', but got %s", result)
}
}

func TestBuildUserAgentGRPCNoSourceTag(t *testing.T) {
sourceTag := ""
expectedStartWith := fmt.Sprintf("go-client[grpc]/%s", getPackageVersion())
result := BuildUserAgentGRPC(sourceTag)
if !strings.HasPrefix(result, expectedStartWith) {
t.Errorf("BuildUserAgent(): expected user-agent to start with %s, but got %s", expectedStartWith, result)
}
if strings.Contains(result, "source_tag") {
t.Errorf("BuildUserAgent(): expected user-agent to not contain 'source_tag', but got %s", result)
}
}

func TestBuildUserAgentGRPCWithSourceTag(t *testing.T) {
sourceTag := "my_source_tag"
expectedStartWith := fmt.Sprintf("go-client[grpc]/%s", getPackageVersion())
result := BuildUserAgentGRPC(sourceTag)
if !strings.HasPrefix(result, expectedStartWith) {
t.Errorf("BuildUserAgent(): expected user-agent to start with %s, but got %s", expectedStartWith, result)
}
if !strings.Contains(result, "source_tag=my_source_tag") {
t.Errorf("BuildUserAgent(): expected user-agent to contain 'source_tag=my_source_tag', but got %s", result)
}
}

func TestBuildUserAgentSourceTagIsNormalized(t *testing.T) {
sourceTag := "my source tag!!!!"
result := BuildUserAgent(sourceTag)
if !strings.Contains(result, "source_tag=my_source_tag") {
t.Errorf("BuildUserAgent(\"%s\"): expected user-agent to contain 'source_tag=my_source_tag', but got %s", sourceTag, result)
}

sourceTag = "My Source Tag"
result = BuildUserAgent(sourceTag)
if !strings.Contains(result, "source_tag=my_source_tag") {
t.Errorf("BuildUserAgent(\"%s\"): expected user-agent to contain 'source_tag=my_source_tag', but got %s", sourceTag, result)
}

sourceTag = " My Source Tag 123 "
result = BuildUserAgent(sourceTag)
if !strings.Contains(result, "source_tag=my_source_tag") {
t.Errorf("BuildUserAgent(\"%s\"): expected user-agent to contain 'source_tag=my_source_tag_123', but got %s", sourceTag, result)
}

sourceTag = " My Source Tag 123 #### !! "
result = BuildUserAgent(sourceTag)
if !strings.Contains(result, "source_tag=my_source_tag") {
t.Errorf("BuildUserAgent(\"%s\"): expected user-agent to contain 'source_tag=my_source_tag_123', but got %s", sourceTag, result)
}
}
17 changes: 13 additions & 4 deletions pinecone/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,21 @@ import (
"fmt"
"github.com/deepmap/oapi-codegen/v2/pkg/securityprovider"
"github.com/pinecone-io/go-pinecone/internal/gen/control"
"github.com/pinecone-io/go-pinecone/internal/provider"
"github.com/pinecone-io/go-pinecone/internal/useragent"
"io"
"net/http"
)

type Client struct {
apiKey string
restClient *control.Client
sourceTag string
ssmith-pc marked this conversation as resolved.
Show resolved Hide resolved
}

type NewClientParams struct {
ApiKey string
ApiKey string
SourceTag string // optional
}

func NewClient(in NewClientParams) (*Client, error) {
Expand All @@ -25,12 +29,17 @@ func NewClient(in NewClientParams) (*Client, error) {
return nil, err
}

client, err := control.NewClient("https://api.pinecone.io", control.WithRequestEditorFn(apiKeyProvider.Intercept))
userAgentProvider := provider.NewHeaderProvider("User-Agent", useragent.BuildUserAgent(in.SourceTag))

client, err := control.NewClient("https://api.pinecone.io",
control.WithRequestEditorFn(apiKeyProvider.Intercept),
control.WithRequestEditorFn(userAgentProvider.Intercept),
)
if err != nil {
return nil, err
}

c := Client{apiKey: in.ApiKey, restClient: client}
c := Client{apiKey: in.ApiKey, restClient: client, sourceTag: in.SourceTag}
return &c, nil
}

Expand All @@ -39,7 +48,7 @@ func (c *Client) Index(host string) (*IndexConnection, error) {
}

func (c *Client) IndexWithNamespace(host string, namespace string) (*IndexConnection, error) {
idx, err := newIndexConnection(c.apiKey, host, namespace)
idx, err := newIndexConnection(c.apiKey, host, namespace, c.sourceTag)
if err != nil {
return nil, err
}
Expand Down
68 changes: 66 additions & 2 deletions pinecone/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,20 @@ package pinecone

import (
"context"
"fmt"
"os"
"testing"

"github.com/google/uuid"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
"os"
"testing"
)

type ClientTests struct {
suite.Suite
client Client
clientSourceTag Client
sourceTag string
podIndex string
serverlessIndex string
}
Expand All @@ -36,19 +40,67 @@ func (ts *ClientTests) SetupSuite() {
}
ts.client = *client

ts.sourceTag = "test_source_tag"
clientSourceTag, err := NewClient(NewClientParams{ApiKey: apiKey, SourceTag: ts.sourceTag})
if err != nil {
ts.FailNow(err.Error())
}
ts.clientSourceTag = *clientSourceTag

// this will clean up the project deleting all indexes and collections that are
// named a UUID. Generally not needed as all tests are cleaning up after themselves
// Left here as a convenience during active development.
//deleteUUIDNamedResources(context.Background(), &ts.client)

}

func (ts *ClientTests) TestNewClientParamsSet() {
apiKey := "test-api-key"
client, err := NewClient(NewClientParams{ApiKey: apiKey})
if err != nil {
ts.FailNow(err.Error())
}
if client.apiKey != apiKey {
ts.FailNow(fmt.Sprintf("Expected client to have apiKey '%s', but got '%s'", apiKey, client.apiKey))
}
if client.sourceTag != "" {
ts.FailNow(fmt.Sprintf("Expected client to have empty sourceTag, but got '%s'", client.sourceTag))
}
if len(client.restClient.RequestEditors) != 2 {
ts.FailNow("Expected 2 request editors on client")
}
}

func (ts *ClientTests) TestNewClientParamsSetSourceTag() {
apiKey := "test-api-key"
sourceTag := "test-source-tag"
client, err := NewClient(NewClientParams{ApiKey: apiKey, SourceTag: sourceTag})
if err != nil {
ts.FailNow(err.Error())
}
if client.apiKey != apiKey {
ts.FailNow(fmt.Sprintf("Expected client to have apiKey '%s', but got '%s'", apiKey, client.apiKey))
}
if client.sourceTag != sourceTag {
ts.FailNow(fmt.Sprintf("Expected client to have sourceTag '%s', but got '%s'", sourceTag, client.sourceTag))
}
if len(client.restClient.RequestEditors) != 2 {
ts.FailNow("Expected 2 request editors on client")
}
}

func (ts *ClientTests) TestListIndexes() {
indexes, err := ts.client.ListIndexes(context.Background())
require.NoError(ts.T(), err)
require.Greater(ts.T(), len(indexes), 0, "Expected at least one index to exist")
}

func (ts *ClientTests) TestListIndexesSourceTag() {
indexes, err := ts.clientSourceTag.ListIndexes(context.Background())
require.NoError(ts.T(), err)
require.Greater(ts.T(), len(indexes), 0, "Expected at least one index to exist")
}

func (ts *ClientTests) TestCreatePodIndex() {
name := uuid.New().String()

Expand Down Expand Up @@ -93,12 +145,24 @@ func (ts *ClientTests) TestDescribeServerlessIndex() {
require.Equal(ts.T(), ts.serverlessIndex, index.Name, "Index name does not match")
}

func (ts *ClientTests) TestDescribeServerlessIndexSourceTag() {
index, err := ts.clientSourceTag.DescribeIndex(context.Background(), ts.serverlessIndex)
require.NoError(ts.T(), err)
require.Equal(ts.T(), ts.serverlessIndex, index.Name, "Index name does not match")
}

func (ts *ClientTests) TestDescribePodIndex() {
index, err := ts.client.DescribeIndex(context.Background(), ts.podIndex)
require.NoError(ts.T(), err)
require.Equal(ts.T(), ts.podIndex, index.Name, "Index name does not match")
}

func (ts *ClientTests) TestDescribePodIndexSourceTag() {
index, err := ts.clientSourceTag.DescribeIndex(context.Background(), ts.podIndex)
require.NoError(ts.T(), err)
require.Equal(ts.T(), ts.podIndex, index.Name, "Index name does not match")
}

func (ts *ClientTests) TestListCollections() {
ctx := context.Background()

Expand Down
4 changes: 3 additions & 1 deletion pinecone/index_connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"crypto/tls"
"fmt"
"github.com/pinecone-io/go-pinecone/internal/gen/data"
"github.com/pinecone-io/go-pinecone/internal/useragent"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/metadata"
Expand All @@ -18,14 +19,15 @@ type IndexConnection struct {
grpcConn *grpc.ClientConn
}

func newIndexConnection(apiKey string, host string, namespace string) (*IndexConnection, error) {
func newIndexConnection(apiKey string, host string, namespace string, sourceTag string) (*IndexConnection, error) {
config := &tls.Config{}
target := fmt.Sprintf("%s:443", host)
conn, err := grpc.Dial(
target,
grpc.WithTransportCredentials(credentials.NewTLS(config)),
grpc.WithAuthority(target),
grpc.WithBlock(),
grpc.WithUserAgent(useragent.BuildUserAgentGRPC(sourceTag)),
)

if err != nil {
Expand Down
Loading
Loading