Skip to content

Commit

Permalink
add new buildClientOptions function, do not pass the Api-Key header i…
Browse files Browse the repository at this point in the history
…f both Api-Key and Authorization header have been provided, add unit tests
  • Loading branch information
austin-denoble committed May 3, 2024
1 parent e5753fb commit 505a4ff
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 19 deletions.
62 changes: 43 additions & 19 deletions pinecone/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"fmt"
"io"
"net/http"
"strings"

"github.com/deepmap/oapi-codegen/v2/pkg/securityprovider"
"github.com/pinecone-io/go-pinecone/internal/gen/control"
Expand All @@ -21,31 +22,17 @@ type Client struct {
}

type NewClientParams struct {
ApiKey string
SourceTag string // optional
Headers map[string]string // optional
RestClient *http.Client // optional
ApiKey string // optional unless no Authorization header provided
SourceTag string // optional
Headers map[string]string // optional
RestClient *http.Client // optional
}

func NewClient(in NewClientParams) (*Client, error) {
clientOptions := []control.ClientOption{}
apiKeyProvider, err := securityprovider.NewSecurityProviderApiKey("header", "Api-Key", in.ApiKey)
clientOptions, err := buildClientOptions(in)
if err != nil {
return nil, err
}
clientOptions = append(clientOptions, control.WithRequestEditorFn(apiKeyProvider.Intercept))

userAgentProvider := provider.NewHeaderProvider("User-Agent", useragent.BuildUserAgent(in.SourceTag))
clientOptions = append(clientOptions, control.WithRequestEditorFn(userAgentProvider.Intercept))

for key, value := range in.Headers {
headerProvider := provider.NewHeaderProvider(key, value)
clientOptions = append(clientOptions, control.WithRequestEditorFn(headerProvider.Intercept))
}

if in.RestClient != nil {
clientOptions = append(clientOptions, control.WithHTTPClient(in.RestClient))
}

client, err := control.NewClient("https://api.pinecone.io", clientOptions...)
if err != nil {
Expand Down Expand Up @@ -425,3 +412,40 @@ func minOne(x int32) int32 {
}
return x
}

func buildClientOptions(in NewClientParams) ([]control.ClientOption, error) {
clientOptions := []control.ClientOption{}
hasAuthorizationHeader := false
hasApiKey := in.ApiKey != ""

userAgentProvider := provider.NewHeaderProvider("User-Agent", useragent.BuildUserAgent(in.SourceTag))
clientOptions = append(clientOptions, control.WithRequestEditorFn(userAgentProvider.Intercept))

for key, value := range in.Headers {
headerProvider := provider.NewHeaderProvider(key, value)

if strings.Contains(key, "Authorization") {
hasAuthorizationHeader = true
}

clientOptions = append(clientOptions, control.WithRequestEditorFn(headerProvider.Intercept))
}

if !hasAuthorizationHeader {
apiKeyProvider, err := securityprovider.NewSecurityProviderApiKey("header", "Api-Key", in.ApiKey)
if err != nil {
return nil, err
}
clientOptions = append(clientOptions, control.WithRequestEditorFn(apiKeyProvider.Intercept))
}

if !hasAuthorizationHeader && !hasApiKey {
return nil, fmt.Errorf("no API key provided, please pass an API key for authorization")
}

if in.RestClient != nil {
clientOptions = append(clientOptions, control.WithHTTPClient(in.RestClient))
}

return clientOptions, nil
}
36 changes: 36 additions & 0 deletions pinecone/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"fmt"
"os"
"reflect"
"strings"
"testing"

"github.com/google/uuid"
Expand Down Expand Up @@ -112,6 +113,16 @@ func (ts *ClientTests) TestNewClientParamsSetHeaders() {
}
}

func (ts *ClientTests) TestNewClientParamsNoApiKeyNoAuthorizationHeader() {
client, err := NewClient(NewClientParams{})
require.NotNil(ts.T(), err, "Expected error when creating client without an API key or Authorization header")
if !strings.Contains(err.Error(), "no API key provided, please pass an API key for authorization") {
ts.FailNow(fmt.Sprintf("Expected error to contain 'no API key provided, please pass an API key for authorization', but got '%s'", err.Error()))
}

require.Nil(ts.T(), client, "Expected client to be nil when creating client without an API key or Authorization header")
}

func (ts *ClientTests) TestHeadersAppliedToRequests() {
apiKey := "test-api-key"
headers := map[string]string{"test-header": "123456"}
Expand All @@ -133,6 +144,31 @@ func (ts *ClientTests) TestHeadersAppliedToRequests() {
}
}

func (ts *ClientTests) TestAuthorizationHeaderOverridesApiKey() {
apiKey := "test-api-key"
headers := map[string]string{"Authorization": "bearer abcd123467890"}

httpClient := mocks.CreateMockClient(`{"indexes": []}`)
client, err := NewClient(NewClientParams{ApiKey: apiKey, Headers: headers, RestClient: httpClient})
if err != nil {
ts.FailNow(err.Error())
}
mockTransport := httpClient.Transport.(*mocks.MockTransport)

_, err = client.ListIndexes(context.Background())
require.NoError(ts.T(), err)
require.NotNil(ts.T(), mockTransport.Req, "Expected request to be made")

apiKeyHeaderValue := mockTransport.Req.Header.Get("Api-Key")
authHeaderValue := mockTransport.Req.Header.Get("Authorization")
if authHeaderValue != "bearer abcd123467890" {
ts.FailNow(fmt.Sprintf("Expected request to have header value 'bearer abcd123467890', but got '%s'", authHeaderValue))
}
if apiKeyHeaderValue != "" {
ts.FailNow(fmt.Sprintf("Expected request to not have Api-Key header, but got '%s'", apiKeyHeaderValue))
}
}

func (ts *ClientTests) TestListIndexes() {
indexes, err := ts.client.ListIndexes(context.Background())
require.NoError(ts.T(), err)
Expand Down

0 comments on commit 505a4ff

Please sign in to comment.