diff --git a/pinecone/client.go b/pinecone/client.go index 014a092..26bcdd0 100644 --- a/pinecone/client.go +++ b/pinecone/client.go @@ -9,7 +9,6 @@ import ( "net/http" "net/url" "os" - "strings" "github.com/deepmap/oapi-codegen/v2/pkg/securityprovider" "github.com/pinecone-io/go-pinecone/internal/gen/control" @@ -24,12 +23,19 @@ type Client struct { sourceTag string } +type NewClientBaseParams struct { + Headers map[string]string + Host string + RestClient *http.Client + SourceTag string +} + type NewClientParams struct { - ApiKey string // required unless Authorization header provided - Headers map[string]string // optional - Host string // optional - RestClient *http.Client // optional - SourceTag string // optional + ApiKey string + Headers map[string]string + Host string + RestClient *http.Client + SourceTag string } func NewClient(in NewClientParams) (*Client, error) { @@ -40,7 +46,7 @@ func NewClient(in NewClientParams) (*Client, error) { controlHostOverride := valueOrFallback(in.Host, os.Getenv("PINECONE_CONTROLLER_HOST")) if controlHostOverride != "" { - controlHostOverride, err = ensureHTTP(controlHostOverride) + controlHostOverride, err = ensureHTTPS(controlHostOverride) if err != nil { return nil, err } @@ -55,6 +61,27 @@ func NewClient(in NewClientParams) (*Client, error) { return &c, nil } +func NewClientBase(in NewClientBaseParams) (*Client, error) { + clientOptions := in.buildClientBaseOptions() + var err error + + controlHostOverride := valueOrFallback(in.Host, os.Getenv("PINECONE_CONTROLLER_HOST")) + if controlHostOverride != "" { + controlHostOverride, err = ensureHTTPS(controlHostOverride) + if err != nil { + return nil, err + } + } + + client, err := control.NewClient(valueOrFallback(controlHostOverride, "https://api.pinecone.io"), clientOptions...) + if err != nil { + return nil, err + } + + c := Client{restClient: client, sourceTag: in.SourceTag, headers: in.Headers} + return &c, nil +} + func (c *Client) Index(host string) (*IndexConnection, error) { return c.IndexWithAdditionalMetadata(host, "", nil) } @@ -435,24 +462,41 @@ func derefOrDefault[T any](ptr *T, defaultValue T) T { func (ncp *NewClientParams) buildClientOptions() ([]control.ClientOption, error) { clientOptions := []control.ClientOption{} - hasAuthorizationHeader := false osApiKey := os.Getenv("PINECONE_API_KEY") hasApiKey := (ncp.ApiKey != "" || osApiKey != "") + + if !hasApiKey { + return nil, fmt.Errorf("no API key provided, please pass an API key for authorization") + } + + appliedApiKey := valueOrFallback(ncp.ApiKey, osApiKey) + apiKeyProvider, err := securityprovider.NewSecurityProviderApiKey("header", "Api-Key", appliedApiKey) + if err != nil { + return nil, err + } + clientOptions = append(clientOptions, control.WithRequestEditorFn(apiKeyProvider.Intercept)) + + baseClient := NewClientBaseParams{ncp.Headers, ncp.Host, ncp.RestClient, ncp.SourceTag} + baseClientOptions := baseClient.buildClientBaseOptions() + clientOptions = append(clientOptions, baseClientOptions...) + + return clientOptions, nil +} + +func (ncbp *NewClientBaseParams) buildClientBaseOptions() []control.ClientOption { + clientOptions := []control.ClientOption{} + envAdditionalHeaders, hasEnvAdditionalHeaders := os.LookupEnv("PINECONE_ADDITIONAL_HEADERS") // build and apply user agent - userAgentProvider := provider.NewHeaderProvider("User-Agent", useragent.BuildUserAgent(ncp.SourceTag)) + fmt.Printf("Source tag on its way in: %v\n", ncbp.SourceTag) + userAgentProvider := provider.NewHeaderProvider("User-Agent", useragent.BuildUserAgent(ncbp.SourceTag)) clientOptions = append(clientOptions, control.WithRequestEditorFn(userAgentProvider.Intercept)) // apply headers from parameters if passed, otherwise use environment headers - if ncp.Headers != nil { - for key, value := range ncp.Headers { + if ncbp.Headers != nil { + for key, value := range ncbp.Headers { headerProvider := provider.NewHeaderProvider(key, value) - - if strings.Contains(strings.ToLower(key), "authorization") { - hasAuthorizationHeader = true - } - clientOptions = append(clientOptions, control.WithRequestEditorFn(headerProvider.Intercept)) } } else if hasEnvAdditionalHeaders { @@ -463,40 +507,19 @@ func (ncp *NewClientParams) buildClientOptions() ([]control.ClientOption, error) } else { for header, value := range additionalHeaders { headerProvider := provider.NewHeaderProvider(header, value) - - if strings.Contains(strings.ToLower(header), "authorization") { - hasAuthorizationHeader = true - } - clientOptions = append(clientOptions, control.WithRequestEditorFn(headerProvider.Intercept)) } } } - // if apiKey is provided and no auth header is set, add the apiKey as a header - // apiKey from parameters takes precedence over apiKey from environment - if hasApiKey && !hasAuthorizationHeader { - appliedApiKey := valueOrFallback(ncp.ApiKey, osApiKey) - - apiKeyProvider, err := securityprovider.NewSecurityProviderApiKey("header", "Api-Key", appliedApiKey) - if err != nil { - return nil, err - } - clientOptions = append(clientOptions, control.WithRequestEditorFn(apiKeyProvider.Intercept)) + if ncbp.RestClient != nil { + clientOptions = append(clientOptions, control.WithHTTPClient(ncbp.RestClient)) } - if !hasAuthorizationHeader && !hasApiKey { - return nil, fmt.Errorf("no API key provided, please pass an API key for authorization") - } - - if ncp.RestClient != nil { - clientOptions = append(clientOptions, control.WithHTTPClient(ncp.RestClient)) - } - - return clientOptions, nil + return clientOptions } -func ensureHTTP(inputURL string) (string, error) { +func ensureHTTPS(inputURL string) (string, error) { parsedURL, err := url.Parse(inputURL) if err != nil { return "", fmt.Errorf("invalid URL: %v", err) diff --git a/pinecone/client_test.go b/pinecone/client_test.go index c4ce200..88637f0 100644 --- a/pinecone/client_test.go +++ b/pinecone/client_test.go @@ -81,7 +81,10 @@ func (ts *ClientTests) TestNewClientParamsSet() { func (ts *ClientTests) TestNewClientParamsSetSourceTag() { apiKey := "test-api-key" sourceTag := "test-source-tag" - client, err := NewClient(NewClientParams{ApiKey: apiKey, SourceTag: sourceTag}) + client, err := NewClient(NewClientParams{ + ApiKey: apiKey, + SourceTag: sourceTag, + }) if err != nil { ts.FailNow(err.Error()) } @@ -188,30 +191,30 @@ func (ts *ClientTests) TestHeadersOverrideAdditionalHeaders() { os.Unsetenv("PINECONE_ADDITIONAL_HEADERS") } -func (ts *ClientTests) TestAuthorizationHeaderOverridesApiKey() { - apiKey := "test-api-key" - headers := map[string]string{"Authorization": "bearer fooo"} - - httpClient := mocks.CreateMockClient(`{"indexes": []}`) - client, err := NewClient(NewClientParams{ApiKey: apiKey, Headers: headers, RestClient: httpClient}) - if err != nil { - ts.FailNow(err.Error()) - } - mockTransport := httpClient.Transport.(*mocks.MockTransport) - - _, err = client.ListIndexes(context.Background()) - require.NoError(ts.T(), err) - require.NotNil(ts.T(), mockTransport.Req, "Expected request to be made") - - apiKeyHeaderValue := mockTransport.Req.Header.Get("Api-Key") - authHeaderValue := mockTransport.Req.Header.Get("Authorization") - if authHeaderValue != "bearer fooo" { - ts.FailNow(fmt.Sprintf("Expected request to have header value 'bearer fooo', but got '%s'", authHeaderValue)) - } - if apiKeyHeaderValue != "" { - ts.FailNow(fmt.Sprintf("Expected request to not have Api-Key header, but got '%s'", apiKeyHeaderValue)) - } -} +// func (ts *ClientTests) TestAuthorizationHeaderOverridesApiKey() { +// apiKey := "test-api-key" +// headers := map[string]string{"Authorization": "bearer fooo"} + +// httpClient := mocks.CreateMockClient(`{"indexes": []}`) +// client, err := NewClient(NewClientParams{ApiKey: apiKey, Headers: headers, RestClient: httpClient}) +// if err != nil { +// ts.FailNow(err.Error()) +// } +// mockTransport := httpClient.Transport.(*mocks.MockTransport) + +// _, err = client.ListIndexes(context.Background()) +// require.NoError(ts.T(), err) +// require.NotNil(ts.T(), mockTransport.Req, "Expected request to be made") + +// apiKeyHeaderValue := mockTransport.Req.Header.Get("Api-Key") +// authHeaderValue := mockTransport.Req.Header.Get("Authorization") +// if authHeaderValue != "bearer fooo" { +// ts.FailNow(fmt.Sprintf("Expected request to have header value 'bearer fooo', but got '%s'", authHeaderValue)) +// } +// if apiKeyHeaderValue != "" { +// ts.FailNow(fmt.Sprintf("Expected request to not have Api-Key header, but got '%s'", apiKeyHeaderValue)) +// } +// } func (ts *ClientTests) TestClientReadsApiKeyFromEnv() { os.Setenv("PINECONE_API_KEY", "test-env-api-key")