Skip to content

Commit

Permalink
add ability to provide apiKey and additional headers through environm…
Browse files Browse the repository at this point in the history
…ent variables, add logic for using passed params over environment variables where applicable, add new unit tests
  • Loading branch information
austin-denoble committed May 8, 2024
1 parent cbb7f75 commit 8eac891
Show file tree
Hide file tree
Showing 4 changed files with 121 additions and 19 deletions.
2 changes: 1 addition & 1 deletion .env.example
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
API_KEY="<Project API Key>"
PINECONE_API_KEY="<Project API Key>"
TEST_POD_INDEX_NAME="<Pod based Index name>"
TEST_SERVERLESS_INDEX_NAME="<Serverless based Index name>"
64 changes: 51 additions & 13 deletions pinecone/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ import (
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"os"
"strings"

"github.com/deepmap/oapi-codegen/v2/pkg/securityprovider"
Expand All @@ -29,7 +31,7 @@ type NewClientParams struct {
}

func NewClient(in NewClientParams) (*Client, error) {
clientOptions, err := buildClientOptions(in)
clientOptions, err := in.buildClientOptions()
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -421,26 +423,62 @@ func derefOrDefault[T any](ptr *T, defaultValue T) T {
return *ptr
}

func buildClientOptions(in NewClientParams) ([]control.ClientOption, error) {
func (ncp *NewClientParams) buildClientOptions() ([]control.ClientOption, error) {
clientOptions := []control.ClientOption{}
osApiKey := os.Getenv("PINECONE_API_KEY")
envAdditionalHeaders, hasEnvAdditionalHeaders := os.LookupEnv("PINECONE_ADDITIONAL_HEADERS")
hasAuthorizationHeader := false
hasApiKey := in.ApiKey != ""
hasApiKey := (ncp.ApiKey != "" || osApiKey != "")

userAgentProvider := provider.NewHeaderProvider("User-Agent", useragent.BuildUserAgent(in.SourceTag))
userAgentProvider := provider.NewHeaderProvider("User-Agent", useragent.BuildUserAgent(ncp.SourceTag))
clientOptions = append(clientOptions, control.WithRequestEditorFn(userAgentProvider.Intercept))

for key, value := range in.Headers {
headerProvider := provider.NewHeaderProvider(key, value)
// apply headers from parameters if passed, otherwise use environment headers
if ncp.Headers != nil {
for key, value := range ncp.Headers {
headerProvider := provider.NewHeaderProvider(key, value)

if strings.Contains(strings.ToLower(key), "authorization") {
hasAuthorizationHeader = true
if strings.Contains(strings.ToLower(key), "authorization") {
hasAuthorizationHeader = true
}

clientOptions = append(clientOptions, control.WithRequestEditorFn(headerProvider.Intercept))
}
} else if hasEnvAdditionalHeaders {
additionalHeaders := make(map[string]string)
err := json.Unmarshal([]byte(envAdditionalHeaders), &additionalHeaders)
if err != nil {
log.Printf("failed to parse PINECONE_ADDITIONAL_HEADERS: %v", err)
} else {
for header, value := range additionalHeaders {
headerProvider := provider.NewHeaderProvider(header, value)

clientOptions = append(clientOptions, control.WithRequestEditorFn(headerProvider.Intercept))
if strings.Contains(strings.ToLower(header), "authorization") {
hasAuthorizationHeader = true
}

clientOptions = append(clientOptions, control.WithRequestEditorFn(headerProvider.Intercept))
}
}
}

if !hasAuthorizationHeader {
apiKeyProvider, err := securityprovider.NewSecurityProviderApiKey("header", "Api-Key", in.ApiKey)
// if apiKey is provided and no auth header is set, add the apiKey as a header
// apiKey from parameters takes precedence over apiKey from environment
if hasApiKey && !hasAuthorizationHeader {

fmt.Printf("OS API KEY: %s\n", osApiKey)
fmt.Printf("NCP PARAMS API KEY: %s\n", ncp.ApiKey)

var appliedApiKey string
if ncp.ApiKey != "" {
appliedApiKey = ncp.ApiKey
fmt.Printf("ncp key applied")
} else {
appliedApiKey = osApiKey
fmt.Printf("os key applied")
}

apiKeyProvider, err := securityprovider.NewSecurityProviderApiKey("header", "Api-Key", appliedApiKey)
if err != nil {
return nil, err
}
Expand All @@ -451,8 +489,8 @@ func buildClientOptions(in NewClientParams) ([]control.ClientOption, error) {
return nil, fmt.Errorf("no API key provided, please pass an API key for authorization")
}

if in.RestClient != nil {
clientOptions = append(clientOptions, control.WithHTTPClient(in.RestClient))
if ncp.RestClient != nil {
clientOptions = append(clientOptions, control.WithHTTPClient(ncp.RestClient))
}

return clientOptions, nil
Expand Down
72 changes: 68 additions & 4 deletions pinecone/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (

"github.com/google/uuid"
"github.com/pinecone-io/go-pinecone/internal/mocks"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
Expand All @@ -28,8 +29,8 @@ func TestClient(t *testing.T) {
}

func (ts *ClientTests) SetupSuite() {
apiKey := os.Getenv("API_KEY")
require.NotEmpty(ts.T(), apiKey, "API_KEY env variable not set")
apiKey := os.Getenv("INTEGRATION_PINECONE_API_KEY")
require.NotEmpty(ts.T(), apiKey, "INTEGRATION_PINECONE_API_KEY env variable not set")

ts.podIndex = os.Getenv("TEST_POD_INDEX_NAME")
require.NotEmpty(ts.T(), ts.podIndex, "TEST_POD_INDEX_NAME env variable not set")
Expand Down Expand Up @@ -139,9 +140,52 @@ func (ts *ClientTests) TestHeadersAppliedToRequests() {
require.NotNil(ts.T(), mockTransport.Req, "Expected request to be made")

testHeaderValue := mockTransport.Req.Header.Get("test-header")
if testHeaderValue != "123456" {
ts.FailNow(fmt.Sprintf("Expected request to have header value '123456', but got '%s'", testHeaderValue))
assert.Equal(ts.T(), "123456", testHeaderValue, "Expected request to have header value '123456', but got '%s'", testHeaderValue)
}

func (ts *ClientTests) TestAdditionalHeadersAppliedToRequest() {
os.Setenv("PINECONE_ADDITIONAL_HEADERS", `{"test-header": "environment-header"}`)

apiKey := "test-api-key"

httpClient := mocks.CreateMockClient(`{"indexes": []}`)
client, err := NewClient(NewClientParams{ApiKey: apiKey, RestClient: httpClient})
if err != nil {
ts.FailNow(err.Error())
}
mockTransport := httpClient.Transport.(*mocks.MockTransport)

_, err = client.ListIndexes(context.Background())
require.NoError(ts.T(), err)
require.NotNil(ts.T(), mockTransport.Req, "Expected request to be made")

testHeaderValue := mockTransport.Req.Header.Get("test-header")
assert.Equal(ts.T(), "environment-header", testHeaderValue, "Expected request to have header value 'environment-header', but got '%s'", testHeaderValue)

os.Unsetenv("PINECONE_ADDITIONAL_HEADERS")
}

func (ts *ClientTests) TestHeadersOverrideAdditionalHeaders() {
os.Setenv("PINECONE_ADDITIONAL_HEADERS", `{"test-header": "environment-header"}`)

apiKey := "test-api-key"
headers := map[string]string{"test-header": "param-header"}

httpClient := mocks.CreateMockClient(`{"indexes": []}`)
client, err := NewClient(NewClientParams{ApiKey: apiKey, Headers: headers, RestClient: httpClient})
if err != nil {
ts.FailNow(err.Error())
}
mockTransport := httpClient.Transport.(*mocks.MockTransport)

_, err = client.ListIndexes(context.Background())
require.NoError(ts.T(), err)
require.NotNil(ts.T(), mockTransport.Req, "Expected request to be made")

testHeaderValue := mockTransport.Req.Header.Get("test-header")
assert.Equal(ts.T(), "param-header", testHeaderValue, "Expected request to have header value 'param-header', but got '%s'", testHeaderValue)

os.Unsetenv("PINECONE_ADDITIONAL_HEADERS")
}

func (ts *ClientTests) TestAuthorizationHeaderOverridesApiKey() {
Expand Down Expand Up @@ -169,6 +213,26 @@ func (ts *ClientTests) TestAuthorizationHeaderOverridesApiKey() {
}
}

func (ts *ClientTests) TestClientReadsApiKeyFromEnv() {
os.Setenv("PINECONE_API_KEY", "test-env-api-key")

httpClient := mocks.CreateMockClient(`{"indexes": []}`)
client, err := NewClient(NewClientParams{RestClient: httpClient})
if err != nil {
ts.FailNow(err.Error())
}
mockTransport := httpClient.Transport.(*mocks.MockTransport)

_, err = client.ListIndexes(context.Background())
require.NoError(ts.T(), err)
require.NotNil(ts.T(), mockTransport.Req, "Expected request to be made")

testHeaderValue := mockTransport.Req.Header.Get("Api-Key")
assert.Equal(ts.T(), "test-env-api-key", testHeaderValue, "Expected request to have header value 'test-env-api-key', but got '%s'", testHeaderValue)

os.Unsetenv("PINECONE_API_KEY")
}

func (ts *ClientTests) TestListIndexes() {
indexes, err := ts.client.ListIndexes(context.Background())
require.NoError(ts.T(), err)
Expand Down
2 changes: 1 addition & 1 deletion pinecone/index_connection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ type IndexConnectionTests struct {

// Runs the test suite with `go test`
func TestIndexConnection(t *testing.T) {
apiKey := os.Getenv("API_KEY")
apiKey := os.Getenv("INTEGRATION_PINECONE_API_KEY")
assert.NotEmptyf(t, apiKey, "API_KEY env variable not set")

client, err := NewClient(NewClientParams{ApiKey: apiKey})
Expand Down

0 comments on commit 8eac891

Please sign in to comment.