Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optional configs for setting custom headers / metadata for REST / gRPC operations #18

Merged
merged 7 commits into from
May 6, 2024
30 changes: 30 additions & 0 deletions internal/mocks/mock_transport.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
package mocks

import (
"bytes"
"io"
"net/http"
)

type MockTransport struct {
Req *http.Request
Resp *http.Response
Err error
}

func (m *MockTransport) RoundTrip(req *http.Request) (*http.Response, error) {
m.Req = req
return m.Resp, m.Err
}

func CreateMockClient(jsonBody string) *http.Client {
return &http.Client {
Transport: &MockTransport{
Resp: &http.Response{
StatusCode: 200,
Body: io.NopCloser(bytes.NewReader([]byte(jsonBody))),
Header: make(http.Header),
},
},
}
}
2 changes: 1 addition & 1 deletion justfile
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ test:
set -o allexport
source .env
set +o allexport
go test -count=1 ./pinecone
go test -count=1 -v ./pinecone
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the additional logging output when running these locally, etc.

bootstrap:
go install google.golang.org/protobuf/cmd/protoc-gen-go@latest
go install google.golang.org/grpc/cmd/protoc-gen-go-grpc@latest
Expand Down
69 changes: 55 additions & 14 deletions pinecone/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,51 +4,55 @@ import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"

"github.com/deepmap/oapi-codegen/v2/pkg/securityprovider"
"github.com/pinecone-io/go-pinecone/internal/gen/control"
"github.com/pinecone-io/go-pinecone/internal/provider"
"github.com/pinecone-io/go-pinecone/internal/useragent"
"io"
"net/http"
)

type Client struct {
apiKey string
restClient *control.Client
sourceTag string
headers map[string]string
}

type NewClientParams struct {
ApiKey string
SourceTag string // optional
ApiKey string // required unless Authorization header provided
SourceTag string // optional
Headers map[string]string // optional
RestClient *http.Client // optional
}

func NewClient(in NewClientParams) (*Client, error) {
apiKeyProvider, err := securityprovider.NewSecurityProviderApiKey("header", "Api-Key", in.ApiKey)
clientOptions, err := buildClientOptions(in)
if err != nil {
return nil, err
}

userAgentProvider := provider.NewHeaderProvider("User-Agent", useragent.BuildUserAgent(in.SourceTag))

client, err := control.NewClient("https://api.pinecone.io",
control.WithRequestEditorFn(apiKeyProvider.Intercept),
control.WithRequestEditorFn(userAgentProvider.Intercept),
)
client, err := control.NewClient("https://api.pinecone.io", clientOptions...)
if err != nil {
return nil, err
}

c := Client{apiKey: in.ApiKey, restClient: client, sourceTag: in.SourceTag}
c := Client{apiKey: in.ApiKey, restClient: client, sourceTag: in.SourceTag, headers: in.Headers}
return &c, nil
}

func (c *Client) Index(host string) (*IndexConnection, error) {
return c.IndexWithNamespace(host, "")
return c.IndexWithAdditionalMetadata(host, "", nil)
}

func (c *Client) IndexWithNamespace(host string, namespace string) (*IndexConnection, error) {
idx, err := newIndexConnection(c.apiKey, host, namespace, c.sourceTag)
return c.IndexWithAdditionalMetadata(host, namespace, nil)
}

func (c *Client) IndexWithAdditionalMetadata(host string, namespace string, additionalMetadata map[string]string) (*IndexConnection, error) {
idx, err := newIndexConnection(newIndexParameters{apiKey: c.apiKey, host: host, namespace: namespace, sourceTag: c.sourceTag, additionalMetadata: additionalMetadata})
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -408,3 +412,40 @@ func minOne(x int32) int32 {
}
return x
}

func buildClientOptions(in NewClientParams) ([]control.ClientOption, error) {
clientOptions := []control.ClientOption{}
hasAuthorizationHeader := false
hasApiKey := in.ApiKey != ""

userAgentProvider := provider.NewHeaderProvider("User-Agent", useragent.BuildUserAgent(in.SourceTag))
clientOptions = append(clientOptions, control.WithRequestEditorFn(userAgentProvider.Intercept))

for key, value := range in.Headers {
headerProvider := provider.NewHeaderProvider(key, value)

if strings.Contains(strings.ToLower(key), "authorization") {
hasAuthorizationHeader = true
}

clientOptions = append(clientOptions, control.WithRequestEditorFn(headerProvider.Intercept))
}

if !hasAuthorizationHeader {
apiKeyProvider, err := securityprovider.NewSecurityProviderApiKey("header", "Api-Key", in.ApiKey)
if err != nil {
return nil, err
}
clientOptions = append(clientOptions, control.WithRequestEditorFn(apiKeyProvider.Intercept))
}

if !hasAuthorizationHeader && !hasApiKey {
return nil, fmt.Errorf("no API key provided, please pass an API key for authorization")
}

if in.RestClient != nil {
clientOptions = append(clientOptions, control.WithHTTPClient(in.RestClient))
}

return clientOptions, nil
}
84 changes: 82 additions & 2 deletions pinecone/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,12 @@ import (
"context"
"fmt"
"os"
"reflect"
"strings"
"testing"

"github.com/google/uuid"
"github.com/pinecone-io/go-pinecone/internal/mocks"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
Expand Down Expand Up @@ -66,8 +69,11 @@ func (ts *ClientTests) TestNewClientParamsSet() {
if client.sourceTag != "" {
ts.FailNow(fmt.Sprintf("Expected client to have empty sourceTag, but got '%s'", client.sourceTag))
}
if client.headers != nil {
ts.FailNow(fmt.Sprintf("Expected client headers to be nil, but got '%v'", client.headers))
}
if len(client.restClient.RequestEditors) != 2 {
ts.FailNow("Expected 2 request editors on client")
ts.FailNow("Expected client to have '%v' request editors, but got '%v'", 2, len(client.restClient.RequestEditors))
}
}

Expand All @@ -85,7 +91,81 @@ func (ts *ClientTests) TestNewClientParamsSetSourceTag() {
ts.FailNow(fmt.Sprintf("Expected client to have sourceTag '%s', but got '%s'", sourceTag, client.sourceTag))
}
if len(client.restClient.RequestEditors) != 2 {
ts.FailNow("Expected 2 request editors on client")
ts.FailNow("Expected client to have '%v' request editors, but got '%v'", 2, len(client.restClient.RequestEditors))
}
}

func (ts *ClientTests) TestNewClientParamsSetHeaders() {
apiKey := "test-api-key"
headers := map[string]string{"test-header": "test-value"}
client, err := NewClient(NewClientParams{ApiKey: apiKey, Headers: headers})
if err != nil {
ts.FailNow(err.Error())
}
if client.apiKey != apiKey {
ts.FailNow(fmt.Sprintf("Expected client to have apiKey '%s', but got '%s'", apiKey, client.apiKey))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this potentially dangerous b/c someone could reveal an API key through this error msg?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(I mean, I know it's a mocked API key in the test, but if this function outputs this error msg IRL, that could be a dangerous vulnerability for users, no?)

Copy link
Contributor Author

@austin-denoble austin-denoble Apr 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is test code and thus not compiled into the actual built package, as far as I understand. Given that and the usage of a fake key in the tests where we're doing this, I feel like this is probably safe. Users would need to be developing the client themselves locally or in CI for it to be an issue I think.

}
if !reflect.DeepEqual(client.headers, headers) {
ts.FailNow(fmt.Sprintf("Expected client to have headers '%+v', but got '%+v'", headers, client.headers))
}
if len(client.restClient.RequestEditors) != 3 {
ts.FailNow(fmt.Sprintf("Expected client to have '%v' request editors, but got '%v'", 3, len(client.restClient.RequestEditors)))
}
}

func (ts *ClientTests) TestNewClientParamsNoApiKeyNoAuthorizationHeader() {
client, err := NewClient(NewClientParams{})
require.NotNil(ts.T(), err, "Expected error when creating client without an API key or Authorization header")
if !strings.Contains(err.Error(), "no API key provided, please pass an API key for authorization") {
ts.FailNow(fmt.Sprintf("Expected error to contain 'no API key provided, please pass an API key for authorization', but got '%s'", err.Error()))
}

require.Nil(ts.T(), client, "Expected client to be nil when creating client without an API key or Authorization header")
}

func (ts *ClientTests) TestHeadersAppliedToRequests() {
apiKey := "test-api-key"
headers := map[string]string{"test-header": "123456"}

httpClient := mocks.CreateMockClient(`{"indexes": []}`)
client, err := NewClient(NewClientParams{ApiKey: apiKey, Headers: headers, RestClient: httpClient})
if err != nil {
ts.FailNow(err.Error())
}
mockTransport := httpClient.Transport.(*mocks.MockTransport)

_, err = client.ListIndexes(context.Background())
require.NoError(ts.T(), err)
require.NotNil(ts.T(), mockTransport.Req, "Expected request to be made")

testHeaderValue := mockTransport.Req.Header.Get("test-header")
if testHeaderValue != "123456" {
ts.FailNow(fmt.Sprintf("Expected request to have header value '123456', but got '%s'", testHeaderValue))
}
}

func (ts *ClientTests) TestAuthorizationHeaderOverridesApiKey() {
apiKey := "test-api-key"
headers := map[string]string{"Authorization": "bearer fooo"}

httpClient := mocks.CreateMockClient(`{"indexes": []}`)
client, err := NewClient(NewClientParams{ApiKey: apiKey, Headers: headers, RestClient: httpClient})
if err != nil {
ts.FailNow(err.Error())
}
mockTransport := httpClient.Transport.(*mocks.MockTransport)

_, err = client.ListIndexes(context.Background())
require.NoError(ts.T(), err)
require.NotNil(ts.T(), mockTransport.Req, "Expected request to be made")

apiKeyHeaderValue := mockTransport.Req.Header.Get("Api-Key")
authHeaderValue := mockTransport.Req.Header.Get("Authorization")
if authHeaderValue != "bearer fooo" {
ts.FailNow(fmt.Sprintf("Expected request to have header value 'bearer fooo', but got '%s'", authHeaderValue))
}
if apiKeyHeaderValue != "" {
ts.FailNow(fmt.Sprintf("Expected request to not have Api-Key header, but got '%s'", apiKeyHeaderValue))
}
}

Expand Down
26 changes: 21 additions & 5 deletions pinecone/index_connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,28 @@ import (
type IndexConnection struct {
Namespace string
apiKey string
additionalMetadata map[string]string
dataClient *data.VectorServiceClient
grpcConn *grpc.ClientConn
}

func newIndexConnection(apiKey string, host string, namespace string, sourceTag string) (*IndexConnection, error) {
type newIndexParameters struct {
apiKey string
host string
namespace string
sourceTag string
additionalMetadata map[string]string
}

func newIndexConnection(in newIndexParameters) (*IndexConnection, error) {
Comment on lines +24 to +32
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, I like this

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Props to @haruska for the suggestion.

config := &tls.Config{}
target := fmt.Sprintf("%s:443", host)
target := fmt.Sprintf("%s:443", in.host)
conn, err := grpc.Dial(
target,
grpc.WithTransportCredentials(credentials.NewTLS(config)),
grpc.WithAuthority(target),
grpc.WithBlock(),
grpc.WithUserAgent(useragent.BuildUserAgentGRPC(sourceTag)),
grpc.WithUserAgent(useragent.BuildUserAgentGRPC(in.sourceTag)),
)

if err != nil {
Expand All @@ -38,7 +47,7 @@ func newIndexConnection(apiKey string, host string, namespace string, sourceTag

dataClient := data.NewVectorServiceClient(conn)

idx := IndexConnection{Namespace: namespace, apiKey: apiKey, dataClient: &dataClient, grpcConn: conn}
idx := IndexConnection{Namespace: in.namespace, apiKey: in.apiKey, dataClient: &dataClient, grpcConn: conn, additionalMetadata: in.additionalMetadata}
return &idx, nil
}

Expand Down Expand Up @@ -360,5 +369,12 @@ func sparseValToGrpc(sv *SparseValues) *data.SparseValues {
}

func (idx *IndexConnection) akCtx(ctx context.Context) context.Context {
return metadata.AppendToOutgoingContext(ctx, "api-key", idx.apiKey)
newMetadata := []string{}
newMetadata = append(newMetadata, "api-key", idx.apiKey)

for key, value := range idx.additionalMetadata{
newMetadata = append(newMetadata, key, value)
}

return metadata.AppendToOutgoingContext(ctx, newMetadata...)
}
Loading
Loading