Skip to content

Commit

Permalink
refactor into NewClientParams and NewClientBaseParams
Browse files Browse the repository at this point in the history
  • Loading branch information
austin-denoble committed May 14, 2024
1 parent 6eb91a3 commit 9f56ded
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 66 deletions.
105 changes: 64 additions & 41 deletions pinecone/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import (
"net/http"
"net/url"
"os"
"strings"

"github.com/deepmap/oapi-codegen/v2/pkg/securityprovider"
"github.com/pinecone-io/go-pinecone/internal/gen/control"
Expand All @@ -24,12 +23,19 @@ type Client struct {
sourceTag string
}

type NewClientBaseParams struct {
Headers map[string]string
Host string
RestClient *http.Client
SourceTag string
}

type NewClientParams struct {
ApiKey string // required unless Authorization header provided
Headers map[string]string // optional
Host string // optional
RestClient *http.Client // optional
SourceTag string // optional
ApiKey string
Headers map[string]string
Host string
RestClient *http.Client
SourceTag string
}

func NewClient(in NewClientParams) (*Client, error) {
Expand All @@ -40,7 +46,7 @@ func NewClient(in NewClientParams) (*Client, error) {

controlHostOverride := valueOrFallback(in.Host, os.Getenv("PINECONE_CONTROLLER_HOST"))
if controlHostOverride != "" {
controlHostOverride, err = ensureHTTP(controlHostOverride)
controlHostOverride, err = ensureHTTPS(controlHostOverride)
if err != nil {
return nil, err
}
Expand All @@ -55,6 +61,27 @@ func NewClient(in NewClientParams) (*Client, error) {
return &c, nil
}

func NewClientBase(in NewClientBaseParams) (*Client, error) {
clientOptions := in.buildClientBaseOptions()
var err error

controlHostOverride := valueOrFallback(in.Host, os.Getenv("PINECONE_CONTROLLER_HOST"))
if controlHostOverride != "" {
controlHostOverride, err = ensureHTTPS(controlHostOverride)
if err != nil {
return nil, err
}
}

client, err := control.NewClient(valueOrFallback(controlHostOverride, "https://api.pinecone.io"), clientOptions...)
if err != nil {
return nil, err
}

c := Client{restClient: client, sourceTag: in.SourceTag, headers: in.Headers}
return &c, nil
}

func (c *Client) Index(host string) (*IndexConnection, error) {
return c.IndexWithAdditionalMetadata(host, "", nil)
}
Expand Down Expand Up @@ -435,24 +462,41 @@ func derefOrDefault[T any](ptr *T, defaultValue T) T {

func (ncp *NewClientParams) buildClientOptions() ([]control.ClientOption, error) {
clientOptions := []control.ClientOption{}
hasAuthorizationHeader := false
osApiKey := os.Getenv("PINECONE_API_KEY")
hasApiKey := (ncp.ApiKey != "" || osApiKey != "")

if !hasApiKey {
return nil, fmt.Errorf("no API key provided, please pass an API key for authorization")
}

appliedApiKey := valueOrFallback(ncp.ApiKey, osApiKey)
apiKeyProvider, err := securityprovider.NewSecurityProviderApiKey("header", "Api-Key", appliedApiKey)
if err != nil {
return nil, err
}
clientOptions = append(clientOptions, control.WithRequestEditorFn(apiKeyProvider.Intercept))

baseClient := NewClientBaseParams{ncp.Headers, ncp.Host, ncp.RestClient, ncp.SourceTag}
baseClientOptions := baseClient.buildClientBaseOptions()
clientOptions = append(clientOptions, baseClientOptions...)

return clientOptions, nil
}

func (ncbp *NewClientBaseParams) buildClientBaseOptions() []control.ClientOption {
clientOptions := []control.ClientOption{}

envAdditionalHeaders, hasEnvAdditionalHeaders := os.LookupEnv("PINECONE_ADDITIONAL_HEADERS")

// build and apply user agent
userAgentProvider := provider.NewHeaderProvider("User-Agent", useragent.BuildUserAgent(ncp.SourceTag))
fmt.Printf("Source tag on its way in: %v\n", ncbp.SourceTag)
userAgentProvider := provider.NewHeaderProvider("User-Agent", useragent.BuildUserAgent(ncbp.SourceTag))
clientOptions = append(clientOptions, control.WithRequestEditorFn(userAgentProvider.Intercept))

// apply headers from parameters if passed, otherwise use environment headers
if ncp.Headers != nil {
for key, value := range ncp.Headers {
if ncbp.Headers != nil {
for key, value := range ncbp.Headers {
headerProvider := provider.NewHeaderProvider(key, value)

if strings.Contains(strings.ToLower(key), "authorization") {
hasAuthorizationHeader = true
}

clientOptions = append(clientOptions, control.WithRequestEditorFn(headerProvider.Intercept))
}
} else if hasEnvAdditionalHeaders {
Expand All @@ -463,40 +507,19 @@ func (ncp *NewClientParams) buildClientOptions() ([]control.ClientOption, error)
} else {
for header, value := range additionalHeaders {
headerProvider := provider.NewHeaderProvider(header, value)

if strings.Contains(strings.ToLower(header), "authorization") {
hasAuthorizationHeader = true
}

clientOptions = append(clientOptions, control.WithRequestEditorFn(headerProvider.Intercept))
}
}
}

// if apiKey is provided and no auth header is set, add the apiKey as a header
// apiKey from parameters takes precedence over apiKey from environment
if hasApiKey && !hasAuthorizationHeader {
appliedApiKey := valueOrFallback(ncp.ApiKey, osApiKey)

apiKeyProvider, err := securityprovider.NewSecurityProviderApiKey("header", "Api-Key", appliedApiKey)
if err != nil {
return nil, err
}
clientOptions = append(clientOptions, control.WithRequestEditorFn(apiKeyProvider.Intercept))
if ncbp.RestClient != nil {
clientOptions = append(clientOptions, control.WithHTTPClient(ncbp.RestClient))
}

if !hasAuthorizationHeader && !hasApiKey {
return nil, fmt.Errorf("no API key provided, please pass an API key for authorization")
}

if ncp.RestClient != nil {
clientOptions = append(clientOptions, control.WithHTTPClient(ncp.RestClient))
}

return clientOptions, nil
return clientOptions
}

func ensureHTTP(inputURL string) (string, error) {
func ensureHTTPS(inputURL string) (string, error) {
parsedURL, err := url.Parse(inputURL)
if err != nil {
return "", fmt.Errorf("invalid URL: %v", err)
Expand Down
53 changes: 28 additions & 25 deletions pinecone/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,10 @@ func (ts *ClientTests) TestNewClientParamsSet() {
func (ts *ClientTests) TestNewClientParamsSetSourceTag() {
apiKey := "test-api-key"
sourceTag := "test-source-tag"
client, err := NewClient(NewClientParams{ApiKey: apiKey, SourceTag: sourceTag})
client, err := NewClient(NewClientParams{
ApiKey: apiKey,
SourceTag: sourceTag,
})
if err != nil {
ts.FailNow(err.Error())
}
Expand Down Expand Up @@ -188,30 +191,30 @@ func (ts *ClientTests) TestHeadersOverrideAdditionalHeaders() {
os.Unsetenv("PINECONE_ADDITIONAL_HEADERS")
}

func (ts *ClientTests) TestAuthorizationHeaderOverridesApiKey() {
apiKey := "test-api-key"
headers := map[string]string{"Authorization": "bearer fooo"}

httpClient := mocks.CreateMockClient(`{"indexes": []}`)
client, err := NewClient(NewClientParams{ApiKey: apiKey, Headers: headers, RestClient: httpClient})
if err != nil {
ts.FailNow(err.Error())
}
mockTransport := httpClient.Transport.(*mocks.MockTransport)

_, err = client.ListIndexes(context.Background())
require.NoError(ts.T(), err)
require.NotNil(ts.T(), mockTransport.Req, "Expected request to be made")

apiKeyHeaderValue := mockTransport.Req.Header.Get("Api-Key")
authHeaderValue := mockTransport.Req.Header.Get("Authorization")
if authHeaderValue != "bearer fooo" {
ts.FailNow(fmt.Sprintf("Expected request to have header value 'bearer fooo', but got '%s'", authHeaderValue))
}
if apiKeyHeaderValue != "" {
ts.FailNow(fmt.Sprintf("Expected request to not have Api-Key header, but got '%s'", apiKeyHeaderValue))
}
}
// func (ts *ClientTests) TestAuthorizationHeaderOverridesApiKey() {
// apiKey := "test-api-key"
// headers := map[string]string{"Authorization": "bearer fooo"}

// httpClient := mocks.CreateMockClient(`{"indexes": []}`)
// client, err := NewClient(NewClientParams{ApiKey: apiKey, Headers: headers, RestClient: httpClient})
// if err != nil {
// ts.FailNow(err.Error())
// }
// mockTransport := httpClient.Transport.(*mocks.MockTransport)

// _, err = client.ListIndexes(context.Background())
// require.NoError(ts.T(), err)
// require.NotNil(ts.T(), mockTransport.Req, "Expected request to be made")

// apiKeyHeaderValue := mockTransport.Req.Header.Get("Api-Key")
// authHeaderValue := mockTransport.Req.Header.Get("Authorization")
// if authHeaderValue != "bearer fooo" {
// ts.FailNow(fmt.Sprintf("Expected request to have header value 'bearer fooo', but got '%s'", authHeaderValue))
// }
// if apiKeyHeaderValue != "" {
// ts.FailNow(fmt.Sprintf("Expected request to not have Api-Key header, but got '%s'", apiKeyHeaderValue))
// }
// }

func (ts *ClientTests) TestClientReadsApiKeyFromEnv() {
os.Setenv("PINECONE_API_KEY", "test-env-api-key")
Expand Down

0 comments on commit 9f56ded

Please sign in to comment.