-
Notifications
You must be signed in to change notification settings - Fork 9
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Rebase & Implement Rerank #80
Changes from 6 commits
d29451b
343b9d4
d76c4aa
7df61e9
be67550
8253494
e1c3343
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1261,7 +1261,7 @@ type InferenceService struct { | |
// Parameters: | ||
// - ctx: A context.Context object controls the request's lifetime, allowing for the request | ||
// to be canceled or to timeout according to the context's deadline. | ||
// - in: A pointer to an EmbedRequest object that contains the model t4o use for embedding generation, the | ||
// - in: A pointer to an EmbedRequest object that contains the model to use for embedding generation, the | ||
// list of input strings to generate embeddings for, and any additional parameters to use for generation. | ||
// | ||
// Returns a pointer to an EmbeddingsList object or an error. | ||
|
@@ -1343,6 +1343,130 @@ func (i *InferenceService) Embed(ctx context.Context, in *EmbedRequest) (*infere | |
return decodeEmbeddingsList(res.Body) | ||
} | ||
|
||
// Document is a map representing the document to be reranked. | ||
type Document map[string]string | ||
|
||
// RerankRequest holds the parameters for calling [InferenceService.Rerank] and reranking documents | ||
// by a specified query and model. | ||
// | ||
// Fields: | ||
// - Model: "The [model] to use for reranking. | ||
// - Query: (Required) The query to rerank Documents against. | ||
// - Documents: (Required) A list of Document objects to be reranked. The default is "text", but you can | ||
// specify this behavior with [RerankRequest.RankFields]. | ||
// - RankFields: (Optional) The fields to rank the Documents by. If not provided, the default is "text". | ||
// - ReturnDocuments: (Optional) Whether to include Documents in the response. Defaults to true. | ||
// - TopN: (Optional) How many Documents to return. Defaults to the length of input Documents. | ||
// - Parameters: (Optional) Additional model-specific parameters for the reranker | ||
// | ||
// [model]: https://docs.pinecone.io/guides/inference/understanding-inference#models | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nice |
||
type RerankRequest struct { | ||
Model string | ||
Query string | ||
Documents []Document | ||
RankFields *[]string | ||
ReturnDocuments *bool | ||
TopN *int | ||
Parameters *map[string]string | ||
} | ||
|
||
// Represents a ranked document with a relevance score and an index position. | ||
// | ||
// Fields: | ||
// - Document: The [Document]. | ||
// - Index: The index position of the Document from the original request. This can be used | ||
// to locate the position of the document relative to others described in the request. | ||
// - Score: The relevance score of the Document indicating how closely it matches the query. | ||
type RankedDocument struct { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd specifiy in the docstring here (or somewhere) that the Docs w/2+ fields in them will only be ranked on 1 field -- either Here's how I did it in TS: https://github.com/pinecone-io/pinecone-ts-client/pull/303/files#diff-77967717b18045071d22a72f646f1c4f3cbddc394751cda28af4dfa8732cbd7fR114 (the backend throws the error, so we just have to catch it) |
||
Document *Document `json:"document,omitempty"` | ||
Index int `json:"index"` | ||
Score float32 `json:"score"` | ||
} | ||
|
||
// RerankResponse is the result of a reranking operation. | ||
// | ||
// Fields: | ||
// - Data: A list of Documents which have been reranked. The Documents are sorted in order of relevance, | ||
// with the first being the most relevant. | ||
// - Model: The model used to rerank Documents. | ||
// - Usage: Usage statistics ([Rerank Units]) for the reranking operation. | ||
// | ||
// [Read Units]: https://docs.pinecone.io/guides/organizations/manage-cost/understanding-cost#rerank | ||
type RerankResponse struct { | ||
Data []RankedDocument `json:"data,omitempty"` | ||
Model string `json:"model"` | ||
Usage RerankUsage `json:"usage"` | ||
} | ||
|
||
// Rerank Documents with associated relevance scores that represent the relevance of each Document | ||
// to the provided query using the specified model. | ||
// | ||
// Parameters: | ||
// - ctx: A context.Context object controls the request's lifetime, allowing for the request | ||
// to be canceled or to timeout according to the context's deadline. | ||
// - in: A pointer to a [RerankRequest] object that contains the model, query, and documents to use for reranking. | ||
// | ||
// Example: | ||
// | ||
// ctx := context.Background() | ||
// | ||
// clientParams := pinecone.NewClientParams{ | ||
// ApiKey: "YOUR_API_KEY", | ||
// SourceTag: "your_source_identifier", // optional | ||
// } | ||
// | ||
// pc, err := pinecone.NewClient(clientParams) | ||
// if err != nil { | ||
// log.Fatalf("Failed to create Client: %v", err) | ||
// } | ||
// | ||
// rerankModel := "bge-reranker-v2-m3" | ||
// topN := 2 | ||
// retunDocuments := true | ||
// documents := []pinecone.Document{ | ||
// {"id": "doc1", "text": "Apple is a popular fruit known for its sweetness and crisp texture."}, | ||
// {"id": "doc2", "text": "Many people enjoy eating apples as a healthy snack."}, | ||
// {"id": "doc3", "text": "Apple Inc. has revolutionized the tech industry with its sleek designs and user-friendly interfaces."}, | ||
// {"id": "doc4", "text": "An apple a day keeps the doctor away, as the saying goes."}, | ||
// } | ||
// | ||
// ranking, err := pc.Inference.Rerank(ctx, &pinecone.RerankRequest{ | ||
// Model: rerankModel, | ||
// Query: "i love to eat apples", | ||
// ReturnDocuments: &retunDocuments, | ||
// TopN: &topN, | ||
// RankFields: &[]string{"text"}, | ||
// Documents: documents, | ||
// }) | ||
// if err != nil { | ||
// log.Fatalf("Failed to rerank: %v", err) | ||
// } | ||
// fmt.Printf("Rerank result: %+v\n", ranking) | ||
func (i *InferenceService) Rerank(ctx context.Context, in *RerankRequest) (*RerankResponse, error) { | ||
convertedDocuments := make([]inference.Document, len(in.Documents)) | ||
for i, doc := range in.Documents { | ||
convertedDocuments[i] = inference.Document(doc) | ||
} | ||
req := inference.RerankJSONRequestBody{ | ||
Model: in.Model, | ||
Query: in.Query, | ||
Documents: convertedDocuments, | ||
RankFields: in.RankFields, | ||
ReturnDocuments: in.ReturnDocuments, | ||
TopN: in.TopN, | ||
Parameters: in.Parameters, | ||
} | ||
res, err := i.client.Rerank(ctx, req) | ||
if err != nil { | ||
return nil, err | ||
} | ||
defer res.Body.Close() | ||
if res.StatusCode != http.StatusOK { | ||
return nil, handleErrorResponseBody(res, "failed to rerank: ") | ||
} | ||
return decodeRerankResponse(res.Body) | ||
} | ||
|
||
func (c *Client) extractAuthHeader() map[string]string { | ||
possibleAuthKeys := []string{ | ||
"api-key", | ||
|
@@ -1423,6 +1547,16 @@ func decodeEmbeddingsList(resBody io.ReadCloser) (*inference.EmbeddingsList, err | |
return &embeddingsList, nil | ||
} | ||
|
||
func decodeRerankResponse(resBody io.ReadCloser) (*RerankResponse, error) { | ||
var rerankResponse RerankResponse | ||
err := json.NewDecoder(resBody).Decode(&rerankResponse) | ||
if err != nil { | ||
return nil, fmt.Errorf("failed to decode rerank response: %w", err) | ||
} | ||
|
||
return &rerankResponse, nil | ||
} | ||
|
||
func toCollection(cm *db_control.CollectionModel) *Collection { | ||
if cm == nil { | ||
return nil | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,11 +13,11 @@ import ( | |
"time" | ||
|
||
"github.com/google/go-cmp/cmp" | ||
"github.com/google/uuid" | ||
"github.com/pinecone-io/go-pinecone/internal/gen" | ||
"github.com/pinecone-io/go-pinecone/internal/gen/db_control" | ||
"github.com/pinecone-io/go-pinecone/internal/provider" | ||
|
||
"github.com/google/uuid" | ||
"github.com/pinecone-io/go-pinecone/internal/utils" | ||
|
||
"github.com/stretchr/testify/assert" | ||
|
@@ -277,6 +277,11 @@ func (ts *IntegrationTests) TestConfigureIndexHitPodLimit() { | |
} | ||
|
||
func (ts *IntegrationTests) TestGenerateEmbeddings() { | ||
// Run Embed tests once rather than duplicating across serverless & pods | ||
if ts.indexType == "pod" { | ||
ts.T().Skip("Skipping Embed tests for pods") | ||
} | ||
|
||
ctx := context.Background() | ||
embeddingModel := "multilingual-e5-large" | ||
embeddings, err := ts.client.Inference.Embed(ctx, &EmbedRequest{ | ||
|
@@ -313,6 +318,164 @@ func (ts *IntegrationTests) TestGenerateEmbeddingsInvalidInputs() { | |
require.Contains(ts.T(), err.Error(), "TextInputs must contain at least one value") | ||
} | ||
|
||
func (ts *IntegrationTests) TestRerankDocumentDefaultField() { | ||
// Run Rerank tests once rather than duplicating across serverless & pods | ||
if ts.indexType == "pod" { | ||
ts.T().Skip("Skipping Rerank tests for pods") | ||
} | ||
|
||
ctx := context.Background() | ||
rerankModel := "bge-reranker-v2-m3" | ||
topN := 2 | ||
retunDocuments := true | ||
ranking, err := ts.client.Inference.Rerank(ctx, &RerankRequest{ | ||
Model: rerankModel, | ||
Query: "i love apples", | ||
ReturnDocuments: &retunDocuments, | ||
TopN: &topN, | ||
Documents: []Document{ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note: this is the situation in which the doc will only be ranked on |
||
{"id": "vec1", "text": "Apple is a popular fruit known for its sweetness and crisp texture."}, | ||
{"id": "vec2", "text": "Many people enjoy eating apples as a healthy snack."}, | ||
{"id": "vec3", "text": "Apple Inc. has revolutionized the tech industry with its sleek designs and user-friendly interfaces."}, | ||
{"id": "vec4", "text": "An apple a day keeps the doctor away, as the saying goes."}, | ||
}}) | ||
|
||
require.NoError(ts.T(), err) | ||
require.NotNil(ts.T(), ranking, "Expected reranking result to be non-nil") | ||
require.Equal(ts.T(), topN, len(ranking.Data), "Expected %v rankings", topN) | ||
|
||
doc := *ranking.Data[0].Document | ||
_, exists := doc["text"] | ||
require.True(ts.T(), exists, "Expected '%s' to exist in Document map", "text") | ||
_, exists = doc["id"] | ||
require.True(ts.T(), exists, "Expected '%s' to exist in Document map", "id") | ||
} | ||
|
||
func (ts *IntegrationTests) TestRerankDocumentCustomField() { | ||
// Run Rerank tests once rather than duplicating across serverless & pods | ||
if ts.indexType == "pod" { | ||
ts.T().Skip("Skipping Rerank tests for pods") | ||
} | ||
|
||
ctx := context.Background() | ||
rerankModel := "bge-reranker-v2-m3" | ||
topN := 2 | ||
retunDocuments := true | ||
ranking, err := ts.client.Inference.Rerank(ctx, &RerankRequest{ | ||
Model: rerankModel, | ||
Query: "i love apples", | ||
ReturnDocuments: &retunDocuments, | ||
TopN: &topN, | ||
RankFields: &[]string{"customField"}, | ||
Documents: []Document{ | ||
{"id": "vec1", "customField": "Apple is a popular fruit known for its sweetness and crisp texture."}, | ||
{"id": "vec2", "customField": "Many people enjoy eating apples as a healthy snack."}, | ||
{"id": "vec3", "customField": "Apple Inc. has revolutionized the tech industry with its sleek designs and user-friendly interfaces."}, | ||
{"id": "vec4", "customField": "An apple a day keeps the doctor away, as the saying goes."}, | ||
}}) | ||
|
||
require.NoError(ts.T(), err) | ||
require.NotNil(ts.T(), ranking, "Expected reranking result to be non-nil") | ||
require.Equal(ts.T(), topN, len(ranking.Data), "Expected %v rankings", topN) | ||
|
||
doc := *ranking.Data[0].Document | ||
_, exists := doc["customField"] | ||
require.True(ts.T(), exists, "Expected '%s' to exist in Document map", "customField") | ||
_, exists = doc["id"] | ||
require.True(ts.T(), exists, "Expected '%s' to exist in Document map", "id") | ||
} | ||
|
||
func (ts *IntegrationTests) TestRerankDocumentAllDefaults() { | ||
// Run Rerank tests once rather than duplicating across serverless & pods | ||
if ts.indexType == "pod" { | ||
ts.T().Skip("Skipping Rerank tests for pods") | ||
} | ||
|
||
ctx := context.Background() | ||
rerankModel := "bge-reranker-v2-m3" | ||
ranking, err := ts.client.Inference.Rerank(ctx, &RerankRequest{ | ||
Model: rerankModel, | ||
Query: "i love apples", | ||
Documents: []Document{ | ||
{"id": "vec1", "text": "Apple is a popular fruit known for its sweetness and crisp texture."}, | ||
{"id": "vec2", "text": "Many people enjoy eating apples as a healthy snack."}, | ||
{"id": "vec3", "text": "Apple Inc. has revolutionized the tech industry with its sleek designs and user-friendly interfaces."}, | ||
{"id": "vec4", "text": "An apple a day keeps the doctor away, as the saying goes."}, | ||
}}) | ||
|
||
require.NoError(ts.T(), err) | ||
require.NotNil(ts.T(), ranking, "Expected reranking result to be non-nil") | ||
require.Equal(ts.T(), 4, len(ranking.Data), "Expected %v rankings", 4) | ||
|
||
doc := *ranking.Data[0].Document | ||
_, exists := doc["text"] | ||
require.True(ts.T(), exists, "Expected '%s' to exist in Document map", "text") | ||
_, exists = doc["id"] | ||
require.True(ts.T(), exists, "Expected '%s' to exist in Document map", "id") | ||
} | ||
|
||
func (ts *IntegrationTests) TestRerankDocumentsMultipleRankFields() { | ||
// Run Rerank tests once rather than duplicating across serverless & pods | ||
if ts.indexType == "pod" { | ||
ts.T().Skip("Skipping Rerank tests for pods") | ||
} | ||
|
||
ctx := context.Background() | ||
rerankModel := "bge-reranker-v2-m3" | ||
_, err := ts.client.Inference.Rerank(ctx, &RerankRequest{ | ||
Model: rerankModel, | ||
Query: "i love apples", | ||
RankFields: &[]string{"text", "custom-field"}, | ||
Documents: []Document{ | ||
{ | ||
"id": "vec1", | ||
"text": "Apple is a popular fruit known for its sweetness and crisp texture.", | ||
"custom-field": "another field", | ||
}, | ||
{ | ||
"id": "vec2", | ||
"text": "Many people enjoy eating apples as a healthy snack.", | ||
"custom-field": "another field", | ||
}, | ||
{ | ||
"id": "vec3", | ||
"text": "Apple Inc. has revolutionized the tech industry with its sleek designs and user-friendly interfaces.", | ||
"custom-field": "another field", | ||
}, | ||
{ | ||
"id": "vec4", | ||
"text": "An apple a day keeps the doctor away, as the saying goes.", | ||
"custom-field": "another field", | ||
}, | ||
}}) | ||
|
||
require.Error(ts.T(), err) | ||
require.Contains(ts.T(), err.Error(), "Only one rank field is supported for model") | ||
} | ||
|
||
func (ts *IntegrationTests) TestRerankDocumentFieldError() { | ||
// Run Rerank tests once rather than duplicating across serverless & pods | ||
if ts.indexType == "pod" { | ||
ts.T().Skip("Skipping Rerank tests for pods") | ||
} | ||
|
||
ctx := context.Background() | ||
rerankModel := "bge-reranker-v2-m3" | ||
_, err := ts.client.Inference.Rerank(ctx, &RerankRequest{ | ||
Model: rerankModel, | ||
Query: "i love apples", | ||
RankFields: &[]string{"custom-field"}, | ||
Documents: []Document{ | ||
{"id": "vec1", "text": "Apple is a popular fruit known for its sweetness and crisp texture."}, | ||
{"id": "vec2", "text": "Many people enjoy eating apples as a healthy snack."}, | ||
{"id": "vec3", "text": "Apple Inc. has revolutionized the tech industry with its sleek designs and user-friendly interfaces."}, | ||
{"id": "vec4", "text": "An apple a day keeps the doctor away, as the saying goes."}, | ||
}}) | ||
|
||
require.Error(ts.T(), err) | ||
require.Contains(ts.T(), err.Error(), "field 'custom-field' not found in document") | ||
} | ||
|
||
// Unit tests: | ||
func TestExtractAuthHeaderUnit(t *testing.T) { | ||
globalApiKey := os.Getenv("PINECONE_API_KEY") | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good catch