From 5a9897e1f40b0b782e2e24c99c854e0633e52c5e Mon Sep 17 00:00:00 2001 From: Austin DeNoble Date: Mon, 6 May 2024 16:48:58 -0400 Subject: [PATCH] Optional configs for setting custom headers / metadata for REST / gRPC operations (#18) ## Problem Currently, the Go SDK does not support specifying custom headers or metadata for control & data plane operations (REST and gRPC). This is a useful feature to aid in debugging or tracking a specific request, and we'd like to enable this functionality in the Go SDK. ## Solution - Update `NewClientParams` and `Client` structs to support `headers`, also allow passing a custom `RestClient` to the new client. This is primarily for allowing mocking in unit tests, but we have a similar approach in other clients allowing users to customize the HTTP module. - Add new `buildClientOptions` function. - Update `NewClient` to support appending `Headers` and `RestClient` to `clientOptions`. - Add error message for empty `ApiKey` without any `Authorization` header provided. If the user passes both, avoid applying the `Api-Key` header in favor of the `Authorization` header. - Add new `Client.IndexWithAdditionalMetadata` constructor. - Update `IndexConnection` struct to include `additionalMetadata`, and update `newIndexConnection` to accept `additionalMetadata` as an argument. - Update `IndexConnection.akCtx` to handle appending the `api-key` and any `additionalMetadata` to requests. - Update unit tests, add new `mocks` package and `mock_transport` to facilitate validating requests made via `Client` including. I spent some time trying to figure out how to mock / test `IndexConnection` but I think it'll be a bit trickier than `Client`, so somewhat saving for a future PR atm. ## Type of Change - [ ] Bug fix (non-breaking change which fixes an issue) - [X] New feature (non-breaking change which adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected) - [ ] This change requires a documentation update - [ ] Infrastructure change (CI configs, etc) - [ ] Non-code change (docs, etc) - [ ] None of the above: (explain here) ## Test Plan `just test` to run the test suite locally. Validate tests pass in CI. --- internal/mocks/mock_transport.go | 30 +++++++++++ justfile | 2 +- pinecone/client.go | 69 +++++++++++++++++++------ pinecone/client_test.go | 84 ++++++++++++++++++++++++++++++- pinecone/index_connection.go | 26 ++++++++-- pinecone/index_connection_test.go | 51 +++++++++++++++++-- pinecone/models.go | 2 +- 7 files changed, 237 insertions(+), 27 deletions(-) create mode 100644 internal/mocks/mock_transport.go diff --git a/internal/mocks/mock_transport.go b/internal/mocks/mock_transport.go new file mode 100644 index 0000000..25ddb3f --- /dev/null +++ b/internal/mocks/mock_transport.go @@ -0,0 +1,30 @@ +package mocks + +import ( + "bytes" + "io" + "net/http" +) + +type MockTransport struct { + Req *http.Request + Resp *http.Response + Err error +} + +func (m *MockTransport) RoundTrip(req *http.Request) (*http.Response, error) { + m.Req = req + return m.Resp, m.Err +} + +func CreateMockClient(jsonBody string) *http.Client { + return &http.Client { + Transport: &MockTransport{ + Resp: &http.Response{ + StatusCode: 200, + Body: io.NopCloser(bytes.NewReader([]byte(jsonBody))), + Header: make(http.Header), + }, + }, + } +} \ No newline at end of file diff --git a/justfile b/justfile index 2196223..64d19e0 100644 --- a/justfile +++ b/justfile @@ -3,7 +3,7 @@ test: set -o allexport source .env set +o allexport - go test -count=1 ./pinecone + go test -count=1 -v ./pinecone bootstrap: go install google.golang.org/protobuf/cmd/protoc-gen-go@latest go install google.golang.org/grpc/cmd/protoc-gen-go-grpc@latest diff --git a/pinecone/client.go b/pinecone/client.go index d701480..97e8e91 100644 --- a/pinecone/client.go +++ b/pinecone/client.go @@ -4,51 +4,55 @@ import ( "context" "encoding/json" "fmt" + "io" + "net/http" + "strings" + "github.com/deepmap/oapi-codegen/v2/pkg/securityprovider" "github.com/pinecone-io/go-pinecone/internal/gen/control" "github.com/pinecone-io/go-pinecone/internal/provider" "github.com/pinecone-io/go-pinecone/internal/useragent" - "io" - "net/http" ) type Client struct { apiKey string restClient *control.Client sourceTag string + headers map[string]string } type NewClientParams struct { - ApiKey string - SourceTag string // optional + ApiKey string // required unless Authorization header provided + SourceTag string // optional + Headers map[string]string // optional + RestClient *http.Client // optional } func NewClient(in NewClientParams) (*Client, error) { - apiKeyProvider, err := securityprovider.NewSecurityProviderApiKey("header", "Api-Key", in.ApiKey) + clientOptions, err := buildClientOptions(in) if err != nil { return nil, err } - userAgentProvider := provider.NewHeaderProvider("User-Agent", useragent.BuildUserAgent(in.SourceTag)) - - client, err := control.NewClient("https://api.pinecone.io", - control.WithRequestEditorFn(apiKeyProvider.Intercept), - control.WithRequestEditorFn(userAgentProvider.Intercept), - ) + client, err := control.NewClient("https://api.pinecone.io", clientOptions...) if err != nil { return nil, err } - c := Client{apiKey: in.ApiKey, restClient: client, sourceTag: in.SourceTag} + c := Client{apiKey: in.ApiKey, restClient: client, sourceTag: in.SourceTag, headers: in.Headers} return &c, nil } func (c *Client) Index(host string) (*IndexConnection, error) { - return c.IndexWithNamespace(host, "") + return c.IndexWithAdditionalMetadata(host, "", nil) } func (c *Client) IndexWithNamespace(host string, namespace string) (*IndexConnection, error) { - idx, err := newIndexConnection(c.apiKey, host, namespace, c.sourceTag) + return c.IndexWithAdditionalMetadata(host, namespace, nil) +} + +func (c *Client) IndexWithAdditionalMetadata(host string, namespace string, additionalMetadata map[string]string) (*IndexConnection, error) { + idx, err := newIndexConnection(newIndexParameters{apiKey: c.apiKey, host: host, namespace: namespace, sourceTag: c.sourceTag, additionalMetadata: additionalMetadata}) if err != nil { return nil, err } @@ -408,3 +412,40 @@ func minOne(x int32) int32 { } return x } + +func buildClientOptions(in NewClientParams) ([]control.ClientOption, error) { + clientOptions := []control.ClientOption{} + hasAuthorizationHeader := false + hasApiKey := in.ApiKey != "" + + userAgentProvider := provider.NewHeaderProvider("User-Agent", useragent.BuildUserAgent(in.SourceTag)) + clientOptions = append(clientOptions, control.WithRequestEditorFn(userAgentProvider.Intercept)) + + for key, value := range in.Headers { + headerProvider := provider.NewHeaderProvider(key, value) + + if strings.Contains(strings.ToLower(key), "authorization") { + hasAuthorizationHeader = true + } + + clientOptions = append(clientOptions, control.WithRequestEditorFn(headerProvider.Intercept)) + } + + if !hasAuthorizationHeader { + apiKeyProvider, err := securityprovider.NewSecurityProviderApiKey("header", "Api-Key", in.ApiKey) + if err != nil { + return nil, err + } + clientOptions = append(clientOptions, control.WithRequestEditorFn(apiKeyProvider.Intercept)) + } + + if !hasAuthorizationHeader && !hasApiKey { + return nil, fmt.Errorf("no API key provided, please pass an API key for authorization") + } + + if in.RestClient != nil { + clientOptions = append(clientOptions, control.WithHTTPClient(in.RestClient)) + } + + return clientOptions, nil +} diff --git a/pinecone/client_test.go b/pinecone/client_test.go index 446d2d3..ecb8bbb 100644 --- a/pinecone/client_test.go +++ b/pinecone/client_test.go @@ -4,9 +4,12 @@ import ( "context" "fmt" "os" + "reflect" + "strings" "testing" "github.com/google/uuid" + "github.com/pinecone-io/go-pinecone/internal/mocks" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" ) @@ -66,8 +69,11 @@ func (ts *ClientTests) TestNewClientParamsSet() { if client.sourceTag != "" { ts.FailNow(fmt.Sprintf("Expected client to have empty sourceTag, but got '%s'", client.sourceTag)) } + if client.headers != nil { + ts.FailNow(fmt.Sprintf("Expected client headers to be nil, but got '%v'", client.headers)) + } if len(client.restClient.RequestEditors) != 2 { - ts.FailNow("Expected 2 request editors on client") + ts.FailNow("Expected client to have '%v' request editors, but got '%v'", 2, len(client.restClient.RequestEditors)) } } @@ -85,7 +91,81 @@ func (ts *ClientTests) TestNewClientParamsSetSourceTag() { ts.FailNow(fmt.Sprintf("Expected client to have sourceTag '%s', but got '%s'", sourceTag, client.sourceTag)) } if len(client.restClient.RequestEditors) != 2 { - ts.FailNow("Expected 2 request editors on client") + ts.FailNow("Expected client to have '%v' request editors, but got '%v'", 2, len(client.restClient.RequestEditors)) + } +} + +func (ts *ClientTests) TestNewClientParamsSetHeaders() { + apiKey := "test-api-key" + headers := map[string]string{"test-header": "test-value"} + client, err := NewClient(NewClientParams{ApiKey: apiKey, Headers: headers}) + if err != nil { + ts.FailNow(err.Error()) + } + if client.apiKey != apiKey { + ts.FailNow(fmt.Sprintf("Expected client to have apiKey '%s', but got '%s'", apiKey, client.apiKey)) + } + if !reflect.DeepEqual(client.headers, headers) { + ts.FailNow(fmt.Sprintf("Expected client to have headers '%+v', but got '%+v'", headers, client.headers)) + } + if len(client.restClient.RequestEditors) != 3 { + ts.FailNow(fmt.Sprintf("Expected client to have '%v' request editors, but got '%v'", 3, len(client.restClient.RequestEditors))) + } +} + +func (ts *ClientTests) TestNewClientParamsNoApiKeyNoAuthorizationHeader() { + client, err := NewClient(NewClientParams{}) + require.NotNil(ts.T(), err, "Expected error when creating client without an API key or Authorization header") + if !strings.Contains(err.Error(), "no API key provided, please pass an API key for authorization") { + ts.FailNow(fmt.Sprintf("Expected error to contain 'no API key provided, please pass an API key for authorization', but got '%s'", err.Error())) + } + + require.Nil(ts.T(), client, "Expected client to be nil when creating client without an API key or Authorization header") +} + +func (ts *ClientTests) TestHeadersAppliedToRequests() { + apiKey := "test-api-key" + headers := map[string]string{"test-header": "123456"} + + httpClient := mocks.CreateMockClient(`{"indexes": []}`) + client, err := NewClient(NewClientParams{ApiKey: apiKey, Headers: headers, RestClient: httpClient}) + if err != nil { + ts.FailNow(err.Error()) + } + mockTransport := httpClient.Transport.(*mocks.MockTransport) + + _, err = client.ListIndexes(context.Background()) + require.NoError(ts.T(), err) + require.NotNil(ts.T(), mockTransport.Req, "Expected request to be made") + + testHeaderValue := mockTransport.Req.Header.Get("test-header") + if testHeaderValue != "123456" { + ts.FailNow(fmt.Sprintf("Expected request to have header value '123456', but got '%s'", testHeaderValue)) + } +} + +func (ts *ClientTests) TestAuthorizationHeaderOverridesApiKey() { + apiKey := "test-api-key" + headers := map[string]string{"Authorization": "bearer fooo"} + + httpClient := mocks.CreateMockClient(`{"indexes": []}`) + client, err := NewClient(NewClientParams{ApiKey: apiKey, Headers: headers, RestClient: httpClient}) + if err != nil { + ts.FailNow(err.Error()) + } + mockTransport := httpClient.Transport.(*mocks.MockTransport) + + _, err = client.ListIndexes(context.Background()) + require.NoError(ts.T(), err) + require.NotNil(ts.T(), mockTransport.Req, "Expected request to be made") + + apiKeyHeaderValue := mockTransport.Req.Header.Get("Api-Key") + authHeaderValue := mockTransport.Req.Header.Get("Authorization") + if authHeaderValue != "bearer fooo" { + ts.FailNow(fmt.Sprintf("Expected request to have header value 'bearer fooo', but got '%s'", authHeaderValue)) + } + if apiKeyHeaderValue != "" { + ts.FailNow(fmt.Sprintf("Expected request to not have Api-Key header, but got '%s'", apiKeyHeaderValue)) } } diff --git a/pinecone/index_connection.go b/pinecone/index_connection.go index 9a2131e..a2daf3a 100644 --- a/pinecone/index_connection.go +++ b/pinecone/index_connection.go @@ -16,19 +16,28 @@ import ( type IndexConnection struct { Namespace string apiKey string + additionalMetadata map[string]string dataClient *data.VectorServiceClient grpcConn *grpc.ClientConn } -func newIndexConnection(apiKey string, host string, namespace string, sourceTag string) (*IndexConnection, error) { +type newIndexParameters struct { + apiKey string + host string + namespace string + sourceTag string + additionalMetadata map[string]string +} + +func newIndexConnection(in newIndexParameters) (*IndexConnection, error) { config := &tls.Config{} - target := fmt.Sprintf("%s:443", host) + target := fmt.Sprintf("%s:443", in.host) conn, err := grpc.Dial( target, grpc.WithTransportCredentials(credentials.NewTLS(config)), grpc.WithAuthority(target), grpc.WithBlock(), - grpc.WithUserAgent(useragent.BuildUserAgentGRPC(sourceTag)), + grpc.WithUserAgent(useragent.BuildUserAgentGRPC(in.sourceTag)), ) if err != nil { @@ -38,7 +47,7 @@ func newIndexConnection(apiKey string, host string, namespace string, sourceTag dataClient := data.NewVectorServiceClient(conn) - idx := IndexConnection{Namespace: namespace, apiKey: apiKey, dataClient: &dataClient, grpcConn: conn} + idx := IndexConnection{Namespace: in.namespace, apiKey: in.apiKey, dataClient: &dataClient, grpcConn: conn, additionalMetadata: in.additionalMetadata} return &idx, nil } @@ -360,5 +369,12 @@ func sparseValToGrpc(sv *SparseValues) *data.SparseValues { } func (idx *IndexConnection) akCtx(ctx context.Context) context.Context { - return metadata.AppendToOutgoingContext(ctx, "api-key", idx.apiKey) + newMetadata := []string{} + newMetadata = append(newMetadata, "api-key", idx.apiKey) + + for key, value := range idx.additionalMetadata{ + newMetadata = append(newMetadata, key, value) + } + + return metadata.AppendToOutgoingContext(ctx, newMetadata...) } diff --git a/pinecone/index_connection_test.go b/pinecone/index_connection_test.go index ffdca5f..fd709d6 100644 --- a/pinecone/index_connection_test.go +++ b/pinecone/index_connection_test.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "os" + "reflect" "testing" "github.com/google/uuid" @@ -18,6 +19,7 @@ type IndexConnectionTests struct { apiKey string idxConn *IndexConnection sourceTag string + metadata map[string]string idxConnSourceTag *IndexConnection vectorIds []string } @@ -36,6 +38,10 @@ func TestIndexConnection(t *testing.T) { assert.NotEmptyf(t, podIndexName, "TEST_POD_INDEX_NAME env variable not set") podIdx, err := client.DescribeIndex(context.Background(), podIndexName) + if err != nil { + t.FailNow() + } + podTestSuite := new(IndexConnectionTests) podTestSuite.host = podIdx.Host podTestSuite.dimension = podIdx.Dimension @@ -45,6 +51,9 @@ func TestIndexConnection(t *testing.T) { assert.NotEmptyf(t, serverlessIndexName, "TEST_SERVERLESS_INDEX_NAME env variable not set") serverlessIdx, err := client.DescribeIndex(context.Background(), serverlessIndexName) + if err != nil { + t.FailNow() + } serverlessTestSuite := new(IndexConnectionTests) serverlessTestSuite.host = serverlessIdx.Host @@ -62,12 +71,12 @@ func (ts *IndexConnectionTests) SetupSuite() { namespace, err := uuid.NewV7() assert.NoError(ts.T(), err) - idxConn, err := newIndexConnection(ts.apiKey, ts.host, namespace.String(), "") + idxConn, err := newIndexConnection(newIndexParameters{apiKey: ts.apiKey, host: ts.host, namespace: namespace.String(), sourceTag: ""}) assert.NoError(ts.T(), err) ts.idxConn = idxConn ts.sourceTag = "test_source_tag" - idxConnSourceTag, err := newIndexConnection(ts.apiKey, ts.host, namespace.String(), ts.sourceTag) + idxConnSourceTag, err := newIndexConnection(newIndexParameters{apiKey: ts.apiKey, host: ts.host, namespace: namespace.String(), sourceTag: ts.sourceTag}) assert.NoError(ts.T(), err) ts.idxConnSourceTag = idxConnSourceTag @@ -79,13 +88,16 @@ func (ts *IndexConnectionTests) TearDownSuite() { err := ts.idxConn.Close() assert.NoError(ts.T(), err) + + err = ts.idxConnSourceTag.Close() + assert.NoError(ts.T(), err) } func (ts *IndexConnectionTests) TestNewIndexConnection() { apiKey := "test-api-key" namespace := "" sourceTag := "" - idxConn, err := newIndexConnection(apiKey, ts.host, namespace, sourceTag) + idxConn, err := newIndexConnection(newIndexParameters{apiKey: apiKey, host: ts.host, namespace: namespace, sourceTag: sourceTag}) assert.NoError(ts.T(), err) if idxConn.apiKey != apiKey { @@ -94,19 +106,47 @@ func (ts *IndexConnectionTests) TestNewIndexConnection() { if idxConn.Namespace != "" { ts.FailNow(fmt.Sprintf("Expected idxConn to have empty namespace, but got '%s'", idxConn.Namespace)) } + if idxConn.additionalMetadata != nil { + ts.FailNow(fmt.Sprintf("Expected idxConn additionalMetadata to be nil, but got '%+v'", idxConn.additionalMetadata)) + } if idxConn.dataClient == nil { ts.FailNow("Expected idxConn to have non-nil dataClient") } if idxConn.grpcConn == nil { ts.FailNow("Expected idxConn to have non-nil grpcConn") } + if idxConn.additionalMetadata != nil { + ts.FailNow("Expected idxConn to have nil additionalMetadata") + } } func (ts *IndexConnectionTests) TestNewIndexConnectionNamespace() { apiKey := "test-api-key" namespace := "test-namespace" sourceTag := "test-source-tag" - idxConn, err := newIndexConnection(apiKey, ts.host, namespace, sourceTag) + idxConn, err := newIndexConnection(newIndexParameters{apiKey: apiKey, host: ts.host, namespace: namespace, sourceTag: sourceTag}) + assert.NoError(ts.T(), err) + + if idxConn.apiKey != apiKey { + ts.FailNow(fmt.Sprintf("Expected idxConn to have apiKey '%s', but got '%s'", apiKey, idxConn.apiKey)) + } + if idxConn.Namespace != namespace { + ts.FailNow(fmt.Sprintf("Expected idxConn to have namespace '%s', but got '%s'", namespace, idxConn.Namespace)) + } + if idxConn.dataClient == nil { + ts.FailNow("Expected idxConn to have non-nil dataClient") + } + if idxConn.grpcConn == nil { + ts.FailNow("Expected idxConn to have non-nil grpcConn") + } +} + +func (ts *IndexConnectionTests) TestNewIndexConnectionAdditionalMetadata() { + apiKey := "test-api-key" + namespace := "test-namespace" + sourceTag := "test-source-tag" + additionalMetadata := map[string]string{"test-header": "test-value"} + idxConn, err := newIndexConnection(newIndexParameters{apiKey: apiKey, host: ts.host, namespace: namespace, sourceTag: sourceTag, additionalMetadata: additionalMetadata}) assert.NoError(ts.T(), err) if idxConn.apiKey != apiKey { @@ -115,6 +155,9 @@ func (ts *IndexConnectionTests) TestNewIndexConnectionNamespace() { if idxConn.Namespace != namespace { ts.FailNow(fmt.Sprintf("Expected idxConn to have namespace '%s', but got '%s'", namespace, idxConn.Namespace)) } + if !reflect.DeepEqual(idxConn.additionalMetadata, additionalMetadata) { + ts.FailNow(fmt.Sprintf("Expected idxConn to have additionalMetadata '%+v', but got '%+v'", additionalMetadata, idxConn.additionalMetadata)) + } if idxConn.dataClient == nil { ts.FailNow("Expected idxConn to have non-nil dataClient") } diff --git a/pinecone/models.go b/pinecone/models.go index 4d00443..3ce895b 100644 --- a/pinecone/models.go +++ b/pinecone/models.go @@ -112,4 +112,4 @@ type Usage struct { } type Filter = structpb.Struct -type Metadata = structpb.Struct +type Metadata = structpb.Struct \ No newline at end of file