diff --git a/pinecone/client.go b/pinecone/client.go index ac2891b..7aa4cad 100644 --- a/pinecone/client.go +++ b/pinecone/client.go @@ -6,6 +6,7 @@ import ( "fmt" "io" "net/http" + "strings" "github.com/deepmap/oapi-codegen/v2/pkg/securityprovider" "github.com/pinecone-io/go-pinecone/internal/gen/control" @@ -21,31 +22,17 @@ type Client struct { } type NewClientParams struct { - ApiKey string - SourceTag string // optional - Headers map[string]string // optional - RestClient *http.Client // optional + ApiKey string // optional unless no Authorization header provided + SourceTag string // optional + Headers map[string]string // optional + RestClient *http.Client // optional } func NewClient(in NewClientParams) (*Client, error) { - clientOptions := []control.ClientOption{} - apiKeyProvider, err := securityprovider.NewSecurityProviderApiKey("header", "Api-Key", in.ApiKey) + clientOptions, err := buildClientOptions(in) if err != nil { return nil, err } - clientOptions = append(clientOptions, control.WithRequestEditorFn(apiKeyProvider.Intercept)) - - userAgentProvider := provider.NewHeaderProvider("User-Agent", useragent.BuildUserAgent(in.SourceTag)) - clientOptions = append(clientOptions, control.WithRequestEditorFn(userAgentProvider.Intercept)) - - for key, value := range in.Headers { - headerProvider := provider.NewHeaderProvider(key, value) - clientOptions = append(clientOptions, control.WithRequestEditorFn(headerProvider.Intercept)) - } - - if in.RestClient != nil { - clientOptions = append(clientOptions, control.WithHTTPClient(in.RestClient)) - } client, err := control.NewClient("https://api.pinecone.io", clientOptions...) if err != nil { @@ -425,3 +412,40 @@ func minOne(x int32) int32 { } return x } + +func buildClientOptions(in NewClientParams) ([]control.ClientOption, error) { + clientOptions := []control.ClientOption{} + hasAuthorizationHeader := false + hasApiKey := in.ApiKey != "" + + userAgentProvider := provider.NewHeaderProvider("User-Agent", useragent.BuildUserAgent(in.SourceTag)) + clientOptions = append(clientOptions, control.WithRequestEditorFn(userAgentProvider.Intercept)) + + for key, value := range in.Headers { + headerProvider := provider.NewHeaderProvider(key, value) + + if strings.Contains(key, "Authorization") { + hasAuthorizationHeader = true + } + + clientOptions = append(clientOptions, control.WithRequestEditorFn(headerProvider.Intercept)) + } + + if !hasAuthorizationHeader { + apiKeyProvider, err := securityprovider.NewSecurityProviderApiKey("header", "Api-Key", in.ApiKey) + if err != nil { + return nil, err + } + clientOptions = append(clientOptions, control.WithRequestEditorFn(apiKeyProvider.Intercept)) + } + + if !hasAuthorizationHeader && !hasApiKey { + return nil, fmt.Errorf("no API key provided, please pass an API key for authorization") + } + + if in.RestClient != nil { + clientOptions = append(clientOptions, control.WithHTTPClient(in.RestClient)) + } + + return clientOptions, nil +} diff --git a/pinecone/client_test.go b/pinecone/client_test.go index 59bb678..ad6d225 100644 --- a/pinecone/client_test.go +++ b/pinecone/client_test.go @@ -5,6 +5,7 @@ import ( "fmt" "os" "reflect" + "strings" "testing" "github.com/google/uuid" @@ -112,6 +113,16 @@ func (ts *ClientTests) TestNewClientParamsSetHeaders() { } } +func (ts *ClientTests) TestNewClientParamsNoApiKeyNoAuthorizationHeader() { + client, err := NewClient(NewClientParams{}) + require.NotNil(ts.T(), err, "Expected error when creating client without an API key or Authorization header") + if !strings.Contains(err.Error(), "no API key provided, please pass an API key for authorization") { + ts.FailNow(fmt.Sprintf("Expected error to contain 'no API key provided, please pass an API key for authorization', but got '%s'", err.Error())) + } + + require.Nil(ts.T(), client, "Expected client to be nil when creating client without an API key or Authorization header") +} + func (ts *ClientTests) TestHeadersAppliedToRequests() { apiKey := "test-api-key" headers := map[string]string{"test-header": "123456"} @@ -133,6 +144,31 @@ func (ts *ClientTests) TestHeadersAppliedToRequests() { } } +func (ts *ClientTests) TestAuthorizationHeaderOverridesApiKey() { + apiKey := "test-api-key" + headers := map[string]string{"Authorization": "bearer abcd123467890"} + + httpClient := mocks.CreateMockClient(`{"indexes": []}`) + client, err := NewClient(NewClientParams{ApiKey: apiKey, Headers: headers, RestClient: httpClient}) + if err != nil { + ts.FailNow(err.Error()) + } + mockTransport := httpClient.Transport.(*mocks.MockTransport) + + _, err = client.ListIndexes(context.Background()) + require.NoError(ts.T(), err) + require.NotNil(ts.T(), mockTransport.Req, "Expected request to be made") + + apiKeyHeaderValue := mockTransport.Req.Header.Get("Api-Key") + authHeaderValue := mockTransport.Req.Header.Get("Authorization") + if authHeaderValue != "bearer abcd123467890" { + ts.FailNow(fmt.Sprintf("Expected request to have header value 'bearer abcd123467890', but got '%s'", authHeaderValue)) + } + if apiKeyHeaderValue != "" { + ts.FailNow(fmt.Sprintf("Expected request to not have Api-Key header, but got '%s'", apiKeyHeaderValue)) + } +} + func (ts *ClientTests) TestListIndexes() { indexes, err := ts.client.ListIndexes(context.Background()) require.NoError(ts.T(), err)