From 6eb91a3660e431daccf39ec18eec6f05971cb5f6 Mon Sep 17 00:00:00 2001 From: Austin DeNoble Date: Wed, 8 May 2024 23:05:34 -0400 Subject: [PATCH] add ability to specify host override for control plane operations through NewClientParams or environment variables, add tests --- pinecone/client.go | 49 +++++++++++++++++----- pinecone/client_test.go | 79 ++++++++++++++++++++++++++++++++++++ pinecone/index_connection.go | 1 - 3 files changed, 117 insertions(+), 12 deletions(-) diff --git a/pinecone/client.go b/pinecone/client.go index 1667bdf..014a092 100644 --- a/pinecone/client.go +++ b/pinecone/client.go @@ -7,6 +7,7 @@ import ( "io" "log" "net/http" + "net/url" "os" "strings" @@ -18,16 +19,17 @@ import ( type Client struct { apiKey string + headers map[string]string restClient *control.Client sourceTag string - headers map[string]string } type NewClientParams struct { ApiKey string // required unless Authorization header provided - SourceTag string // optional Headers map[string]string // optional + Host string // optional RestClient *http.Client // optional + SourceTag string // optional } func NewClient(in NewClientParams) (*Client, error) { @@ -36,7 +38,15 @@ func NewClient(in NewClientParams) (*Client, error) { return nil, err } - client, err := control.NewClient("https://api.pinecone.io", clientOptions...) + controlHostOverride := valueOrFallback(in.Host, os.Getenv("PINECONE_CONTROLLER_HOST")) + if controlHostOverride != "" { + controlHostOverride, err = ensureHTTP(controlHostOverride) + if err != nil { + return nil, err + } + } + + client, err := control.NewClient(valueOrFallback(controlHostOverride, "https://api.pinecone.io"), clientOptions...) if err != nil { return nil, err } @@ -425,11 +435,12 @@ func derefOrDefault[T any](ptr *T, defaultValue T) T { func (ncp *NewClientParams) buildClientOptions() ([]control.ClientOption, error) { clientOptions := []control.ClientOption{} - osApiKey := os.Getenv("PINECONE_API_KEY") - envAdditionalHeaders, hasEnvAdditionalHeaders := os.LookupEnv("PINECONE_ADDITIONAL_HEADERS") hasAuthorizationHeader := false + osApiKey := os.Getenv("PINECONE_API_KEY") hasApiKey := (ncp.ApiKey != "" || osApiKey != "") + envAdditionalHeaders, hasEnvAdditionalHeaders := os.LookupEnv("PINECONE_ADDITIONAL_HEADERS") + // build and apply user agent userAgentProvider := provider.NewHeaderProvider("User-Agent", useragent.BuildUserAgent(ncp.SourceTag)) clientOptions = append(clientOptions, control.WithRequestEditorFn(userAgentProvider.Intercept)) @@ -465,12 +476,7 @@ func (ncp *NewClientParams) buildClientOptions() ([]control.ClientOption, error) // if apiKey is provided and no auth header is set, add the apiKey as a header // apiKey from parameters takes precedence over apiKey from environment if hasApiKey && !hasAuthorizationHeader { - var appliedApiKey string - if ncp.ApiKey != "" { - appliedApiKey = ncp.ApiKey - } else { - appliedApiKey = osApiKey - } + appliedApiKey := valueOrFallback(ncp.ApiKey, osApiKey) apiKeyProvider, err := securityprovider.NewSecurityProviderApiKey("header", "Api-Key", appliedApiKey) if err != nil { @@ -489,3 +495,24 @@ func (ncp *NewClientParams) buildClientOptions() ([]control.ClientOption, error) return clientOptions, nil } + +func ensureHTTP(inputURL string) (string, error) { + parsedURL, err := url.Parse(inputURL) + if err != nil { + return "", fmt.Errorf("invalid URL: %v", err) + } + + if parsedURL.Scheme == "" { + return "https://" + inputURL, nil + } + return inputURL, nil +} + +func valueOrFallback[T comparable](value, fallback T) T { + var zero T + if value != zero { + return value + } else { + return fallback + } +} diff --git a/pinecone/client_test.go b/pinecone/client_test.go index a2acd66..c4ce200 100644 --- a/pinecone/client_test.go +++ b/pinecone/client_test.go @@ -233,6 +233,85 @@ func (ts *ClientTests) TestClientReadsApiKeyFromEnv() { os.Unsetenv("PINECONE_API_KEY") } +func (ts *ClientTests) TestControllerHostOverride() { + apiKey := "test-api-key" + httpClient := mocks.CreateMockClient(`{"indexes": []}`) + client, err := NewClient(NewClientParams{ApiKey: apiKey, Host: "https://test-controller-host.io", RestClient: httpClient}) + if err != nil { + ts.FailNow(err.Error()) + } + mockTransport := httpClient.Transport.(*mocks.MockTransport) + + _, err = client.ListIndexes(context.Background()) + require.NoError(ts.T(), err) + require.NotNil(ts.T(), mockTransport.Req, "Expected request to be made") + assert.Equal(ts.T(), "test-controller-host.io", mockTransport.Req.Host, "Expected request to be made to 'test-controller-host.io', but got '%s'", mockTransport.Req.URL.Host) +} + +func (ts *ClientTests) TestControllerHostOverrideFromEnv() { + os.Setenv("PINECONE_CONTROLLER_HOST", "https://env-controller-host.io") + + apiKey := "test-api-key" + httpClient := mocks.CreateMockClient(`{"indexes": []}`) + client, err := NewClient(NewClientParams{ApiKey: apiKey, RestClient: httpClient}) + if err != nil { + ts.FailNow(err.Error()) + } + mockTransport := httpClient.Transport.(*mocks.MockTransport) + + _, err = client.ListIndexes(context.Background()) + require.NoError(ts.T(), err) + require.NotNil(ts.T(), mockTransport.Req, "Expected request to be made") + assert.Equal(ts.T(), "env-controller-host.io", mockTransport.Req.Host, "Expected request to be made to 'env-controller-host.io', but got '%s'", mockTransport.Req.URL.Host) + + os.Unsetenv("PINECONE_CONTROLLER_HOST") +} + +func (ts *ClientTests) TestControllerHostNormalization() { + tests := []struct { + name string + host string + wantHost string + wantScheme string + }{ + { + name: "Test with https prefix", + host: "https://pinecone-api.io", + wantHost: "pinecone-api.io", + wantScheme: "https", + }, { + name: "Test with http prefix", + host: "http://pinecone-api.io", + wantHost: "pinecone-api.io", + wantScheme: "http", + }, { + name: "Test without prefix", + host: "pinecone-api.io", + wantHost: "pinecone-api.io", + wantScheme: "https", + }, + } + + for _, tt := range tests { + ts.Run(tt.name, func() { + apiKey := "test-api-key" + httpClient := mocks.CreateMockClient(`{"indexes": []}`) + client, err := NewClient(NewClientParams{ApiKey: apiKey, Host: tt.host, RestClient: httpClient}) + if err != nil { + ts.FailNow(err.Error()) + } + mockTransport := httpClient.Transport.(*mocks.MockTransport) + + _, err = client.ListIndexes(context.Background()) + require.NoError(ts.T(), err) + require.NotNil(ts.T(), mockTransport.Req, "Expected request to be made") + + assert.Equal(ts.T(), tt.wantHost, mockTransport.Req.URL.Host, "Expected request to be made to host '%s', but got '%s'", tt.wantHost, mockTransport.Req.URL.Host) + assert.Equal(ts.T(), tt.wantScheme, mockTransport.Req.URL.Scheme, "Expected request to be made to host '%s, but got '%s'", tt.wantScheme, mockTransport.Req.URL.Host) + }) + } +} + func (ts *ClientTests) TestListIndexes() { indexes, err := ts.client.ListIndexes(context.Background()) require.NoError(ts.T(), err) diff --git a/pinecone/index_connection.go b/pinecone/index_connection.go index 866d557..fb72391 100644 --- a/pinecone/index_connection.go +++ b/pinecone/index_connection.go @@ -94,7 +94,6 @@ func (idx *IndexConnection) FetchVectors(ctx context.Context, ids []string) (*Fe for id, vector := range res.Vectors { vectors[id] = toVector(vector) } - fmt.Printf("VECTORS: %+v\n", vectors) return &FetchVectorsResponse{ Vectors: vectors,