-
Notifications
You must be signed in to change notification settings - Fork 9
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Optional configs for setting custom headers / metadata for REST / gRPC operations #18
Changes from 6 commits
fcfac96
b311770
233a473
395be59
e5753fb
59a6918
89fc762
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
package mocks | ||
|
||
import ( | ||
"bytes" | ||
"io" | ||
"net/http" | ||
) | ||
|
||
type MockTransport struct { | ||
Req *http.Request | ||
Resp *http.Response | ||
Err error | ||
} | ||
|
||
func (m *MockTransport) RoundTrip(req *http.Request) (*http.Response, error) { | ||
m.Req = req | ||
return m.Resp, m.Err | ||
} | ||
|
||
func CreateMockClient(jsonBody string) *http.Client { | ||
return &http.Client { | ||
Transport: &MockTransport{ | ||
Resp: &http.Response{ | ||
StatusCode: 200, | ||
Body: io.NopCloser(bytes.NewReader([]byte(jsonBody))), | ||
Header: make(http.Header), | ||
}, | ||
}, | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,51 +4,55 @@ import ( | |
"context" | ||
"encoding/json" | ||
"fmt" | ||
"io" | ||
"net/http" | ||
"strings" | ||
|
||
"github.com/deepmap/oapi-codegen/v2/pkg/securityprovider" | ||
"github.com/pinecone-io/go-pinecone/internal/gen/control" | ||
"github.com/pinecone-io/go-pinecone/internal/provider" | ||
"github.com/pinecone-io/go-pinecone/internal/useragent" | ||
"io" | ||
"net/http" | ||
) | ||
|
||
type Client struct { | ||
apiKey string | ||
restClient *control.Client | ||
sourceTag string | ||
headers map[string]string | ||
} | ||
|
||
type NewClientParams struct { | ||
ApiKey string | ||
SourceTag string // optional | ||
ApiKey string // optional unless no Authorization header provided | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Suggest rephrasing this to "required unless Authorization header provided" |
||
SourceTag string // optional | ||
Headers map[string]string // optional | ||
RestClient *http.Client // optional | ||
} | ||
|
||
func NewClient(in NewClientParams) (*Client, error) { | ||
apiKeyProvider, err := securityprovider.NewSecurityProviderApiKey("header", "Api-Key", in.ApiKey) | ||
clientOptions, err := buildClientOptions(in) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
userAgentProvider := provider.NewHeaderProvider("User-Agent", useragent.BuildUserAgent(in.SourceTag)) | ||
|
||
client, err := control.NewClient("https://api.pinecone.io", | ||
control.WithRequestEditorFn(apiKeyProvider.Intercept), | ||
control.WithRequestEditorFn(userAgentProvider.Intercept), | ||
) | ||
client, err := control.NewClient("https://api.pinecone.io", clientOptions...) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
c := Client{apiKey: in.ApiKey, restClient: client, sourceTag: in.SourceTag} | ||
c := Client{apiKey: in.ApiKey, restClient: client, sourceTag: in.SourceTag, headers: in.Headers} | ||
return &c, nil | ||
} | ||
|
||
func (c *Client) Index(host string) (*IndexConnection, error) { | ||
return c.IndexWithNamespace(host, "") | ||
return c.IndexWithAdditionalMetadata(host, "", nil) | ||
} | ||
|
||
func (c *Client) IndexWithNamespace(host string, namespace string) (*IndexConnection, error) { | ||
idx, err := newIndexConnection(c.apiKey, host, namespace, c.sourceTag) | ||
return c.IndexWithAdditionalMetadata(host, namespace, nil) | ||
} | ||
|
||
func (c *Client) IndexWithAdditionalMetadata(host string, namespace string, additionalMetadata map[string]string) (*IndexConnection, error) { | ||
idx, err := newIndexConnection(newIndexParameters{apiKey: c.apiKey, host: host, namespace: namespace, sourceTag: c.sourceTag, additionalMetadata: additionalMetadata}) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
@@ -408,3 +412,40 @@ func minOne(x int32) int32 { | |
} | ||
return x | ||
} | ||
|
||
func buildClientOptions(in NewClientParams) ([]control.ClientOption, error) { | ||
clientOptions := []control.ClientOption{} | ||
hasAuthorizationHeader := false | ||
hasApiKey := in.ApiKey != "" | ||
|
||
userAgentProvider := provider.NewHeaderProvider("User-Agent", useragent.BuildUserAgent(in.SourceTag)) | ||
clientOptions = append(clientOptions, control.WithRequestEditorFn(userAgentProvider.Intercept)) | ||
|
||
for key, value := range in.Headers { | ||
headerProvider := provider.NewHeaderProvider(key, value) | ||
|
||
if strings.Contains(key, "Authorization") { | ||
hasAuthorizationHeader = true | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. http headers are typically case in-sensitive so we probably want to relax this check There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Updated to use |
||
|
||
clientOptions = append(clientOptions, control.WithRequestEditorFn(headerProvider.Intercept)) | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are there any headers we'd want to prevent them from setting? I'm not sure there, just asking the question. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. At the moment I don't think there are, good question though I'm also not sure. |
||
|
||
if !hasAuthorizationHeader { | ||
apiKeyProvider, err := securityprovider.NewSecurityProviderApiKey("header", "Api-Key", in.ApiKey) | ||
if err != nil { | ||
return nil, err | ||
} | ||
clientOptions = append(clientOptions, control.WithRequestEditorFn(apiKeyProvider.Intercept)) | ||
} | ||
|
||
if !hasAuthorizationHeader && !hasApiKey { | ||
return nil, fmt.Errorf("no API key provided, please pass an API key for authorization") | ||
} | ||
|
||
if in.RestClient != nil { | ||
clientOptions = append(clientOptions, control.WithHTTPClient(in.RestClient)) | ||
} | ||
|
||
return clientOptions, nil | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,9 +4,12 @@ import ( | |
"context" | ||
"fmt" | ||
"os" | ||
"reflect" | ||
"strings" | ||
"testing" | ||
|
||
"github.com/google/uuid" | ||
"github.com/pinecone-io/go-pinecone/internal/mocks" | ||
"github.com/stretchr/testify/require" | ||
"github.com/stretchr/testify/suite" | ||
) | ||
|
@@ -66,8 +69,11 @@ func (ts *ClientTests) TestNewClientParamsSet() { | |
if client.sourceTag != "" { | ||
ts.FailNow(fmt.Sprintf("Expected client to have empty sourceTag, but got '%s'", client.sourceTag)) | ||
} | ||
if client.headers != nil { | ||
ts.FailNow(fmt.Sprintf("Expected client headers to be nil, but got '%v'", client.headers)) | ||
} | ||
if len(client.restClient.RequestEditors) != 2 { | ||
ts.FailNow("Expected 2 request editors on client") | ||
ts.FailNow("Expected client to have '%v' request editors, but got '%v'", 2, len(client.restClient.RequestEditors)) | ||
} | ||
} | ||
|
||
|
@@ -85,7 +91,81 @@ func (ts *ClientTests) TestNewClientParamsSetSourceTag() { | |
ts.FailNow(fmt.Sprintf("Expected client to have sourceTag '%s', but got '%s'", sourceTag, client.sourceTag)) | ||
} | ||
if len(client.restClient.RequestEditors) != 2 { | ||
ts.FailNow("Expected 2 request editors on client") | ||
ts.FailNow("Expected client to have '%v' request editors, but got '%v'", 2, len(client.restClient.RequestEditors)) | ||
} | ||
} | ||
|
||
func (ts *ClientTests) TestNewClientParamsSetHeaders() { | ||
apiKey := "test-api-key" | ||
headers := map[string]string{"test-header": "test-value"} | ||
client, err := NewClient(NewClientParams{ApiKey: apiKey, Headers: headers}) | ||
if err != nil { | ||
ts.FailNow(err.Error()) | ||
} | ||
if client.apiKey != apiKey { | ||
ts.FailNow(fmt.Sprintf("Expected client to have apiKey '%s', but got '%s'", apiKey, client.apiKey)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this potentially dangerous b/c someone could reveal an API key through this error msg? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (I mean, I know it's a mocked API key in the test, but if this function outputs this error msg IRL, that could be a dangerous vulnerability for users, no?) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is test code and thus not compiled into the actual built package, as far as I understand. Given that and the usage of a fake key in the tests where we're doing this, I feel like this is probably safe. Users would need to be developing the client themselves locally or in CI for it to be an issue I think. |
||
} | ||
if !reflect.DeepEqual(client.headers, headers) { | ||
ts.FailNow(fmt.Sprintf("Expected client to have headers '%+v', but got '%+v'", headers, client.headers)) | ||
} | ||
if len(client.restClient.RequestEditors) != 3 { | ||
ts.FailNow(fmt.Sprintf("Expected client to have '%v' request editors, but got '%v'", 3, len(client.restClient.RequestEditors))) | ||
} | ||
} | ||
|
||
func (ts *ClientTests) TestNewClientParamsNoApiKeyNoAuthorizationHeader() { | ||
client, err := NewClient(NewClientParams{}) | ||
require.NotNil(ts.T(), err, "Expected error when creating client without an API key or Authorization header") | ||
if !strings.Contains(err.Error(), "no API key provided, please pass an API key for authorization") { | ||
ts.FailNow(fmt.Sprintf("Expected error to contain 'no API key provided, please pass an API key for authorization', but got '%s'", err.Error())) | ||
} | ||
|
||
require.Nil(ts.T(), client, "Expected client to be nil when creating client without an API key or Authorization header") | ||
} | ||
|
||
func (ts *ClientTests) TestHeadersAppliedToRequests() { | ||
apiKey := "test-api-key" | ||
headers := map[string]string{"test-header": "123456"} | ||
|
||
httpClient := mocks.CreateMockClient(`{"indexes": []}`) | ||
client, err := NewClient(NewClientParams{ApiKey: apiKey, Headers: headers, RestClient: httpClient}) | ||
if err != nil { | ||
ts.FailNow(err.Error()) | ||
} | ||
mockTransport := httpClient.Transport.(*mocks.MockTransport) | ||
|
||
_, err = client.ListIndexes(context.Background()) | ||
require.NoError(ts.T(), err) | ||
require.NotNil(ts.T(), mockTransport.Req, "Expected request to be made") | ||
|
||
testHeaderValue := mockTransport.Req.Header.Get("test-header") | ||
if testHeaderValue != "123456" { | ||
ts.FailNow(fmt.Sprintf("Expected request to have header value '123456', but got '%s'", testHeaderValue)) | ||
} | ||
} | ||
|
||
func (ts *ClientTests) TestAuthorizationHeaderOverridesApiKey() { | ||
apiKey := "test-api-key" | ||
headers := map[string]string{"Authorization": "bearer fooo"} | ||
|
||
httpClient := mocks.CreateMockClient(`{"indexes": []}`) | ||
client, err := NewClient(NewClientParams{ApiKey: apiKey, Headers: headers, RestClient: httpClient}) | ||
if err != nil { | ||
ts.FailNow(err.Error()) | ||
} | ||
mockTransport := httpClient.Transport.(*mocks.MockTransport) | ||
|
||
_, err = client.ListIndexes(context.Background()) | ||
require.NoError(ts.T(), err) | ||
require.NotNil(ts.T(), mockTransport.Req, "Expected request to be made") | ||
|
||
apiKeyHeaderValue := mockTransport.Req.Header.Get("Api-Key") | ||
authHeaderValue := mockTransport.Req.Header.Get("Authorization") | ||
if authHeaderValue != "bearer fooo" { | ||
ts.FailNow(fmt.Sprintf("Expected request to have header value 'bearer fooo', but got '%s'", authHeaderValue)) | ||
} | ||
if apiKeyHeaderValue != "" { | ||
ts.FailNow(fmt.Sprintf("Expected request to not have Api-Key header, but got '%s'", apiKeyHeaderValue)) | ||
} | ||
} | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,19 +16,28 @@ import ( | |
type IndexConnection struct { | ||
Namespace string | ||
apiKey string | ||
additionalMetadata map[string]string | ||
dataClient *data.VectorServiceClient | ||
grpcConn *grpc.ClientConn | ||
} | ||
|
||
func newIndexConnection(apiKey string, host string, namespace string, sourceTag string) (*IndexConnection, error) { | ||
type newIndexParameters struct { | ||
apiKey string | ||
host string | ||
namespace string | ||
sourceTag string | ||
additionalMetadata map[string]string | ||
} | ||
|
||
func newIndexConnection(in newIndexParameters) (*IndexConnection, error) { | ||
Comment on lines
+24
to
+32
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nice, I like this There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Props to @haruska for the suggestion. |
||
config := &tls.Config{} | ||
target := fmt.Sprintf("%s:443", host) | ||
target := fmt.Sprintf("%s:443", in.host) | ||
conn, err := grpc.Dial( | ||
target, | ||
grpc.WithTransportCredentials(credentials.NewTLS(config)), | ||
grpc.WithAuthority(target), | ||
grpc.WithBlock(), | ||
grpc.WithUserAgent(useragent.BuildUserAgentGRPC(sourceTag)), | ||
grpc.WithUserAgent(useragent.BuildUserAgentGRPC(in.sourceTag)), | ||
) | ||
|
||
if err != nil { | ||
|
@@ -38,7 +47,7 @@ func newIndexConnection(apiKey string, host string, namespace string, sourceTag | |
|
||
dataClient := data.NewVectorServiceClient(conn) | ||
|
||
idx := IndexConnection{Namespace: namespace, apiKey: apiKey, dataClient: &dataClient, grpcConn: conn} | ||
idx := IndexConnection{Namespace: in.namespace, apiKey: in.apiKey, dataClient: &dataClient, grpcConn: conn, additionalMetadata: in.additionalMetadata} | ||
return &idx, nil | ||
} | ||
|
||
|
@@ -360,5 +369,12 @@ func sparseValToGrpc(sv *SparseValues) *data.SparseValues { | |
} | ||
|
||
func (idx *IndexConnection) akCtx(ctx context.Context) context.Context { | ||
return metadata.AppendToOutgoingContext(ctx, "api-key", idx.apiKey) | ||
newMetadata := []string{} | ||
newMetadata = append(newMetadata, "api-key", idx.apiKey) | ||
|
||
for key, value := range idx.additionalMetadata{ | ||
newMetadata = append(newMetadata, key, value) | ||
} | ||
|
||
return metadata.AppendToOutgoingContext(ctx, newMetadata...) | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like the additional logging output when running these locally, etc.