Skip to content

Commit

Permalink
Rebase & Implement Rerank (#80)
Browse files Browse the repository at this point in the history
## Problem
The `InferenceService` and `Embed` operation were implemented in a
previous PR: #67

As a part of the `2024-10` release we are also adding `Rerank` to the
client, which will live under `InferenceService` with `Embed`.

## Solution
This work builds on top of external contributions from @Stosan
(#73) Thank you for you
diligence in getting this started!

- Update `InferenceService` to expose the new `Rerank` method.
- Add various request and response types to represent the interface.
- Add doc comments and an integration test for calling the reranker.

## Type of Change
- [ ] Bug fix (non-breaking change which fixes an issue)
- [X] New feature (non-breaking change which adds functionality)
- [ ] Breaking change (fix or feature that would cause existing
functionality to not work as expected)
- [ ] This change requires a documentation update
- [ ] Infrastructure change (CI configs, etc)
- [ ] Non-code change (docs, etc)
- [ ] None of the above: (explain here)

## Test Plan
`just test`

If you want to use the code yourself, here's an example:

```go
	ctx := context.Background()

	clientParams := pinecone.NewClientParams{
		ApiKey:    "YOUR_API_KEY",
	}

	pc, err := pinecone.NewClient(clientParams)
	if err != nil {
		log.Fatalf("Failed to create Client: %v", err)
	}

	rerankModel := "bge-reranker-v2-m3"
	topN := 2
	retunDocuments := true
	documents := []pinecone.Document{
		{
                    "id": "vec1", 
                    "text": "Apple is a popular fruit known for its sweetness and crisp texture."},
		{
                    "id": "vec2", 
                    "text": "Many people enjoy eating apples as a healthy snack."},
		{
                    "id": "vec3", 
                    "text": "Apple Inc. has revolutionized the tech industry with its sleek designs and user-friendly interfaces."},
		{
                    "id": "vec4", 
                    "text": "An apple a day keeps the doctor away, as the saying goes."},
        }

	ranking, err := pc.Inference.Rerank(ctx, &pinecone.RerankRequest{
		Model:           rerankModel,
		Query:           "i love cats",
		ReturnDocuments: &retunDocuments,
		TopN:            &topN,
		RankFields:      &[]string{"text"},
		Documents:       documents,
	})
	if err != nil {
		log.Fatalf("Failed to rerank: %v", err)
	}
	fmt.Printf("Rerank result: %+v\n", ranking)
```

---------

Co-authored-by: Sam Ayo <[email protected]>
  • Loading branch information
austin-denoble and Stosan committed Oct 23, 2024
1 parent feb3a85 commit 1069770
Show file tree
Hide file tree
Showing 4 changed files with 307 additions and 3 deletions.
2 changes: 1 addition & 1 deletion codegen/apis
Submodule apis updated from 3002f1 to 9d3a41
136 changes: 135 additions & 1 deletion pinecone/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -1261,7 +1261,7 @@ type InferenceService struct {
// Parameters:
// - ctx: A context.Context object controls the request's lifetime, allowing for the request
// to be canceled or to timeout according to the context's deadline.
// - in: A pointer to an EmbedRequest object that contains the model t4o use for embedding generation, the
// - in: A pointer to an EmbedRequest object that contains the model to use for embedding generation, the
// list of input strings to generate embeddings for, and any additional parameters to use for generation.
//
// Returns a pointer to an EmbeddingsList object or an error.
Expand Down Expand Up @@ -1343,6 +1343,130 @@ func (i *InferenceService) Embed(ctx context.Context, in *EmbedRequest) (*infere
return decodeEmbeddingsList(res.Body)
}

// Document is a map representing the document to be reranked.
type Document map[string]string

// RerankRequest holds the parameters for calling [InferenceService.Rerank] and reranking documents
// by a specified query and model.
//
// Fields:
// - Model: "The [model] to use for reranking.
// - Query: (Required) The query to rerank Documents against.
// - Documents: (Required) A list of Document objects to be reranked. The default is "text", but you can
// specify this behavior with [RerankRequest.RankFields].
// - RankFields: (Optional) The fields to rank the Documents by. If not provided, the default is "text".
// - ReturnDocuments: (Optional) Whether to include Documents in the response. Defaults to true.
// - TopN: (Optional) How many Documents to return. Defaults to the length of input Documents.
// - Parameters: (Optional) Additional model-specific parameters for the reranker
//
// [model]: https://docs.pinecone.io/guides/inference/understanding-inference#models
type RerankRequest struct {
Model string
Query string
Documents []Document
RankFields *[]string
ReturnDocuments *bool
TopN *int
Parameters *map[string]string
}

// Represents a ranked document with a relevance score and an index position.
//
// Fields:
// - Document: The [Document].
// - Index: The index position of the Document from the original request. This can be used
// to locate the position of the document relative to others described in the request.
// - Score: The relevance score of the Document indicating how closely it matches the query.
type RankedDocument struct {
Document *Document `json:"document,omitempty"`
Index int `json:"index"`
Score float32 `json:"score"`
}

// RerankResponse is the result of a reranking operation.
//
// Fields:
// - Data: A list of Documents which have been reranked. The Documents are sorted in order of relevance,
// with the first being the most relevant.
// - Model: The model used to rerank Documents.
// - Usage: Usage statistics ([Rerank Units]) for the reranking operation.
//
// [Read Units]: https://docs.pinecone.io/guides/organizations/manage-cost/understanding-cost#rerank
type RerankResponse struct {
Data []RankedDocument `json:"data,omitempty"`
Model string `json:"model"`
Usage RerankUsage `json:"usage"`
}

// Rerank Documents with associated relevance scores that represent the relevance of each Document
// to the provided query using the specified model.
//
// Parameters:
// - ctx: A context.Context object controls the request's lifetime, allowing for the request
// to be canceled or to timeout according to the context's deadline.
// - in: A pointer to a [RerankRequest] object that contains the model, query, and documents to use for reranking.
//
// Example:
//
// ctx := context.Background()
//
// clientParams := pinecone.NewClientParams{
// ApiKey: "YOUR_API_KEY",
// SourceTag: "your_source_identifier", // optional
// }
//
// pc, err := pinecone.NewClient(clientParams)
// if err != nil {
// log.Fatalf("Failed to create Client: %v", err)
// }
//
// rerankModel := "bge-reranker-v2-m3"
// topN := 2
// retunDocuments := true
// documents := []pinecone.Document{
// {"id": "doc1", "text": "Apple is a popular fruit known for its sweetness and crisp texture."},
// {"id": "doc2", "text": "Many people enjoy eating apples as a healthy snack."},
// {"id": "doc3", "text": "Apple Inc. has revolutionized the tech industry with its sleek designs and user-friendly interfaces."},
// {"id": "doc4", "text": "An apple a day keeps the doctor away, as the saying goes."},
// }
//
// ranking, err := pc.Inference.Rerank(ctx, &pinecone.RerankRequest{
// Model: rerankModel,
// Query: "i love to eat apples",
// ReturnDocuments: &retunDocuments,
// TopN: &topN,
// RankFields: &[]string{"text"},
// Documents: documents,
// })
// if err != nil {
// log.Fatalf("Failed to rerank: %v", err)
// }
// fmt.Printf("Rerank result: %+v\n", ranking)
func (i *InferenceService) Rerank(ctx context.Context, in *RerankRequest) (*RerankResponse, error) {
convertedDocuments := make([]inference.Document, len(in.Documents))
for i, doc := range in.Documents {
convertedDocuments[i] = inference.Document(doc)
}
req := inference.RerankJSONRequestBody{
Model: in.Model,
Query: in.Query,
Documents: convertedDocuments,
RankFields: in.RankFields,
ReturnDocuments: in.ReturnDocuments,
TopN: in.TopN,
Parameters: in.Parameters,
}
res, err := i.client.Rerank(ctx, req)
if err != nil {
return nil, err
}
defer res.Body.Close()
if res.StatusCode != http.StatusOK {
return nil, handleErrorResponseBody(res, "failed to rerank: ")
}
return decodeRerankResponse(res.Body)
}

func (c *Client) extractAuthHeader() map[string]string {
possibleAuthKeys := []string{
"api-key",
Expand Down Expand Up @@ -1423,6 +1547,16 @@ func decodeEmbeddingsList(resBody io.ReadCloser) (*inference.EmbeddingsList, err
return &embeddingsList, nil
}

func decodeRerankResponse(resBody io.ReadCloser) (*RerankResponse, error) {
var rerankResponse RerankResponse
err := json.NewDecoder(resBody).Decode(&rerankResponse)
if err != nil {
return nil, fmt.Errorf("failed to decode rerank response: %w", err)
}

return &rerankResponse, nil
}

func toCollection(cm *db_control.CollectionModel) *Collection {
if cm == nil {
return nil
Expand Down
165 changes: 164 additions & 1 deletion pinecone/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@ import (
"time"

"github.com/google/go-cmp/cmp"
"github.com/google/uuid"
"github.com/pinecone-io/go-pinecone/internal/gen"
"github.com/pinecone-io/go-pinecone/internal/gen/db_control"
"github.com/pinecone-io/go-pinecone/internal/provider"

"github.com/google/uuid"
"github.com/pinecone-io/go-pinecone/internal/utils"

"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -277,6 +277,11 @@ func (ts *IntegrationTests) TestConfigureIndexHitPodLimit() {
}

func (ts *IntegrationTests) TestGenerateEmbeddings() {
// Run Embed tests once rather than duplicating across serverless & pods
if ts.indexType == "pod" {
ts.T().Skip("Skipping Embed tests for pods")
}

ctx := context.Background()
embeddingModel := "multilingual-e5-large"
embeddings, err := ts.client.Inference.Embed(ctx, &EmbedRequest{
Expand Down Expand Up @@ -313,6 +318,164 @@ func (ts *IntegrationTests) TestGenerateEmbeddingsInvalidInputs() {
require.Contains(ts.T(), err.Error(), "TextInputs must contain at least one value")
}

func (ts *IntegrationTests) TestRerankDocumentDefaultField() {
// Run Rerank tests once rather than duplicating across serverless & pods
if ts.indexType == "pod" {
ts.T().Skip("Skipping Rerank tests for pods")
}

ctx := context.Background()
rerankModel := "bge-reranker-v2-m3"
topN := 2
retunDocuments := true
ranking, err := ts.client.Inference.Rerank(ctx, &RerankRequest{
Model: rerankModel,
Query: "i love apples",
ReturnDocuments: &retunDocuments,
TopN: &topN,
Documents: []Document{
{"id": "vec1", "text": "Apple is a popular fruit known for its sweetness and crisp texture."},
{"id": "vec2", "text": "Many people enjoy eating apples as a healthy snack."},
{"id": "vec3", "text": "Apple Inc. has revolutionized the tech industry with its sleek designs and user-friendly interfaces."},
{"id": "vec4", "text": "An apple a day keeps the doctor away, as the saying goes."},
}})

require.NoError(ts.T(), err)
require.NotNil(ts.T(), ranking, "Expected reranking result to be non-nil")
require.Equal(ts.T(), topN, len(ranking.Data), "Expected %v rankings", topN)

doc := *ranking.Data[0].Document
_, exists := doc["text"]
require.True(ts.T(), exists, "Expected '%s' to exist in Document map", "text")
_, exists = doc["id"]
require.True(ts.T(), exists, "Expected '%s' to exist in Document map", "id")
}

func (ts *IntegrationTests) TestRerankDocumentCustomField() {
// Run Rerank tests once rather than duplicating across serverless & pods
if ts.indexType == "pod" {
ts.T().Skip("Skipping Rerank tests for pods")
}

ctx := context.Background()
rerankModel := "bge-reranker-v2-m3"
topN := 2
retunDocuments := true
ranking, err := ts.client.Inference.Rerank(ctx, &RerankRequest{
Model: rerankModel,
Query: "i love apples",
ReturnDocuments: &retunDocuments,
TopN: &topN,
RankFields: &[]string{"customField"},
Documents: []Document{
{"id": "vec1", "customField": "Apple is a popular fruit known for its sweetness and crisp texture."},
{"id": "vec2", "customField": "Many people enjoy eating apples as a healthy snack."},
{"id": "vec3", "customField": "Apple Inc. has revolutionized the tech industry with its sleek designs and user-friendly interfaces."},
{"id": "vec4", "customField": "An apple a day keeps the doctor away, as the saying goes."},
}})

require.NoError(ts.T(), err)
require.NotNil(ts.T(), ranking, "Expected reranking result to be non-nil")
require.Equal(ts.T(), topN, len(ranking.Data), "Expected %v rankings", topN)

doc := *ranking.Data[0].Document
_, exists := doc["customField"]
require.True(ts.T(), exists, "Expected '%s' to exist in Document map", "customField")
_, exists = doc["id"]
require.True(ts.T(), exists, "Expected '%s' to exist in Document map", "id")
}

func (ts *IntegrationTests) TestRerankDocumentAllDefaults() {
// Run Rerank tests once rather than duplicating across serverless & pods
if ts.indexType == "pod" {
ts.T().Skip("Skipping Rerank tests for pods")
}

ctx := context.Background()
rerankModel := "bge-reranker-v2-m3"
ranking, err := ts.client.Inference.Rerank(ctx, &RerankRequest{
Model: rerankModel,
Query: "i love apples",
Documents: []Document{
{"id": "vec1", "text": "Apple is a popular fruit known for its sweetness and crisp texture."},
{"id": "vec2", "text": "Many people enjoy eating apples as a healthy snack."},
{"id": "vec3", "text": "Apple Inc. has revolutionized the tech industry with its sleek designs and user-friendly interfaces."},
{"id": "vec4", "text": "An apple a day keeps the doctor away, as the saying goes."},
}})

require.NoError(ts.T(), err)
require.NotNil(ts.T(), ranking, "Expected reranking result to be non-nil")
require.Equal(ts.T(), 4, len(ranking.Data), "Expected %v rankings", 4)

doc := *ranking.Data[0].Document
_, exists := doc["text"]
require.True(ts.T(), exists, "Expected '%s' to exist in Document map", "text")
_, exists = doc["id"]
require.True(ts.T(), exists, "Expected '%s' to exist in Document map", "id")
}

func (ts *IntegrationTests) TestRerankDocumentsMultipleRankFields() {
// Run Rerank tests once rather than duplicating across serverless & pods
if ts.indexType == "pod" {
ts.T().Skip("Skipping Rerank tests for pods")
}

ctx := context.Background()
rerankModel := "bge-reranker-v2-m3"
_, err := ts.client.Inference.Rerank(ctx, &RerankRequest{
Model: rerankModel,
Query: "i love apples",
RankFields: &[]string{"text", "custom-field"},
Documents: []Document{
{
"id": "vec1",
"text": "Apple is a popular fruit known for its sweetness and crisp texture.",
"custom-field": "another field",
},
{
"id": "vec2",
"text": "Many people enjoy eating apples as a healthy snack.",
"custom-field": "another field",
},
{
"id": "vec3",
"text": "Apple Inc. has revolutionized the tech industry with its sleek designs and user-friendly interfaces.",
"custom-field": "another field",
},
{
"id": "vec4",
"text": "An apple a day keeps the doctor away, as the saying goes.",
"custom-field": "another field",
},
}})

require.Error(ts.T(), err)
require.Contains(ts.T(), err.Error(), "Only one rank field is supported for model")
}

func (ts *IntegrationTests) TestRerankDocumentFieldError() {
// Run Rerank tests once rather than duplicating across serverless & pods
if ts.indexType == "pod" {
ts.T().Skip("Skipping Rerank tests for pods")
}

ctx := context.Background()
rerankModel := "bge-reranker-v2-m3"
_, err := ts.client.Inference.Rerank(ctx, &RerankRequest{
Model: rerankModel,
Query: "i love apples",
RankFields: &[]string{"custom-field"},
Documents: []Document{
{"id": "vec1", "text": "Apple is a popular fruit known for its sweetness and crisp texture."},
{"id": "vec2", "text": "Many people enjoy eating apples as a healthy snack."},
{"id": "vec3", "text": "Apple Inc. has revolutionized the tech industry with its sleek designs and user-friendly interfaces."},
{"id": "vec4", "text": "An apple a day keeps the doctor away, as the saying goes."},
}})

require.Error(ts.T(), err)
require.Contains(ts.T(), err.Error(), "field 'custom-field' not found in document")
}

// Unit tests:
func TestExtractAuthHeaderUnit(t *testing.T) {
globalApiKey := os.Getenv("PINECONE_API_KEY")
Expand Down
7 changes: 7 additions & 0 deletions pinecone/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,13 @@ type Usage struct {
ReadUnits uint32 `json:"read_units"`
}

// RerankUsage is the usage stats ([Rerank Units]) for a reranking request.
//
// [Rerank Units]: https://docs.pinecone.io/guides/organizations/manage-cost/understanding-cost#rerank
type RerankUsage struct {
RerankUnits *int `json:"rerank_units,omitempty"`
}

// MetadataFilter represents the [metadata filters] attached to a Pinecone request.
// These optional metadata filters are applied to query and deletion requests.
//
Expand Down

0 comments on commit 1069770

Please sign in to comment.