Skip to content

Commit

Permalink
refactor NewClientParams into NewClientBaseParams, update NewClient f…
Browse files Browse the repository at this point in the history
…unc and add NewClientBase, update tests to support the new interface for creating clients, remove apiKey from Client and IndexConnection, handle turning apiKey into a header early, and allowing for complete control over headers without using apiKey
  • Loading branch information
austin-denoble committed May 14, 2024
1 parent 9f56ded commit e6ba260
Show file tree
Hide file tree
Showing 4 changed files with 165 additions and 231 deletions.
191 changes: 99 additions & 92 deletions pinecone/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,14 @@ import (
"net/http"
"net/url"
"os"
"strings"

"github.com/deepmap/oapi-codegen/v2/pkg/securityprovider"
"github.com/pinecone-io/go-pinecone/internal/gen/control"
"github.com/pinecone-io/go-pinecone/internal/provider"
"github.com/pinecone-io/go-pinecone/internal/useragent"
)

type Client struct {
apiKey string
headers map[string]string
restClient *control.Client
sourceTag string
Expand All @@ -39,30 +38,29 @@ type NewClientParams struct {
}

func NewClient(in NewClientParams) (*Client, error) {
clientOptions, err := in.buildClientOptions()
if err != nil {
return nil, err
}
osApiKey := os.Getenv("PINECONE_API_KEY")
hasApiKey := (in.ApiKey != "" || osApiKey != "")

controlHostOverride := valueOrFallback(in.Host, os.Getenv("PINECONE_CONTROLLER_HOST"))
if controlHostOverride != "" {
controlHostOverride, err = ensureHTTPS(controlHostOverride)
if err != nil {
return nil, err
}
if !hasApiKey {
return nil, fmt.Errorf("no API key provided, please pass an API key for authorization")
}

client, err := control.NewClient(valueOrFallback(controlHostOverride, "https://api.pinecone.io"), clientOptions...)
if err != nil {
return nil, err
apiKeyHeader := struct{ Key, Value string }{"Api-Key", valueOrFallback(in.ApiKey, osApiKey)}

clientHeaders := in.Headers
if clientHeaders == nil {
clientHeaders = make(map[string]string)
clientHeaders[apiKeyHeader.Key] = apiKeyHeader.Value

} else {
clientHeaders[apiKeyHeader.Key] = apiKeyHeader.Value
}

c := Client{apiKey: in.ApiKey, restClient: client, sourceTag: in.SourceTag, headers: in.Headers}
return &c, nil
return NewClientBase(NewClientBaseParams{Headers: clientHeaders, Host: in.Host, RestClient: in.RestClient, SourceTag: in.SourceTag})
}

func NewClientBase(in NewClientBaseParams) (*Client, error) {
clientOptions := in.buildClientBaseOptions()
clientOptions := buildClientBaseOptions(in)
var err error

controlHostOverride := valueOrFallback(in.Host, os.Getenv("PINECONE_CONTROLLER_HOST"))
Expand All @@ -82,6 +80,45 @@ func NewClientBase(in NewClientBaseParams) (*Client, error) {
return &c, nil
}

func buildClientBaseOptions(in NewClientBaseParams) []control.ClientOption {
clientOptions := []control.ClientOption{}

// build and apply user agent
userAgentProvider := provider.NewHeaderProvider("User-Agent", useragent.BuildUserAgent(in.SourceTag))
clientOptions = append(clientOptions, control.WithRequestEditorFn(userAgentProvider.Intercept))

envAdditionalHeaders, hasEnvAdditionalHeaders := os.LookupEnv("PINECONE_ADDITIONAL_HEADERS")
additionalHeaders := make(map[string]string)

// add headers from environment
if hasEnvAdditionalHeaders {
err := json.Unmarshal([]byte(envAdditionalHeaders), &additionalHeaders)
if err != nil {
log.Printf("failed to parse PINECONE_ADDITIONAL_HEADERS: %v", err)
}
}

// merge headers from parameters if passed
if in.Headers != nil {
for key, value := range in.Headers {
additionalHeaders[key] = value
}
}

// add headers to client options
for key, value := range additionalHeaders {
headerProvider := provider.NewHeaderProvider(key, value)
clientOptions = append(clientOptions, control.WithRequestEditorFn(headerProvider.Intercept))
}

// apply custom http client if provided
if in.RestClient != nil {
clientOptions = append(clientOptions, control.WithHTTPClient(in.RestClient))
}

return clientOptions
}

func (c *Client) Index(host string) (*IndexConnection, error) {
return c.IndexWithAdditionalMetadata(host, "", nil)
}
Expand All @@ -91,13 +128,42 @@ func (c *Client) IndexWithNamespace(host string, namespace string) (*IndexConnec
}

func (c *Client) IndexWithAdditionalMetadata(host string, namespace string, additionalMetadata map[string]string) (*IndexConnection, error) {
idx, err := newIndexConnection(newIndexParameters{apiKey: c.apiKey, host: host, namespace: namespace, sourceTag: c.sourceTag, additionalMetadata: additionalMetadata})
authHeader := c.extractAuthHeader()

// merge additionalMetadata with authHeader
if additionalMetadata != nil {
for _, key := range authHeader {
additionalMetadata[key] = authHeader[key]
}
} else {
additionalMetadata = authHeader
}

idx, err := newIndexConnection(newIndexParameters{host: host, namespace: namespace, sourceTag: c.sourceTag, additionalMetadata: additionalMetadata})
if err != nil {
return nil, err
}
return idx, nil
}

func (c *Client) extractAuthHeader() map[string]string {
possibleAuthKeys := []string{
"api-key",
"authorization",
"access_token",
}

for key, value := range c.headers {
for _, checkKey := range possibleAuthKeys {
if strings.ToLower(key) == checkKey {
return map[string]string{key: value}
}
}
}

return nil
}

func (c *Client) ListIndexes(ctx context.Context) ([]*Index, error) {
res, err := c.restClient.ListIndexes(ctx)
if err != nil {
Expand Down Expand Up @@ -446,79 +512,6 @@ func decodeCollection(resBody io.ReadCloser) (*Collection, error) {
return toCollection(&collectionModel), nil
}

func minOne(x int32) int32 {
if x < 1 {
return 1
}
return x
}

func derefOrDefault[T any](ptr *T, defaultValue T) T {
if ptr == nil {
return defaultValue
}
return *ptr
}

func (ncp *NewClientParams) buildClientOptions() ([]control.ClientOption, error) {
clientOptions := []control.ClientOption{}
osApiKey := os.Getenv("PINECONE_API_KEY")
hasApiKey := (ncp.ApiKey != "" || osApiKey != "")

if !hasApiKey {
return nil, fmt.Errorf("no API key provided, please pass an API key for authorization")
}

appliedApiKey := valueOrFallback(ncp.ApiKey, osApiKey)
apiKeyProvider, err := securityprovider.NewSecurityProviderApiKey("header", "Api-Key", appliedApiKey)
if err != nil {
return nil, err
}
clientOptions = append(clientOptions, control.WithRequestEditorFn(apiKeyProvider.Intercept))

baseClient := NewClientBaseParams{ncp.Headers, ncp.Host, ncp.RestClient, ncp.SourceTag}
baseClientOptions := baseClient.buildClientBaseOptions()
clientOptions = append(clientOptions, baseClientOptions...)

return clientOptions, nil
}

func (ncbp *NewClientBaseParams) buildClientBaseOptions() []control.ClientOption {
clientOptions := []control.ClientOption{}

envAdditionalHeaders, hasEnvAdditionalHeaders := os.LookupEnv("PINECONE_ADDITIONAL_HEADERS")

// build and apply user agent
fmt.Printf("Source tag on its way in: %v\n", ncbp.SourceTag)
userAgentProvider := provider.NewHeaderProvider("User-Agent", useragent.BuildUserAgent(ncbp.SourceTag))
clientOptions = append(clientOptions, control.WithRequestEditorFn(userAgentProvider.Intercept))

// apply headers from parameters if passed, otherwise use environment headers
if ncbp.Headers != nil {
for key, value := range ncbp.Headers {
headerProvider := provider.NewHeaderProvider(key, value)
clientOptions = append(clientOptions, control.WithRequestEditorFn(headerProvider.Intercept))
}
} else if hasEnvAdditionalHeaders {
additionalHeaders := make(map[string]string)
err := json.Unmarshal([]byte(envAdditionalHeaders), &additionalHeaders)
if err != nil {
log.Printf("failed to parse PINECONE_ADDITIONAL_HEADERS: %v", err)
} else {
for header, value := range additionalHeaders {
headerProvider := provider.NewHeaderProvider(header, value)
clientOptions = append(clientOptions, control.WithRequestEditorFn(headerProvider.Intercept))
}
}
}

if ncbp.RestClient != nil {
clientOptions = append(clientOptions, control.WithHTTPClient(ncbp.RestClient))
}

return clientOptions
}

func ensureHTTPS(inputURL string) (string, error) {
parsedURL, err := url.Parse(inputURL)
if err != nil {
Expand All @@ -539,3 +532,17 @@ func valueOrFallback[T comparable](value, fallback T) T {
return fallback
}
}

func derefOrDefault[T any](ptr *T, defaultValue T) T {
if ptr == nil {
return defaultValue
}
return *ptr
}

func minOne(x int32) int32 {
if x < 1 {
return 1
}
return x
}
96 changes: 25 additions & 71 deletions pinecone/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"context"
"fmt"
"os"
"reflect"
"strings"
"testing"

Expand Down Expand Up @@ -39,16 +38,13 @@ func (ts *ClientTests) SetupSuite() {
require.NotEmpty(ts.T(), ts.serverlessIndex, "TEST_SERVERLESS_INDEX_NAME env variable not set")

client, err := NewClient(NewClientParams{ApiKey: apiKey})
if err != nil {
ts.FailNow(err.Error())
}
require.NoError(ts.T(), err)

ts.client = *client

ts.sourceTag = "test_source_tag"
clientSourceTag, err := NewClient(NewClientParams{ApiKey: apiKey, SourceTag: ts.sourceTag})
if err != nil {
ts.FailNow(err.Error())
}
require.NoError(ts.T(), err)
ts.clientSourceTag = *clientSourceTag

// this will clean up the project deleting all indexes and collections that are
Expand All @@ -61,21 +57,14 @@ func (ts *ClientTests) SetupSuite() {
func (ts *ClientTests) TestNewClientParamsSet() {
apiKey := "test-api-key"
client, err := NewClient(NewClientParams{ApiKey: apiKey})
if err != nil {
ts.FailNow(err.Error())
}
if client.apiKey != apiKey {
ts.FailNow(fmt.Sprintf("Expected client to have apiKey '%s', but got '%s'", apiKey, client.apiKey))
}
if client.sourceTag != "" {
ts.FailNow(fmt.Sprintf("Expected client to have empty sourceTag, but got '%s'", client.sourceTag))
}
if client.headers != nil {
ts.FailNow(fmt.Sprintf("Expected client headers to be nil, but got '%v'", client.headers))
}
if len(client.restClient.RequestEditors) != 2 {
ts.FailNow("Expected client to have '%v' request editors, but got '%v'", 2, len(client.restClient.RequestEditors))
}

require.NoError(ts.T(), err)
require.Empty(ts.T(), client.sourceTag, "Expected client to have empty sourceTag")
require.NotNil(ts.T(), client.headers, "Expected client headers to not be nil")
apiKeyHeader, ok := client.headers["Api-Key"]
require.True(ts.T(), ok, "Expected client to have an 'Api-Key' header")
require.Equal(ts.T(), apiKey, apiKeyHeader, "Expected 'Api-Key' header to match provided ApiKey")
require.Equal(ts.T(), 2, len(client.restClient.RequestEditors), "Expected client to have correct number of require editors")
}

func (ts *ClientTests) TestNewClientParamsSetSourceTag() {
Expand All @@ -85,36 +74,26 @@ func (ts *ClientTests) TestNewClientParamsSetSourceTag() {
ApiKey: apiKey,
SourceTag: sourceTag,
})
if err != nil {
ts.FailNow(err.Error())
}
if client.apiKey != apiKey {
ts.FailNow(fmt.Sprintf("Expected client to have apiKey '%s', but got '%s'", apiKey, client.apiKey))
}
if client.sourceTag != sourceTag {
ts.FailNow(fmt.Sprintf("Expected client to have sourceTag '%s', but got '%s'", sourceTag, client.sourceTag))
}
if len(client.restClient.RequestEditors) != 2 {
ts.FailNow("Expected client to have '%v' request editors, but got '%v'", 2, len(client.restClient.RequestEditors))
}

require.NoError(ts.T(), err)
apiKeyHeader, ok := client.headers["Api-Key"]
require.True(ts.T(), ok, "Expected client to have an 'Api-Key' header")
require.Equal(ts.T(), apiKey, apiKeyHeader, "Expected 'Api-Key' header to match provided ApiKey")
require.Equal(ts.T(), sourceTag, client.sourceTag, "Expected client to have sourceTag '%s', but got '%s'", sourceTag, client.sourceTag)
require.Equal(ts.T(), 2, len(client.restClient.RequestEditors), "Expected client to have %s request editors, but got %s", 2, len(client.restClient.RequestEditors))
}

func (ts *ClientTests) TestNewClientParamsSetHeaders() {
apiKey := "test-api-key"
headers := map[string]string{"test-header": "test-value"}
client, err := NewClient(NewClientParams{ApiKey: apiKey, Headers: headers})
if err != nil {
ts.FailNow(err.Error())
}
if client.apiKey != apiKey {
ts.FailNow(fmt.Sprintf("Expected client to have apiKey '%s', but got '%s'", apiKey, client.apiKey))
}
if !reflect.DeepEqual(client.headers, headers) {
ts.FailNow(fmt.Sprintf("Expected client to have headers '%+v', but got '%+v'", headers, client.headers))
}
if len(client.restClient.RequestEditors) != 3 {
ts.FailNow(fmt.Sprintf("Expected client to have '%v' request editors, but got '%v'", 3, len(client.restClient.RequestEditors)))
}

require.NoError(ts.T(), err)
apiKeyHeader, ok := client.headers["Api-Key"]
require.True(ts.T(), ok, "Expected client to have an 'Api-Key' header")
require.Equal(ts.T(), apiKey, apiKeyHeader, "Expected 'Api-Key' header to match provided ApiKey")
require.Equal(ts.T(), client.headers, headers, "Expected client to have headers '%+v', but got '%+v'", headers, client.headers)
require.Equal(ts.T(), 3, len(client.restClient.RequestEditors), "Expected client to have %s request editors, but got %s", 3, len(client.restClient.RequestEditors))
}

func (ts *ClientTests) TestNewClientParamsNoApiKeyNoAuthorizationHeader() {
Expand Down Expand Up @@ -191,31 +170,6 @@ func (ts *ClientTests) TestHeadersOverrideAdditionalHeaders() {
os.Unsetenv("PINECONE_ADDITIONAL_HEADERS")
}

// func (ts *ClientTests) TestAuthorizationHeaderOverridesApiKey() {
// apiKey := "test-api-key"
// headers := map[string]string{"Authorization": "bearer fooo"}

// httpClient := mocks.CreateMockClient(`{"indexes": []}`)
// client, err := NewClient(NewClientParams{ApiKey: apiKey, Headers: headers, RestClient: httpClient})
// if err != nil {
// ts.FailNow(err.Error())
// }
// mockTransport := httpClient.Transport.(*mocks.MockTransport)

// _, err = client.ListIndexes(context.Background())
// require.NoError(ts.T(), err)
// require.NotNil(ts.T(), mockTransport.Req, "Expected request to be made")

// apiKeyHeaderValue := mockTransport.Req.Header.Get("Api-Key")
// authHeaderValue := mockTransport.Req.Header.Get("Authorization")
// if authHeaderValue != "bearer fooo" {
// ts.FailNow(fmt.Sprintf("Expected request to have header value 'bearer fooo', but got '%s'", authHeaderValue))
// }
// if apiKeyHeaderValue != "" {
// ts.FailNow(fmt.Sprintf("Expected request to not have Api-Key header, but got '%s'", apiKeyHeaderValue))
// }
// }

func (ts *ClientTests) TestClientReadsApiKeyFromEnv() {
os.Setenv("PINECONE_API_KEY", "test-env-api-key")

Expand Down
Loading

0 comments on commit e6ba260

Please sign in to comment.