Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rebase & Implement Rerank #80

Merged
merged 7 commits into from
Oct 18, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion codegen/apis
Submodule apis updated from 3002f1 to 9d3a41
136 changes: 135 additions & 1 deletion pinecone/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -1261,7 +1261,7 @@ type InferenceService struct {
// Parameters:
// - ctx: A context.Context object controls the request's lifetime, allowing for the request
// to be canceled or to timeout according to the context's deadline.
// - in: A pointer to an EmbedRequest object that contains the model t4o use for embedding generation, the
// - in: A pointer to an EmbedRequest object that contains the model to use for embedding generation, the
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch

// list of input strings to generate embeddings for, and any additional parameters to use for generation.
//
// Returns a pointer to an EmbeddingsList object or an error.
Expand Down Expand Up @@ -1343,6 +1343,130 @@ func (i *InferenceService) Embed(ctx context.Context, in *EmbedRequest) (*infere
return decodeEmbeddingsList(res.Body)
}

// Document is a map representing the document to be reranked.
type Document map[string]string

// RerankRequest holds the parameters for calling [InferenceService.Rerank] and reranking documents
// by a specified query and model.
//
// Fields:
// - Model: "The [model] to use for reranking.
// - Query: (Required) The query to rerank Documents against.
// - Documents: (Required) A list of Document objects to be reranked. The default is "text", but you can
// specify this behavior with [RerankRequest.RankFields].
// - RankFields: (Optional) The fields to rank the Documents by. If not provided, the default is "text".
// - ReturnDocuments: (Optional) Whether to include Documents in the response. Defaults to true.
// - TopN: (Optional) How many Documents to return. Defaults to the length of input Documents.
// - Parameters: (Optional) Additional model-specific parameters for the reranker
//
// [model]: https://docs.pinecone.io/guides/inference/understanding-inference#models
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice

type RerankRequest struct {
Model string
Query string
Documents []Document
RankFields *[]string
ReturnDocuments *bool
TopN *int
Parameters *map[string]string
}

// Represents a ranked document with a relevance score and an index position.
//
// Fields:
// - Document: The [Document].
// - Index: The index position of the Document from the original request. This can be used
// to locate the position of the document relative to others described in the request.
// - Score: The relevance score of the Document indicating how closely it matches the query.
type RankedDocument struct {
Copy link
Contributor

@aulorbe aulorbe Oct 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd specifiy in the docstring here (or somewhere) that the Docs w/2+ fields in them will only be ranked on 1 field -- either text if they do not pass in a custom field to rank or, or their customField, which has to be present in the Doc objs.

Here's how I did it in TS: https://github.com/pinecone-io/pinecone-ts-client/pull/303/files#diff-77967717b18045071d22a72f646f1c4f3cbddc394751cda28af4dfa8732cbd7fR114 (the backend throws the error, so we just have to catch it)

Document *Document `json:"document,omitempty"`
Index int `json:"index"`
Score float32 `json:"score"`
}

// RerankResponse is the result of a reranking operation.
//
// Fields:
// - Data: A list of Documents which have been reranked. The Documents are sorted in order of relevance,
// with the first being the most relevant.
// - Model: The model used to rerank Documents.
// - Usage: Usage statistics ([Rerank Units]) for the reranking operation.
//
// [Read Units]: https://docs.pinecone.io/guides/organizations/manage-cost/understanding-cost#rerank
type RerankResponse struct {
Data []RankedDocument `json:"data,omitempty"`
Model string `json:"model"`
Usage RerankUsage `json:"usage"`
}

// Rerank Documents with associated relevance scores that represent the relevance of each Document
// to the provided query using the specified model.
//
// Parameters:
// - ctx: A context.Context object controls the request's lifetime, allowing for the request
// to be canceled or to timeout according to the context's deadline.
// - in: A pointer to a [RerankRequest] object that contains the model, query, and documents to use for reranking.
//
// Example:
//
// ctx := context.Background()
//
// clientParams := pinecone.NewClientParams{
// ApiKey: "YOUR_API_KEY",
// SourceTag: "your_source_identifier", // optional
// }
//
// pc, err := pinecone.NewClient(clientParams)
// if err != nil {
// log.Fatalf("Failed to create Client: %v", err)
// }
//
// rerankModel := "bge-reranker-v2-m3"
// topN := 2
// retunDocuments := true
// documents := []pinecone.Document{
// {"id": "doc1", "text": "Apple is a popular fruit known for its sweetness and crisp texture."},
// {"id": "doc2", "text": "Many people enjoy eating apples as a healthy snack."},
// {"id": "doc3", "text": "Apple Inc. has revolutionized the tech industry with its sleek designs and user-friendly interfaces."},
// {"id": "doc4", "text": "An apple a day keeps the doctor away, as the saying goes."},
// }
//
// ranking, err := pc.Inference.Rerank(ctx, &pinecone.RerankRequest{
// Model: rerankModel,
// Query: "i love to eat apples",
// ReturnDocuments: &retunDocuments,
// TopN: &topN,
// RankFields: &[]string{"text"},
// Documents: documents,
// })
// if err != nil {
// log.Fatalf("Failed to rerank: %v", err)
// }
// fmt.Printf("Rerank result: %+v\n", ranking)
func (i *InferenceService) Rerank(ctx context.Context, in *RerankRequest) (*RerankResponse, error) {
convertedDocuments := make([]inference.Document, len(in.Documents))
for i, doc := range in.Documents {
convertedDocuments[i] = inference.Document(doc)
}
req := inference.RerankJSONRequestBody{
Model: in.Model,
Query: in.Query,
Documents: convertedDocuments,
RankFields: in.RankFields,
ReturnDocuments: in.ReturnDocuments,
TopN: in.TopN,
Parameters: in.Parameters,
}
res, err := i.client.Rerank(ctx, req)
if err != nil {
return nil, err
}
defer res.Body.Close()
if res.StatusCode != http.StatusOK {
return nil, handleErrorResponseBody(res, "failed to rerank: ")
}
return decodeRerankResponse(res.Body)
}

func (c *Client) extractAuthHeader() map[string]string {
possibleAuthKeys := []string{
"api-key",
Expand Down Expand Up @@ -1423,6 +1547,16 @@ func decodeEmbeddingsList(resBody io.ReadCloser) (*inference.EmbeddingsList, err
return &embeddingsList, nil
}

func decodeRerankResponse(resBody io.ReadCloser) (*RerankResponse, error) {
var rerankResponse RerankResponse
err := json.NewDecoder(resBody).Decode(&rerankResponse)
if err != nil {
return nil, fmt.Errorf("failed to decode rerank response: %w", err)
}

return &rerankResponse, nil
}

func toCollection(cm *db_control.CollectionModel) *Collection {
if cm == nil {
return nil
Expand Down
165 changes: 164 additions & 1 deletion pinecone/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@ import (
"time"

"github.com/google/go-cmp/cmp"
"github.com/google/uuid"
"github.com/pinecone-io/go-pinecone/internal/gen"
"github.com/pinecone-io/go-pinecone/internal/gen/db_control"
"github.com/pinecone-io/go-pinecone/internal/provider"

"github.com/google/uuid"
"github.com/pinecone-io/go-pinecone/internal/utils"

"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -277,6 +277,11 @@ func (ts *IntegrationTests) TestConfigureIndexHitPodLimit() {
}

func (ts *IntegrationTests) TestGenerateEmbeddings() {
// Run Embed tests once rather than duplicating across serverless & pods
if ts.indexType == "pod" {
ts.T().Skip("Skipping Embed tests for pods")
}

ctx := context.Background()
embeddingModel := "multilingual-e5-large"
embeddings, err := ts.client.Inference.Embed(ctx, &EmbedRequest{
Expand Down Expand Up @@ -313,6 +318,164 @@ func (ts *IntegrationTests) TestGenerateEmbeddingsInvalidInputs() {
require.Contains(ts.T(), err.Error(), "TextInputs must contain at least one value")
}

func (ts *IntegrationTests) TestRerankDocumentDefaultField() {
// Run Rerank tests once rather than duplicating across serverless & pods
if ts.indexType == "pod" {
ts.T().Skip("Skipping Rerank tests for pods")
}

ctx := context.Background()
rerankModel := "bge-reranker-v2-m3"
topN := 2
retunDocuments := true
ranking, err := ts.client.Inference.Rerank(ctx, &RerankRequest{
Model: rerankModel,
Query: "i love apples",
ReturnDocuments: &retunDocuments,
TopN: &topN,
Documents: []Document{
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: this is the situation in which the doc will only be ranked on text, in case the user isn't clear on that re: my comment here

{"id": "vec1", "text": "Apple is a popular fruit known for its sweetness and crisp texture."},
{"id": "vec2", "text": "Many people enjoy eating apples as a healthy snack."},
{"id": "vec3", "text": "Apple Inc. has revolutionized the tech industry with its sleek designs and user-friendly interfaces."},
{"id": "vec4", "text": "An apple a day keeps the doctor away, as the saying goes."},
}})

require.NoError(ts.T(), err)
require.NotNil(ts.T(), ranking, "Expected reranking result to be non-nil")
require.Equal(ts.T(), topN, len(ranking.Data), "Expected %v rankings", topN)

doc := *ranking.Data[0].Document
_, exists := doc["text"]
require.True(ts.T(), exists, "Expected '%s' to exist in Document map", "text")
_, exists = doc["id"]
require.True(ts.T(), exists, "Expected '%s' to exist in Document map", "id")
}

func (ts *IntegrationTests) TestRerankDocumentCustomField() {
// Run Rerank tests once rather than duplicating across serverless & pods
if ts.indexType == "pod" {
ts.T().Skip("Skipping Rerank tests for pods")
}

ctx := context.Background()
rerankModel := "bge-reranker-v2-m3"
topN := 2
retunDocuments := true
ranking, err := ts.client.Inference.Rerank(ctx, &RerankRequest{
Model: rerankModel,
Query: "i love apples",
ReturnDocuments: &retunDocuments,
TopN: &topN,
RankFields: &[]string{"customField"},
Documents: []Document{
{"id": "vec1", "customField": "Apple is a popular fruit known for its sweetness and crisp texture."},
{"id": "vec2", "customField": "Many people enjoy eating apples as a healthy snack."},
{"id": "vec3", "customField": "Apple Inc. has revolutionized the tech industry with its sleek designs and user-friendly interfaces."},
{"id": "vec4", "customField": "An apple a day keeps the doctor away, as the saying goes."},
}})

require.NoError(ts.T(), err)
require.NotNil(ts.T(), ranking, "Expected reranking result to be non-nil")
require.Equal(ts.T(), topN, len(ranking.Data), "Expected %v rankings", topN)

doc := *ranking.Data[0].Document
_, exists := doc["customField"]
require.True(ts.T(), exists, "Expected '%s' to exist in Document map", "customField")
_, exists = doc["id"]
require.True(ts.T(), exists, "Expected '%s' to exist in Document map", "id")
}

func (ts *IntegrationTests) TestRerankDocumentAllDefaults() {
// Run Rerank tests once rather than duplicating across serverless & pods
if ts.indexType == "pod" {
ts.T().Skip("Skipping Rerank tests for pods")
}

ctx := context.Background()
rerankModel := "bge-reranker-v2-m3"
ranking, err := ts.client.Inference.Rerank(ctx, &RerankRequest{
Model: rerankModel,
Query: "i love apples",
Documents: []Document{
{"id": "vec1", "text": "Apple is a popular fruit known for its sweetness and crisp texture."},
{"id": "vec2", "text": "Many people enjoy eating apples as a healthy snack."},
{"id": "vec3", "text": "Apple Inc. has revolutionized the tech industry with its sleek designs and user-friendly interfaces."},
{"id": "vec4", "text": "An apple a day keeps the doctor away, as the saying goes."},
}})

require.NoError(ts.T(), err)
require.NotNil(ts.T(), ranking, "Expected reranking result to be non-nil")
require.Equal(ts.T(), 4, len(ranking.Data), "Expected %v rankings", 4)

doc := *ranking.Data[0].Document
_, exists := doc["text"]
require.True(ts.T(), exists, "Expected '%s' to exist in Document map", "text")
_, exists = doc["id"]
require.True(ts.T(), exists, "Expected '%s' to exist in Document map", "id")
}

func (ts *IntegrationTests) TestRerankDocumentsMultipleRankFields() {
// Run Rerank tests once rather than duplicating across serverless & pods
if ts.indexType == "pod" {
ts.T().Skip("Skipping Rerank tests for pods")
}

ctx := context.Background()
rerankModel := "bge-reranker-v2-m3"
_, err := ts.client.Inference.Rerank(ctx, &RerankRequest{
Model: rerankModel,
Query: "i love apples",
RankFields: &[]string{"text", "custom-field"},
Documents: []Document{
{
"id": "vec1",
"text": "Apple is a popular fruit known for its sweetness and crisp texture.",
"custom-field": "another field",
},
{
"id": "vec2",
"text": "Many people enjoy eating apples as a healthy snack.",
"custom-field": "another field",
},
{
"id": "vec3",
"text": "Apple Inc. has revolutionized the tech industry with its sleek designs and user-friendly interfaces.",
"custom-field": "another field",
},
{
"id": "vec4",
"text": "An apple a day keeps the doctor away, as the saying goes.",
"custom-field": "another field",
},
}})

require.Error(ts.T(), err)
require.Contains(ts.T(), err.Error(), "Only one rank field is supported for model")
}

func (ts *IntegrationTests) TestRerankDocumentFieldError() {
// Run Rerank tests once rather than duplicating across serverless & pods
if ts.indexType == "pod" {
ts.T().Skip("Skipping Rerank tests for pods")
}

ctx := context.Background()
rerankModel := "bge-reranker-v2-m3"
_, err := ts.client.Inference.Rerank(ctx, &RerankRequest{
Model: rerankModel,
Query: "i love apples",
RankFields: &[]string{"custom-field"},
Documents: []Document{
{"id": "vec1", "text": "Apple is a popular fruit known for its sweetness and crisp texture."},
{"id": "vec2", "text": "Many people enjoy eating apples as a healthy snack."},
{"id": "vec3", "text": "Apple Inc. has revolutionized the tech industry with its sleek designs and user-friendly interfaces."},
{"id": "vec4", "text": "An apple a day keeps the doctor away, as the saying goes."},
}})

require.Error(ts.T(), err)
require.Contains(ts.T(), err.Error(), "field 'custom-field' not found in document")
}

// Unit tests:
func TestExtractAuthHeaderUnit(t *testing.T) {
globalApiKey := os.Getenv("PINECONE_API_KEY")
Expand Down
7 changes: 7 additions & 0 deletions pinecone/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,13 @@ type Usage struct {
ReadUnits uint32 `json:"read_units"`
}

// RerankUsage is the usage stats ([Rerank Units]) for a reranking request.
//
// [Read Units]: https://docs.pinecone.io/guides/organizations/manage-cost/understanding-cost#rerank
type RerankUsage struct {
RerankUnits *int `json:"rerank_units,omitempty"`
}

// MetadataFilter represents the [metadata filters] attached to a Pinecone request.
// These optional metadata filters are applied to query and deletion requests.
//
Expand Down