From 2ad000fb94e3eb35ae8dbca8c6ee731e09517243 Mon Sep 17 00:00:00 2001 From: Austin DeNoble Date: Thu, 10 Oct 2024 19:36:31 -0400 Subject: [PATCH] update IndexConnection and Client to support adding the REST implementation for the db_data client as a member on the IndexConnection struct, for users this happens under the hood, regenerate and update submodule --- codegen/apis | 2 +- .../gen/db_data/rest/db_data_2024-10.oas.go | 50 ++++++++----- pinecone/client.go | 61 +++++++++++++--- pinecone/client_test.go | 16 ++--- pinecone/index_connection.go | 70 ++++++++++--------- pinecone/index_connection_test.go | 6 +- 6 files changed, 131 insertions(+), 74 deletions(-) diff --git a/codegen/apis b/codegen/apis index 404bf15..7cedbc1 160000 --- a/codegen/apis +++ b/codegen/apis @@ -1 +1 @@ -Subproject commit 404bf15a7ef14740475d228a77243bdeebb8090d +Subproject commit 7cedbc17518ae76c2cc5f0d6330f739b0e4540b5 diff --git a/internal/gen/db_data/rest/db_data_2024-10.oas.go b/internal/gen/db_data/rest/db_data_2024-10.oas.go index b0ecbeb..5d6535f 100644 --- a/internal/gen/db_data/rest/db_data_2024-10.oas.go +++ b/internal/gen/db_data/rest/db_data_2024-10.oas.go @@ -274,6 +274,12 @@ type StartImportRequest struct { Uri *string `json:"uri,omitempty"` } +// StartImportResponse The response for the `start_import` operation. +type StartImportResponse struct { + // Id Unique identifier for the import operations. + Id *string `json:"id,omitempty"` +} + // UpdateRequest The request for the `update` operation. type UpdateRequest struct { // Id Vector's unique id. @@ -475,8 +481,8 @@ type ClientInterface interface { // CancelBulkImport request CancelBulkImport(ctx context.Context, id string, reqEditors ...RequestEditorFn) (*http.Response, error) - // DescribeImport request - DescribeImport(ctx context.Context, id string, reqEditors ...RequestEditorFn) (*http.Response, error) + // DescribeBulkImport request + DescribeBulkImport(ctx context.Context, id string, reqEditors ...RequestEditorFn) (*http.Response, error) // DescribeIndexStatsWithBody request with any body DescribeIndexStatsWithBody(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*http.Response, error) @@ -558,8 +564,8 @@ func (c *Client) CancelBulkImport(ctx context.Context, id string, reqEditors ... return c.Client.Do(req) } -func (c *Client) DescribeImport(ctx context.Context, id string, reqEditors ...RequestEditorFn) (*http.Response, error) { - req, err := NewDescribeImportRequest(c.Server, id) +func (c *Client) DescribeBulkImport(ctx context.Context, id string, reqEditors ...RequestEditorFn) (*http.Response, error) { + req, err := NewDescribeBulkImportRequest(c.Server, id) if err != nil { return nil, err } @@ -853,8 +859,8 @@ func NewCancelBulkImportRequest(server string, id string) (*http.Request, error) return req, nil } -// NewDescribeImportRequest generates requests for DescribeImport -func NewDescribeImportRequest(server string, id string) (*http.Request, error) { +// NewDescribeBulkImportRequest generates requests for DescribeBulkImport +func NewDescribeBulkImportRequest(server string, id string) (*http.Request, error) { var err error var pathParam0 string @@ -1299,8 +1305,8 @@ type ClientWithResponsesInterface interface { // CancelBulkImportWithResponse request CancelBulkImportWithResponse(ctx context.Context, id string, reqEditors ...RequestEditorFn) (*CancelBulkImportResponse, error) - // DescribeImportWithResponse request - DescribeImportWithResponse(ctx context.Context, id string, reqEditors ...RequestEditorFn) (*DescribeImportResponse, error) + // DescribeBulkImportWithResponse request + DescribeBulkImportWithResponse(ctx context.Context, id string, reqEditors ...RequestEditorFn) (*DescribeBulkImportResponse, error) // DescribeIndexStatsWithBodyWithResponse request with any body DescribeIndexStatsWithBodyWithResponse(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*DescribeIndexStatsResponse, error) @@ -1362,6 +1368,7 @@ func (r ListBulkImportsResponse) StatusCode() int { type StartBulkImportResponse struct { Body []byte HTTPResponse *http.Response + JSON200 *StartImportResponse JSON400 *RpcStatus JSON4XX *RpcStatus JSON5XX *RpcStatus @@ -1408,7 +1415,7 @@ func (r CancelBulkImportResponse) StatusCode() int { return 0 } -type DescribeImportResponse struct { +type DescribeBulkImportResponse struct { Body []byte HTTPResponse *http.Response JSON200 *ImportModel @@ -1418,7 +1425,7 @@ type DescribeImportResponse struct { } // Status returns HTTPResponse.Status -func (r DescribeImportResponse) Status() string { +func (r DescribeBulkImportResponse) Status() string { if r.HTTPResponse != nil { return r.HTTPResponse.Status } @@ -1426,7 +1433,7 @@ func (r DescribeImportResponse) Status() string { } // StatusCode returns HTTPResponse.StatusCode -func (r DescribeImportResponse) StatusCode() int { +func (r DescribeBulkImportResponse) StatusCode() int { if r.HTTPResponse != nil { return r.HTTPResponse.StatusCode } @@ -1643,13 +1650,13 @@ func (c *ClientWithResponses) CancelBulkImportWithResponse(ctx context.Context, return ParseCancelBulkImportResponse(rsp) } -// DescribeImportWithResponse request returning *DescribeImportResponse -func (c *ClientWithResponses) DescribeImportWithResponse(ctx context.Context, id string, reqEditors ...RequestEditorFn) (*DescribeImportResponse, error) { - rsp, err := c.DescribeImport(ctx, id, reqEditors...) +// DescribeBulkImportWithResponse request returning *DescribeBulkImportResponse +func (c *ClientWithResponses) DescribeBulkImportWithResponse(ctx context.Context, id string, reqEditors ...RequestEditorFn) (*DescribeBulkImportResponse, error) { + rsp, err := c.DescribeBulkImport(ctx, id, reqEditors...) if err != nil { return nil, err } - return ParseDescribeImportResponse(rsp) + return ParseDescribeBulkImportResponse(rsp) } // DescribeIndexStatsWithBodyWithResponse request with arbitrary body returning *DescribeIndexStatsResponse @@ -1816,6 +1823,13 @@ func ParseStartBulkImportResponse(rsp *http.Response) (*StartBulkImportResponse, } switch { + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 200: + var dest StartImportResponse + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON200 = &dest + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 400: var dest RpcStatus if err := json.Unmarshal(bodyBytes, &dest); err != nil { @@ -1889,15 +1903,15 @@ func ParseCancelBulkImportResponse(rsp *http.Response) (*CancelBulkImportRespons return response, nil } -// ParseDescribeImportResponse parses an HTTP response from a DescribeImportWithResponse call -func ParseDescribeImportResponse(rsp *http.Response) (*DescribeImportResponse, error) { +// ParseDescribeBulkImportResponse parses an HTTP response from a DescribeBulkImportWithResponse call +func ParseDescribeBulkImportResponse(rsp *http.Response) (*DescribeBulkImportResponse, error) { bodyBytes, err := io.ReadAll(rsp.Body) defer func() { _ = rsp.Body.Close() }() if err != nil { return nil, err } - response := &DescribeImportResponse{ + response := &DescribeBulkImportResponse{ Body: bodyBytes, HTTPResponse: rsp, } diff --git a/pinecone/client.go b/pinecone/client.go index 72ce15a..1368b3f 100644 --- a/pinecone/client.go +++ b/pinecone/client.go @@ -16,6 +16,7 @@ import ( "github.com/pinecone-io/go-pinecone/internal/gen" "github.com/pinecone-io/go-pinecone/internal/gen/db_control" + db_data "github.com/pinecone-io/go-pinecone/internal/gen/db_data/rest" "github.com/pinecone-io/go-pinecone/internal/gen/inference" "github.com/pinecone-io/go-pinecone/internal/provider" "github.com/pinecone-io/go-pinecone/internal/useragent" @@ -71,10 +72,11 @@ import ( // [docs.pinecone.io/reference/api]: https://docs.pinecone.io/reference/api/control-plane/list_indexes // [Inference API]: https://docs.pinecone.io/reference/api/2024-07/inference/generate-embeddings type Client struct { - Inference *InferenceService - headers map[string]string + Inference *InferenceService + // headers map[string]string restClient *db_control.Client - sourceTag string + // sourceTag string + baseParams *NewClientBaseParams } // NewClientParams holds the parameters for creating a new Client instance while authenticating via an API key. @@ -210,8 +212,8 @@ func NewClient(in NewClientParams) (*Client, error) { // fmt.Println("Successfully created a new Client object!") // } func NewClientBase(in NewClientBaseParams) (*Client, error) { - clientOptions := buildClientBaseOptions(in) - inference_client_options := buildInferenceBaseOptions(in) + controlOptions := buildClientBaseOptions(in) + inferenceOptions := buildInferenceBaseOptions(in) var err error controlHostOverride := valueOrFallback(in.Host, os.Getenv("PINECONE_CONTROLLER_HOST")) @@ -222,16 +224,22 @@ func NewClientBase(in NewClientBaseParams) (*Client, error) { } } - db_control_client, err := db_control.NewClient(valueOrFallback(controlHostOverride, "https://api.pinecone.io"), clientOptions...) + dbControlClient, err := db_control.NewClient(valueOrFallback(controlHostOverride, "https://api.pinecone.io"), controlOptions...) if err != nil { return nil, err } - inference_client, err := inference.NewClient(valueOrFallback(controlHostOverride, "https://api.pinecone.io"), inference_client_options...) + inferenceClient, err := inference.NewClient(valueOrFallback(controlHostOverride, "https://api.pinecone.io"), inferenceOptions...) if err != nil { return nil, err } - c := Client{Inference: &InferenceService{client: inference_client}, restClient: db_control_client, sourceTag: in.SourceTag, headers: in.Headers} + c := Client{ + Inference: &InferenceService{client: inferenceClient}, + restClient: dbControlClient, + // sourceTag: in.SourceTag, + // headers: in.Headers, + baseParams: &in, + } return &c, nil } @@ -304,11 +312,18 @@ func (c *Client) Index(in NewIndexConnParams, dialOpts ...grpc.DialOption) (*Ind in.AdditionalMetadata[key] = value } + dbDataOptions := buildDataClientBaseOptions(*c.baseParams) + dbDataClient, err := db_data.NewClient(ensureHostHasHttps(in.Host), dbDataOptions...) + if err != nil { + return nil, err + } + idx, err := newIndexConnection(newIndexParameters{ host: in.Host, namespace: in.Namespace, - sourceTag: c.sourceTag, + sourceTag: c.baseParams.SourceTag, additionalMetadata: in.AdditionalMetadata, + dbDataClient: dbDataClient, }, dialOpts...) if err != nil { return nil, err @@ -316,6 +331,16 @@ func (c *Client) Index(in NewIndexConnParams, dialOpts ...grpc.DialOption) (*Ind return idx, nil } +func ensureHostHasHttps(host string) string { + if strings.HasPrefix("http://", host) { + return strings.Replace(host, "http://", "https://", 1) + } else if !strings.HasPrefix("https://", host) { + return "https://" + host + } + + return host +} + // ListIndexes retrieves a list of all Indexes in a Pinecone [project]. // // Parameters: @@ -1332,7 +1357,7 @@ func (c *Client) extractAuthHeader() map[string]string { "access_token", } - for key, value := range c.headers { + for key, value := range c.baseParams.Headers { for _, checkKey := range possibleAuthKeys { if strings.ToLower(key) == checkKey { return map[string]string{key: value} @@ -1525,6 +1550,22 @@ func buildInferenceBaseOptions(in NewClientBaseParams) []inference.ClientOption return clientOptions } +func buildDataClientBaseOptions(in NewClientBaseParams) []db_data.ClientOption { + clientOptions := []db_data.ClientOption{} + headerProviders := buildSharedProviderHeaders(in) + + for _, provider := range headerProviders { + clientOptions = append(clientOptions, db_data.WithRequestEditorFn(provider.Intercept)) + } + + // apply custom http client if provided + if in.RestClient != nil { + clientOptions = append(clientOptions, db_data.WithHTTPClient(in.RestClient)) + } + + return clientOptions +} + func buildSharedProviderHeaders(in NewClientBaseParams) []*provider.CustomHeader { providers := []*provider.CustomHeader{} diff --git a/pinecone/client_test.go b/pinecone/client_test.go index 17200e1..7771849 100644 --- a/pinecone/client_test.go +++ b/pinecone/client_test.go @@ -388,9 +388,9 @@ func TestNewClientParamsSetUnit(t *testing.T) { client, err := NewClient(NewClientParams{ApiKey: apiKey}) require.NoError(t, err) - require.Empty(t, client.sourceTag, "Expected client to have empty sourceTag") - require.NotNil(t, client.headers, "Expected client headers to not be nil") - apiKeyHeader, ok := client.headers["Api-Key"] + require.Empty(t, client.baseParams.SourceTag, "Expected client to have empty sourceTag") + require.NotNil(t, client.baseParams.Headers, "Expected client headers to not be nil") + apiKeyHeader, ok := client.baseParams.Headers["Api-Key"] require.True(t, ok, "Expected client to have an 'Api-Key' header") require.Equal(t, apiKey, apiKeyHeader, "Expected 'Api-Key' header to match provided ApiKey") require.Equal(t, 3, len(client.restClient.RequestEditors), "Expected client to have correct number of request editors") @@ -405,10 +405,10 @@ func TestNewClientParamsSetSourceTagUnit(t *testing.T) { }) require.NoError(t, err) - apiKeyHeader, ok := client.headers["Api-Key"] + apiKeyHeader, ok := client.baseParams.Headers["Api-Key"] require.True(t, ok, "Expected client to have an 'Api-Key' header") require.Equal(t, apiKey, apiKeyHeader, "Expected 'Api-Key' header to match provided ApiKey") - require.Equal(t, sourceTag, client.sourceTag, "Expected client to have sourceTag '%s', but got '%s'", sourceTag, client.sourceTag) + require.Equal(t, sourceTag, client.baseParams.SourceTag, "Expected client to have sourceTag '%s', but got '%s'", sourceTag, client.baseParams.SourceTag) require.Equal(t, 3, len(client.restClient.RequestEditors), "Expected client to have %s request editors, but got %s", 2, len(client.restClient.RequestEditors)) } @@ -418,10 +418,10 @@ func TestNewClientParamsSetHeadersUnit(t *testing.T) { client, err := NewClient(NewClientParams{ApiKey: apiKey, Headers: headers}) require.NoError(t, err) - apiKeyHeader, ok := client.headers["Api-Key"] + apiKeyHeader, ok := client.baseParams.Headers["Api-Key"] require.True(t, ok, "Expected client to have an 'Api-Key' header") require.Equal(t, apiKey, apiKeyHeader, "Expected 'Api-Key' header to match provided ApiKey") - require.Equal(t, client.headers, headers, "Expected client to have headers '%+v', but got '%+v'", headers, client.headers) + require.Equal(t, client.baseParams.Headers, headers, "Expected client to have headers '%+v', but got '%+v'", headers, client.baseParams.Headers) require.Equal(t, 4, len(client.restClient.RequestEditors), "Expected client to have %s request editors, but got %s", 3, len(client.restClient.RequestEditors)) } @@ -1072,7 +1072,7 @@ func TestNewClientUnit(t *testing.T) { } else { assert.NoError(t, err) assert.NotNil(t, client) - assert.Equal(t, tc.expectedHeaders, client.headers, "Expected headers to be '%v', but got '%v'", tc.expectedHeaders, client.headers) + assert.Equal(t, tc.expectedHeaders, client.baseParams.Headers, "Expected headers to be '%v', but got '%v'", tc.expectedHeaders, client.baseParams.Headers) } }) } diff --git a/pinecone/index_connection.go b/pinecone/index_connection.go index 23b3f1f..843ecf6 100644 --- a/pinecone/index_connection.go +++ b/pinecone/index_connection.go @@ -8,7 +8,8 @@ import ( "net/url" "strings" - db_data "github.com/pinecone-io/go-pinecone/internal/gen/db_data/grpc" + dbDataGrpc "github.com/pinecone-io/go-pinecone/internal/gen/db_data/grpc" + dbDataRest "github.com/pinecone-io/go-pinecone/internal/gen/db_data/rest" "github.com/pinecone-io/go-pinecone/internal/useragent" "google.golang.org/grpc" "google.golang.org/grpc/credentials" @@ -26,7 +27,7 @@ import ( type IndexConnection struct { Namespace string additionalMetadata map[string]string - dataClient *db_data.VectorServiceClient + grpcClient *dbDataGrpc.VectorServiceClient grpcConn *grpc.ClientConn } @@ -35,6 +36,7 @@ type newIndexParameters struct { namespace string sourceTag string additionalMetadata map[string]string + dbDataClient *dbDataRest.Client } func newIndexConnection(in newIndexParameters, dialOpts ...grpc.DialOption) (*IndexConnection, error) { @@ -65,11 +67,11 @@ func newIndexConnection(in newIndexParameters, dialOpts ...grpc.DialOption) (*In return nil, err } - dataClient := db_data.NewVectorServiceClient(conn) + dataClient := dbDataGrpc.NewVectorServiceClient(conn) idx := IndexConnection{ Namespace: in.namespace, - dataClient: &dataClient, + grpcClient: &dataClient, grpcConn: conn, additionalMetadata: in.additionalMetadata, } @@ -185,17 +187,17 @@ func (idx *IndexConnection) Close() error { // log.Fatalf("Successfully upserted %d vector(s)!\n", count) // } func (idx *IndexConnection) UpsertVectors(ctx context.Context, in []*Vector) (uint32, error) { - vectors := make([]*db_data.Vector, len(in)) + vectors := make([]*dbDataGrpc.Vector, len(in)) for i, v := range in { vectors[i] = vecToGrpc(v) } - req := &db_data.UpsertRequest{ + req := &dbDataGrpc.UpsertRequest{ Vectors: vectors, Namespace: idx.Namespace, } - res, err := (*idx.dataClient).Upsert(idx.akCtx(ctx), req) + res, err := (*idx.grpcClient).Upsert(idx.akCtx(ctx), req) if err != nil { return 0, err } @@ -263,12 +265,12 @@ type FetchVectorsResponse struct { // fmt.Println("No vectors found") // } func (idx *IndexConnection) FetchVectors(ctx context.Context, ids []string) (*FetchVectorsResponse, error) { - req := &db_data.FetchRequest{ + req := &dbDataGrpc.FetchRequest{ Ids: ids, Namespace: idx.Namespace, } - res, err := (*idx.dataClient).Fetch(idx.akCtx(ctx), req) + res, err := (*idx.grpcClient).Fetch(idx.akCtx(ctx), req) if err != nil { return nil, err } @@ -371,13 +373,13 @@ type ListVectorsResponse struct { // fmt.Printf("Found %d vector(s)\n", len(res.VectorIds)) // } func (idx *IndexConnection) ListVectors(ctx context.Context, in *ListVectorsRequest) (*ListVectorsResponse, error) { - req := &db_data.ListRequest{ + req := &dbDataGrpc.ListRequest{ Prefix: in.Prefix, Limit: in.Limit, PaginationToken: in.PaginationToken, Namespace: idx.Namespace, } - res, err := (*idx.dataClient).List(idx.akCtx(ctx), req) + res, err := (*idx.grpcClient).List(idx.akCtx(ctx), req) if err != nil { return nil, err } @@ -501,7 +503,7 @@ type QueryVectorsResponse struct { // } // } func (idx *IndexConnection) QueryByVectorValues(ctx context.Context, in *QueryByVectorValuesRequest) (*QueryVectorsResponse, error) { - req := &db_data.QueryRequest{ + req := &dbDataGrpc.QueryRequest{ Namespace: idx.Namespace, TopK: in.TopK, Filter: in.MetadataFilter, @@ -591,7 +593,7 @@ type QueryByVectorIdRequest struct { // } // } func (idx *IndexConnection) QueryByVectorId(ctx context.Context, in *QueryByVectorIdRequest) (*QueryVectorsResponse, error) { - req := &db_data.QueryRequest{ + req := &dbDataGrpc.QueryRequest{ Id: in.VectorId, Namespace: idx.Namespace, TopK: in.TopK, @@ -652,7 +654,7 @@ func (idx *IndexConnection) QueryByVectorId(ctx context.Context, in *QueryByVect // log.Fatalf("Failed to delete vector with ID: %s. Error: %s\n", vectorId, err) // } func (idx *IndexConnection) DeleteVectorsById(ctx context.Context, ids []string) error { - req := db_data.DeleteRequest{ + req := dbDataGrpc.DeleteRequest{ Ids: ids, Namespace: idx.Namespace, } @@ -716,7 +718,7 @@ func (idx *IndexConnection) DeleteVectorsById(ctx context.Context, ids []string) // log.Fatalf("Failed to delete vector(s) with filter: %+v. Error: %s\n", filter, err) // } func (idx *IndexConnection) DeleteVectorsByFilter(ctx context.Context, metadataFilter *MetadataFilter) error { - req := db_data.DeleteRequest{ + req := dbDataGrpc.DeleteRequest{ Filter: metadataFilter, Namespace: idx.Namespace, } @@ -768,7 +770,7 @@ func (idx *IndexConnection) DeleteVectorsByFilter(ctx context.Context, metadataF // log.Fatalf("Failed to delete vectors in namespace: \"%s\". Error: %s", idxConnection.Namespace, err) // } func (idx *IndexConnection) DeleteAllVectorsInNamespace(ctx context.Context) error { - req := db_data.DeleteRequest{ + req := dbDataGrpc.DeleteRequest{ Namespace: idx.Namespace, DeleteAll: true, } @@ -842,7 +844,7 @@ func (idx *IndexConnection) UpdateVector(ctx context.Context, in *UpdateVectorRe return fmt.Errorf("a vector ID plus at least one of Values, SparseValues, or Metadata must be provided to update a vector") } - req := &db_data.UpdateRequest{ + req := &dbDataGrpc.UpdateRequest{ Id: in.Id, Values: in.Values, SparseValues: sparseValToGrpc(in.SparseValues), @@ -850,7 +852,7 @@ func (idx *IndexConnection) UpdateVector(ctx context.Context, in *UpdateVectorRe Namespace: idx.Namespace, } - _, err := (*idx.dataClient).Update(idx.akCtx(ctx), req) + _, err := (*idx.grpcClient).Update(idx.akCtx(ctx), req) return err } @@ -973,10 +975,10 @@ func (idx *IndexConnection) DescribeIndexStats(ctx context.Context) (*DescribeIn // } // } func (idx *IndexConnection) DescribeIndexStatsFiltered(ctx context.Context, metadataFilter *MetadataFilter) (*DescribeIndexStatsResponse, error) { - req := &db_data.DescribeIndexStatsRequest{ + req := &dbDataGrpc.DescribeIndexStatsRequest{ Filter: metadataFilter, } - res, err := (*idx.dataClient).DescribeIndexStats(idx.akCtx(ctx), req) + res, err := (*idx.grpcClient).DescribeIndexStats(idx.akCtx(ctx), req) if err != nil { return nil, err } @@ -996,8 +998,8 @@ func (idx *IndexConnection) DescribeIndexStatsFiltered(ctx context.Context, meta }, nil } -func (idx *IndexConnection) query(ctx context.Context, req *db_data.QueryRequest) (*QueryVectorsResponse, error) { - res, err := (*idx.dataClient).Query(idx.akCtx(ctx), req) +func (idx *IndexConnection) query(ctx context.Context, req *dbDataGrpc.QueryRequest) (*QueryVectorsResponse, error) { + res, err := (*idx.grpcClient).Query(idx.akCtx(ctx), req) if err != nil { return nil, err } @@ -1014,8 +1016,8 @@ func (idx *IndexConnection) query(ctx context.Context, req *db_data.QueryRequest }, nil } -func (idx *IndexConnection) delete(ctx context.Context, req *db_data.DeleteRequest) error { - _, err := (*idx.dataClient).Delete(idx.akCtx(ctx), req) +func (idx *IndexConnection) delete(ctx context.Context, req *dbDataGrpc.DeleteRequest) error { + _, err := (*idx.grpcClient).Delete(idx.akCtx(ctx), req) return err } @@ -1029,7 +1031,7 @@ func (idx *IndexConnection) akCtx(ctx context.Context) context.Context { return metadata.AppendToOutgoingContext(ctx, newMetadata...) } -func toVector(vector *db_data.Vector) *Vector { +func toVector(vector *dbDataGrpc.Vector) *Vector { if vector == nil { return nil } @@ -1041,11 +1043,11 @@ func toVector(vector *db_data.Vector) *Vector { } } -func toScoredVector(sv *db_data.ScoredVector) *ScoredVector { +func toScoredVector(sv *dbDataGrpc.ScoredVector) *ScoredVector { if sv == nil { return nil } - v := toVector(&db_data.Vector{ + v := toVector(&dbDataGrpc.Vector{ Id: sv.Id, Values: sv.Values, SparseValues: sv.SparseValues, @@ -1057,7 +1059,7 @@ func toScoredVector(sv *db_data.ScoredVector) *ScoredVector { } } -func toSparseValues(sv *db_data.SparseValues) *SparseValues { +func toSparseValues(sv *dbDataGrpc.SparseValues) *SparseValues { if sv == nil { return nil } @@ -1067,7 +1069,7 @@ func toSparseValues(sv *db_data.SparseValues) *SparseValues { } } -func toUsage(u *db_data.Usage) *Usage { +func toUsage(u *dbDataGrpc.Usage) *Usage { if u == nil { return nil } @@ -1076,18 +1078,18 @@ func toUsage(u *db_data.Usage) *Usage { } } -func toPaginationToken(p *db_data.Pagination) *string { +func toPaginationToken(p *dbDataGrpc.Pagination) *string { if p == nil { return nil } return &p.Next } -func vecToGrpc(v *Vector) *db_data.Vector { +func vecToGrpc(v *Vector) *dbDataGrpc.Vector { if v == nil { return nil } - return &db_data.Vector{ + return &dbDataGrpc.Vector{ Id: v.Id, Values: v.Values, Metadata: v.Metadata, @@ -1095,11 +1097,11 @@ func vecToGrpc(v *Vector) *db_data.Vector { } } -func sparseValToGrpc(sv *SparseValues) *db_data.SparseValues { +func sparseValToGrpc(sv *SparseValues) *dbDataGrpc.SparseValues { if sv == nil { return nil } - return &db_data.SparseValues{ + return &dbDataGrpc.SparseValues{ Indices: sv.Indices, Values: sv.Values, } diff --git a/pinecone/index_connection_test.go b/pinecone/index_connection_test.go index c1951ad..7da6b99 100644 --- a/pinecone/index_connection_test.go +++ b/pinecone/index_connection_test.go @@ -177,7 +177,7 @@ func (ts *IntegrationTests) TestMetadataAppliedToRequests() { require.True(ts.T(), ok, "Expected client to have an 'api-key' header") require.Equal(ts.T(), apiKey, apiKeyHeader, "Expected 'api-key' header to equal %s", apiKey) require.Equal(ts.T(), namespace, idxConn.Namespace, "Expected idxConn to have namespace '%s', but got '%s'", namespace, idxConn.Namespace) - require.NotNil(ts.T(), idxConn.dataClient, "Expected idxConn to have non-nil dataClient") + require.NotNil(ts.T(), idxConn.grpcClient, "Expected idxConn to have non-nil dataClient") require.NotNil(ts.T(), idxConn.grpcConn, "Expected idxConn to have non-nil grpcConn") // initiate request to trigger the MetadataInterceptor @@ -299,7 +299,7 @@ func TestNewIndexConnection(t *testing.T) { require.True(t, ok, "Expected client to have an 'api-key' header") require.Equal(t, apiKey, apiKeyHeader, "Expected 'api-key' header to equal %s", apiKey) require.Empty(t, idxConn.Namespace, "Expected idxConn to have empty namespace, but got '%s'", idxConn.Namespace) - require.NotNil(t, idxConn.dataClient, "Expected idxConn to have non-nil dataClient") + require.NotNil(t, idxConn.grpcClient, "Expected idxConn to have non-nil dataClient") require.NotNil(t, idxConn.grpcConn, "Expected idxConn to have non-nil grpcConn") } @@ -320,7 +320,7 @@ func TestNewIndexConnectionNamespace(t *testing.T) { require.True(t, ok, "Expected client to have an 'api-key' header") require.Equal(t, apiKey, apiKeyHeader, "Expected 'api-key' header to equal %s", apiKey) require.Equal(t, namespace, idxConn.Namespace, "Expected idxConn to have namespace '%s', but got '%s'", namespace, idxConn.Namespace) - require.NotNil(t, idxConn.dataClient, "Expected idxConn to have non-nil dataClient") + require.NotNil(t, idxConn.grpcClient, "Expected idxConn to have non-nil dataClient") require.NotNil(t, idxConn.grpcConn, "Expected idxConn to have non-nil grpcConn") }