From 8eac891605518785da08d3c5231ebd2ab7ce3a34 Mon Sep 17 00:00:00 2001 From: Austin DeNoble Date: Wed, 8 May 2024 18:43:54 -0400 Subject: [PATCH 01/11] add ability to provide apiKey and additional headers through environment variables, add logic for using passed params over environment variables where applicable, add new unit tests --- .env.example | 2 +- pinecone/client.go | 64 +++++++++++++++++++++------ pinecone/client_test.go | 72 +++++++++++++++++++++++++++++-- pinecone/index_connection_test.go | 2 +- 4 files changed, 121 insertions(+), 19 deletions(-) diff --git a/.env.example b/.env.example index 846feee..de982a6 100644 --- a/.env.example +++ b/.env.example @@ -1,3 +1,3 @@ -API_KEY="" +PINECONE_API_KEY="" TEST_POD_INDEX_NAME="" TEST_SERVERLESS_INDEX_NAME="" diff --git a/pinecone/client.go b/pinecone/client.go index b34e017..d199f57 100644 --- a/pinecone/client.go +++ b/pinecone/client.go @@ -5,7 +5,9 @@ import ( "encoding/json" "fmt" "io" + "log" "net/http" + "os" "strings" "github.com/deepmap/oapi-codegen/v2/pkg/securityprovider" @@ -29,7 +31,7 @@ type NewClientParams struct { } func NewClient(in NewClientParams) (*Client, error) { - clientOptions, err := buildClientOptions(in) + clientOptions, err := in.buildClientOptions() if err != nil { return nil, err } @@ -421,26 +423,62 @@ func derefOrDefault[T any](ptr *T, defaultValue T) T { return *ptr } -func buildClientOptions(in NewClientParams) ([]control.ClientOption, error) { +func (ncp *NewClientParams) buildClientOptions() ([]control.ClientOption, error) { clientOptions := []control.ClientOption{} + osApiKey := os.Getenv("PINECONE_API_KEY") + envAdditionalHeaders, hasEnvAdditionalHeaders := os.LookupEnv("PINECONE_ADDITIONAL_HEADERS") hasAuthorizationHeader := false - hasApiKey := in.ApiKey != "" + hasApiKey := (ncp.ApiKey != "" || osApiKey != "") - userAgentProvider := provider.NewHeaderProvider("User-Agent", useragent.BuildUserAgent(in.SourceTag)) + userAgentProvider := provider.NewHeaderProvider("User-Agent", useragent.BuildUserAgent(ncp.SourceTag)) clientOptions = append(clientOptions, control.WithRequestEditorFn(userAgentProvider.Intercept)) - for key, value := range in.Headers { - headerProvider := provider.NewHeaderProvider(key, value) + // apply headers from parameters if passed, otherwise use environment headers + if ncp.Headers != nil { + for key, value := range ncp.Headers { + headerProvider := provider.NewHeaderProvider(key, value) - if strings.Contains(strings.ToLower(key), "authorization") { - hasAuthorizationHeader = true + if strings.Contains(strings.ToLower(key), "authorization") { + hasAuthorizationHeader = true + } + + clientOptions = append(clientOptions, control.WithRequestEditorFn(headerProvider.Intercept)) } + } else if hasEnvAdditionalHeaders { + additionalHeaders := make(map[string]string) + err := json.Unmarshal([]byte(envAdditionalHeaders), &additionalHeaders) + if err != nil { + log.Printf("failed to parse PINECONE_ADDITIONAL_HEADERS: %v", err) + } else { + for header, value := range additionalHeaders { + headerProvider := provider.NewHeaderProvider(header, value) - clientOptions = append(clientOptions, control.WithRequestEditorFn(headerProvider.Intercept)) + if strings.Contains(strings.ToLower(header), "authorization") { + hasAuthorizationHeader = true + } + + clientOptions = append(clientOptions, control.WithRequestEditorFn(headerProvider.Intercept)) + } + } } - if !hasAuthorizationHeader { - apiKeyProvider, err := securityprovider.NewSecurityProviderApiKey("header", "Api-Key", in.ApiKey) + // if apiKey is provided and no auth header is set, add the apiKey as a header + // apiKey from parameters takes precedence over apiKey from environment + if hasApiKey && !hasAuthorizationHeader { + + fmt.Printf("OS API KEY: %s\n", osApiKey) + fmt.Printf("NCP PARAMS API KEY: %s\n", ncp.ApiKey) + + var appliedApiKey string + if ncp.ApiKey != "" { + appliedApiKey = ncp.ApiKey + fmt.Printf("ncp key applied") + } else { + appliedApiKey = osApiKey + fmt.Printf("os key applied") + } + + apiKeyProvider, err := securityprovider.NewSecurityProviderApiKey("header", "Api-Key", appliedApiKey) if err != nil { return nil, err } @@ -451,8 +489,8 @@ func buildClientOptions(in NewClientParams) ([]control.ClientOption, error) { return nil, fmt.Errorf("no API key provided, please pass an API key for authorization") } - if in.RestClient != nil { - clientOptions = append(clientOptions, control.WithHTTPClient(in.RestClient)) + if ncp.RestClient != nil { + clientOptions = append(clientOptions, control.WithHTTPClient(ncp.RestClient)) } return clientOptions, nil diff --git a/pinecone/client_test.go b/pinecone/client_test.go index ecb8bbb..a2acd66 100644 --- a/pinecone/client_test.go +++ b/pinecone/client_test.go @@ -10,6 +10,7 @@ import ( "github.com/google/uuid" "github.com/pinecone-io/go-pinecone/internal/mocks" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" ) @@ -28,8 +29,8 @@ func TestClient(t *testing.T) { } func (ts *ClientTests) SetupSuite() { - apiKey := os.Getenv("API_KEY") - require.NotEmpty(ts.T(), apiKey, "API_KEY env variable not set") + apiKey := os.Getenv("INTEGRATION_PINECONE_API_KEY") + require.NotEmpty(ts.T(), apiKey, "INTEGRATION_PINECONE_API_KEY env variable not set") ts.podIndex = os.Getenv("TEST_POD_INDEX_NAME") require.NotEmpty(ts.T(), ts.podIndex, "TEST_POD_INDEX_NAME env variable not set") @@ -139,9 +140,52 @@ func (ts *ClientTests) TestHeadersAppliedToRequests() { require.NotNil(ts.T(), mockTransport.Req, "Expected request to be made") testHeaderValue := mockTransport.Req.Header.Get("test-header") - if testHeaderValue != "123456" { - ts.FailNow(fmt.Sprintf("Expected request to have header value '123456', but got '%s'", testHeaderValue)) + assert.Equal(ts.T(), "123456", testHeaderValue, "Expected request to have header value '123456', but got '%s'", testHeaderValue) +} + +func (ts *ClientTests) TestAdditionalHeadersAppliedToRequest() { + os.Setenv("PINECONE_ADDITIONAL_HEADERS", `{"test-header": "environment-header"}`) + + apiKey := "test-api-key" + + httpClient := mocks.CreateMockClient(`{"indexes": []}`) + client, err := NewClient(NewClientParams{ApiKey: apiKey, RestClient: httpClient}) + if err != nil { + ts.FailNow(err.Error()) + } + mockTransport := httpClient.Transport.(*mocks.MockTransport) + + _, err = client.ListIndexes(context.Background()) + require.NoError(ts.T(), err) + require.NotNil(ts.T(), mockTransport.Req, "Expected request to be made") + + testHeaderValue := mockTransport.Req.Header.Get("test-header") + assert.Equal(ts.T(), "environment-header", testHeaderValue, "Expected request to have header value 'environment-header', but got '%s'", testHeaderValue) + + os.Unsetenv("PINECONE_ADDITIONAL_HEADERS") +} + +func (ts *ClientTests) TestHeadersOverrideAdditionalHeaders() { + os.Setenv("PINECONE_ADDITIONAL_HEADERS", `{"test-header": "environment-header"}`) + + apiKey := "test-api-key" + headers := map[string]string{"test-header": "param-header"} + + httpClient := mocks.CreateMockClient(`{"indexes": []}`) + client, err := NewClient(NewClientParams{ApiKey: apiKey, Headers: headers, RestClient: httpClient}) + if err != nil { + ts.FailNow(err.Error()) } + mockTransport := httpClient.Transport.(*mocks.MockTransport) + + _, err = client.ListIndexes(context.Background()) + require.NoError(ts.T(), err) + require.NotNil(ts.T(), mockTransport.Req, "Expected request to be made") + + testHeaderValue := mockTransport.Req.Header.Get("test-header") + assert.Equal(ts.T(), "param-header", testHeaderValue, "Expected request to have header value 'param-header', but got '%s'", testHeaderValue) + + os.Unsetenv("PINECONE_ADDITIONAL_HEADERS") } func (ts *ClientTests) TestAuthorizationHeaderOverridesApiKey() { @@ -169,6 +213,26 @@ func (ts *ClientTests) TestAuthorizationHeaderOverridesApiKey() { } } +func (ts *ClientTests) TestClientReadsApiKeyFromEnv() { + os.Setenv("PINECONE_API_KEY", "test-env-api-key") + + httpClient := mocks.CreateMockClient(`{"indexes": []}`) + client, err := NewClient(NewClientParams{RestClient: httpClient}) + if err != nil { + ts.FailNow(err.Error()) + } + mockTransport := httpClient.Transport.(*mocks.MockTransport) + + _, err = client.ListIndexes(context.Background()) + require.NoError(ts.T(), err) + require.NotNil(ts.T(), mockTransport.Req, "Expected request to be made") + + testHeaderValue := mockTransport.Req.Header.Get("Api-Key") + assert.Equal(ts.T(), "test-env-api-key", testHeaderValue, "Expected request to have header value 'test-env-api-key', but got '%s'", testHeaderValue) + + os.Unsetenv("PINECONE_API_KEY") +} + func (ts *ClientTests) TestListIndexes() { indexes, err := ts.client.ListIndexes(context.Background()) require.NoError(ts.T(), err) diff --git a/pinecone/index_connection_test.go b/pinecone/index_connection_test.go index 7589535..1e5cb46 100644 --- a/pinecone/index_connection_test.go +++ b/pinecone/index_connection_test.go @@ -28,7 +28,7 @@ type IndexConnectionTests struct { // Runs the test suite with `go test` func TestIndexConnection(t *testing.T) { - apiKey := os.Getenv("API_KEY") + apiKey := os.Getenv("INTEGRATION_PINECONE_API_KEY") assert.NotEmptyf(t, apiKey, "API_KEY env variable not set") client, err := NewClient(NewClientParams{ApiKey: apiKey}) From 8446df9b40d80fe8a566e5e156cb8c3335a2476b Mon Sep 17 00:00:00 2001 From: Austin DeNoble Date: Wed, 8 May 2024 19:02:26 -0400 Subject: [PATCH 02/11] remove printf statements() --- pinecone/client.go | 6 ------ 1 file changed, 6 deletions(-) diff --git a/pinecone/client.go b/pinecone/client.go index d199f57..1667bdf 100644 --- a/pinecone/client.go +++ b/pinecone/client.go @@ -465,17 +465,11 @@ func (ncp *NewClientParams) buildClientOptions() ([]control.ClientOption, error) // if apiKey is provided and no auth header is set, add the apiKey as a header // apiKey from parameters takes precedence over apiKey from environment if hasApiKey && !hasAuthorizationHeader { - - fmt.Printf("OS API KEY: %s\n", osApiKey) - fmt.Printf("NCP PARAMS API KEY: %s\n", ncp.ApiKey) - var appliedApiKey string if ncp.ApiKey != "" { appliedApiKey = ncp.ApiKey - fmt.Printf("ncp key applied") } else { appliedApiKey = osApiKey - fmt.Printf("os key applied") } apiKeyProvider, err := securityprovider.NewSecurityProviderApiKey("header", "Api-Key", appliedApiKey) From 995829d0e2fb536fa07124a550417fd332de0ce6 Mon Sep 17 00:00:00 2001 From: Austin DeNoble Date: Wed, 8 May 2024 19:06:08 -0400 Subject: [PATCH 03/11] update CI workflow env variable --- .github/workflows/ci.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 6470966..9fa3bf1 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -12,7 +12,7 @@ jobs: - name: Setup Go uses: actions/setup-go@v4 with: - go-version: '1.21.x' + go-version: "1.21.x" - name: Install dependencies run: | go get ./pinecone @@ -21,4 +21,4 @@ jobs: env: TEST_POD_INDEX_NAME: ${{ secrets.TEST_POD_INDEX_NAME }} TEST_SERVERLESS_INDEX_NAME: ${{ secrets.TEST_SERVERLESS_INDEX_NAME }} - API_KEY: ${{ secrets.API_KEY }} \ No newline at end of file + INTEGRATION_PINECONE_API_KEY: ${{ secrets.API_KEY }} From 6eb91a3660e431daccf39ec18eec6f05971cb5f6 Mon Sep 17 00:00:00 2001 From: Austin DeNoble Date: Wed, 8 May 2024 23:05:34 -0400 Subject: [PATCH 04/11] add ability to specify host override for control plane operations through NewClientParams or environment variables, add tests --- pinecone/client.go | 49 +++++++++++++++++----- pinecone/client_test.go | 79 ++++++++++++++++++++++++++++++++++++ pinecone/index_connection.go | 1 - 3 files changed, 117 insertions(+), 12 deletions(-) diff --git a/pinecone/client.go b/pinecone/client.go index 1667bdf..014a092 100644 --- a/pinecone/client.go +++ b/pinecone/client.go @@ -7,6 +7,7 @@ import ( "io" "log" "net/http" + "net/url" "os" "strings" @@ -18,16 +19,17 @@ import ( type Client struct { apiKey string + headers map[string]string restClient *control.Client sourceTag string - headers map[string]string } type NewClientParams struct { ApiKey string // required unless Authorization header provided - SourceTag string // optional Headers map[string]string // optional + Host string // optional RestClient *http.Client // optional + SourceTag string // optional } func NewClient(in NewClientParams) (*Client, error) { @@ -36,7 +38,15 @@ func NewClient(in NewClientParams) (*Client, error) { return nil, err } - client, err := control.NewClient("https://api.pinecone.io", clientOptions...) + controlHostOverride := valueOrFallback(in.Host, os.Getenv("PINECONE_CONTROLLER_HOST")) + if controlHostOverride != "" { + controlHostOverride, err = ensureHTTP(controlHostOverride) + if err != nil { + return nil, err + } + } + + client, err := control.NewClient(valueOrFallback(controlHostOverride, "https://api.pinecone.io"), clientOptions...) if err != nil { return nil, err } @@ -425,11 +435,12 @@ func derefOrDefault[T any](ptr *T, defaultValue T) T { func (ncp *NewClientParams) buildClientOptions() ([]control.ClientOption, error) { clientOptions := []control.ClientOption{} - osApiKey := os.Getenv("PINECONE_API_KEY") - envAdditionalHeaders, hasEnvAdditionalHeaders := os.LookupEnv("PINECONE_ADDITIONAL_HEADERS") hasAuthorizationHeader := false + osApiKey := os.Getenv("PINECONE_API_KEY") hasApiKey := (ncp.ApiKey != "" || osApiKey != "") + envAdditionalHeaders, hasEnvAdditionalHeaders := os.LookupEnv("PINECONE_ADDITIONAL_HEADERS") + // build and apply user agent userAgentProvider := provider.NewHeaderProvider("User-Agent", useragent.BuildUserAgent(ncp.SourceTag)) clientOptions = append(clientOptions, control.WithRequestEditorFn(userAgentProvider.Intercept)) @@ -465,12 +476,7 @@ func (ncp *NewClientParams) buildClientOptions() ([]control.ClientOption, error) // if apiKey is provided and no auth header is set, add the apiKey as a header // apiKey from parameters takes precedence over apiKey from environment if hasApiKey && !hasAuthorizationHeader { - var appliedApiKey string - if ncp.ApiKey != "" { - appliedApiKey = ncp.ApiKey - } else { - appliedApiKey = osApiKey - } + appliedApiKey := valueOrFallback(ncp.ApiKey, osApiKey) apiKeyProvider, err := securityprovider.NewSecurityProviderApiKey("header", "Api-Key", appliedApiKey) if err != nil { @@ -489,3 +495,24 @@ func (ncp *NewClientParams) buildClientOptions() ([]control.ClientOption, error) return clientOptions, nil } + +func ensureHTTP(inputURL string) (string, error) { + parsedURL, err := url.Parse(inputURL) + if err != nil { + return "", fmt.Errorf("invalid URL: %v", err) + } + + if parsedURL.Scheme == "" { + return "https://" + inputURL, nil + } + return inputURL, nil +} + +func valueOrFallback[T comparable](value, fallback T) T { + var zero T + if value != zero { + return value + } else { + return fallback + } +} diff --git a/pinecone/client_test.go b/pinecone/client_test.go index a2acd66..c4ce200 100644 --- a/pinecone/client_test.go +++ b/pinecone/client_test.go @@ -233,6 +233,85 @@ func (ts *ClientTests) TestClientReadsApiKeyFromEnv() { os.Unsetenv("PINECONE_API_KEY") } +func (ts *ClientTests) TestControllerHostOverride() { + apiKey := "test-api-key" + httpClient := mocks.CreateMockClient(`{"indexes": []}`) + client, err := NewClient(NewClientParams{ApiKey: apiKey, Host: "https://test-controller-host.io", RestClient: httpClient}) + if err != nil { + ts.FailNow(err.Error()) + } + mockTransport := httpClient.Transport.(*mocks.MockTransport) + + _, err = client.ListIndexes(context.Background()) + require.NoError(ts.T(), err) + require.NotNil(ts.T(), mockTransport.Req, "Expected request to be made") + assert.Equal(ts.T(), "test-controller-host.io", mockTransport.Req.Host, "Expected request to be made to 'test-controller-host.io', but got '%s'", mockTransport.Req.URL.Host) +} + +func (ts *ClientTests) TestControllerHostOverrideFromEnv() { + os.Setenv("PINECONE_CONTROLLER_HOST", "https://env-controller-host.io") + + apiKey := "test-api-key" + httpClient := mocks.CreateMockClient(`{"indexes": []}`) + client, err := NewClient(NewClientParams{ApiKey: apiKey, RestClient: httpClient}) + if err != nil { + ts.FailNow(err.Error()) + } + mockTransport := httpClient.Transport.(*mocks.MockTransport) + + _, err = client.ListIndexes(context.Background()) + require.NoError(ts.T(), err) + require.NotNil(ts.T(), mockTransport.Req, "Expected request to be made") + assert.Equal(ts.T(), "env-controller-host.io", mockTransport.Req.Host, "Expected request to be made to 'env-controller-host.io', but got '%s'", mockTransport.Req.URL.Host) + + os.Unsetenv("PINECONE_CONTROLLER_HOST") +} + +func (ts *ClientTests) TestControllerHostNormalization() { + tests := []struct { + name string + host string + wantHost string + wantScheme string + }{ + { + name: "Test with https prefix", + host: "https://pinecone-api.io", + wantHost: "pinecone-api.io", + wantScheme: "https", + }, { + name: "Test with http prefix", + host: "http://pinecone-api.io", + wantHost: "pinecone-api.io", + wantScheme: "http", + }, { + name: "Test without prefix", + host: "pinecone-api.io", + wantHost: "pinecone-api.io", + wantScheme: "https", + }, + } + + for _, tt := range tests { + ts.Run(tt.name, func() { + apiKey := "test-api-key" + httpClient := mocks.CreateMockClient(`{"indexes": []}`) + client, err := NewClient(NewClientParams{ApiKey: apiKey, Host: tt.host, RestClient: httpClient}) + if err != nil { + ts.FailNow(err.Error()) + } + mockTransport := httpClient.Transport.(*mocks.MockTransport) + + _, err = client.ListIndexes(context.Background()) + require.NoError(ts.T(), err) + require.NotNil(ts.T(), mockTransport.Req, "Expected request to be made") + + assert.Equal(ts.T(), tt.wantHost, mockTransport.Req.URL.Host, "Expected request to be made to host '%s', but got '%s'", tt.wantHost, mockTransport.Req.URL.Host) + assert.Equal(ts.T(), tt.wantScheme, mockTransport.Req.URL.Scheme, "Expected request to be made to host '%s, but got '%s'", tt.wantScheme, mockTransport.Req.URL.Host) + }) + } +} + func (ts *ClientTests) TestListIndexes() { indexes, err := ts.client.ListIndexes(context.Background()) require.NoError(ts.T(), err) diff --git a/pinecone/index_connection.go b/pinecone/index_connection.go index 866d557..fb72391 100644 --- a/pinecone/index_connection.go +++ b/pinecone/index_connection.go @@ -94,7 +94,6 @@ func (idx *IndexConnection) FetchVectors(ctx context.Context, ids []string) (*Fe for id, vector := range res.Vectors { vectors[id] = toVector(vector) } - fmt.Printf("VECTORS: %+v\n", vectors) return &FetchVectorsResponse{ Vectors: vectors, From 9f56dedf4cb956cc5579401fe8be8d5e23ad5a52 Mon Sep 17 00:00:00 2001 From: Austin DeNoble Date: Tue, 14 May 2024 12:40:56 -0400 Subject: [PATCH 05/11] refactor into NewClientParams and NewClientBaseParams --- pinecone/client.go | 105 ++++++++++++++++++++++++---------------- pinecone/client_test.go | 53 ++++++++++---------- 2 files changed, 92 insertions(+), 66 deletions(-) diff --git a/pinecone/client.go b/pinecone/client.go index 014a092..26bcdd0 100644 --- a/pinecone/client.go +++ b/pinecone/client.go @@ -9,7 +9,6 @@ import ( "net/http" "net/url" "os" - "strings" "github.com/deepmap/oapi-codegen/v2/pkg/securityprovider" "github.com/pinecone-io/go-pinecone/internal/gen/control" @@ -24,12 +23,19 @@ type Client struct { sourceTag string } +type NewClientBaseParams struct { + Headers map[string]string + Host string + RestClient *http.Client + SourceTag string +} + type NewClientParams struct { - ApiKey string // required unless Authorization header provided - Headers map[string]string // optional - Host string // optional - RestClient *http.Client // optional - SourceTag string // optional + ApiKey string + Headers map[string]string + Host string + RestClient *http.Client + SourceTag string } func NewClient(in NewClientParams) (*Client, error) { @@ -40,7 +46,7 @@ func NewClient(in NewClientParams) (*Client, error) { controlHostOverride := valueOrFallback(in.Host, os.Getenv("PINECONE_CONTROLLER_HOST")) if controlHostOverride != "" { - controlHostOverride, err = ensureHTTP(controlHostOverride) + controlHostOverride, err = ensureHTTPS(controlHostOverride) if err != nil { return nil, err } @@ -55,6 +61,27 @@ func NewClient(in NewClientParams) (*Client, error) { return &c, nil } +func NewClientBase(in NewClientBaseParams) (*Client, error) { + clientOptions := in.buildClientBaseOptions() + var err error + + controlHostOverride := valueOrFallback(in.Host, os.Getenv("PINECONE_CONTROLLER_HOST")) + if controlHostOverride != "" { + controlHostOverride, err = ensureHTTPS(controlHostOverride) + if err != nil { + return nil, err + } + } + + client, err := control.NewClient(valueOrFallback(controlHostOverride, "https://api.pinecone.io"), clientOptions...) + if err != nil { + return nil, err + } + + c := Client{restClient: client, sourceTag: in.SourceTag, headers: in.Headers} + return &c, nil +} + func (c *Client) Index(host string) (*IndexConnection, error) { return c.IndexWithAdditionalMetadata(host, "", nil) } @@ -435,24 +462,41 @@ func derefOrDefault[T any](ptr *T, defaultValue T) T { func (ncp *NewClientParams) buildClientOptions() ([]control.ClientOption, error) { clientOptions := []control.ClientOption{} - hasAuthorizationHeader := false osApiKey := os.Getenv("PINECONE_API_KEY") hasApiKey := (ncp.ApiKey != "" || osApiKey != "") + + if !hasApiKey { + return nil, fmt.Errorf("no API key provided, please pass an API key for authorization") + } + + appliedApiKey := valueOrFallback(ncp.ApiKey, osApiKey) + apiKeyProvider, err := securityprovider.NewSecurityProviderApiKey("header", "Api-Key", appliedApiKey) + if err != nil { + return nil, err + } + clientOptions = append(clientOptions, control.WithRequestEditorFn(apiKeyProvider.Intercept)) + + baseClient := NewClientBaseParams{ncp.Headers, ncp.Host, ncp.RestClient, ncp.SourceTag} + baseClientOptions := baseClient.buildClientBaseOptions() + clientOptions = append(clientOptions, baseClientOptions...) + + return clientOptions, nil +} + +func (ncbp *NewClientBaseParams) buildClientBaseOptions() []control.ClientOption { + clientOptions := []control.ClientOption{} + envAdditionalHeaders, hasEnvAdditionalHeaders := os.LookupEnv("PINECONE_ADDITIONAL_HEADERS") // build and apply user agent - userAgentProvider := provider.NewHeaderProvider("User-Agent", useragent.BuildUserAgent(ncp.SourceTag)) + fmt.Printf("Source tag on its way in: %v\n", ncbp.SourceTag) + userAgentProvider := provider.NewHeaderProvider("User-Agent", useragent.BuildUserAgent(ncbp.SourceTag)) clientOptions = append(clientOptions, control.WithRequestEditorFn(userAgentProvider.Intercept)) // apply headers from parameters if passed, otherwise use environment headers - if ncp.Headers != nil { - for key, value := range ncp.Headers { + if ncbp.Headers != nil { + for key, value := range ncbp.Headers { headerProvider := provider.NewHeaderProvider(key, value) - - if strings.Contains(strings.ToLower(key), "authorization") { - hasAuthorizationHeader = true - } - clientOptions = append(clientOptions, control.WithRequestEditorFn(headerProvider.Intercept)) } } else if hasEnvAdditionalHeaders { @@ -463,40 +507,19 @@ func (ncp *NewClientParams) buildClientOptions() ([]control.ClientOption, error) } else { for header, value := range additionalHeaders { headerProvider := provider.NewHeaderProvider(header, value) - - if strings.Contains(strings.ToLower(header), "authorization") { - hasAuthorizationHeader = true - } - clientOptions = append(clientOptions, control.WithRequestEditorFn(headerProvider.Intercept)) } } } - // if apiKey is provided and no auth header is set, add the apiKey as a header - // apiKey from parameters takes precedence over apiKey from environment - if hasApiKey && !hasAuthorizationHeader { - appliedApiKey := valueOrFallback(ncp.ApiKey, osApiKey) - - apiKeyProvider, err := securityprovider.NewSecurityProviderApiKey("header", "Api-Key", appliedApiKey) - if err != nil { - return nil, err - } - clientOptions = append(clientOptions, control.WithRequestEditorFn(apiKeyProvider.Intercept)) + if ncbp.RestClient != nil { + clientOptions = append(clientOptions, control.WithHTTPClient(ncbp.RestClient)) } - if !hasAuthorizationHeader && !hasApiKey { - return nil, fmt.Errorf("no API key provided, please pass an API key for authorization") - } - - if ncp.RestClient != nil { - clientOptions = append(clientOptions, control.WithHTTPClient(ncp.RestClient)) - } - - return clientOptions, nil + return clientOptions } -func ensureHTTP(inputURL string) (string, error) { +func ensureHTTPS(inputURL string) (string, error) { parsedURL, err := url.Parse(inputURL) if err != nil { return "", fmt.Errorf("invalid URL: %v", err) diff --git a/pinecone/client_test.go b/pinecone/client_test.go index c4ce200..88637f0 100644 --- a/pinecone/client_test.go +++ b/pinecone/client_test.go @@ -81,7 +81,10 @@ func (ts *ClientTests) TestNewClientParamsSet() { func (ts *ClientTests) TestNewClientParamsSetSourceTag() { apiKey := "test-api-key" sourceTag := "test-source-tag" - client, err := NewClient(NewClientParams{ApiKey: apiKey, SourceTag: sourceTag}) + client, err := NewClient(NewClientParams{ + ApiKey: apiKey, + SourceTag: sourceTag, + }) if err != nil { ts.FailNow(err.Error()) } @@ -188,30 +191,30 @@ func (ts *ClientTests) TestHeadersOverrideAdditionalHeaders() { os.Unsetenv("PINECONE_ADDITIONAL_HEADERS") } -func (ts *ClientTests) TestAuthorizationHeaderOverridesApiKey() { - apiKey := "test-api-key" - headers := map[string]string{"Authorization": "bearer fooo"} - - httpClient := mocks.CreateMockClient(`{"indexes": []}`) - client, err := NewClient(NewClientParams{ApiKey: apiKey, Headers: headers, RestClient: httpClient}) - if err != nil { - ts.FailNow(err.Error()) - } - mockTransport := httpClient.Transport.(*mocks.MockTransport) - - _, err = client.ListIndexes(context.Background()) - require.NoError(ts.T(), err) - require.NotNil(ts.T(), mockTransport.Req, "Expected request to be made") - - apiKeyHeaderValue := mockTransport.Req.Header.Get("Api-Key") - authHeaderValue := mockTransport.Req.Header.Get("Authorization") - if authHeaderValue != "bearer fooo" { - ts.FailNow(fmt.Sprintf("Expected request to have header value 'bearer fooo', but got '%s'", authHeaderValue)) - } - if apiKeyHeaderValue != "" { - ts.FailNow(fmt.Sprintf("Expected request to not have Api-Key header, but got '%s'", apiKeyHeaderValue)) - } -} +// func (ts *ClientTests) TestAuthorizationHeaderOverridesApiKey() { +// apiKey := "test-api-key" +// headers := map[string]string{"Authorization": "bearer fooo"} + +// httpClient := mocks.CreateMockClient(`{"indexes": []}`) +// client, err := NewClient(NewClientParams{ApiKey: apiKey, Headers: headers, RestClient: httpClient}) +// if err != nil { +// ts.FailNow(err.Error()) +// } +// mockTransport := httpClient.Transport.(*mocks.MockTransport) + +// _, err = client.ListIndexes(context.Background()) +// require.NoError(ts.T(), err) +// require.NotNil(ts.T(), mockTransport.Req, "Expected request to be made") + +// apiKeyHeaderValue := mockTransport.Req.Header.Get("Api-Key") +// authHeaderValue := mockTransport.Req.Header.Get("Authorization") +// if authHeaderValue != "bearer fooo" { +// ts.FailNow(fmt.Sprintf("Expected request to have header value 'bearer fooo', but got '%s'", authHeaderValue)) +// } +// if apiKeyHeaderValue != "" { +// ts.FailNow(fmt.Sprintf("Expected request to not have Api-Key header, but got '%s'", apiKeyHeaderValue)) +// } +// } func (ts *ClientTests) TestClientReadsApiKeyFromEnv() { os.Setenv("PINECONE_API_KEY", "test-env-api-key") From e6ba260367e050305ac3a3dd5f0a92e7d7a5fa90 Mon Sep 17 00:00:00 2001 From: Austin DeNoble Date: Tue, 14 May 2024 19:57:11 -0400 Subject: [PATCH 06/11] refactor NewClientParams into NewClientBaseParams, update NewClient func and add NewClientBase, update tests to support the new interface for creating clients, remove apiKey from Client and IndexConnection, handle turning apiKey into a header early, and allowing for complete control over headers without using apiKey --- pinecone/client.go | 191 ++++++++++++++++-------------- pinecone/client_test.go | 96 ++++----------- pinecone/index_connection.go | 5 +- pinecone/index_connection_test.go | 104 +++++++--------- 4 files changed, 165 insertions(+), 231 deletions(-) diff --git a/pinecone/client.go b/pinecone/client.go index 26bcdd0..2d98642 100644 --- a/pinecone/client.go +++ b/pinecone/client.go @@ -9,15 +9,14 @@ import ( "net/http" "net/url" "os" + "strings" - "github.com/deepmap/oapi-codegen/v2/pkg/securityprovider" "github.com/pinecone-io/go-pinecone/internal/gen/control" "github.com/pinecone-io/go-pinecone/internal/provider" "github.com/pinecone-io/go-pinecone/internal/useragent" ) type Client struct { - apiKey string headers map[string]string restClient *control.Client sourceTag string @@ -39,30 +38,29 @@ type NewClientParams struct { } func NewClient(in NewClientParams) (*Client, error) { - clientOptions, err := in.buildClientOptions() - if err != nil { - return nil, err - } + osApiKey := os.Getenv("PINECONE_API_KEY") + hasApiKey := (in.ApiKey != "" || osApiKey != "") - controlHostOverride := valueOrFallback(in.Host, os.Getenv("PINECONE_CONTROLLER_HOST")) - if controlHostOverride != "" { - controlHostOverride, err = ensureHTTPS(controlHostOverride) - if err != nil { - return nil, err - } + if !hasApiKey { + return nil, fmt.Errorf("no API key provided, please pass an API key for authorization") } - client, err := control.NewClient(valueOrFallback(controlHostOverride, "https://api.pinecone.io"), clientOptions...) - if err != nil { - return nil, err + apiKeyHeader := struct{ Key, Value string }{"Api-Key", valueOrFallback(in.ApiKey, osApiKey)} + + clientHeaders := in.Headers + if clientHeaders == nil { + clientHeaders = make(map[string]string) + clientHeaders[apiKeyHeader.Key] = apiKeyHeader.Value + + } else { + clientHeaders[apiKeyHeader.Key] = apiKeyHeader.Value } - c := Client{apiKey: in.ApiKey, restClient: client, sourceTag: in.SourceTag, headers: in.Headers} - return &c, nil + return NewClientBase(NewClientBaseParams{Headers: clientHeaders, Host: in.Host, RestClient: in.RestClient, SourceTag: in.SourceTag}) } func NewClientBase(in NewClientBaseParams) (*Client, error) { - clientOptions := in.buildClientBaseOptions() + clientOptions := buildClientBaseOptions(in) var err error controlHostOverride := valueOrFallback(in.Host, os.Getenv("PINECONE_CONTROLLER_HOST")) @@ -82,6 +80,45 @@ func NewClientBase(in NewClientBaseParams) (*Client, error) { return &c, nil } +func buildClientBaseOptions(in NewClientBaseParams) []control.ClientOption { + clientOptions := []control.ClientOption{} + + // build and apply user agent + userAgentProvider := provider.NewHeaderProvider("User-Agent", useragent.BuildUserAgent(in.SourceTag)) + clientOptions = append(clientOptions, control.WithRequestEditorFn(userAgentProvider.Intercept)) + + envAdditionalHeaders, hasEnvAdditionalHeaders := os.LookupEnv("PINECONE_ADDITIONAL_HEADERS") + additionalHeaders := make(map[string]string) + + // add headers from environment + if hasEnvAdditionalHeaders { + err := json.Unmarshal([]byte(envAdditionalHeaders), &additionalHeaders) + if err != nil { + log.Printf("failed to parse PINECONE_ADDITIONAL_HEADERS: %v", err) + } + } + + // merge headers from parameters if passed + if in.Headers != nil { + for key, value := range in.Headers { + additionalHeaders[key] = value + } + } + + // add headers to client options + for key, value := range additionalHeaders { + headerProvider := provider.NewHeaderProvider(key, value) + clientOptions = append(clientOptions, control.WithRequestEditorFn(headerProvider.Intercept)) + } + + // apply custom http client if provided + if in.RestClient != nil { + clientOptions = append(clientOptions, control.WithHTTPClient(in.RestClient)) + } + + return clientOptions +} + func (c *Client) Index(host string) (*IndexConnection, error) { return c.IndexWithAdditionalMetadata(host, "", nil) } @@ -91,13 +128,42 @@ func (c *Client) IndexWithNamespace(host string, namespace string) (*IndexConnec } func (c *Client) IndexWithAdditionalMetadata(host string, namespace string, additionalMetadata map[string]string) (*IndexConnection, error) { - idx, err := newIndexConnection(newIndexParameters{apiKey: c.apiKey, host: host, namespace: namespace, sourceTag: c.sourceTag, additionalMetadata: additionalMetadata}) + authHeader := c.extractAuthHeader() + + // merge additionalMetadata with authHeader + if additionalMetadata != nil { + for _, key := range authHeader { + additionalMetadata[key] = authHeader[key] + } + } else { + additionalMetadata = authHeader + } + + idx, err := newIndexConnection(newIndexParameters{host: host, namespace: namespace, sourceTag: c.sourceTag, additionalMetadata: additionalMetadata}) if err != nil { return nil, err } return idx, nil } +func (c *Client) extractAuthHeader() map[string]string { + possibleAuthKeys := []string{ + "api-key", + "authorization", + "access_token", + } + + for key, value := range c.headers { + for _, checkKey := range possibleAuthKeys { + if strings.ToLower(key) == checkKey { + return map[string]string{key: value} + } + } + } + + return nil +} + func (c *Client) ListIndexes(ctx context.Context) ([]*Index, error) { res, err := c.restClient.ListIndexes(ctx) if err != nil { @@ -446,79 +512,6 @@ func decodeCollection(resBody io.ReadCloser) (*Collection, error) { return toCollection(&collectionModel), nil } -func minOne(x int32) int32 { - if x < 1 { - return 1 - } - return x -} - -func derefOrDefault[T any](ptr *T, defaultValue T) T { - if ptr == nil { - return defaultValue - } - return *ptr -} - -func (ncp *NewClientParams) buildClientOptions() ([]control.ClientOption, error) { - clientOptions := []control.ClientOption{} - osApiKey := os.Getenv("PINECONE_API_KEY") - hasApiKey := (ncp.ApiKey != "" || osApiKey != "") - - if !hasApiKey { - return nil, fmt.Errorf("no API key provided, please pass an API key for authorization") - } - - appliedApiKey := valueOrFallback(ncp.ApiKey, osApiKey) - apiKeyProvider, err := securityprovider.NewSecurityProviderApiKey("header", "Api-Key", appliedApiKey) - if err != nil { - return nil, err - } - clientOptions = append(clientOptions, control.WithRequestEditorFn(apiKeyProvider.Intercept)) - - baseClient := NewClientBaseParams{ncp.Headers, ncp.Host, ncp.RestClient, ncp.SourceTag} - baseClientOptions := baseClient.buildClientBaseOptions() - clientOptions = append(clientOptions, baseClientOptions...) - - return clientOptions, nil -} - -func (ncbp *NewClientBaseParams) buildClientBaseOptions() []control.ClientOption { - clientOptions := []control.ClientOption{} - - envAdditionalHeaders, hasEnvAdditionalHeaders := os.LookupEnv("PINECONE_ADDITIONAL_HEADERS") - - // build and apply user agent - fmt.Printf("Source tag on its way in: %v\n", ncbp.SourceTag) - userAgentProvider := provider.NewHeaderProvider("User-Agent", useragent.BuildUserAgent(ncbp.SourceTag)) - clientOptions = append(clientOptions, control.WithRequestEditorFn(userAgentProvider.Intercept)) - - // apply headers from parameters if passed, otherwise use environment headers - if ncbp.Headers != nil { - for key, value := range ncbp.Headers { - headerProvider := provider.NewHeaderProvider(key, value) - clientOptions = append(clientOptions, control.WithRequestEditorFn(headerProvider.Intercept)) - } - } else if hasEnvAdditionalHeaders { - additionalHeaders := make(map[string]string) - err := json.Unmarshal([]byte(envAdditionalHeaders), &additionalHeaders) - if err != nil { - log.Printf("failed to parse PINECONE_ADDITIONAL_HEADERS: %v", err) - } else { - for header, value := range additionalHeaders { - headerProvider := provider.NewHeaderProvider(header, value) - clientOptions = append(clientOptions, control.WithRequestEditorFn(headerProvider.Intercept)) - } - } - } - - if ncbp.RestClient != nil { - clientOptions = append(clientOptions, control.WithHTTPClient(ncbp.RestClient)) - } - - return clientOptions -} - func ensureHTTPS(inputURL string) (string, error) { parsedURL, err := url.Parse(inputURL) if err != nil { @@ -539,3 +532,17 @@ func valueOrFallback[T comparable](value, fallback T) T { return fallback } } + +func derefOrDefault[T any](ptr *T, defaultValue T) T { + if ptr == nil { + return defaultValue + } + return *ptr +} + +func minOne(x int32) int32 { + if x < 1 { + return 1 + } + return x +} diff --git a/pinecone/client_test.go b/pinecone/client_test.go index 88637f0..81bf2a0 100644 --- a/pinecone/client_test.go +++ b/pinecone/client_test.go @@ -4,7 +4,6 @@ import ( "context" "fmt" "os" - "reflect" "strings" "testing" @@ -39,16 +38,13 @@ func (ts *ClientTests) SetupSuite() { require.NotEmpty(ts.T(), ts.serverlessIndex, "TEST_SERVERLESS_INDEX_NAME env variable not set") client, err := NewClient(NewClientParams{ApiKey: apiKey}) - if err != nil { - ts.FailNow(err.Error()) - } + require.NoError(ts.T(), err) + ts.client = *client ts.sourceTag = "test_source_tag" clientSourceTag, err := NewClient(NewClientParams{ApiKey: apiKey, SourceTag: ts.sourceTag}) - if err != nil { - ts.FailNow(err.Error()) - } + require.NoError(ts.T(), err) ts.clientSourceTag = *clientSourceTag // this will clean up the project deleting all indexes and collections that are @@ -61,21 +57,14 @@ func (ts *ClientTests) SetupSuite() { func (ts *ClientTests) TestNewClientParamsSet() { apiKey := "test-api-key" client, err := NewClient(NewClientParams{ApiKey: apiKey}) - if err != nil { - ts.FailNow(err.Error()) - } - if client.apiKey != apiKey { - ts.FailNow(fmt.Sprintf("Expected client to have apiKey '%s', but got '%s'", apiKey, client.apiKey)) - } - if client.sourceTag != "" { - ts.FailNow(fmt.Sprintf("Expected client to have empty sourceTag, but got '%s'", client.sourceTag)) - } - if client.headers != nil { - ts.FailNow(fmt.Sprintf("Expected client headers to be nil, but got '%v'", client.headers)) - } - if len(client.restClient.RequestEditors) != 2 { - ts.FailNow("Expected client to have '%v' request editors, but got '%v'", 2, len(client.restClient.RequestEditors)) - } + + require.NoError(ts.T(), err) + require.Empty(ts.T(), client.sourceTag, "Expected client to have empty sourceTag") + require.NotNil(ts.T(), client.headers, "Expected client headers to not be nil") + apiKeyHeader, ok := client.headers["Api-Key"] + require.True(ts.T(), ok, "Expected client to have an 'Api-Key' header") + require.Equal(ts.T(), apiKey, apiKeyHeader, "Expected 'Api-Key' header to match provided ApiKey") + require.Equal(ts.T(), 2, len(client.restClient.RequestEditors), "Expected client to have correct number of require editors") } func (ts *ClientTests) TestNewClientParamsSetSourceTag() { @@ -85,36 +74,26 @@ func (ts *ClientTests) TestNewClientParamsSetSourceTag() { ApiKey: apiKey, SourceTag: sourceTag, }) - if err != nil { - ts.FailNow(err.Error()) - } - if client.apiKey != apiKey { - ts.FailNow(fmt.Sprintf("Expected client to have apiKey '%s', but got '%s'", apiKey, client.apiKey)) - } - if client.sourceTag != sourceTag { - ts.FailNow(fmt.Sprintf("Expected client to have sourceTag '%s', but got '%s'", sourceTag, client.sourceTag)) - } - if len(client.restClient.RequestEditors) != 2 { - ts.FailNow("Expected client to have '%v' request editors, but got '%v'", 2, len(client.restClient.RequestEditors)) - } + + require.NoError(ts.T(), err) + apiKeyHeader, ok := client.headers["Api-Key"] + require.True(ts.T(), ok, "Expected client to have an 'Api-Key' header") + require.Equal(ts.T(), apiKey, apiKeyHeader, "Expected 'Api-Key' header to match provided ApiKey") + require.Equal(ts.T(), sourceTag, client.sourceTag, "Expected client to have sourceTag '%s', but got '%s'", sourceTag, client.sourceTag) + require.Equal(ts.T(), 2, len(client.restClient.RequestEditors), "Expected client to have %s request editors, but got %s", 2, len(client.restClient.RequestEditors)) } func (ts *ClientTests) TestNewClientParamsSetHeaders() { apiKey := "test-api-key" headers := map[string]string{"test-header": "test-value"} client, err := NewClient(NewClientParams{ApiKey: apiKey, Headers: headers}) - if err != nil { - ts.FailNow(err.Error()) - } - if client.apiKey != apiKey { - ts.FailNow(fmt.Sprintf("Expected client to have apiKey '%s', but got '%s'", apiKey, client.apiKey)) - } - if !reflect.DeepEqual(client.headers, headers) { - ts.FailNow(fmt.Sprintf("Expected client to have headers '%+v', but got '%+v'", headers, client.headers)) - } - if len(client.restClient.RequestEditors) != 3 { - ts.FailNow(fmt.Sprintf("Expected client to have '%v' request editors, but got '%v'", 3, len(client.restClient.RequestEditors))) - } + + require.NoError(ts.T(), err) + apiKeyHeader, ok := client.headers["Api-Key"] + require.True(ts.T(), ok, "Expected client to have an 'Api-Key' header") + require.Equal(ts.T(), apiKey, apiKeyHeader, "Expected 'Api-Key' header to match provided ApiKey") + require.Equal(ts.T(), client.headers, headers, "Expected client to have headers '%+v', but got '%+v'", headers, client.headers) + require.Equal(ts.T(), 3, len(client.restClient.RequestEditors), "Expected client to have %s request editors, but got %s", 3, len(client.restClient.RequestEditors)) } func (ts *ClientTests) TestNewClientParamsNoApiKeyNoAuthorizationHeader() { @@ -191,31 +170,6 @@ func (ts *ClientTests) TestHeadersOverrideAdditionalHeaders() { os.Unsetenv("PINECONE_ADDITIONAL_HEADERS") } -// func (ts *ClientTests) TestAuthorizationHeaderOverridesApiKey() { -// apiKey := "test-api-key" -// headers := map[string]string{"Authorization": "bearer fooo"} - -// httpClient := mocks.CreateMockClient(`{"indexes": []}`) -// client, err := NewClient(NewClientParams{ApiKey: apiKey, Headers: headers, RestClient: httpClient}) -// if err != nil { -// ts.FailNow(err.Error()) -// } -// mockTransport := httpClient.Transport.(*mocks.MockTransport) - -// _, err = client.ListIndexes(context.Background()) -// require.NoError(ts.T(), err) -// require.NotNil(ts.T(), mockTransport.Req, "Expected request to be made") - -// apiKeyHeaderValue := mockTransport.Req.Header.Get("Api-Key") -// authHeaderValue := mockTransport.Req.Header.Get("Authorization") -// if authHeaderValue != "bearer fooo" { -// ts.FailNow(fmt.Sprintf("Expected request to have header value 'bearer fooo', but got '%s'", authHeaderValue)) -// } -// if apiKeyHeaderValue != "" { -// ts.FailNow(fmt.Sprintf("Expected request to not have Api-Key header, but got '%s'", apiKeyHeaderValue)) -// } -// } - func (ts *ClientTests) TestClientReadsApiKeyFromEnv() { os.Setenv("PINECONE_API_KEY", "test-env-api-key") diff --git a/pinecone/index_connection.go b/pinecone/index_connection.go index fb72391..b8cb27f 100644 --- a/pinecone/index_connection.go +++ b/pinecone/index_connection.go @@ -15,14 +15,12 @@ import ( type IndexConnection struct { Namespace string - apiKey string additionalMetadata map[string]string dataClient *data.VectorServiceClient grpcConn *grpc.ClientConn } type newIndexParameters struct { - apiKey string host string namespace string sourceTag string @@ -47,7 +45,7 @@ func newIndexConnection(in newIndexParameters) (*IndexConnection, error) { dataClient := data.NewVectorServiceClient(conn) - idx := IndexConnection{Namespace: in.namespace, apiKey: in.apiKey, dataClient: &dataClient, grpcConn: conn, additionalMetadata: in.additionalMetadata} + idx := IndexConnection{Namespace: in.namespace, dataClient: &dataClient, grpcConn: conn, additionalMetadata: in.additionalMetadata} return &idx, nil } @@ -370,7 +368,6 @@ func sparseValToGrpc(sv *SparseValues) *data.SparseValues { func (idx *IndexConnection) akCtx(ctx context.Context) context.Context { newMetadata := []string{} - newMetadata = append(newMetadata, "api-key", idx.apiKey) for key, value := range idx.additionalMetadata { newMetadata = append(newMetadata, key, value) diff --git a/pinecone/index_connection_test.go b/pinecone/index_connection_test.go index 1e5cb46..493bd54 100644 --- a/pinecone/index_connection_test.go +++ b/pinecone/index_connection_test.go @@ -5,11 +5,11 @@ import ( "encoding/json" "fmt" "os" - "reflect" "testing" "github.com/google/uuid" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" "google.golang.org/protobuf/types/known/structpb" ) @@ -71,16 +71,25 @@ func TestIndexConnection(t *testing.T) { func (ts *IndexConnectionTests) SetupSuite() { assert.NotEmptyf(ts.T(), ts.host, "HOST env variable not set") assert.NotEmptyf(ts.T(), ts.apiKey, "API_KEY env variable not set") + additionalMetadata := map[string]string{"api-key": ts.apiKey} namespace, err := uuid.NewV7() assert.NoError(ts.T(), err) - idxConn, err := newIndexConnection(newIndexParameters{apiKey: ts.apiKey, host: ts.host, namespace: namespace.String(), sourceTag: ""}) + idxConn, err := newIndexConnection(newIndexParameters{ + additionalMetadata: additionalMetadata, + host: ts.host, + namespace: namespace.String(), + sourceTag: ""}) assert.NoError(ts.T(), err) ts.idxConn = idxConn ts.sourceTag = "test_source_tag" - idxConnSourceTag, err := newIndexConnection(newIndexParameters{apiKey: ts.apiKey, host: ts.host, namespace: namespace.String(), sourceTag: ts.sourceTag}) + idxConnSourceTag, err := newIndexConnection(newIndexParameters{ + additionalMetadata: additionalMetadata, + host: ts.host, + namespace: namespace.String(), + sourceTag: ts.sourceTag}) assert.NoError(ts.T(), err) ts.idxConnSourceTag = idxConnSourceTag @@ -101,73 +110,40 @@ func (ts *IndexConnectionTests) TestNewIndexConnection() { apiKey := "test-api-key" namespace := "" sourceTag := "" - idxConn, err := newIndexConnection(newIndexParameters{apiKey: apiKey, host: ts.host, namespace: namespace, sourceTag: sourceTag}) - assert.NoError(ts.T(), err) - - if idxConn.apiKey != apiKey { - ts.FailNow(fmt.Sprintf("Expected idxConn to have apiKey '%s', but got '%s'", apiKey, idxConn.apiKey)) - } - if idxConn.Namespace != "" { - ts.FailNow(fmt.Sprintf("Expected idxConn to have empty namespace, but got '%s'", idxConn.Namespace)) - } - if idxConn.additionalMetadata != nil { - ts.FailNow(fmt.Sprintf("Expected idxConn additionalMetadata to be nil, but got '%+v'", idxConn.additionalMetadata)) - } - if idxConn.dataClient == nil { - ts.FailNow("Expected idxConn to have non-nil dataClient") - } - if idxConn.grpcConn == nil { - ts.FailNow("Expected idxConn to have non-nil grpcConn") - } - if idxConn.additionalMetadata != nil { - ts.FailNow("Expected idxConn to have nil additionalMetadata") - } + additionalMetadata := map[string]string{"api-key": apiKey} + idxConn, err := newIndexConnection(newIndexParameters{ + additionalMetadata: additionalMetadata, + host: ts.host, + namespace: namespace, + sourceTag: sourceTag}) + + require.NoError(ts.T(), err) + apiKeyHeader, ok := idxConn.additionalMetadata["api-key"] + require.True(ts.T(), ok, "Expected client to have an 'api-key' header") + require.Equal(ts.T(), apiKey, apiKeyHeader, "Expected 'api-key' header to equal %s", apiKey) + require.Empty(ts.T(), idxConn.Namespace, "Expected idxConn to have empty namespace, but got '%s'", idxConn.Namespace) + require.NotNil(ts.T(), idxConn.dataClient, "Expected idxConn to have non-nil dataClient") + require.NotNil(ts.T(), idxConn.grpcConn, "Expected idxConn to have non-nil grpcConn") } func (ts *IndexConnectionTests) TestNewIndexConnectionNamespace() { apiKey := "test-api-key" namespace := "test-namespace" sourceTag := "test-source-tag" - idxConn, err := newIndexConnection(newIndexParameters{apiKey: apiKey, host: ts.host, namespace: namespace, sourceTag: sourceTag}) - assert.NoError(ts.T(), err) - - if idxConn.apiKey != apiKey { - ts.FailNow(fmt.Sprintf("Expected idxConn to have apiKey '%s', but got '%s'", apiKey, idxConn.apiKey)) - } - if idxConn.Namespace != namespace { - ts.FailNow(fmt.Sprintf("Expected idxConn to have namespace '%s', but got '%s'", namespace, idxConn.Namespace)) - } - if idxConn.dataClient == nil { - ts.FailNow("Expected idxConn to have non-nil dataClient") - } - if idxConn.grpcConn == nil { - ts.FailNow("Expected idxConn to have non-nil grpcConn") - } -} - -func (ts *IndexConnectionTests) TestNewIndexConnectionAdditionalMetadata() { - apiKey := "test-api-key" - namespace := "test-namespace" - sourceTag := "test-source-tag" - additionalMetadata := map[string]string{"test-header": "test-value"} - idxConn, err := newIndexConnection(newIndexParameters{apiKey: apiKey, host: ts.host, namespace: namespace, sourceTag: sourceTag, additionalMetadata: additionalMetadata}) - assert.NoError(ts.T(), err) - - if idxConn.apiKey != apiKey { - ts.FailNow(fmt.Sprintf("Expected idxConn to have apiKey '%s', but got '%s'", apiKey, idxConn.apiKey)) - } - if idxConn.Namespace != namespace { - ts.FailNow(fmt.Sprintf("Expected idxConn to have namespace '%s', but got '%s'", namespace, idxConn.Namespace)) - } - if !reflect.DeepEqual(idxConn.additionalMetadata, additionalMetadata) { - ts.FailNow(fmt.Sprintf("Expected idxConn to have additionalMetadata '%+v', but got '%+v'", additionalMetadata, idxConn.additionalMetadata)) - } - if idxConn.dataClient == nil { - ts.FailNow("Expected idxConn to have non-nil dataClient") - } - if idxConn.grpcConn == nil { - ts.FailNow("Expected idxConn to have non-nil grpcConn") - } + additionalMetadata := map[string]string{"api-key": apiKey} + idxConn, err := newIndexConnection(newIndexParameters{ + additionalMetadata: additionalMetadata, + host: ts.host, + namespace: namespace, + sourceTag: sourceTag}) + + require.NoError(ts.T(), err) + apiKeyHeader, ok := idxConn.additionalMetadata["api-key"] + require.True(ts.T(), ok, "Expected client to have an 'api-key' header") + require.Equal(ts.T(), apiKey, apiKeyHeader, "Expected 'api-key' header to equal %s", apiKey) + require.Equal(ts.T(), namespace, idxConn.Namespace, "Expected idxConn to have namespace '%s', but got '%s'", namespace, idxConn.Namespace) + require.NotNil(ts.T(), idxConn.dataClient, "Expected idxConn to have non-nil dataClient") + require.NotNil(ts.T(), idxConn.grpcConn, "Expected idxConn to have non-nil grpcConn") } func (ts *IndexConnectionTests) TestFetchVectors() { From fc13ea13acd3717cd574df0df3da51774606b993 Mon Sep 17 00:00:00 2001 From: Austin DeNoble Date: Wed, 15 May 2024 09:36:44 -0400 Subject: [PATCH 07/11] add comments back to NewClientParams struct --- pinecone/client.go | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/pinecone/client.go b/pinecone/client.go index 2d98642..f8dd502 100644 --- a/pinecone/client.go +++ b/pinecone/client.go @@ -22,15 +22,15 @@ type Client struct { sourceTag string } -type NewClientBaseParams struct { - Headers map[string]string - Host string - RestClient *http.Client - SourceTag string +type NewClientParams struct { + ApiKey string // required + Headers map[string]string // optional + Host string // optional + RestClient *http.Client // optional + SourceTag string // optional } -type NewClientParams struct { - ApiKey string +type NewClientBaseParams struct { Headers map[string]string Host string RestClient *http.Client From db00a6d3c2f046245213667cf85b051d84995c03 Mon Sep 17 00:00:00 2001 From: Austin DeNoble Date: Wed, 15 May 2024 09:45:22 -0400 Subject: [PATCH 08/11] update INTEGREATION_PINECONE_API_KEY to TEST_ to match with the associated env values --- .env.example | 1 + .github/workflows/ci.yaml | 2 +- pinecone/index_connection_test.go | 2 +- 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/.env.example b/.env.example index de982a6..47297ec 100644 --- a/.env.example +++ b/.env.example @@ -1,3 +1,4 @@ PINECONE_API_KEY="" +TEST_PINECONE_API_KEY="" TEST_POD_INDEX_NAME="" TEST_SERVERLESS_INDEX_NAME="" diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 9fa3bf1..68db2d7 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -21,4 +21,4 @@ jobs: env: TEST_POD_INDEX_NAME: ${{ secrets.TEST_POD_INDEX_NAME }} TEST_SERVERLESS_INDEX_NAME: ${{ secrets.TEST_SERVERLESS_INDEX_NAME }} - INTEGRATION_PINECONE_API_KEY: ${{ secrets.API_KEY }} + TEST_PINECONE_API_KEY: ${{ secrets.API_KEY }} diff --git a/pinecone/index_connection_test.go b/pinecone/index_connection_test.go index 493bd54..9136926 100644 --- a/pinecone/index_connection_test.go +++ b/pinecone/index_connection_test.go @@ -28,7 +28,7 @@ type IndexConnectionTests struct { // Runs the test suite with `go test` func TestIndexConnection(t *testing.T) { - apiKey := os.Getenv("INTEGRATION_PINECONE_API_KEY") + apiKey := os.Getenv("TEST_PINECONE_API_KEY") assert.NotEmptyf(t, apiKey, "API_KEY env variable not set") client, err := NewClient(NewClientParams{ApiKey: apiKey}) From 99e6972d53cd3e9d43d64c799fe85b11e96defaf Mon Sep 17 00:00:00 2001 From: Austin DeNoble Date: Wed, 15 May 2024 09:47:29 -0400 Subject: [PATCH 09/11] update client_test --- pinecone/client_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pinecone/client_test.go b/pinecone/client_test.go index 81bf2a0..c08aec8 100644 --- a/pinecone/client_test.go +++ b/pinecone/client_test.go @@ -28,8 +28,8 @@ func TestClient(t *testing.T) { } func (ts *ClientTests) SetupSuite() { - apiKey := os.Getenv("INTEGRATION_PINECONE_API_KEY") - require.NotEmpty(ts.T(), apiKey, "INTEGRATION_PINECONE_API_KEY env variable not set") + apiKey := os.Getenv("TEST_PINECONE_API_KEY") + require.NotEmpty(ts.T(), apiKey, "TEST_PINECONE_API_KEY env variable not set") ts.podIndex = os.Getenv("TEST_POD_INDEX_NAME") require.NotEmpty(ts.T(), ts.podIndex, "TEST_POD_INDEX_NAME env variable not set") From 7dea9947206e2914156385ce3c95dbd2e4ebfd60 Mon Sep 17 00:00:00 2001 From: Austin DeNoble Date: Thu, 16 May 2024 02:57:41 -0400 Subject: [PATCH 10/11] more review feedback: rename URL scheme checker, remove extra API key env value, remove unnecessary test --- .env.example | 1 - .github/workflows/ci.yaml | 2 +- pinecone/client.go | 10 +++++----- pinecone/client_test.go | 31 ++++++++----------------------- pinecone/index_connection_test.go | 4 ++-- 5 files changed, 16 insertions(+), 32 deletions(-) diff --git a/.env.example b/.env.example index 47297ec..de982a6 100644 --- a/.env.example +++ b/.env.example @@ -1,4 +1,3 @@ PINECONE_API_KEY="" -TEST_PINECONE_API_KEY="" TEST_POD_INDEX_NAME="" TEST_SERVERLESS_INDEX_NAME="" diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 68db2d7..60c697e 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -21,4 +21,4 @@ jobs: env: TEST_POD_INDEX_NAME: ${{ secrets.TEST_POD_INDEX_NAME }} TEST_SERVERLESS_INDEX_NAME: ${{ secrets.TEST_SERVERLESS_INDEX_NAME }} - TEST_PINECONE_API_KEY: ${{ secrets.API_KEY }} + PINECONE_API_KEY: ${{ secrets.API_KEY }} diff --git a/pinecone/client.go b/pinecone/client.go index f8dd502..8b379b2 100644 --- a/pinecone/client.go +++ b/pinecone/client.go @@ -23,7 +23,7 @@ type Client struct { } type NewClientParams struct { - ApiKey string // required + ApiKey string // required - provide through NewClientParams or environment variable PINECONE_API_KEY Headers map[string]string // optional Host string // optional RestClient *http.Client // optional @@ -39,10 +39,10 @@ type NewClientBaseParams struct { func NewClient(in NewClientParams) (*Client, error) { osApiKey := os.Getenv("PINECONE_API_KEY") - hasApiKey := (in.ApiKey != "" || osApiKey != "") + hasApiKey := (valueOrFallback(in.ApiKey, osApiKey) != "") if !hasApiKey { - return nil, fmt.Errorf("no API key provided, please pass an API key for authorization") + return nil, fmt.Errorf("no API key provided, please pass an API key for authorization through NewClientParams or set the PINECONE_API_KEY environment variable") } apiKeyHeader := struct{ Key, Value string }{"Api-Key", valueOrFallback(in.ApiKey, osApiKey)} @@ -65,7 +65,7 @@ func NewClientBase(in NewClientBaseParams) (*Client, error) { controlHostOverride := valueOrFallback(in.Host, os.Getenv("PINECONE_CONTROLLER_HOST")) if controlHostOverride != "" { - controlHostOverride, err = ensureHTTPS(controlHostOverride) + controlHostOverride, err = ensureURLScheme(controlHostOverride) if err != nil { return nil, err } @@ -512,7 +512,7 @@ func decodeCollection(resBody io.ReadCloser) (*Collection, error) { return toCollection(&collectionModel), nil } -func ensureHTTPS(inputURL string) (string, error) { +func ensureURLScheme(inputURL string) (string, error) { parsedURL, err := url.Parse(inputURL) if err != nil { return "", fmt.Errorf("invalid URL: %v", err) diff --git a/pinecone/client_test.go b/pinecone/client_test.go index c08aec8..c5533d1 100644 --- a/pinecone/client_test.go +++ b/pinecone/client_test.go @@ -28,8 +28,8 @@ func TestClient(t *testing.T) { } func (ts *ClientTests) SetupSuite() { - apiKey := os.Getenv("TEST_PINECONE_API_KEY") - require.NotEmpty(ts.T(), apiKey, "TEST_PINECONE_API_KEY env variable not set") + apiKey := os.Getenv("PINECONE_API_KEY") + require.NotEmpty(ts.T(), apiKey, "PINECONE_API_KEY env variable not set") ts.podIndex = os.Getenv("TEST_POD_INDEX_NAME") require.NotEmpty(ts.T(), ts.podIndex, "TEST_POD_INDEX_NAME env variable not set") @@ -37,7 +37,7 @@ func (ts *ClientTests) SetupSuite() { ts.serverlessIndex = os.Getenv("TEST_SERVERLESS_INDEX_NAME") require.NotEmpty(ts.T(), ts.serverlessIndex, "TEST_SERVERLESS_INDEX_NAME env variable not set") - client, err := NewClient(NewClientParams{ApiKey: apiKey}) + client, err := NewClient(NewClientParams{}) require.NoError(ts.T(), err) ts.client = *client @@ -97,6 +97,9 @@ func (ts *ClientTests) TestNewClientParamsSetHeaders() { } func (ts *ClientTests) TestNewClientParamsNoApiKeyNoAuthorizationHeader() { + apiKey := os.Getenv("PINECONE_API_KEY") + os.Unsetenv("PINECONE_API_KEY") + client, err := NewClient(NewClientParams{}) require.NotNil(ts.T(), err, "Expected error when creating client without an API key or Authorization header") if !strings.Contains(err.Error(), "no API key provided, please pass an API key for authorization") { @@ -104,6 +107,8 @@ func (ts *ClientTests) TestNewClientParamsNoApiKeyNoAuthorizationHeader() { } require.Nil(ts.T(), client, "Expected client to be nil when creating client without an API key or Authorization header") + + os.Setenv("PINECONE_API_KEY", apiKey) } func (ts *ClientTests) TestHeadersAppliedToRequests() { @@ -170,26 +175,6 @@ func (ts *ClientTests) TestHeadersOverrideAdditionalHeaders() { os.Unsetenv("PINECONE_ADDITIONAL_HEADERS") } -func (ts *ClientTests) TestClientReadsApiKeyFromEnv() { - os.Setenv("PINECONE_API_KEY", "test-env-api-key") - - httpClient := mocks.CreateMockClient(`{"indexes": []}`) - client, err := NewClient(NewClientParams{RestClient: httpClient}) - if err != nil { - ts.FailNow(err.Error()) - } - mockTransport := httpClient.Transport.(*mocks.MockTransport) - - _, err = client.ListIndexes(context.Background()) - require.NoError(ts.T(), err) - require.NotNil(ts.T(), mockTransport.Req, "Expected request to be made") - - testHeaderValue := mockTransport.Req.Header.Get("Api-Key") - assert.Equal(ts.T(), "test-env-api-key", testHeaderValue, "Expected request to have header value 'test-env-api-key', but got '%s'", testHeaderValue) - - os.Unsetenv("PINECONE_API_KEY") -} - func (ts *ClientTests) TestControllerHostOverride() { apiKey := "test-api-key" httpClient := mocks.CreateMockClient(`{"indexes": []}`) diff --git a/pinecone/index_connection_test.go b/pinecone/index_connection_test.go index 9136926..0920564 100644 --- a/pinecone/index_connection_test.go +++ b/pinecone/index_connection_test.go @@ -28,8 +28,8 @@ type IndexConnectionTests struct { // Runs the test suite with `go test` func TestIndexConnection(t *testing.T) { - apiKey := os.Getenv("TEST_PINECONE_API_KEY") - assert.NotEmptyf(t, apiKey, "API_KEY env variable not set") + apiKey := os.Getenv("PINECONE_API_KEY") + assert.NotEmptyf(t, apiKey, "PINECONE_API_KEY env variable not set") client, err := NewClient(NewClientParams{ApiKey: apiKey}) if err != nil { From 2e394d842542c36be54f391349c3e1f28033b164 Mon Sep 17 00:00:00 2001 From: Austin DeNoble Date: Thu, 16 May 2024 03:01:35 -0400 Subject: [PATCH 11/11] move buildClientBaseOptions back down --- pinecone/client.go | 78 +++++++++++++++++++++++----------------------- 1 file changed, 39 insertions(+), 39 deletions(-) diff --git a/pinecone/client.go b/pinecone/client.go index 8b379b2..2461344 100644 --- a/pinecone/client.go +++ b/pinecone/client.go @@ -80,45 +80,6 @@ func NewClientBase(in NewClientBaseParams) (*Client, error) { return &c, nil } -func buildClientBaseOptions(in NewClientBaseParams) []control.ClientOption { - clientOptions := []control.ClientOption{} - - // build and apply user agent - userAgentProvider := provider.NewHeaderProvider("User-Agent", useragent.BuildUserAgent(in.SourceTag)) - clientOptions = append(clientOptions, control.WithRequestEditorFn(userAgentProvider.Intercept)) - - envAdditionalHeaders, hasEnvAdditionalHeaders := os.LookupEnv("PINECONE_ADDITIONAL_HEADERS") - additionalHeaders := make(map[string]string) - - // add headers from environment - if hasEnvAdditionalHeaders { - err := json.Unmarshal([]byte(envAdditionalHeaders), &additionalHeaders) - if err != nil { - log.Printf("failed to parse PINECONE_ADDITIONAL_HEADERS: %v", err) - } - } - - // merge headers from parameters if passed - if in.Headers != nil { - for key, value := range in.Headers { - additionalHeaders[key] = value - } - } - - // add headers to client options - for key, value := range additionalHeaders { - headerProvider := provider.NewHeaderProvider(key, value) - clientOptions = append(clientOptions, control.WithRequestEditorFn(headerProvider.Intercept)) - } - - // apply custom http client if provided - if in.RestClient != nil { - clientOptions = append(clientOptions, control.WithHTTPClient(in.RestClient)) - } - - return clientOptions -} - func (c *Client) Index(host string) (*IndexConnection, error) { return c.IndexWithAdditionalMetadata(host, "", nil) } @@ -512,6 +473,45 @@ func decodeCollection(resBody io.ReadCloser) (*Collection, error) { return toCollection(&collectionModel), nil } +func buildClientBaseOptions(in NewClientBaseParams) []control.ClientOption { + clientOptions := []control.ClientOption{} + + // build and apply user agent + userAgentProvider := provider.NewHeaderProvider("User-Agent", useragent.BuildUserAgent(in.SourceTag)) + clientOptions = append(clientOptions, control.WithRequestEditorFn(userAgentProvider.Intercept)) + + envAdditionalHeaders, hasEnvAdditionalHeaders := os.LookupEnv("PINECONE_ADDITIONAL_HEADERS") + additionalHeaders := make(map[string]string) + + // add headers from environment + if hasEnvAdditionalHeaders { + err := json.Unmarshal([]byte(envAdditionalHeaders), &additionalHeaders) + if err != nil { + log.Printf("failed to parse PINECONE_ADDITIONAL_HEADERS: %v", err) + } + } + + // merge headers from parameters if passed + if in.Headers != nil { + for key, value := range in.Headers { + additionalHeaders[key] = value + } + } + + // add headers to client options + for key, value := range additionalHeaders { + headerProvider := provider.NewHeaderProvider(key, value) + clientOptions = append(clientOptions, control.WithRequestEditorFn(headerProvider.Intercept)) + } + + // apply custom http client if provided + if in.RestClient != nil { + clientOptions = append(clientOptions, control.WithHTTPClient(in.RestClient)) + } + + return clientOptions +} + func ensureURLScheme(inputURL string) (string, error) { parsedURL, err := url.Parse(inputURL) if err != nil {