Skip to content

Commit

Permalink
add ability to specify host override for control plane operations thr…
Browse files Browse the repository at this point in the history
…ough NewClientParams or environment variables, add tests
  • Loading branch information
austin-denoble committed May 9, 2024
1 parent 995829d commit 6eb91a3
Show file tree
Hide file tree
Showing 3 changed files with 117 additions and 12 deletions.
49 changes: 38 additions & 11 deletions pinecone/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"io"
"log"
"net/http"
"net/url"
"os"
"strings"

Expand All @@ -18,16 +19,17 @@ import (

type Client struct {
apiKey string
headers map[string]string
restClient *control.Client
sourceTag string
headers map[string]string
}

type NewClientParams struct {
ApiKey string // required unless Authorization header provided
SourceTag string // optional
Headers map[string]string // optional
Host string // optional
RestClient *http.Client // optional
SourceTag string // optional
}

func NewClient(in NewClientParams) (*Client, error) {
Expand All @@ -36,7 +38,15 @@ func NewClient(in NewClientParams) (*Client, error) {
return nil, err
}

client, err := control.NewClient("https://api.pinecone.io", clientOptions...)
controlHostOverride := valueOrFallback(in.Host, os.Getenv("PINECONE_CONTROLLER_HOST"))
if controlHostOverride != "" {
controlHostOverride, err = ensureHTTP(controlHostOverride)
if err != nil {
return nil, err
}
}

client, err := control.NewClient(valueOrFallback(controlHostOverride, "https://api.pinecone.io"), clientOptions...)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -425,11 +435,12 @@ func derefOrDefault[T any](ptr *T, defaultValue T) T {

func (ncp *NewClientParams) buildClientOptions() ([]control.ClientOption, error) {
clientOptions := []control.ClientOption{}
osApiKey := os.Getenv("PINECONE_API_KEY")
envAdditionalHeaders, hasEnvAdditionalHeaders := os.LookupEnv("PINECONE_ADDITIONAL_HEADERS")
hasAuthorizationHeader := false
osApiKey := os.Getenv("PINECONE_API_KEY")
hasApiKey := (ncp.ApiKey != "" || osApiKey != "")
envAdditionalHeaders, hasEnvAdditionalHeaders := os.LookupEnv("PINECONE_ADDITIONAL_HEADERS")

// build and apply user agent
userAgentProvider := provider.NewHeaderProvider("User-Agent", useragent.BuildUserAgent(ncp.SourceTag))
clientOptions = append(clientOptions, control.WithRequestEditorFn(userAgentProvider.Intercept))

Expand Down Expand Up @@ -465,12 +476,7 @@ func (ncp *NewClientParams) buildClientOptions() ([]control.ClientOption, error)
// if apiKey is provided and no auth header is set, add the apiKey as a header
// apiKey from parameters takes precedence over apiKey from environment
if hasApiKey && !hasAuthorizationHeader {
var appliedApiKey string
if ncp.ApiKey != "" {
appliedApiKey = ncp.ApiKey
} else {
appliedApiKey = osApiKey
}
appliedApiKey := valueOrFallback(ncp.ApiKey, osApiKey)

apiKeyProvider, err := securityprovider.NewSecurityProviderApiKey("header", "Api-Key", appliedApiKey)
if err != nil {
Expand All @@ -489,3 +495,24 @@ func (ncp *NewClientParams) buildClientOptions() ([]control.ClientOption, error)

return clientOptions, nil
}

func ensureHTTP(inputURL string) (string, error) {
parsedURL, err := url.Parse(inputURL)
if err != nil {
return "", fmt.Errorf("invalid URL: %v", err)
}

if parsedURL.Scheme == "" {
return "https://" + inputURL, nil
}
return inputURL, nil
}

func valueOrFallback[T comparable](value, fallback T) T {
var zero T
if value != zero {
return value
} else {
return fallback
}
}
79 changes: 79 additions & 0 deletions pinecone/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,85 @@ func (ts *ClientTests) TestClientReadsApiKeyFromEnv() {
os.Unsetenv("PINECONE_API_KEY")
}

func (ts *ClientTests) TestControllerHostOverride() {
apiKey := "test-api-key"
httpClient := mocks.CreateMockClient(`{"indexes": []}`)
client, err := NewClient(NewClientParams{ApiKey: apiKey, Host: "https://test-controller-host.io", RestClient: httpClient})
if err != nil {
ts.FailNow(err.Error())
}
mockTransport := httpClient.Transport.(*mocks.MockTransport)

_, err = client.ListIndexes(context.Background())
require.NoError(ts.T(), err)
require.NotNil(ts.T(), mockTransport.Req, "Expected request to be made")
assert.Equal(ts.T(), "test-controller-host.io", mockTransport.Req.Host, "Expected request to be made to 'test-controller-host.io', but got '%s'", mockTransport.Req.URL.Host)
}

func (ts *ClientTests) TestControllerHostOverrideFromEnv() {
os.Setenv("PINECONE_CONTROLLER_HOST", "https://env-controller-host.io")

apiKey := "test-api-key"
httpClient := mocks.CreateMockClient(`{"indexes": []}`)
client, err := NewClient(NewClientParams{ApiKey: apiKey, RestClient: httpClient})
if err != nil {
ts.FailNow(err.Error())
}
mockTransport := httpClient.Transport.(*mocks.MockTransport)

_, err = client.ListIndexes(context.Background())
require.NoError(ts.T(), err)
require.NotNil(ts.T(), mockTransport.Req, "Expected request to be made")
assert.Equal(ts.T(), "env-controller-host.io", mockTransport.Req.Host, "Expected request to be made to 'env-controller-host.io', but got '%s'", mockTransport.Req.URL.Host)

os.Unsetenv("PINECONE_CONTROLLER_HOST")
}

func (ts *ClientTests) TestControllerHostNormalization() {
tests := []struct {
name string
host string
wantHost string
wantScheme string
}{
{
name: "Test with https prefix",
host: "https://pinecone-api.io",
wantHost: "pinecone-api.io",
wantScheme: "https",
}, {
name: "Test with http prefix",
host: "http://pinecone-api.io",
wantHost: "pinecone-api.io",
wantScheme: "http",
}, {
name: "Test without prefix",
host: "pinecone-api.io",
wantHost: "pinecone-api.io",
wantScheme: "https",
},
}

for _, tt := range tests {
ts.Run(tt.name, func() {
apiKey := "test-api-key"
httpClient := mocks.CreateMockClient(`{"indexes": []}`)
client, err := NewClient(NewClientParams{ApiKey: apiKey, Host: tt.host, RestClient: httpClient})
if err != nil {
ts.FailNow(err.Error())
}
mockTransport := httpClient.Transport.(*mocks.MockTransport)

_, err = client.ListIndexes(context.Background())
require.NoError(ts.T(), err)
require.NotNil(ts.T(), mockTransport.Req, "Expected request to be made")

assert.Equal(ts.T(), tt.wantHost, mockTransport.Req.URL.Host, "Expected request to be made to host '%s', but got '%s'", tt.wantHost, mockTransport.Req.URL.Host)
assert.Equal(ts.T(), tt.wantScheme, mockTransport.Req.URL.Scheme, "Expected request to be made to host '%s, but got '%s'", tt.wantScheme, mockTransport.Req.URL.Host)
})
}
}

func (ts *ClientTests) TestListIndexes() {
indexes, err := ts.client.ListIndexes(context.Background())
require.NoError(ts.T(), err)
Expand Down
1 change: 0 additions & 1 deletion pinecone/index_connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,6 @@ func (idx *IndexConnection) FetchVectors(ctx context.Context, ids []string) (*Fe
for id, vector := range res.Vectors {
vectors[id] = toVector(vector)
}
fmt.Printf("VECTORS: %+v\n", vectors)

return &FetchVectorsResponse{
Vectors: vectors,
Expand Down

0 comments on commit 6eb91a3

Please sign in to comment.