Skip to content

Commit

Permalink
update IndexConnection and Client to support adding the REST implemen…
Browse files Browse the repository at this point in the history
…tation for the db_data client as a member on the IndexConnection struct, for users this happens under the hood, regenerate and update submodule
  • Loading branch information
austin-denoble committed Oct 10, 2024
1 parent e225c56 commit 2ad000f
Show file tree
Hide file tree
Showing 6 changed files with 131 additions and 74 deletions.
2 changes: 1 addition & 1 deletion codegen/apis
Submodule apis updated from 404bf1 to 7cedbc
50 changes: 32 additions & 18 deletions internal/gen/db_data/rest/db_data_2024-10.oas.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

61 changes: 51 additions & 10 deletions pinecone/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (

"github.com/pinecone-io/go-pinecone/internal/gen"
"github.com/pinecone-io/go-pinecone/internal/gen/db_control"
db_data "github.com/pinecone-io/go-pinecone/internal/gen/db_data/rest"
"github.com/pinecone-io/go-pinecone/internal/gen/inference"
"github.com/pinecone-io/go-pinecone/internal/provider"
"github.com/pinecone-io/go-pinecone/internal/useragent"
Expand Down Expand Up @@ -71,10 +72,11 @@ import (
// [docs.pinecone.io/reference/api]: https://docs.pinecone.io/reference/api/control-plane/list_indexes
// [Inference API]: https://docs.pinecone.io/reference/api/2024-07/inference/generate-embeddings
type Client struct {
Inference *InferenceService
headers map[string]string
Inference *InferenceService
// headers map[string]string
restClient *db_control.Client
sourceTag string
// sourceTag string
baseParams *NewClientBaseParams
}

// NewClientParams holds the parameters for creating a new Client instance while authenticating via an API key.
Expand Down Expand Up @@ -210,8 +212,8 @@ func NewClient(in NewClientParams) (*Client, error) {
// fmt.Println("Successfully created a new Client object!")
// }
func NewClientBase(in NewClientBaseParams) (*Client, error) {
clientOptions := buildClientBaseOptions(in)
inference_client_options := buildInferenceBaseOptions(in)
controlOptions := buildClientBaseOptions(in)
inferenceOptions := buildInferenceBaseOptions(in)
var err error

controlHostOverride := valueOrFallback(in.Host, os.Getenv("PINECONE_CONTROLLER_HOST"))
Expand All @@ -222,16 +224,22 @@ func NewClientBase(in NewClientBaseParams) (*Client, error) {
}
}

db_control_client, err := db_control.NewClient(valueOrFallback(controlHostOverride, "https://api.pinecone.io"), clientOptions...)
dbControlClient, err := db_control.NewClient(valueOrFallback(controlHostOverride, "https://api.pinecone.io"), controlOptions...)
if err != nil {
return nil, err
}
inference_client, err := inference.NewClient(valueOrFallback(controlHostOverride, "https://api.pinecone.io"), inference_client_options...)
inferenceClient, err := inference.NewClient(valueOrFallback(controlHostOverride, "https://api.pinecone.io"), inferenceOptions...)
if err != nil {
return nil, err
}

c := Client{Inference: &InferenceService{client: inference_client}, restClient: db_control_client, sourceTag: in.SourceTag, headers: in.Headers}
c := Client{
Inference: &InferenceService{client: inferenceClient},
restClient: dbControlClient,
// sourceTag: in.SourceTag,
// headers: in.Headers,
baseParams: &in,
}
return &c, nil
}

Expand Down Expand Up @@ -304,18 +312,35 @@ func (c *Client) Index(in NewIndexConnParams, dialOpts ...grpc.DialOption) (*Ind
in.AdditionalMetadata[key] = value
}

dbDataOptions := buildDataClientBaseOptions(*c.baseParams)
dbDataClient, err := db_data.NewClient(ensureHostHasHttps(in.Host), dbDataOptions...)
if err != nil {
return nil, err
}

idx, err := newIndexConnection(newIndexParameters{
host: in.Host,
namespace: in.Namespace,
sourceTag: c.sourceTag,
sourceTag: c.baseParams.SourceTag,
additionalMetadata: in.AdditionalMetadata,
dbDataClient: dbDataClient,
}, dialOpts...)
if err != nil {
return nil, err
}
return idx, nil
}

func ensureHostHasHttps(host string) string {
if strings.HasPrefix("http://", host) {
return strings.Replace(host, "http://", "https://", 1)
} else if !strings.HasPrefix("https://", host) {
return "https://" + host
}

return host
}

// ListIndexes retrieves a list of all Indexes in a Pinecone [project].
//
// Parameters:
Expand Down Expand Up @@ -1332,7 +1357,7 @@ func (c *Client) extractAuthHeader() map[string]string {
"access_token",
}

for key, value := range c.headers {
for key, value := range c.baseParams.Headers {
for _, checkKey := range possibleAuthKeys {
if strings.ToLower(key) == checkKey {
return map[string]string{key: value}
Expand Down Expand Up @@ -1525,6 +1550,22 @@ func buildInferenceBaseOptions(in NewClientBaseParams) []inference.ClientOption
return clientOptions
}

func buildDataClientBaseOptions(in NewClientBaseParams) []db_data.ClientOption {
clientOptions := []db_data.ClientOption{}
headerProviders := buildSharedProviderHeaders(in)

for _, provider := range headerProviders {
clientOptions = append(clientOptions, db_data.WithRequestEditorFn(provider.Intercept))
}

// apply custom http client if provided
if in.RestClient != nil {
clientOptions = append(clientOptions, db_data.WithHTTPClient(in.RestClient))
}

return clientOptions
}

func buildSharedProviderHeaders(in NewClientBaseParams) []*provider.CustomHeader {
providers := []*provider.CustomHeader{}

Expand Down
16 changes: 8 additions & 8 deletions pinecone/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -388,9 +388,9 @@ func TestNewClientParamsSetUnit(t *testing.T) {
client, err := NewClient(NewClientParams{ApiKey: apiKey})

require.NoError(t, err)
require.Empty(t, client.sourceTag, "Expected client to have empty sourceTag")
require.NotNil(t, client.headers, "Expected client headers to not be nil")
apiKeyHeader, ok := client.headers["Api-Key"]
require.Empty(t, client.baseParams.SourceTag, "Expected client to have empty sourceTag")
require.NotNil(t, client.baseParams.Headers, "Expected client headers to not be nil")
apiKeyHeader, ok := client.baseParams.Headers["Api-Key"]
require.True(t, ok, "Expected client to have an 'Api-Key' header")
require.Equal(t, apiKey, apiKeyHeader, "Expected 'Api-Key' header to match provided ApiKey")
require.Equal(t, 3, len(client.restClient.RequestEditors), "Expected client to have correct number of request editors")
Expand All @@ -405,10 +405,10 @@ func TestNewClientParamsSetSourceTagUnit(t *testing.T) {
})

require.NoError(t, err)
apiKeyHeader, ok := client.headers["Api-Key"]
apiKeyHeader, ok := client.baseParams.Headers["Api-Key"]
require.True(t, ok, "Expected client to have an 'Api-Key' header")
require.Equal(t, apiKey, apiKeyHeader, "Expected 'Api-Key' header to match provided ApiKey")
require.Equal(t, sourceTag, client.sourceTag, "Expected client to have sourceTag '%s', but got '%s'", sourceTag, client.sourceTag)
require.Equal(t, sourceTag, client.baseParams.SourceTag, "Expected client to have sourceTag '%s', but got '%s'", sourceTag, client.baseParams.SourceTag)
require.Equal(t, 3, len(client.restClient.RequestEditors), "Expected client to have %s request editors, but got %s", 2, len(client.restClient.RequestEditors))
}

Expand All @@ -418,10 +418,10 @@ func TestNewClientParamsSetHeadersUnit(t *testing.T) {
client, err := NewClient(NewClientParams{ApiKey: apiKey, Headers: headers})

require.NoError(t, err)
apiKeyHeader, ok := client.headers["Api-Key"]
apiKeyHeader, ok := client.baseParams.Headers["Api-Key"]
require.True(t, ok, "Expected client to have an 'Api-Key' header")
require.Equal(t, apiKey, apiKeyHeader, "Expected 'Api-Key' header to match provided ApiKey")
require.Equal(t, client.headers, headers, "Expected client to have headers '%+v', but got '%+v'", headers, client.headers)
require.Equal(t, client.baseParams.Headers, headers, "Expected client to have headers '%+v', but got '%+v'", headers, client.baseParams.Headers)
require.Equal(t, 4, len(client.restClient.RequestEditors), "Expected client to have %s request editors, but got %s", 3, len(client.restClient.RequestEditors))
}

Expand Down Expand Up @@ -1072,7 +1072,7 @@ func TestNewClientUnit(t *testing.T) {
} else {
assert.NoError(t, err)
assert.NotNil(t, client)
assert.Equal(t, tc.expectedHeaders, client.headers, "Expected headers to be '%v', but got '%v'", tc.expectedHeaders, client.headers)
assert.Equal(t, tc.expectedHeaders, client.baseParams.Headers, "Expected headers to be '%v', but got '%v'", tc.expectedHeaders, client.baseParams.Headers)
}
})
}
Expand Down
Loading

0 comments on commit 2ad000f

Please sign in to comment.