diff --git a/.env.example b/.env.example index 846feee..de982a6 100644 --- a/.env.example +++ b/.env.example @@ -1,3 +1,3 @@ -API_KEY="" +PINECONE_API_KEY="" TEST_POD_INDEX_NAME="" TEST_SERVERLESS_INDEX_NAME="" diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 6470966..60c697e 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -12,7 +12,7 @@ jobs: - name: Setup Go uses: actions/setup-go@v4 with: - go-version: '1.21.x' + go-version: "1.21.x" - name: Install dependencies run: | go get ./pinecone @@ -21,4 +21,4 @@ jobs: env: TEST_POD_INDEX_NAME: ${{ secrets.TEST_POD_INDEX_NAME }} TEST_SERVERLESS_INDEX_NAME: ${{ secrets.TEST_SERVERLESS_INDEX_NAME }} - API_KEY: ${{ secrets.API_KEY }} \ No newline at end of file + PINECONE_API_KEY: ${{ secrets.API_KEY }} diff --git a/pinecone/client.go b/pinecone/client.go index b34e017..2461344 100644 --- a/pinecone/client.go +++ b/pinecone/client.go @@ -5,41 +5,78 @@ import ( "encoding/json" "fmt" "io" + "log" "net/http" + "net/url" + "os" "strings" - "github.com/deepmap/oapi-codegen/v2/pkg/securityprovider" "github.com/pinecone-io/go-pinecone/internal/gen/control" "github.com/pinecone-io/go-pinecone/internal/provider" "github.com/pinecone-io/go-pinecone/internal/useragent" ) type Client struct { - apiKey string + headers map[string]string restClient *control.Client sourceTag string - headers map[string]string } type NewClientParams struct { - ApiKey string // required unless Authorization header provided - SourceTag string // optional + ApiKey string // required - provide through NewClientParams or environment variable PINECONE_API_KEY Headers map[string]string // optional + Host string // optional RestClient *http.Client // optional + SourceTag string // optional +} + +type NewClientBaseParams struct { + Headers map[string]string + Host string + RestClient *http.Client + SourceTag string } func NewClient(in NewClientParams) (*Client, error) { - clientOptions, err := buildClientOptions(in) - if err != nil { - return nil, err + osApiKey := os.Getenv("PINECONE_API_KEY") + hasApiKey := (valueOrFallback(in.ApiKey, osApiKey) != "") + + if !hasApiKey { + return nil, fmt.Errorf("no API key provided, please pass an API key for authorization through NewClientParams or set the PINECONE_API_KEY environment variable") + } + + apiKeyHeader := struct{ Key, Value string }{"Api-Key", valueOrFallback(in.ApiKey, osApiKey)} + + clientHeaders := in.Headers + if clientHeaders == nil { + clientHeaders = make(map[string]string) + clientHeaders[apiKeyHeader.Key] = apiKeyHeader.Value + + } else { + clientHeaders[apiKeyHeader.Key] = apiKeyHeader.Value } - client, err := control.NewClient("https://api.pinecone.io", clientOptions...) + return NewClientBase(NewClientBaseParams{Headers: clientHeaders, Host: in.Host, RestClient: in.RestClient, SourceTag: in.SourceTag}) +} + +func NewClientBase(in NewClientBaseParams) (*Client, error) { + clientOptions := buildClientBaseOptions(in) + var err error + + controlHostOverride := valueOrFallback(in.Host, os.Getenv("PINECONE_CONTROLLER_HOST")) + if controlHostOverride != "" { + controlHostOverride, err = ensureURLScheme(controlHostOverride) + if err != nil { + return nil, err + } + } + + client, err := control.NewClient(valueOrFallback(controlHostOverride, "https://api.pinecone.io"), clientOptions...) if err != nil { return nil, err } - c := Client{apiKey: in.ApiKey, restClient: client, sourceTag: in.SourceTag, headers: in.Headers} + c := Client{restClient: client, sourceTag: in.SourceTag, headers: in.Headers} return &c, nil } @@ -52,13 +89,42 @@ func (c *Client) IndexWithNamespace(host string, namespace string) (*IndexConnec } func (c *Client) IndexWithAdditionalMetadata(host string, namespace string, additionalMetadata map[string]string) (*IndexConnection, error) { - idx, err := newIndexConnection(newIndexParameters{apiKey: c.apiKey, host: host, namespace: namespace, sourceTag: c.sourceTag, additionalMetadata: additionalMetadata}) + authHeader := c.extractAuthHeader() + + // merge additionalMetadata with authHeader + if additionalMetadata != nil { + for _, key := range authHeader { + additionalMetadata[key] = authHeader[key] + } + } else { + additionalMetadata = authHeader + } + + idx, err := newIndexConnection(newIndexParameters{host: host, namespace: namespace, sourceTag: c.sourceTag, additionalMetadata: additionalMetadata}) if err != nil { return nil, err } return idx, nil } +func (c *Client) extractAuthHeader() map[string]string { + possibleAuthKeys := []string{ + "api-key", + "authorization", + "access_token", + } + + for key, value := range c.headers { + for _, checkKey := range possibleAuthKeys { + if strings.ToLower(key) == checkKey { + return map[string]string{key: value} + } + } + } + + return nil +} + func (c *Client) ListIndexes(ctx context.Context) ([]*Index, error) { res, err := c.restClient.ListIndexes(ctx) if err != nil { @@ -407,53 +473,76 @@ func decodeCollection(resBody io.ReadCloser) (*Collection, error) { return toCollection(&collectionModel), nil } -func minOne(x int32) int32 { - if x < 1 { - return 1 - } - return x -} - -func derefOrDefault[T any](ptr *T, defaultValue T) T { - if ptr == nil { - return defaultValue - } - return *ptr -} - -func buildClientOptions(in NewClientParams) ([]control.ClientOption, error) { +func buildClientBaseOptions(in NewClientBaseParams) []control.ClientOption { clientOptions := []control.ClientOption{} - hasAuthorizationHeader := false - hasApiKey := in.ApiKey != "" + // build and apply user agent userAgentProvider := provider.NewHeaderProvider("User-Agent", useragent.BuildUserAgent(in.SourceTag)) clientOptions = append(clientOptions, control.WithRequestEditorFn(userAgentProvider.Intercept)) - for key, value := range in.Headers { - headerProvider := provider.NewHeaderProvider(key, value) + envAdditionalHeaders, hasEnvAdditionalHeaders := os.LookupEnv("PINECONE_ADDITIONAL_HEADERS") + additionalHeaders := make(map[string]string) - if strings.Contains(strings.ToLower(key), "authorization") { - hasAuthorizationHeader = true + // add headers from environment + if hasEnvAdditionalHeaders { + err := json.Unmarshal([]byte(envAdditionalHeaders), &additionalHeaders) + if err != nil { + log.Printf("failed to parse PINECONE_ADDITIONAL_HEADERS: %v", err) } - - clientOptions = append(clientOptions, control.WithRequestEditorFn(headerProvider.Intercept)) } - if !hasAuthorizationHeader { - apiKeyProvider, err := securityprovider.NewSecurityProviderApiKey("header", "Api-Key", in.ApiKey) - if err != nil { - return nil, err + // merge headers from parameters if passed + if in.Headers != nil { + for key, value := range in.Headers { + additionalHeaders[key] = value } - clientOptions = append(clientOptions, control.WithRequestEditorFn(apiKeyProvider.Intercept)) } - if !hasAuthorizationHeader && !hasApiKey { - return nil, fmt.Errorf("no API key provided, please pass an API key for authorization") + // add headers to client options + for key, value := range additionalHeaders { + headerProvider := provider.NewHeaderProvider(key, value) + clientOptions = append(clientOptions, control.WithRequestEditorFn(headerProvider.Intercept)) } + // apply custom http client if provided if in.RestClient != nil { clientOptions = append(clientOptions, control.WithHTTPClient(in.RestClient)) } - return clientOptions, nil + return clientOptions +} + +func ensureURLScheme(inputURL string) (string, error) { + parsedURL, err := url.Parse(inputURL) + if err != nil { + return "", fmt.Errorf("invalid URL: %v", err) + } + + if parsedURL.Scheme == "" { + return "https://" + inputURL, nil + } + return inputURL, nil +} + +func valueOrFallback[T comparable](value, fallback T) T { + var zero T + if value != zero { + return value + } else { + return fallback + } +} + +func derefOrDefault[T any](ptr *T, defaultValue T) T { + if ptr == nil { + return defaultValue + } + return *ptr +} + +func minOne(x int32) int32 { + if x < 1 { + return 1 + } + return x } diff --git a/pinecone/client_test.go b/pinecone/client_test.go index ecb8bbb..c5533d1 100644 --- a/pinecone/client_test.go +++ b/pinecone/client_test.go @@ -4,12 +4,12 @@ import ( "context" "fmt" "os" - "reflect" "strings" "testing" "github.com/google/uuid" "github.com/pinecone-io/go-pinecone/internal/mocks" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" ) @@ -28,8 +28,8 @@ func TestClient(t *testing.T) { } func (ts *ClientTests) SetupSuite() { - apiKey := os.Getenv("API_KEY") - require.NotEmpty(ts.T(), apiKey, "API_KEY env variable not set") + apiKey := os.Getenv("PINECONE_API_KEY") + require.NotEmpty(ts.T(), apiKey, "PINECONE_API_KEY env variable not set") ts.podIndex = os.Getenv("TEST_POD_INDEX_NAME") require.NotEmpty(ts.T(), ts.podIndex, "TEST_POD_INDEX_NAME env variable not set") @@ -37,17 +37,14 @@ func (ts *ClientTests) SetupSuite() { ts.serverlessIndex = os.Getenv("TEST_SERVERLESS_INDEX_NAME") require.NotEmpty(ts.T(), ts.serverlessIndex, "TEST_SERVERLESS_INDEX_NAME env variable not set") - client, err := NewClient(NewClientParams{ApiKey: apiKey}) - if err != nil { - ts.FailNow(err.Error()) - } + client, err := NewClient(NewClientParams{}) + require.NoError(ts.T(), err) + ts.client = *client ts.sourceTag = "test_source_tag" clientSourceTag, err := NewClient(NewClientParams{ApiKey: apiKey, SourceTag: ts.sourceTag}) - if err != nil { - ts.FailNow(err.Error()) - } + require.NoError(ts.T(), err) ts.clientSourceTag = *clientSourceTag // this will clean up the project deleting all indexes and collections that are @@ -60,60 +57,49 @@ func (ts *ClientTests) SetupSuite() { func (ts *ClientTests) TestNewClientParamsSet() { apiKey := "test-api-key" client, err := NewClient(NewClientParams{ApiKey: apiKey}) - if err != nil { - ts.FailNow(err.Error()) - } - if client.apiKey != apiKey { - ts.FailNow(fmt.Sprintf("Expected client to have apiKey '%s', but got '%s'", apiKey, client.apiKey)) - } - if client.sourceTag != "" { - ts.FailNow(fmt.Sprintf("Expected client to have empty sourceTag, but got '%s'", client.sourceTag)) - } - if client.headers != nil { - ts.FailNow(fmt.Sprintf("Expected client headers to be nil, but got '%v'", client.headers)) - } - if len(client.restClient.RequestEditors) != 2 { - ts.FailNow("Expected client to have '%v' request editors, but got '%v'", 2, len(client.restClient.RequestEditors)) - } + + require.NoError(ts.T(), err) + require.Empty(ts.T(), client.sourceTag, "Expected client to have empty sourceTag") + require.NotNil(ts.T(), client.headers, "Expected client headers to not be nil") + apiKeyHeader, ok := client.headers["Api-Key"] + require.True(ts.T(), ok, "Expected client to have an 'Api-Key' header") + require.Equal(ts.T(), apiKey, apiKeyHeader, "Expected 'Api-Key' header to match provided ApiKey") + require.Equal(ts.T(), 2, len(client.restClient.RequestEditors), "Expected client to have correct number of require editors") } func (ts *ClientTests) TestNewClientParamsSetSourceTag() { apiKey := "test-api-key" sourceTag := "test-source-tag" - client, err := NewClient(NewClientParams{ApiKey: apiKey, SourceTag: sourceTag}) - if err != nil { - ts.FailNow(err.Error()) - } - if client.apiKey != apiKey { - ts.FailNow(fmt.Sprintf("Expected client to have apiKey '%s', but got '%s'", apiKey, client.apiKey)) - } - if client.sourceTag != sourceTag { - ts.FailNow(fmt.Sprintf("Expected client to have sourceTag '%s', but got '%s'", sourceTag, client.sourceTag)) - } - if len(client.restClient.RequestEditors) != 2 { - ts.FailNow("Expected client to have '%v' request editors, but got '%v'", 2, len(client.restClient.RequestEditors)) - } + client, err := NewClient(NewClientParams{ + ApiKey: apiKey, + SourceTag: sourceTag, + }) + + require.NoError(ts.T(), err) + apiKeyHeader, ok := client.headers["Api-Key"] + require.True(ts.T(), ok, "Expected client to have an 'Api-Key' header") + require.Equal(ts.T(), apiKey, apiKeyHeader, "Expected 'Api-Key' header to match provided ApiKey") + require.Equal(ts.T(), sourceTag, client.sourceTag, "Expected client to have sourceTag '%s', but got '%s'", sourceTag, client.sourceTag) + require.Equal(ts.T(), 2, len(client.restClient.RequestEditors), "Expected client to have %s request editors, but got %s", 2, len(client.restClient.RequestEditors)) } func (ts *ClientTests) TestNewClientParamsSetHeaders() { apiKey := "test-api-key" headers := map[string]string{"test-header": "test-value"} client, err := NewClient(NewClientParams{ApiKey: apiKey, Headers: headers}) - if err != nil { - ts.FailNow(err.Error()) - } - if client.apiKey != apiKey { - ts.FailNow(fmt.Sprintf("Expected client to have apiKey '%s', but got '%s'", apiKey, client.apiKey)) - } - if !reflect.DeepEqual(client.headers, headers) { - ts.FailNow(fmt.Sprintf("Expected client to have headers '%+v', but got '%+v'", headers, client.headers)) - } - if len(client.restClient.RequestEditors) != 3 { - ts.FailNow(fmt.Sprintf("Expected client to have '%v' request editors, but got '%v'", 3, len(client.restClient.RequestEditors))) - } + + require.NoError(ts.T(), err) + apiKeyHeader, ok := client.headers["Api-Key"] + require.True(ts.T(), ok, "Expected client to have an 'Api-Key' header") + require.Equal(ts.T(), apiKey, apiKeyHeader, "Expected 'Api-Key' header to match provided ApiKey") + require.Equal(ts.T(), client.headers, headers, "Expected client to have headers '%+v', but got '%+v'", headers, client.headers) + require.Equal(ts.T(), 3, len(client.restClient.RequestEditors), "Expected client to have %s request editors, but got %s", 3, len(client.restClient.RequestEditors)) } func (ts *ClientTests) TestNewClientParamsNoApiKeyNoAuthorizationHeader() { + apiKey := os.Getenv("PINECONE_API_KEY") + os.Unsetenv("PINECONE_API_KEY") + client, err := NewClient(NewClientParams{}) require.NotNil(ts.T(), err, "Expected error when creating client without an API key or Authorization header") if !strings.Contains(err.Error(), "no API key provided, please pass an API key for authorization") { @@ -121,6 +107,8 @@ func (ts *ClientTests) TestNewClientParamsNoApiKeyNoAuthorizationHeader() { } require.Nil(ts.T(), client, "Expected client to be nil when creating client without an API key or Authorization header") + + os.Setenv("PINECONE_API_KEY", apiKey) } func (ts *ClientTests) TestHeadersAppliedToRequests() { @@ -139,14 +127,36 @@ func (ts *ClientTests) TestHeadersAppliedToRequests() { require.NotNil(ts.T(), mockTransport.Req, "Expected request to be made") testHeaderValue := mockTransport.Req.Header.Get("test-header") - if testHeaderValue != "123456" { - ts.FailNow(fmt.Sprintf("Expected request to have header value '123456', but got '%s'", testHeaderValue)) + assert.Equal(ts.T(), "123456", testHeaderValue, "Expected request to have header value '123456', but got '%s'", testHeaderValue) +} + +func (ts *ClientTests) TestAdditionalHeadersAppliedToRequest() { + os.Setenv("PINECONE_ADDITIONAL_HEADERS", `{"test-header": "environment-header"}`) + + apiKey := "test-api-key" + + httpClient := mocks.CreateMockClient(`{"indexes": []}`) + client, err := NewClient(NewClientParams{ApiKey: apiKey, RestClient: httpClient}) + if err != nil { + ts.FailNow(err.Error()) } + mockTransport := httpClient.Transport.(*mocks.MockTransport) + + _, err = client.ListIndexes(context.Background()) + require.NoError(ts.T(), err) + require.NotNil(ts.T(), mockTransport.Req, "Expected request to be made") + + testHeaderValue := mockTransport.Req.Header.Get("test-header") + assert.Equal(ts.T(), "environment-header", testHeaderValue, "Expected request to have header value 'environment-header', but got '%s'", testHeaderValue) + + os.Unsetenv("PINECONE_ADDITIONAL_HEADERS") } -func (ts *ClientTests) TestAuthorizationHeaderOverridesApiKey() { +func (ts *ClientTests) TestHeadersOverrideAdditionalHeaders() { + os.Setenv("PINECONE_ADDITIONAL_HEADERS", `{"test-header": "environment-header"}`) + apiKey := "test-api-key" - headers := map[string]string{"Authorization": "bearer fooo"} + headers := map[string]string{"test-header": "param-header"} httpClient := mocks.CreateMockClient(`{"indexes": []}`) client, err := NewClient(NewClientParams{ApiKey: apiKey, Headers: headers, RestClient: httpClient}) @@ -159,13 +169,88 @@ func (ts *ClientTests) TestAuthorizationHeaderOverridesApiKey() { require.NoError(ts.T(), err) require.NotNil(ts.T(), mockTransport.Req, "Expected request to be made") - apiKeyHeaderValue := mockTransport.Req.Header.Get("Api-Key") - authHeaderValue := mockTransport.Req.Header.Get("Authorization") - if authHeaderValue != "bearer fooo" { - ts.FailNow(fmt.Sprintf("Expected request to have header value 'bearer fooo', but got '%s'", authHeaderValue)) + testHeaderValue := mockTransport.Req.Header.Get("test-header") + assert.Equal(ts.T(), "param-header", testHeaderValue, "Expected request to have header value 'param-header', but got '%s'", testHeaderValue) + + os.Unsetenv("PINECONE_ADDITIONAL_HEADERS") +} + +func (ts *ClientTests) TestControllerHostOverride() { + apiKey := "test-api-key" + httpClient := mocks.CreateMockClient(`{"indexes": []}`) + client, err := NewClient(NewClientParams{ApiKey: apiKey, Host: "https://test-controller-host.io", RestClient: httpClient}) + if err != nil { + ts.FailNow(err.Error()) + } + mockTransport := httpClient.Transport.(*mocks.MockTransport) + + _, err = client.ListIndexes(context.Background()) + require.NoError(ts.T(), err) + require.NotNil(ts.T(), mockTransport.Req, "Expected request to be made") + assert.Equal(ts.T(), "test-controller-host.io", mockTransport.Req.Host, "Expected request to be made to 'test-controller-host.io', but got '%s'", mockTransport.Req.URL.Host) +} + +func (ts *ClientTests) TestControllerHostOverrideFromEnv() { + os.Setenv("PINECONE_CONTROLLER_HOST", "https://env-controller-host.io") + + apiKey := "test-api-key" + httpClient := mocks.CreateMockClient(`{"indexes": []}`) + client, err := NewClient(NewClientParams{ApiKey: apiKey, RestClient: httpClient}) + if err != nil { + ts.FailNow(err.Error()) } - if apiKeyHeaderValue != "" { - ts.FailNow(fmt.Sprintf("Expected request to not have Api-Key header, but got '%s'", apiKeyHeaderValue)) + mockTransport := httpClient.Transport.(*mocks.MockTransport) + + _, err = client.ListIndexes(context.Background()) + require.NoError(ts.T(), err) + require.NotNil(ts.T(), mockTransport.Req, "Expected request to be made") + assert.Equal(ts.T(), "env-controller-host.io", mockTransport.Req.Host, "Expected request to be made to 'env-controller-host.io', but got '%s'", mockTransport.Req.URL.Host) + + os.Unsetenv("PINECONE_CONTROLLER_HOST") +} + +func (ts *ClientTests) TestControllerHostNormalization() { + tests := []struct { + name string + host string + wantHost string + wantScheme string + }{ + { + name: "Test with https prefix", + host: "https://pinecone-api.io", + wantHost: "pinecone-api.io", + wantScheme: "https", + }, { + name: "Test with http prefix", + host: "http://pinecone-api.io", + wantHost: "pinecone-api.io", + wantScheme: "http", + }, { + name: "Test without prefix", + host: "pinecone-api.io", + wantHost: "pinecone-api.io", + wantScheme: "https", + }, + } + + for _, tt := range tests { + ts.Run(tt.name, func() { + apiKey := "test-api-key" + httpClient := mocks.CreateMockClient(`{"indexes": []}`) + client, err := NewClient(NewClientParams{ApiKey: apiKey, Host: tt.host, RestClient: httpClient}) + if err != nil { + ts.FailNow(err.Error()) + } + mockTransport := httpClient.Transport.(*mocks.MockTransport) + + _, err = client.ListIndexes(context.Background()) + require.NoError(ts.T(), err) + require.NotNil(ts.T(), mockTransport.Req, "Expected request to be made") + + assert.Equal(ts.T(), tt.wantHost, mockTransport.Req.URL.Host, "Expected request to be made to host '%s', but got '%s'", tt.wantHost, mockTransport.Req.URL.Host) + assert.Equal(ts.T(), tt.wantScheme, mockTransport.Req.URL.Scheme, "Expected request to be made to host '%s, but got '%s'", tt.wantScheme, mockTransport.Req.URL.Host) + }) } } diff --git a/pinecone/index_connection.go b/pinecone/index_connection.go index 866d557..b8cb27f 100644 --- a/pinecone/index_connection.go +++ b/pinecone/index_connection.go @@ -15,14 +15,12 @@ import ( type IndexConnection struct { Namespace string - apiKey string additionalMetadata map[string]string dataClient *data.VectorServiceClient grpcConn *grpc.ClientConn } type newIndexParameters struct { - apiKey string host string namespace string sourceTag string @@ -47,7 +45,7 @@ func newIndexConnection(in newIndexParameters) (*IndexConnection, error) { dataClient := data.NewVectorServiceClient(conn) - idx := IndexConnection{Namespace: in.namespace, apiKey: in.apiKey, dataClient: &dataClient, grpcConn: conn, additionalMetadata: in.additionalMetadata} + idx := IndexConnection{Namespace: in.namespace, dataClient: &dataClient, grpcConn: conn, additionalMetadata: in.additionalMetadata} return &idx, nil } @@ -94,7 +92,6 @@ func (idx *IndexConnection) FetchVectors(ctx context.Context, ids []string) (*Fe for id, vector := range res.Vectors { vectors[id] = toVector(vector) } - fmt.Printf("VECTORS: %+v\n", vectors) return &FetchVectorsResponse{ Vectors: vectors, @@ -371,7 +368,6 @@ func sparseValToGrpc(sv *SparseValues) *data.SparseValues { func (idx *IndexConnection) akCtx(ctx context.Context) context.Context { newMetadata := []string{} - newMetadata = append(newMetadata, "api-key", idx.apiKey) for key, value := range idx.additionalMetadata { newMetadata = append(newMetadata, key, value) diff --git a/pinecone/index_connection_test.go b/pinecone/index_connection_test.go index 7589535..0920564 100644 --- a/pinecone/index_connection_test.go +++ b/pinecone/index_connection_test.go @@ -5,11 +5,11 @@ import ( "encoding/json" "fmt" "os" - "reflect" "testing" "github.com/google/uuid" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" "google.golang.org/protobuf/types/known/structpb" ) @@ -28,8 +28,8 @@ type IndexConnectionTests struct { // Runs the test suite with `go test` func TestIndexConnection(t *testing.T) { - apiKey := os.Getenv("API_KEY") - assert.NotEmptyf(t, apiKey, "API_KEY env variable not set") + apiKey := os.Getenv("PINECONE_API_KEY") + assert.NotEmptyf(t, apiKey, "PINECONE_API_KEY env variable not set") client, err := NewClient(NewClientParams{ApiKey: apiKey}) if err != nil { @@ -71,16 +71,25 @@ func TestIndexConnection(t *testing.T) { func (ts *IndexConnectionTests) SetupSuite() { assert.NotEmptyf(ts.T(), ts.host, "HOST env variable not set") assert.NotEmptyf(ts.T(), ts.apiKey, "API_KEY env variable not set") + additionalMetadata := map[string]string{"api-key": ts.apiKey} namespace, err := uuid.NewV7() assert.NoError(ts.T(), err) - idxConn, err := newIndexConnection(newIndexParameters{apiKey: ts.apiKey, host: ts.host, namespace: namespace.String(), sourceTag: ""}) + idxConn, err := newIndexConnection(newIndexParameters{ + additionalMetadata: additionalMetadata, + host: ts.host, + namespace: namespace.String(), + sourceTag: ""}) assert.NoError(ts.T(), err) ts.idxConn = idxConn ts.sourceTag = "test_source_tag" - idxConnSourceTag, err := newIndexConnection(newIndexParameters{apiKey: ts.apiKey, host: ts.host, namespace: namespace.String(), sourceTag: ts.sourceTag}) + idxConnSourceTag, err := newIndexConnection(newIndexParameters{ + additionalMetadata: additionalMetadata, + host: ts.host, + namespace: namespace.String(), + sourceTag: ts.sourceTag}) assert.NoError(ts.T(), err) ts.idxConnSourceTag = idxConnSourceTag @@ -101,73 +110,40 @@ func (ts *IndexConnectionTests) TestNewIndexConnection() { apiKey := "test-api-key" namespace := "" sourceTag := "" - idxConn, err := newIndexConnection(newIndexParameters{apiKey: apiKey, host: ts.host, namespace: namespace, sourceTag: sourceTag}) - assert.NoError(ts.T(), err) - - if idxConn.apiKey != apiKey { - ts.FailNow(fmt.Sprintf("Expected idxConn to have apiKey '%s', but got '%s'", apiKey, idxConn.apiKey)) - } - if idxConn.Namespace != "" { - ts.FailNow(fmt.Sprintf("Expected idxConn to have empty namespace, but got '%s'", idxConn.Namespace)) - } - if idxConn.additionalMetadata != nil { - ts.FailNow(fmt.Sprintf("Expected idxConn additionalMetadata to be nil, but got '%+v'", idxConn.additionalMetadata)) - } - if idxConn.dataClient == nil { - ts.FailNow("Expected idxConn to have non-nil dataClient") - } - if idxConn.grpcConn == nil { - ts.FailNow("Expected idxConn to have non-nil grpcConn") - } - if idxConn.additionalMetadata != nil { - ts.FailNow("Expected idxConn to have nil additionalMetadata") - } + additionalMetadata := map[string]string{"api-key": apiKey} + idxConn, err := newIndexConnection(newIndexParameters{ + additionalMetadata: additionalMetadata, + host: ts.host, + namespace: namespace, + sourceTag: sourceTag}) + + require.NoError(ts.T(), err) + apiKeyHeader, ok := idxConn.additionalMetadata["api-key"] + require.True(ts.T(), ok, "Expected client to have an 'api-key' header") + require.Equal(ts.T(), apiKey, apiKeyHeader, "Expected 'api-key' header to equal %s", apiKey) + require.Empty(ts.T(), idxConn.Namespace, "Expected idxConn to have empty namespace, but got '%s'", idxConn.Namespace) + require.NotNil(ts.T(), idxConn.dataClient, "Expected idxConn to have non-nil dataClient") + require.NotNil(ts.T(), idxConn.grpcConn, "Expected idxConn to have non-nil grpcConn") } func (ts *IndexConnectionTests) TestNewIndexConnectionNamespace() { apiKey := "test-api-key" namespace := "test-namespace" sourceTag := "test-source-tag" - idxConn, err := newIndexConnection(newIndexParameters{apiKey: apiKey, host: ts.host, namespace: namespace, sourceTag: sourceTag}) - assert.NoError(ts.T(), err) - - if idxConn.apiKey != apiKey { - ts.FailNow(fmt.Sprintf("Expected idxConn to have apiKey '%s', but got '%s'", apiKey, idxConn.apiKey)) - } - if idxConn.Namespace != namespace { - ts.FailNow(fmt.Sprintf("Expected idxConn to have namespace '%s', but got '%s'", namespace, idxConn.Namespace)) - } - if idxConn.dataClient == nil { - ts.FailNow("Expected idxConn to have non-nil dataClient") - } - if idxConn.grpcConn == nil { - ts.FailNow("Expected idxConn to have non-nil grpcConn") - } -} - -func (ts *IndexConnectionTests) TestNewIndexConnectionAdditionalMetadata() { - apiKey := "test-api-key" - namespace := "test-namespace" - sourceTag := "test-source-tag" - additionalMetadata := map[string]string{"test-header": "test-value"} - idxConn, err := newIndexConnection(newIndexParameters{apiKey: apiKey, host: ts.host, namespace: namespace, sourceTag: sourceTag, additionalMetadata: additionalMetadata}) - assert.NoError(ts.T(), err) - - if idxConn.apiKey != apiKey { - ts.FailNow(fmt.Sprintf("Expected idxConn to have apiKey '%s', but got '%s'", apiKey, idxConn.apiKey)) - } - if idxConn.Namespace != namespace { - ts.FailNow(fmt.Sprintf("Expected idxConn to have namespace '%s', but got '%s'", namespace, idxConn.Namespace)) - } - if !reflect.DeepEqual(idxConn.additionalMetadata, additionalMetadata) { - ts.FailNow(fmt.Sprintf("Expected idxConn to have additionalMetadata '%+v', but got '%+v'", additionalMetadata, idxConn.additionalMetadata)) - } - if idxConn.dataClient == nil { - ts.FailNow("Expected idxConn to have non-nil dataClient") - } - if idxConn.grpcConn == nil { - ts.FailNow("Expected idxConn to have non-nil grpcConn") - } + additionalMetadata := map[string]string{"api-key": apiKey} + idxConn, err := newIndexConnection(newIndexParameters{ + additionalMetadata: additionalMetadata, + host: ts.host, + namespace: namespace, + sourceTag: sourceTag}) + + require.NoError(ts.T(), err) + apiKeyHeader, ok := idxConn.additionalMetadata["api-key"] + require.True(ts.T(), ok, "Expected client to have an 'api-key' header") + require.Equal(ts.T(), apiKey, apiKeyHeader, "Expected 'api-key' header to equal %s", apiKey) + require.Equal(ts.T(), namespace, idxConn.Namespace, "Expected idxConn to have namespace '%s', but got '%s'", namespace, idxConn.Namespace) + require.NotNil(ts.T(), idxConn.dataClient, "Expected idxConn to have non-nil dataClient") + require.NotNil(ts.T(), idxConn.grpcConn, "Expected idxConn to have non-nil grpcConn") } func (ts *IndexConnectionTests) TestFetchVectors() {