diff --git a/.env.example b/.env.example index 846feee..de982a6 100644 --- a/.env.example +++ b/.env.example @@ -1,3 +1,3 @@ -API_KEY="" +PINECONE_API_KEY="" TEST_POD_INDEX_NAME="" TEST_SERVERLESS_INDEX_NAME="" diff --git a/pinecone/client.go b/pinecone/client.go index b34e017..d199f57 100644 --- a/pinecone/client.go +++ b/pinecone/client.go @@ -5,7 +5,9 @@ import ( "encoding/json" "fmt" "io" + "log" "net/http" + "os" "strings" "github.com/deepmap/oapi-codegen/v2/pkg/securityprovider" @@ -29,7 +31,7 @@ type NewClientParams struct { } func NewClient(in NewClientParams) (*Client, error) { - clientOptions, err := buildClientOptions(in) + clientOptions, err := in.buildClientOptions() if err != nil { return nil, err } @@ -421,26 +423,62 @@ func derefOrDefault[T any](ptr *T, defaultValue T) T { return *ptr } -func buildClientOptions(in NewClientParams) ([]control.ClientOption, error) { +func (ncp *NewClientParams) buildClientOptions() ([]control.ClientOption, error) { clientOptions := []control.ClientOption{} + osApiKey := os.Getenv("PINECONE_API_KEY") + envAdditionalHeaders, hasEnvAdditionalHeaders := os.LookupEnv("PINECONE_ADDITIONAL_HEADERS") hasAuthorizationHeader := false - hasApiKey := in.ApiKey != "" + hasApiKey := (ncp.ApiKey != "" || osApiKey != "") - userAgentProvider := provider.NewHeaderProvider("User-Agent", useragent.BuildUserAgent(in.SourceTag)) + userAgentProvider := provider.NewHeaderProvider("User-Agent", useragent.BuildUserAgent(ncp.SourceTag)) clientOptions = append(clientOptions, control.WithRequestEditorFn(userAgentProvider.Intercept)) - for key, value := range in.Headers { - headerProvider := provider.NewHeaderProvider(key, value) + // apply headers from parameters if passed, otherwise use environment headers + if ncp.Headers != nil { + for key, value := range ncp.Headers { + headerProvider := provider.NewHeaderProvider(key, value) - if strings.Contains(strings.ToLower(key), "authorization") { - hasAuthorizationHeader = true + if strings.Contains(strings.ToLower(key), "authorization") { + hasAuthorizationHeader = true + } + + clientOptions = append(clientOptions, control.WithRequestEditorFn(headerProvider.Intercept)) } + } else if hasEnvAdditionalHeaders { + additionalHeaders := make(map[string]string) + err := json.Unmarshal([]byte(envAdditionalHeaders), &additionalHeaders) + if err != nil { + log.Printf("failed to parse PINECONE_ADDITIONAL_HEADERS: %v", err) + } else { + for header, value := range additionalHeaders { + headerProvider := provider.NewHeaderProvider(header, value) - clientOptions = append(clientOptions, control.WithRequestEditorFn(headerProvider.Intercept)) + if strings.Contains(strings.ToLower(header), "authorization") { + hasAuthorizationHeader = true + } + + clientOptions = append(clientOptions, control.WithRequestEditorFn(headerProvider.Intercept)) + } + } } - if !hasAuthorizationHeader { - apiKeyProvider, err := securityprovider.NewSecurityProviderApiKey("header", "Api-Key", in.ApiKey) + // if apiKey is provided and no auth header is set, add the apiKey as a header + // apiKey from parameters takes precedence over apiKey from environment + if hasApiKey && !hasAuthorizationHeader { + + fmt.Printf("OS API KEY: %s\n", osApiKey) + fmt.Printf("NCP PARAMS API KEY: %s\n", ncp.ApiKey) + + var appliedApiKey string + if ncp.ApiKey != "" { + appliedApiKey = ncp.ApiKey + fmt.Printf("ncp key applied") + } else { + appliedApiKey = osApiKey + fmt.Printf("os key applied") + } + + apiKeyProvider, err := securityprovider.NewSecurityProviderApiKey("header", "Api-Key", appliedApiKey) if err != nil { return nil, err } @@ -451,8 +489,8 @@ func buildClientOptions(in NewClientParams) ([]control.ClientOption, error) { return nil, fmt.Errorf("no API key provided, please pass an API key for authorization") } - if in.RestClient != nil { - clientOptions = append(clientOptions, control.WithHTTPClient(in.RestClient)) + if ncp.RestClient != nil { + clientOptions = append(clientOptions, control.WithHTTPClient(ncp.RestClient)) } return clientOptions, nil diff --git a/pinecone/client_test.go b/pinecone/client_test.go index ecb8bbb..a2acd66 100644 --- a/pinecone/client_test.go +++ b/pinecone/client_test.go @@ -10,6 +10,7 @@ import ( "github.com/google/uuid" "github.com/pinecone-io/go-pinecone/internal/mocks" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" ) @@ -28,8 +29,8 @@ func TestClient(t *testing.T) { } func (ts *ClientTests) SetupSuite() { - apiKey := os.Getenv("API_KEY") - require.NotEmpty(ts.T(), apiKey, "API_KEY env variable not set") + apiKey := os.Getenv("INTEGRATION_PINECONE_API_KEY") + require.NotEmpty(ts.T(), apiKey, "INTEGRATION_PINECONE_API_KEY env variable not set") ts.podIndex = os.Getenv("TEST_POD_INDEX_NAME") require.NotEmpty(ts.T(), ts.podIndex, "TEST_POD_INDEX_NAME env variable not set") @@ -139,9 +140,52 @@ func (ts *ClientTests) TestHeadersAppliedToRequests() { require.NotNil(ts.T(), mockTransport.Req, "Expected request to be made") testHeaderValue := mockTransport.Req.Header.Get("test-header") - if testHeaderValue != "123456" { - ts.FailNow(fmt.Sprintf("Expected request to have header value '123456', but got '%s'", testHeaderValue)) + assert.Equal(ts.T(), "123456", testHeaderValue, "Expected request to have header value '123456', but got '%s'", testHeaderValue) +} + +func (ts *ClientTests) TestAdditionalHeadersAppliedToRequest() { + os.Setenv("PINECONE_ADDITIONAL_HEADERS", `{"test-header": "environment-header"}`) + + apiKey := "test-api-key" + + httpClient := mocks.CreateMockClient(`{"indexes": []}`) + client, err := NewClient(NewClientParams{ApiKey: apiKey, RestClient: httpClient}) + if err != nil { + ts.FailNow(err.Error()) + } + mockTransport := httpClient.Transport.(*mocks.MockTransport) + + _, err = client.ListIndexes(context.Background()) + require.NoError(ts.T(), err) + require.NotNil(ts.T(), mockTransport.Req, "Expected request to be made") + + testHeaderValue := mockTransport.Req.Header.Get("test-header") + assert.Equal(ts.T(), "environment-header", testHeaderValue, "Expected request to have header value 'environment-header', but got '%s'", testHeaderValue) + + os.Unsetenv("PINECONE_ADDITIONAL_HEADERS") +} + +func (ts *ClientTests) TestHeadersOverrideAdditionalHeaders() { + os.Setenv("PINECONE_ADDITIONAL_HEADERS", `{"test-header": "environment-header"}`) + + apiKey := "test-api-key" + headers := map[string]string{"test-header": "param-header"} + + httpClient := mocks.CreateMockClient(`{"indexes": []}`) + client, err := NewClient(NewClientParams{ApiKey: apiKey, Headers: headers, RestClient: httpClient}) + if err != nil { + ts.FailNow(err.Error()) } + mockTransport := httpClient.Transport.(*mocks.MockTransport) + + _, err = client.ListIndexes(context.Background()) + require.NoError(ts.T(), err) + require.NotNil(ts.T(), mockTransport.Req, "Expected request to be made") + + testHeaderValue := mockTransport.Req.Header.Get("test-header") + assert.Equal(ts.T(), "param-header", testHeaderValue, "Expected request to have header value 'param-header', but got '%s'", testHeaderValue) + + os.Unsetenv("PINECONE_ADDITIONAL_HEADERS") } func (ts *ClientTests) TestAuthorizationHeaderOverridesApiKey() { @@ -169,6 +213,26 @@ func (ts *ClientTests) TestAuthorizationHeaderOverridesApiKey() { } } +func (ts *ClientTests) TestClientReadsApiKeyFromEnv() { + os.Setenv("PINECONE_API_KEY", "test-env-api-key") + + httpClient := mocks.CreateMockClient(`{"indexes": []}`) + client, err := NewClient(NewClientParams{RestClient: httpClient}) + if err != nil { + ts.FailNow(err.Error()) + } + mockTransport := httpClient.Transport.(*mocks.MockTransport) + + _, err = client.ListIndexes(context.Background()) + require.NoError(ts.T(), err) + require.NotNil(ts.T(), mockTransport.Req, "Expected request to be made") + + testHeaderValue := mockTransport.Req.Header.Get("Api-Key") + assert.Equal(ts.T(), "test-env-api-key", testHeaderValue, "Expected request to have header value 'test-env-api-key', but got '%s'", testHeaderValue) + + os.Unsetenv("PINECONE_API_KEY") +} + func (ts *ClientTests) TestListIndexes() { indexes, err := ts.client.ListIndexes(context.Background()) require.NoError(ts.T(), err) diff --git a/pinecone/index_connection_test.go b/pinecone/index_connection_test.go index 7589535..1e5cb46 100644 --- a/pinecone/index_connection_test.go +++ b/pinecone/index_connection_test.go @@ -28,7 +28,7 @@ type IndexConnectionTests struct { // Runs the test suite with `go test` func TestIndexConnection(t *testing.T) { - apiKey := os.Getenv("API_KEY") + apiKey := os.Getenv("INTEGRATION_PINECONE_API_KEY") assert.NotEmptyf(t, apiKey, "API_KEY env variable not set") client, err := NewClient(NewClientParams{ApiKey: apiKey})