From 1153dff5ab930860394f1dcce023acf9ddb0286c Mon Sep 17 00:00:00 2001 From: Austin DeNoble Date: Fri, 18 Oct 2024 18:56:06 -0400 Subject: [PATCH] Rebase & Implement Rerank (#80) ## Problem The `InferenceService` and `Embed` operation were implemented in a previous PR: https://github.com/pinecone-io/go-pinecone/pull/67 As a part of the `2024-10` release we are also adding `Rerank` to the client, which will live under `InferenceService` with `Embed`. ## Solution This work builds on top of external contributions from @Stosan (https://github.com/pinecone-io/go-pinecone/pull/73) Thank you for you diligence in getting this started! - Update `InferenceService` to expose the new `Rerank` method. - Add various request and response types to represent the interface. - Add doc comments and an integration test for calling the reranker. ## Type of Change - [ ] Bug fix (non-breaking change which fixes an issue) - [X] New feature (non-breaking change which adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected) - [ ] This change requires a documentation update - [ ] Infrastructure change (CI configs, etc) - [ ] Non-code change (docs, etc) - [ ] None of the above: (explain here) ## Test Plan `just test` If you want to use the code yourself, here's an example: ```go ctx := context.Background() clientParams := pinecone.NewClientParams{ ApiKey: "YOUR_API_KEY", } pc, err := pinecone.NewClient(clientParams) if err != nil { log.Fatalf("Failed to create Client: %v", err) } rerankModel := "bge-reranker-v2-m3" topN := 2 retunDocuments := true documents := []pinecone.Document{ { "id": "vec1", "text": "Apple is a popular fruit known for its sweetness and crisp texture."}, { "id": "vec2", "text": "Many people enjoy eating apples as a healthy snack."}, { "id": "vec3", "text": "Apple Inc. has revolutionized the tech industry with its sleek designs and user-friendly interfaces."}, { "id": "vec4", "text": "An apple a day keeps the doctor away, as the saying goes."}, } ranking, err := pc.Inference.Rerank(ctx, &pinecone.RerankRequest{ Model: rerankModel, Query: "i love cats", ReturnDocuments: &retunDocuments, TopN: &topN, RankFields: &[]string{"text"}, Documents: documents, }) if err != nil { log.Fatalf("Failed to rerank: %v", err) } fmt.Printf("Rerank result: %+v\n", ranking) ``` --------- Co-authored-by: Sam Ayo --- codegen/apis | 2 +- pinecone/client.go | 136 ++++++++++++++++++++++++++++++++- pinecone/client_test.go | 165 +++++++++++++++++++++++++++++++++++++++- pinecone/models.go | 7 ++ 4 files changed, 307 insertions(+), 3 deletions(-) diff --git a/codegen/apis b/codegen/apis index 3002f1e..9d3a41f 160000 --- a/codegen/apis +++ b/codegen/apis @@ -1 +1 @@ -Subproject commit 3002f1e62b895d2e64fba53e346ad84eb4719934 +Subproject commit 9d3a41f6ae657d9b0b818065dcdc3adf76bdb7fe diff --git a/pinecone/client.go b/pinecone/client.go index 01e1c06..89f8865 100644 --- a/pinecone/client.go +++ b/pinecone/client.go @@ -1261,7 +1261,7 @@ type InferenceService struct { // Parameters: // - ctx: A context.Context object controls the request's lifetime, allowing for the request // to be canceled or to timeout according to the context's deadline. -// - in: A pointer to an EmbedRequest object that contains the model t4o use for embedding generation, the +// - in: A pointer to an EmbedRequest object that contains the model to use for embedding generation, the // list of input strings to generate embeddings for, and any additional parameters to use for generation. // // Returns a pointer to an EmbeddingsList object or an error. @@ -1343,6 +1343,130 @@ func (i *InferenceService) Embed(ctx context.Context, in *EmbedRequest) (*infere return decodeEmbeddingsList(res.Body) } +// Document is a map representing the document to be reranked. +type Document map[string]string + +// RerankRequest holds the parameters for calling [InferenceService.Rerank] and reranking documents +// by a specified query and model. +// +// Fields: +// - Model: "The [model] to use for reranking. +// - Query: (Required) The query to rerank Documents against. +// - Documents: (Required) A list of Document objects to be reranked. The default is "text", but you can +// specify this behavior with [RerankRequest.RankFields]. +// - RankFields: (Optional) The fields to rank the Documents by. If not provided, the default is "text". +// - ReturnDocuments: (Optional) Whether to include Documents in the response. Defaults to true. +// - TopN: (Optional) How many Documents to return. Defaults to the length of input Documents. +// - Parameters: (Optional) Additional model-specific parameters for the reranker +// +// [model]: https://docs.pinecone.io/guides/inference/understanding-inference#models +type RerankRequest struct { + Model string + Query string + Documents []Document + RankFields *[]string + ReturnDocuments *bool + TopN *int + Parameters *map[string]string +} + +// Represents a ranked document with a relevance score and an index position. +// +// Fields: +// - Document: The [Document]. +// - Index: The index position of the Document from the original request. This can be used +// to locate the position of the document relative to others described in the request. +// - Score: The relevance score of the Document indicating how closely it matches the query. +type RankedDocument struct { + Document *Document `json:"document,omitempty"` + Index int `json:"index"` + Score float32 `json:"score"` +} + +// RerankResponse is the result of a reranking operation. +// +// Fields: +// - Data: A list of Documents which have been reranked. The Documents are sorted in order of relevance, +// with the first being the most relevant. +// - Model: The model used to rerank Documents. +// - Usage: Usage statistics ([Rerank Units]) for the reranking operation. +// +// [Read Units]: https://docs.pinecone.io/guides/organizations/manage-cost/understanding-cost#rerank +type RerankResponse struct { + Data []RankedDocument `json:"data,omitempty"` + Model string `json:"model"` + Usage RerankUsage `json:"usage"` +} + +// Rerank Documents with associated relevance scores that represent the relevance of each Document +// to the provided query using the specified model. +// +// Parameters: +// - ctx: A context.Context object controls the request's lifetime, allowing for the request +// to be canceled or to timeout according to the context's deadline. +// - in: A pointer to a [RerankRequest] object that contains the model, query, and documents to use for reranking. +// +// Example: +// +// ctx := context.Background() +// +// clientParams := pinecone.NewClientParams{ +// ApiKey: "YOUR_API_KEY", +// SourceTag: "your_source_identifier", // optional +// } +// +// pc, err := pinecone.NewClient(clientParams) +// if err != nil { +// log.Fatalf("Failed to create Client: %v", err) +// } +// +// rerankModel := "bge-reranker-v2-m3" +// topN := 2 +// retunDocuments := true +// documents := []pinecone.Document{ +// {"id": "doc1", "text": "Apple is a popular fruit known for its sweetness and crisp texture."}, +// {"id": "doc2", "text": "Many people enjoy eating apples as a healthy snack."}, +// {"id": "doc3", "text": "Apple Inc. has revolutionized the tech industry with its sleek designs and user-friendly interfaces."}, +// {"id": "doc4", "text": "An apple a day keeps the doctor away, as the saying goes."}, +// } +// +// ranking, err := pc.Inference.Rerank(ctx, &pinecone.RerankRequest{ +// Model: rerankModel, +// Query: "i love to eat apples", +// ReturnDocuments: &retunDocuments, +// TopN: &topN, +// RankFields: &[]string{"text"}, +// Documents: documents, +// }) +// if err != nil { +// log.Fatalf("Failed to rerank: %v", err) +// } +// fmt.Printf("Rerank result: %+v\n", ranking) +func (i *InferenceService) Rerank(ctx context.Context, in *RerankRequest) (*RerankResponse, error) { + convertedDocuments := make([]inference.Document, len(in.Documents)) + for i, doc := range in.Documents { + convertedDocuments[i] = inference.Document(doc) + } + req := inference.RerankJSONRequestBody{ + Model: in.Model, + Query: in.Query, + Documents: convertedDocuments, + RankFields: in.RankFields, + ReturnDocuments: in.ReturnDocuments, + TopN: in.TopN, + Parameters: in.Parameters, + } + res, err := i.client.Rerank(ctx, req) + if err != nil { + return nil, err + } + defer res.Body.Close() + if res.StatusCode != http.StatusOK { + return nil, handleErrorResponseBody(res, "failed to rerank: ") + } + return decodeRerankResponse(res.Body) +} + func (c *Client) extractAuthHeader() map[string]string { possibleAuthKeys := []string{ "api-key", @@ -1423,6 +1547,16 @@ func decodeEmbeddingsList(resBody io.ReadCloser) (*inference.EmbeddingsList, err return &embeddingsList, nil } +func decodeRerankResponse(resBody io.ReadCloser) (*RerankResponse, error) { + var rerankResponse RerankResponse + err := json.NewDecoder(resBody).Decode(&rerankResponse) + if err != nil { + return nil, fmt.Errorf("failed to decode rerank response: %w", err) + } + + return &rerankResponse, nil +} + func toCollection(cm *db_control.CollectionModel) *Collection { if cm == nil { return nil diff --git a/pinecone/client_test.go b/pinecone/client_test.go index 7771849..eb54fa2 100644 --- a/pinecone/client_test.go +++ b/pinecone/client_test.go @@ -13,11 +13,11 @@ import ( "time" "github.com/google/go-cmp/cmp" + "github.com/google/uuid" "github.com/pinecone-io/go-pinecone/internal/gen" "github.com/pinecone-io/go-pinecone/internal/gen/db_control" "github.com/pinecone-io/go-pinecone/internal/provider" - "github.com/google/uuid" "github.com/pinecone-io/go-pinecone/internal/utils" "github.com/stretchr/testify/assert" @@ -277,6 +277,11 @@ func (ts *IntegrationTests) TestConfigureIndexHitPodLimit() { } func (ts *IntegrationTests) TestGenerateEmbeddings() { + // Run Embed tests once rather than duplicating across serverless & pods + if ts.indexType == "pod" { + ts.T().Skip("Skipping Embed tests for pods") + } + ctx := context.Background() embeddingModel := "multilingual-e5-large" embeddings, err := ts.client.Inference.Embed(ctx, &EmbedRequest{ @@ -313,6 +318,164 @@ func (ts *IntegrationTests) TestGenerateEmbeddingsInvalidInputs() { require.Contains(ts.T(), err.Error(), "TextInputs must contain at least one value") } +func (ts *IntegrationTests) TestRerankDocumentDefaultField() { + // Run Rerank tests once rather than duplicating across serverless & pods + if ts.indexType == "pod" { + ts.T().Skip("Skipping Rerank tests for pods") + } + + ctx := context.Background() + rerankModel := "bge-reranker-v2-m3" + topN := 2 + retunDocuments := true + ranking, err := ts.client.Inference.Rerank(ctx, &RerankRequest{ + Model: rerankModel, + Query: "i love apples", + ReturnDocuments: &retunDocuments, + TopN: &topN, + Documents: []Document{ + {"id": "vec1", "text": "Apple is a popular fruit known for its sweetness and crisp texture."}, + {"id": "vec2", "text": "Many people enjoy eating apples as a healthy snack."}, + {"id": "vec3", "text": "Apple Inc. has revolutionized the tech industry with its sleek designs and user-friendly interfaces."}, + {"id": "vec4", "text": "An apple a day keeps the doctor away, as the saying goes."}, + }}) + + require.NoError(ts.T(), err) + require.NotNil(ts.T(), ranking, "Expected reranking result to be non-nil") + require.Equal(ts.T(), topN, len(ranking.Data), "Expected %v rankings", topN) + + doc := *ranking.Data[0].Document + _, exists := doc["text"] + require.True(ts.T(), exists, "Expected '%s' to exist in Document map", "text") + _, exists = doc["id"] + require.True(ts.T(), exists, "Expected '%s' to exist in Document map", "id") +} + +func (ts *IntegrationTests) TestRerankDocumentCustomField() { + // Run Rerank tests once rather than duplicating across serverless & pods + if ts.indexType == "pod" { + ts.T().Skip("Skipping Rerank tests for pods") + } + + ctx := context.Background() + rerankModel := "bge-reranker-v2-m3" + topN := 2 + retunDocuments := true + ranking, err := ts.client.Inference.Rerank(ctx, &RerankRequest{ + Model: rerankModel, + Query: "i love apples", + ReturnDocuments: &retunDocuments, + TopN: &topN, + RankFields: &[]string{"customField"}, + Documents: []Document{ + {"id": "vec1", "customField": "Apple is a popular fruit known for its sweetness and crisp texture."}, + {"id": "vec2", "customField": "Many people enjoy eating apples as a healthy snack."}, + {"id": "vec3", "customField": "Apple Inc. has revolutionized the tech industry with its sleek designs and user-friendly interfaces."}, + {"id": "vec4", "customField": "An apple a day keeps the doctor away, as the saying goes."}, + }}) + + require.NoError(ts.T(), err) + require.NotNil(ts.T(), ranking, "Expected reranking result to be non-nil") + require.Equal(ts.T(), topN, len(ranking.Data), "Expected %v rankings", topN) + + doc := *ranking.Data[0].Document + _, exists := doc["customField"] + require.True(ts.T(), exists, "Expected '%s' to exist in Document map", "customField") + _, exists = doc["id"] + require.True(ts.T(), exists, "Expected '%s' to exist in Document map", "id") +} + +func (ts *IntegrationTests) TestRerankDocumentAllDefaults() { + // Run Rerank tests once rather than duplicating across serverless & pods + if ts.indexType == "pod" { + ts.T().Skip("Skipping Rerank tests for pods") + } + + ctx := context.Background() + rerankModel := "bge-reranker-v2-m3" + ranking, err := ts.client.Inference.Rerank(ctx, &RerankRequest{ + Model: rerankModel, + Query: "i love apples", + Documents: []Document{ + {"id": "vec1", "text": "Apple is a popular fruit known for its sweetness and crisp texture."}, + {"id": "vec2", "text": "Many people enjoy eating apples as a healthy snack."}, + {"id": "vec3", "text": "Apple Inc. has revolutionized the tech industry with its sleek designs and user-friendly interfaces."}, + {"id": "vec4", "text": "An apple a day keeps the doctor away, as the saying goes."}, + }}) + + require.NoError(ts.T(), err) + require.NotNil(ts.T(), ranking, "Expected reranking result to be non-nil") + require.Equal(ts.T(), 4, len(ranking.Data), "Expected %v rankings", 4) + + doc := *ranking.Data[0].Document + _, exists := doc["text"] + require.True(ts.T(), exists, "Expected '%s' to exist in Document map", "text") + _, exists = doc["id"] + require.True(ts.T(), exists, "Expected '%s' to exist in Document map", "id") +} + +func (ts *IntegrationTests) TestRerankDocumentsMultipleRankFields() { + // Run Rerank tests once rather than duplicating across serverless & pods + if ts.indexType == "pod" { + ts.T().Skip("Skipping Rerank tests for pods") + } + + ctx := context.Background() + rerankModel := "bge-reranker-v2-m3" + _, err := ts.client.Inference.Rerank(ctx, &RerankRequest{ + Model: rerankModel, + Query: "i love apples", + RankFields: &[]string{"text", "custom-field"}, + Documents: []Document{ + { + "id": "vec1", + "text": "Apple is a popular fruit known for its sweetness and crisp texture.", + "custom-field": "another field", + }, + { + "id": "vec2", + "text": "Many people enjoy eating apples as a healthy snack.", + "custom-field": "another field", + }, + { + "id": "vec3", + "text": "Apple Inc. has revolutionized the tech industry with its sleek designs and user-friendly interfaces.", + "custom-field": "another field", + }, + { + "id": "vec4", + "text": "An apple a day keeps the doctor away, as the saying goes.", + "custom-field": "another field", + }, + }}) + + require.Error(ts.T(), err) + require.Contains(ts.T(), err.Error(), "Only one rank field is supported for model") +} + +func (ts *IntegrationTests) TestRerankDocumentFieldError() { + // Run Rerank tests once rather than duplicating across serverless & pods + if ts.indexType == "pod" { + ts.T().Skip("Skipping Rerank tests for pods") + } + + ctx := context.Background() + rerankModel := "bge-reranker-v2-m3" + _, err := ts.client.Inference.Rerank(ctx, &RerankRequest{ + Model: rerankModel, + Query: "i love apples", + RankFields: &[]string{"custom-field"}, + Documents: []Document{ + {"id": "vec1", "text": "Apple is a popular fruit known for its sweetness and crisp texture."}, + {"id": "vec2", "text": "Many people enjoy eating apples as a healthy snack."}, + {"id": "vec3", "text": "Apple Inc. has revolutionized the tech industry with its sleek designs and user-friendly interfaces."}, + {"id": "vec4", "text": "An apple a day keeps the doctor away, as the saying goes."}, + }}) + + require.Error(ts.T(), err) + require.Contains(ts.T(), err.Error(), "field 'custom-field' not found in document") +} + // Unit tests: func TestExtractAuthHeaderUnit(t *testing.T) { globalApiKey := os.Getenv("PINECONE_API_KEY") diff --git a/pinecone/models.go b/pinecone/models.go index a29cf9d..43d71c6 100644 --- a/pinecone/models.go +++ b/pinecone/models.go @@ -158,6 +158,13 @@ type Usage struct { ReadUnits uint32 `json:"read_units"` } +// RerankUsage is the usage stats ([Rerank Units]) for a reranking request. +// +// [Rerank Units]: https://docs.pinecone.io/guides/organizations/manage-cost/understanding-cost#rerank +type RerankUsage struct { + RerankUnits *int `json:"rerank_units,omitempty"` +} + // MetadataFilter represents the [metadata filters] attached to a Pinecone request. // These optional metadata filters are applied to query and deletion requests. //