Skip to content

Commit

Permalink
Optional configs for setting custom headers / metadata for REST / gRP…
Browse files Browse the repository at this point in the history
…C operations (#18)

## Problem
Currently, the Go SDK does not support specifying custom headers or
metadata for control & data plane operations (REST and gRPC). This is a
useful feature to aid in debugging or tracking a specific request, and
we'd like to enable this functionality in the Go SDK.

## Solution
- Update `NewClientParams` and `Client` structs to support `headers`,
also allow passing a custom `RestClient` to the new client. This is
primarily for allowing mocking in unit tests, but we have a similar
approach in other clients allowing users to customize the HTTP module.
- Add new `buildClientOptions` function.
- Update `NewClient` to support appending `Headers` and `RestClient` to
`clientOptions`.
- Add error message for empty `ApiKey` without any `Authorization`
header provided. If the user passes both, avoid applying the `Api-Key`
header in favor of the `Authorization` header.
- Add new `Client.IndexWithAdditionalMetadata` constructor.
- Update `IndexConnection` struct to include `additionalMetadata`, and
update `newIndexConnection` to accept `additionalMetadata` as an
argument.
- Update `IndexConnection.akCtx` to handle appending the `api-key` and
any `additionalMetadata` to requests.
- Update unit tests, add new `mocks` package and `mock_transport` to
facilitate validating requests made via `Client` including.

I spent some time trying to figure out how to mock / test
`IndexConnection` but I think it'll be a bit trickier than `Client`, so
somewhat saving for a future PR atm.

## Type of Change
- [ ] Bug fix (non-breaking change which fixes an issue)
- [X] New feature (non-breaking change which adds functionality)
- [ ] Breaking change (fix or feature that would cause existing
functionality to not work as expected)
- [ ] This change requires a documentation update
- [ ] Infrastructure change (CI configs, etc)
- [ ] Non-code change (docs, etc)
- [ ] None of the above: (explain here)

## Test Plan
`just test` to run the test suite locally. Validate tests pass in CI.
  • Loading branch information
austin-denoble authored May 6, 2024
1 parent 32391bc commit 5a9897e
Show file tree
Hide file tree
Showing 7 changed files with 237 additions and 27 deletions.
30 changes: 30 additions & 0 deletions internal/mocks/mock_transport.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
package mocks

import (
"bytes"
"io"
"net/http"
)

type MockTransport struct {
Req *http.Request
Resp *http.Response
Err error
}

func (m *MockTransport) RoundTrip(req *http.Request) (*http.Response, error) {
m.Req = req
return m.Resp, m.Err
}

func CreateMockClient(jsonBody string) *http.Client {
return &http.Client {
Transport: &MockTransport{
Resp: &http.Response{
StatusCode: 200,
Body: io.NopCloser(bytes.NewReader([]byte(jsonBody))),
Header: make(http.Header),
},
},
}
}
2 changes: 1 addition & 1 deletion justfile
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ test:
set -o allexport
source .env
set +o allexport
go test -count=1 ./pinecone
go test -count=1 -v ./pinecone
bootstrap:
go install google.golang.org/protobuf/cmd/protoc-gen-go@latest
go install google.golang.org/grpc/cmd/protoc-gen-go-grpc@latest
Expand Down
69 changes: 55 additions & 14 deletions pinecone/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,51 +4,55 @@ import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"

"github.com/deepmap/oapi-codegen/v2/pkg/securityprovider"
"github.com/pinecone-io/go-pinecone/internal/gen/control"
"github.com/pinecone-io/go-pinecone/internal/provider"
"github.com/pinecone-io/go-pinecone/internal/useragent"
"io"
"net/http"
)

type Client struct {
apiKey string
restClient *control.Client
sourceTag string
headers map[string]string
}

type NewClientParams struct {
ApiKey string
SourceTag string // optional
ApiKey string // required unless Authorization header provided
SourceTag string // optional
Headers map[string]string // optional
RestClient *http.Client // optional
}

func NewClient(in NewClientParams) (*Client, error) {
apiKeyProvider, err := securityprovider.NewSecurityProviderApiKey("header", "Api-Key", in.ApiKey)
clientOptions, err := buildClientOptions(in)
if err != nil {
return nil, err
}

userAgentProvider := provider.NewHeaderProvider("User-Agent", useragent.BuildUserAgent(in.SourceTag))

client, err := control.NewClient("https://api.pinecone.io",
control.WithRequestEditorFn(apiKeyProvider.Intercept),
control.WithRequestEditorFn(userAgentProvider.Intercept),
)
client, err := control.NewClient("https://api.pinecone.io", clientOptions...)
if err != nil {
return nil, err
}

c := Client{apiKey: in.ApiKey, restClient: client, sourceTag: in.SourceTag}
c := Client{apiKey: in.ApiKey, restClient: client, sourceTag: in.SourceTag, headers: in.Headers}
return &c, nil
}

func (c *Client) Index(host string) (*IndexConnection, error) {
return c.IndexWithNamespace(host, "")
return c.IndexWithAdditionalMetadata(host, "", nil)
}

func (c *Client) IndexWithNamespace(host string, namespace string) (*IndexConnection, error) {
idx, err := newIndexConnection(c.apiKey, host, namespace, c.sourceTag)
return c.IndexWithAdditionalMetadata(host, namespace, nil)
}

func (c *Client) IndexWithAdditionalMetadata(host string, namespace string, additionalMetadata map[string]string) (*IndexConnection, error) {
idx, err := newIndexConnection(newIndexParameters{apiKey: c.apiKey, host: host, namespace: namespace, sourceTag: c.sourceTag, additionalMetadata: additionalMetadata})
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -408,3 +412,40 @@ func minOne(x int32) int32 {
}
return x
}

func buildClientOptions(in NewClientParams) ([]control.ClientOption, error) {
clientOptions := []control.ClientOption{}
hasAuthorizationHeader := false
hasApiKey := in.ApiKey != ""

userAgentProvider := provider.NewHeaderProvider("User-Agent", useragent.BuildUserAgent(in.SourceTag))
clientOptions = append(clientOptions, control.WithRequestEditorFn(userAgentProvider.Intercept))

for key, value := range in.Headers {
headerProvider := provider.NewHeaderProvider(key, value)

if strings.Contains(strings.ToLower(key), "authorization") {
hasAuthorizationHeader = true
}

clientOptions = append(clientOptions, control.WithRequestEditorFn(headerProvider.Intercept))
}

if !hasAuthorizationHeader {
apiKeyProvider, err := securityprovider.NewSecurityProviderApiKey("header", "Api-Key", in.ApiKey)
if err != nil {
return nil, err
}
clientOptions = append(clientOptions, control.WithRequestEditorFn(apiKeyProvider.Intercept))
}

if !hasAuthorizationHeader && !hasApiKey {
return nil, fmt.Errorf("no API key provided, please pass an API key for authorization")
}

if in.RestClient != nil {
clientOptions = append(clientOptions, control.WithHTTPClient(in.RestClient))
}

return clientOptions, nil
}
84 changes: 82 additions & 2 deletions pinecone/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,12 @@ import (
"context"
"fmt"
"os"
"reflect"
"strings"
"testing"

"github.com/google/uuid"
"github.com/pinecone-io/go-pinecone/internal/mocks"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
Expand Down Expand Up @@ -66,8 +69,11 @@ func (ts *ClientTests) TestNewClientParamsSet() {
if client.sourceTag != "" {
ts.FailNow(fmt.Sprintf("Expected client to have empty sourceTag, but got '%s'", client.sourceTag))
}
if client.headers != nil {
ts.FailNow(fmt.Sprintf("Expected client headers to be nil, but got '%v'", client.headers))
}
if len(client.restClient.RequestEditors) != 2 {
ts.FailNow("Expected 2 request editors on client")
ts.FailNow("Expected client to have '%v' request editors, but got '%v'", 2, len(client.restClient.RequestEditors))
}
}

Expand All @@ -85,7 +91,81 @@ func (ts *ClientTests) TestNewClientParamsSetSourceTag() {
ts.FailNow(fmt.Sprintf("Expected client to have sourceTag '%s', but got '%s'", sourceTag, client.sourceTag))
}
if len(client.restClient.RequestEditors) != 2 {
ts.FailNow("Expected 2 request editors on client")
ts.FailNow("Expected client to have '%v' request editors, but got '%v'", 2, len(client.restClient.RequestEditors))
}
}

func (ts *ClientTests) TestNewClientParamsSetHeaders() {
apiKey := "test-api-key"
headers := map[string]string{"test-header": "test-value"}
client, err := NewClient(NewClientParams{ApiKey: apiKey, Headers: headers})
if err != nil {
ts.FailNow(err.Error())
}
if client.apiKey != apiKey {
ts.FailNow(fmt.Sprintf("Expected client to have apiKey '%s', but got '%s'", apiKey, client.apiKey))
}
if !reflect.DeepEqual(client.headers, headers) {
ts.FailNow(fmt.Sprintf("Expected client to have headers '%+v', but got '%+v'", headers, client.headers))
}
if len(client.restClient.RequestEditors) != 3 {
ts.FailNow(fmt.Sprintf("Expected client to have '%v' request editors, but got '%v'", 3, len(client.restClient.RequestEditors)))
}
}

func (ts *ClientTests) TestNewClientParamsNoApiKeyNoAuthorizationHeader() {
client, err := NewClient(NewClientParams{})
require.NotNil(ts.T(), err, "Expected error when creating client without an API key or Authorization header")
if !strings.Contains(err.Error(), "no API key provided, please pass an API key for authorization") {
ts.FailNow(fmt.Sprintf("Expected error to contain 'no API key provided, please pass an API key for authorization', but got '%s'", err.Error()))
}

require.Nil(ts.T(), client, "Expected client to be nil when creating client without an API key or Authorization header")
}

func (ts *ClientTests) TestHeadersAppliedToRequests() {
apiKey := "test-api-key"
headers := map[string]string{"test-header": "123456"}

httpClient := mocks.CreateMockClient(`{"indexes": []}`)
client, err := NewClient(NewClientParams{ApiKey: apiKey, Headers: headers, RestClient: httpClient})
if err != nil {
ts.FailNow(err.Error())
}
mockTransport := httpClient.Transport.(*mocks.MockTransport)

_, err = client.ListIndexes(context.Background())
require.NoError(ts.T(), err)
require.NotNil(ts.T(), mockTransport.Req, "Expected request to be made")

testHeaderValue := mockTransport.Req.Header.Get("test-header")
if testHeaderValue != "123456" {
ts.FailNow(fmt.Sprintf("Expected request to have header value '123456', but got '%s'", testHeaderValue))
}
}

func (ts *ClientTests) TestAuthorizationHeaderOverridesApiKey() {
apiKey := "test-api-key"
headers := map[string]string{"Authorization": "bearer fooo"}

httpClient := mocks.CreateMockClient(`{"indexes": []}`)
client, err := NewClient(NewClientParams{ApiKey: apiKey, Headers: headers, RestClient: httpClient})
if err != nil {
ts.FailNow(err.Error())
}
mockTransport := httpClient.Transport.(*mocks.MockTransport)

_, err = client.ListIndexes(context.Background())
require.NoError(ts.T(), err)
require.NotNil(ts.T(), mockTransport.Req, "Expected request to be made")

apiKeyHeaderValue := mockTransport.Req.Header.Get("Api-Key")
authHeaderValue := mockTransport.Req.Header.Get("Authorization")
if authHeaderValue != "bearer fooo" {
ts.FailNow(fmt.Sprintf("Expected request to have header value 'bearer fooo', but got '%s'", authHeaderValue))
}
if apiKeyHeaderValue != "" {
ts.FailNow(fmt.Sprintf("Expected request to not have Api-Key header, but got '%s'", apiKeyHeaderValue))
}
}

Expand Down
26 changes: 21 additions & 5 deletions pinecone/index_connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,28 @@ import (
type IndexConnection struct {
Namespace string
apiKey string
additionalMetadata map[string]string
dataClient *data.VectorServiceClient
grpcConn *grpc.ClientConn
}

func newIndexConnection(apiKey string, host string, namespace string, sourceTag string) (*IndexConnection, error) {
type newIndexParameters struct {
apiKey string
host string
namespace string
sourceTag string
additionalMetadata map[string]string
}

func newIndexConnection(in newIndexParameters) (*IndexConnection, error) {
config := &tls.Config{}
target := fmt.Sprintf("%s:443", host)
target := fmt.Sprintf("%s:443", in.host)
conn, err := grpc.Dial(
target,
grpc.WithTransportCredentials(credentials.NewTLS(config)),
grpc.WithAuthority(target),
grpc.WithBlock(),
grpc.WithUserAgent(useragent.BuildUserAgentGRPC(sourceTag)),
grpc.WithUserAgent(useragent.BuildUserAgentGRPC(in.sourceTag)),
)

if err != nil {
Expand All @@ -38,7 +47,7 @@ func newIndexConnection(apiKey string, host string, namespace string, sourceTag

dataClient := data.NewVectorServiceClient(conn)

idx := IndexConnection{Namespace: namespace, apiKey: apiKey, dataClient: &dataClient, grpcConn: conn}
idx := IndexConnection{Namespace: in.namespace, apiKey: in.apiKey, dataClient: &dataClient, grpcConn: conn, additionalMetadata: in.additionalMetadata}
return &idx, nil
}

Expand Down Expand Up @@ -360,5 +369,12 @@ func sparseValToGrpc(sv *SparseValues) *data.SparseValues {
}

func (idx *IndexConnection) akCtx(ctx context.Context) context.Context {
return metadata.AppendToOutgoingContext(ctx, "api-key", idx.apiKey)
newMetadata := []string{}
newMetadata = append(newMetadata, "api-key", idx.apiKey)

for key, value := range idx.additionalMetadata{
newMetadata = append(newMetadata, key, value)
}

return metadata.AppendToOutgoingContext(ctx, newMetadata...)
}
Loading

0 comments on commit 5a9897e

Please sign in to comment.