diff --git a/pinecone/client.go b/pinecone/client.go index 26bcdd0..2d98642 100644 --- a/pinecone/client.go +++ b/pinecone/client.go @@ -9,15 +9,14 @@ import ( "net/http" "net/url" "os" + "strings" - "github.com/deepmap/oapi-codegen/v2/pkg/securityprovider" "github.com/pinecone-io/go-pinecone/internal/gen/control" "github.com/pinecone-io/go-pinecone/internal/provider" "github.com/pinecone-io/go-pinecone/internal/useragent" ) type Client struct { - apiKey string headers map[string]string restClient *control.Client sourceTag string @@ -39,30 +38,29 @@ type NewClientParams struct { } func NewClient(in NewClientParams) (*Client, error) { - clientOptions, err := in.buildClientOptions() - if err != nil { - return nil, err - } + osApiKey := os.Getenv("PINECONE_API_KEY") + hasApiKey := (in.ApiKey != "" || osApiKey != "") - controlHostOverride := valueOrFallback(in.Host, os.Getenv("PINECONE_CONTROLLER_HOST")) - if controlHostOverride != "" { - controlHostOverride, err = ensureHTTPS(controlHostOverride) - if err != nil { - return nil, err - } + if !hasApiKey { + return nil, fmt.Errorf("no API key provided, please pass an API key for authorization") } - client, err := control.NewClient(valueOrFallback(controlHostOverride, "https://api.pinecone.io"), clientOptions...) - if err != nil { - return nil, err + apiKeyHeader := struct{ Key, Value string }{"Api-Key", valueOrFallback(in.ApiKey, osApiKey)} + + clientHeaders := in.Headers + if clientHeaders == nil { + clientHeaders = make(map[string]string) + clientHeaders[apiKeyHeader.Key] = apiKeyHeader.Value + + } else { + clientHeaders[apiKeyHeader.Key] = apiKeyHeader.Value } - c := Client{apiKey: in.ApiKey, restClient: client, sourceTag: in.SourceTag, headers: in.Headers} - return &c, nil + return NewClientBase(NewClientBaseParams{Headers: clientHeaders, Host: in.Host, RestClient: in.RestClient, SourceTag: in.SourceTag}) } func NewClientBase(in NewClientBaseParams) (*Client, error) { - clientOptions := in.buildClientBaseOptions() + clientOptions := buildClientBaseOptions(in) var err error controlHostOverride := valueOrFallback(in.Host, os.Getenv("PINECONE_CONTROLLER_HOST")) @@ -82,6 +80,45 @@ func NewClientBase(in NewClientBaseParams) (*Client, error) { return &c, nil } +func buildClientBaseOptions(in NewClientBaseParams) []control.ClientOption { + clientOptions := []control.ClientOption{} + + // build and apply user agent + userAgentProvider := provider.NewHeaderProvider("User-Agent", useragent.BuildUserAgent(in.SourceTag)) + clientOptions = append(clientOptions, control.WithRequestEditorFn(userAgentProvider.Intercept)) + + envAdditionalHeaders, hasEnvAdditionalHeaders := os.LookupEnv("PINECONE_ADDITIONAL_HEADERS") + additionalHeaders := make(map[string]string) + + // add headers from environment + if hasEnvAdditionalHeaders { + err := json.Unmarshal([]byte(envAdditionalHeaders), &additionalHeaders) + if err != nil { + log.Printf("failed to parse PINECONE_ADDITIONAL_HEADERS: %v", err) + } + } + + // merge headers from parameters if passed + if in.Headers != nil { + for key, value := range in.Headers { + additionalHeaders[key] = value + } + } + + // add headers to client options + for key, value := range additionalHeaders { + headerProvider := provider.NewHeaderProvider(key, value) + clientOptions = append(clientOptions, control.WithRequestEditorFn(headerProvider.Intercept)) + } + + // apply custom http client if provided + if in.RestClient != nil { + clientOptions = append(clientOptions, control.WithHTTPClient(in.RestClient)) + } + + return clientOptions +} + func (c *Client) Index(host string) (*IndexConnection, error) { return c.IndexWithAdditionalMetadata(host, "", nil) } @@ -91,13 +128,42 @@ func (c *Client) IndexWithNamespace(host string, namespace string) (*IndexConnec } func (c *Client) IndexWithAdditionalMetadata(host string, namespace string, additionalMetadata map[string]string) (*IndexConnection, error) { - idx, err := newIndexConnection(newIndexParameters{apiKey: c.apiKey, host: host, namespace: namespace, sourceTag: c.sourceTag, additionalMetadata: additionalMetadata}) + authHeader := c.extractAuthHeader() + + // merge additionalMetadata with authHeader + if additionalMetadata != nil { + for _, key := range authHeader { + additionalMetadata[key] = authHeader[key] + } + } else { + additionalMetadata = authHeader + } + + idx, err := newIndexConnection(newIndexParameters{host: host, namespace: namespace, sourceTag: c.sourceTag, additionalMetadata: additionalMetadata}) if err != nil { return nil, err } return idx, nil } +func (c *Client) extractAuthHeader() map[string]string { + possibleAuthKeys := []string{ + "api-key", + "authorization", + "access_token", + } + + for key, value := range c.headers { + for _, checkKey := range possibleAuthKeys { + if strings.ToLower(key) == checkKey { + return map[string]string{key: value} + } + } + } + + return nil +} + func (c *Client) ListIndexes(ctx context.Context) ([]*Index, error) { res, err := c.restClient.ListIndexes(ctx) if err != nil { @@ -446,79 +512,6 @@ func decodeCollection(resBody io.ReadCloser) (*Collection, error) { return toCollection(&collectionModel), nil } -func minOne(x int32) int32 { - if x < 1 { - return 1 - } - return x -} - -func derefOrDefault[T any](ptr *T, defaultValue T) T { - if ptr == nil { - return defaultValue - } - return *ptr -} - -func (ncp *NewClientParams) buildClientOptions() ([]control.ClientOption, error) { - clientOptions := []control.ClientOption{} - osApiKey := os.Getenv("PINECONE_API_KEY") - hasApiKey := (ncp.ApiKey != "" || osApiKey != "") - - if !hasApiKey { - return nil, fmt.Errorf("no API key provided, please pass an API key for authorization") - } - - appliedApiKey := valueOrFallback(ncp.ApiKey, osApiKey) - apiKeyProvider, err := securityprovider.NewSecurityProviderApiKey("header", "Api-Key", appliedApiKey) - if err != nil { - return nil, err - } - clientOptions = append(clientOptions, control.WithRequestEditorFn(apiKeyProvider.Intercept)) - - baseClient := NewClientBaseParams{ncp.Headers, ncp.Host, ncp.RestClient, ncp.SourceTag} - baseClientOptions := baseClient.buildClientBaseOptions() - clientOptions = append(clientOptions, baseClientOptions...) - - return clientOptions, nil -} - -func (ncbp *NewClientBaseParams) buildClientBaseOptions() []control.ClientOption { - clientOptions := []control.ClientOption{} - - envAdditionalHeaders, hasEnvAdditionalHeaders := os.LookupEnv("PINECONE_ADDITIONAL_HEADERS") - - // build and apply user agent - fmt.Printf("Source tag on its way in: %v\n", ncbp.SourceTag) - userAgentProvider := provider.NewHeaderProvider("User-Agent", useragent.BuildUserAgent(ncbp.SourceTag)) - clientOptions = append(clientOptions, control.WithRequestEditorFn(userAgentProvider.Intercept)) - - // apply headers from parameters if passed, otherwise use environment headers - if ncbp.Headers != nil { - for key, value := range ncbp.Headers { - headerProvider := provider.NewHeaderProvider(key, value) - clientOptions = append(clientOptions, control.WithRequestEditorFn(headerProvider.Intercept)) - } - } else if hasEnvAdditionalHeaders { - additionalHeaders := make(map[string]string) - err := json.Unmarshal([]byte(envAdditionalHeaders), &additionalHeaders) - if err != nil { - log.Printf("failed to parse PINECONE_ADDITIONAL_HEADERS: %v", err) - } else { - for header, value := range additionalHeaders { - headerProvider := provider.NewHeaderProvider(header, value) - clientOptions = append(clientOptions, control.WithRequestEditorFn(headerProvider.Intercept)) - } - } - } - - if ncbp.RestClient != nil { - clientOptions = append(clientOptions, control.WithHTTPClient(ncbp.RestClient)) - } - - return clientOptions -} - func ensureHTTPS(inputURL string) (string, error) { parsedURL, err := url.Parse(inputURL) if err != nil { @@ -539,3 +532,17 @@ func valueOrFallback[T comparable](value, fallback T) T { return fallback } } + +func derefOrDefault[T any](ptr *T, defaultValue T) T { + if ptr == nil { + return defaultValue + } + return *ptr +} + +func minOne(x int32) int32 { + if x < 1 { + return 1 + } + return x +} diff --git a/pinecone/client_test.go b/pinecone/client_test.go index 88637f0..81bf2a0 100644 --- a/pinecone/client_test.go +++ b/pinecone/client_test.go @@ -4,7 +4,6 @@ import ( "context" "fmt" "os" - "reflect" "strings" "testing" @@ -39,16 +38,13 @@ func (ts *ClientTests) SetupSuite() { require.NotEmpty(ts.T(), ts.serverlessIndex, "TEST_SERVERLESS_INDEX_NAME env variable not set") client, err := NewClient(NewClientParams{ApiKey: apiKey}) - if err != nil { - ts.FailNow(err.Error()) - } + require.NoError(ts.T(), err) + ts.client = *client ts.sourceTag = "test_source_tag" clientSourceTag, err := NewClient(NewClientParams{ApiKey: apiKey, SourceTag: ts.sourceTag}) - if err != nil { - ts.FailNow(err.Error()) - } + require.NoError(ts.T(), err) ts.clientSourceTag = *clientSourceTag // this will clean up the project deleting all indexes and collections that are @@ -61,21 +57,14 @@ func (ts *ClientTests) SetupSuite() { func (ts *ClientTests) TestNewClientParamsSet() { apiKey := "test-api-key" client, err := NewClient(NewClientParams{ApiKey: apiKey}) - if err != nil { - ts.FailNow(err.Error()) - } - if client.apiKey != apiKey { - ts.FailNow(fmt.Sprintf("Expected client to have apiKey '%s', but got '%s'", apiKey, client.apiKey)) - } - if client.sourceTag != "" { - ts.FailNow(fmt.Sprintf("Expected client to have empty sourceTag, but got '%s'", client.sourceTag)) - } - if client.headers != nil { - ts.FailNow(fmt.Sprintf("Expected client headers to be nil, but got '%v'", client.headers)) - } - if len(client.restClient.RequestEditors) != 2 { - ts.FailNow("Expected client to have '%v' request editors, but got '%v'", 2, len(client.restClient.RequestEditors)) - } + + require.NoError(ts.T(), err) + require.Empty(ts.T(), client.sourceTag, "Expected client to have empty sourceTag") + require.NotNil(ts.T(), client.headers, "Expected client headers to not be nil") + apiKeyHeader, ok := client.headers["Api-Key"] + require.True(ts.T(), ok, "Expected client to have an 'Api-Key' header") + require.Equal(ts.T(), apiKey, apiKeyHeader, "Expected 'Api-Key' header to match provided ApiKey") + require.Equal(ts.T(), 2, len(client.restClient.RequestEditors), "Expected client to have correct number of require editors") } func (ts *ClientTests) TestNewClientParamsSetSourceTag() { @@ -85,36 +74,26 @@ func (ts *ClientTests) TestNewClientParamsSetSourceTag() { ApiKey: apiKey, SourceTag: sourceTag, }) - if err != nil { - ts.FailNow(err.Error()) - } - if client.apiKey != apiKey { - ts.FailNow(fmt.Sprintf("Expected client to have apiKey '%s', but got '%s'", apiKey, client.apiKey)) - } - if client.sourceTag != sourceTag { - ts.FailNow(fmt.Sprintf("Expected client to have sourceTag '%s', but got '%s'", sourceTag, client.sourceTag)) - } - if len(client.restClient.RequestEditors) != 2 { - ts.FailNow("Expected client to have '%v' request editors, but got '%v'", 2, len(client.restClient.RequestEditors)) - } + + require.NoError(ts.T(), err) + apiKeyHeader, ok := client.headers["Api-Key"] + require.True(ts.T(), ok, "Expected client to have an 'Api-Key' header") + require.Equal(ts.T(), apiKey, apiKeyHeader, "Expected 'Api-Key' header to match provided ApiKey") + require.Equal(ts.T(), sourceTag, client.sourceTag, "Expected client to have sourceTag '%s', but got '%s'", sourceTag, client.sourceTag) + require.Equal(ts.T(), 2, len(client.restClient.RequestEditors), "Expected client to have %s request editors, but got %s", 2, len(client.restClient.RequestEditors)) } func (ts *ClientTests) TestNewClientParamsSetHeaders() { apiKey := "test-api-key" headers := map[string]string{"test-header": "test-value"} client, err := NewClient(NewClientParams{ApiKey: apiKey, Headers: headers}) - if err != nil { - ts.FailNow(err.Error()) - } - if client.apiKey != apiKey { - ts.FailNow(fmt.Sprintf("Expected client to have apiKey '%s', but got '%s'", apiKey, client.apiKey)) - } - if !reflect.DeepEqual(client.headers, headers) { - ts.FailNow(fmt.Sprintf("Expected client to have headers '%+v', but got '%+v'", headers, client.headers)) - } - if len(client.restClient.RequestEditors) != 3 { - ts.FailNow(fmt.Sprintf("Expected client to have '%v' request editors, but got '%v'", 3, len(client.restClient.RequestEditors))) - } + + require.NoError(ts.T(), err) + apiKeyHeader, ok := client.headers["Api-Key"] + require.True(ts.T(), ok, "Expected client to have an 'Api-Key' header") + require.Equal(ts.T(), apiKey, apiKeyHeader, "Expected 'Api-Key' header to match provided ApiKey") + require.Equal(ts.T(), client.headers, headers, "Expected client to have headers '%+v', but got '%+v'", headers, client.headers) + require.Equal(ts.T(), 3, len(client.restClient.RequestEditors), "Expected client to have %s request editors, but got %s", 3, len(client.restClient.RequestEditors)) } func (ts *ClientTests) TestNewClientParamsNoApiKeyNoAuthorizationHeader() { @@ -191,31 +170,6 @@ func (ts *ClientTests) TestHeadersOverrideAdditionalHeaders() { os.Unsetenv("PINECONE_ADDITIONAL_HEADERS") } -// func (ts *ClientTests) TestAuthorizationHeaderOverridesApiKey() { -// apiKey := "test-api-key" -// headers := map[string]string{"Authorization": "bearer fooo"} - -// httpClient := mocks.CreateMockClient(`{"indexes": []}`) -// client, err := NewClient(NewClientParams{ApiKey: apiKey, Headers: headers, RestClient: httpClient}) -// if err != nil { -// ts.FailNow(err.Error()) -// } -// mockTransport := httpClient.Transport.(*mocks.MockTransport) - -// _, err = client.ListIndexes(context.Background()) -// require.NoError(ts.T(), err) -// require.NotNil(ts.T(), mockTransport.Req, "Expected request to be made") - -// apiKeyHeaderValue := mockTransport.Req.Header.Get("Api-Key") -// authHeaderValue := mockTransport.Req.Header.Get("Authorization") -// if authHeaderValue != "bearer fooo" { -// ts.FailNow(fmt.Sprintf("Expected request to have header value 'bearer fooo', but got '%s'", authHeaderValue)) -// } -// if apiKeyHeaderValue != "" { -// ts.FailNow(fmt.Sprintf("Expected request to not have Api-Key header, but got '%s'", apiKeyHeaderValue)) -// } -// } - func (ts *ClientTests) TestClientReadsApiKeyFromEnv() { os.Setenv("PINECONE_API_KEY", "test-env-api-key") diff --git a/pinecone/index_connection.go b/pinecone/index_connection.go index fb72391..b8cb27f 100644 --- a/pinecone/index_connection.go +++ b/pinecone/index_connection.go @@ -15,14 +15,12 @@ import ( type IndexConnection struct { Namespace string - apiKey string additionalMetadata map[string]string dataClient *data.VectorServiceClient grpcConn *grpc.ClientConn } type newIndexParameters struct { - apiKey string host string namespace string sourceTag string @@ -47,7 +45,7 @@ func newIndexConnection(in newIndexParameters) (*IndexConnection, error) { dataClient := data.NewVectorServiceClient(conn) - idx := IndexConnection{Namespace: in.namespace, apiKey: in.apiKey, dataClient: &dataClient, grpcConn: conn, additionalMetadata: in.additionalMetadata} + idx := IndexConnection{Namespace: in.namespace, dataClient: &dataClient, grpcConn: conn, additionalMetadata: in.additionalMetadata} return &idx, nil } @@ -370,7 +368,6 @@ func sparseValToGrpc(sv *SparseValues) *data.SparseValues { func (idx *IndexConnection) akCtx(ctx context.Context) context.Context { newMetadata := []string{} - newMetadata = append(newMetadata, "api-key", idx.apiKey) for key, value := range idx.additionalMetadata { newMetadata = append(newMetadata, key, value) diff --git a/pinecone/index_connection_test.go b/pinecone/index_connection_test.go index 1e5cb46..493bd54 100644 --- a/pinecone/index_connection_test.go +++ b/pinecone/index_connection_test.go @@ -5,11 +5,11 @@ import ( "encoding/json" "fmt" "os" - "reflect" "testing" "github.com/google/uuid" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" "google.golang.org/protobuf/types/known/structpb" ) @@ -71,16 +71,25 @@ func TestIndexConnection(t *testing.T) { func (ts *IndexConnectionTests) SetupSuite() { assert.NotEmptyf(ts.T(), ts.host, "HOST env variable not set") assert.NotEmptyf(ts.T(), ts.apiKey, "API_KEY env variable not set") + additionalMetadata := map[string]string{"api-key": ts.apiKey} namespace, err := uuid.NewV7() assert.NoError(ts.T(), err) - idxConn, err := newIndexConnection(newIndexParameters{apiKey: ts.apiKey, host: ts.host, namespace: namespace.String(), sourceTag: ""}) + idxConn, err := newIndexConnection(newIndexParameters{ + additionalMetadata: additionalMetadata, + host: ts.host, + namespace: namespace.String(), + sourceTag: ""}) assert.NoError(ts.T(), err) ts.idxConn = idxConn ts.sourceTag = "test_source_tag" - idxConnSourceTag, err := newIndexConnection(newIndexParameters{apiKey: ts.apiKey, host: ts.host, namespace: namespace.String(), sourceTag: ts.sourceTag}) + idxConnSourceTag, err := newIndexConnection(newIndexParameters{ + additionalMetadata: additionalMetadata, + host: ts.host, + namespace: namespace.String(), + sourceTag: ts.sourceTag}) assert.NoError(ts.T(), err) ts.idxConnSourceTag = idxConnSourceTag @@ -101,73 +110,40 @@ func (ts *IndexConnectionTests) TestNewIndexConnection() { apiKey := "test-api-key" namespace := "" sourceTag := "" - idxConn, err := newIndexConnection(newIndexParameters{apiKey: apiKey, host: ts.host, namespace: namespace, sourceTag: sourceTag}) - assert.NoError(ts.T(), err) - - if idxConn.apiKey != apiKey { - ts.FailNow(fmt.Sprintf("Expected idxConn to have apiKey '%s', but got '%s'", apiKey, idxConn.apiKey)) - } - if idxConn.Namespace != "" { - ts.FailNow(fmt.Sprintf("Expected idxConn to have empty namespace, but got '%s'", idxConn.Namespace)) - } - if idxConn.additionalMetadata != nil { - ts.FailNow(fmt.Sprintf("Expected idxConn additionalMetadata to be nil, but got '%+v'", idxConn.additionalMetadata)) - } - if idxConn.dataClient == nil { - ts.FailNow("Expected idxConn to have non-nil dataClient") - } - if idxConn.grpcConn == nil { - ts.FailNow("Expected idxConn to have non-nil grpcConn") - } - if idxConn.additionalMetadata != nil { - ts.FailNow("Expected idxConn to have nil additionalMetadata") - } + additionalMetadata := map[string]string{"api-key": apiKey} + idxConn, err := newIndexConnection(newIndexParameters{ + additionalMetadata: additionalMetadata, + host: ts.host, + namespace: namespace, + sourceTag: sourceTag}) + + require.NoError(ts.T(), err) + apiKeyHeader, ok := idxConn.additionalMetadata["api-key"] + require.True(ts.T(), ok, "Expected client to have an 'api-key' header") + require.Equal(ts.T(), apiKey, apiKeyHeader, "Expected 'api-key' header to equal %s", apiKey) + require.Empty(ts.T(), idxConn.Namespace, "Expected idxConn to have empty namespace, but got '%s'", idxConn.Namespace) + require.NotNil(ts.T(), idxConn.dataClient, "Expected idxConn to have non-nil dataClient") + require.NotNil(ts.T(), idxConn.grpcConn, "Expected idxConn to have non-nil grpcConn") } func (ts *IndexConnectionTests) TestNewIndexConnectionNamespace() { apiKey := "test-api-key" namespace := "test-namespace" sourceTag := "test-source-tag" - idxConn, err := newIndexConnection(newIndexParameters{apiKey: apiKey, host: ts.host, namespace: namespace, sourceTag: sourceTag}) - assert.NoError(ts.T(), err) - - if idxConn.apiKey != apiKey { - ts.FailNow(fmt.Sprintf("Expected idxConn to have apiKey '%s', but got '%s'", apiKey, idxConn.apiKey)) - } - if idxConn.Namespace != namespace { - ts.FailNow(fmt.Sprintf("Expected idxConn to have namespace '%s', but got '%s'", namespace, idxConn.Namespace)) - } - if idxConn.dataClient == nil { - ts.FailNow("Expected idxConn to have non-nil dataClient") - } - if idxConn.grpcConn == nil { - ts.FailNow("Expected idxConn to have non-nil grpcConn") - } -} - -func (ts *IndexConnectionTests) TestNewIndexConnectionAdditionalMetadata() { - apiKey := "test-api-key" - namespace := "test-namespace" - sourceTag := "test-source-tag" - additionalMetadata := map[string]string{"test-header": "test-value"} - idxConn, err := newIndexConnection(newIndexParameters{apiKey: apiKey, host: ts.host, namespace: namespace, sourceTag: sourceTag, additionalMetadata: additionalMetadata}) - assert.NoError(ts.T(), err) - - if idxConn.apiKey != apiKey { - ts.FailNow(fmt.Sprintf("Expected idxConn to have apiKey '%s', but got '%s'", apiKey, idxConn.apiKey)) - } - if idxConn.Namespace != namespace { - ts.FailNow(fmt.Sprintf("Expected idxConn to have namespace '%s', but got '%s'", namespace, idxConn.Namespace)) - } - if !reflect.DeepEqual(idxConn.additionalMetadata, additionalMetadata) { - ts.FailNow(fmt.Sprintf("Expected idxConn to have additionalMetadata '%+v', but got '%+v'", additionalMetadata, idxConn.additionalMetadata)) - } - if idxConn.dataClient == nil { - ts.FailNow("Expected idxConn to have non-nil dataClient") - } - if idxConn.grpcConn == nil { - ts.FailNow("Expected idxConn to have non-nil grpcConn") - } + additionalMetadata := map[string]string{"api-key": apiKey} + idxConn, err := newIndexConnection(newIndexParameters{ + additionalMetadata: additionalMetadata, + host: ts.host, + namespace: namespace, + sourceTag: sourceTag}) + + require.NoError(ts.T(), err) + apiKeyHeader, ok := idxConn.additionalMetadata["api-key"] + require.True(ts.T(), ok, "Expected client to have an 'api-key' header") + require.Equal(ts.T(), apiKey, apiKeyHeader, "Expected 'api-key' header to equal %s", apiKey) + require.Equal(ts.T(), namespace, idxConn.Namespace, "Expected idxConn to have namespace '%s', but got '%s'", namespace, idxConn.Namespace) + require.NotNil(ts.T(), idxConn.dataClient, "Expected idxConn to have non-nil dataClient") + require.NotNil(ts.T(), idxConn.grpcConn, "Expected idxConn to have non-nil grpcConn") } func (ts *IndexConnectionTests) TestFetchVectors() {