Skip to content

Commit

Permalink
add add-hoc rerank implementation to embedding, add async rerank (#1572)
Browse files Browse the repository at this point in the history
  • Loading branch information
emrgnt-cmplxty authored Nov 11, 2024
1 parent 22c0e26 commit 6ee6d9e
Show file tree
Hide file tree
Showing 7 changed files with 437 additions and 299 deletions.
17 changes: 15 additions & 2 deletions py/core/base/providers/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@ class EmbeddingConfig(ProviderConfig):
base_model: str
base_dimension: int
rerank_model: Optional[str] = None
rerank_dimension: Optional[int] = None
rerank_transformer_type: Optional[str] = None
rerank_url: Optional[str] = None
batch_size: int = 1
prefixes: Optional[dict[str, str]] = None
add_title_as_prefix: bool = True
Expand All @@ -38,6 +37,10 @@ class EmbeddingConfig(ProviderConfig):
VectorQuantizationSettings()
)

## deprecated
rerank_dimension: Optional[int] = None
rerank_transformer_type: Optional[str] = None

def validate_config(self) -> None:
if self.provider not in self.supported_providers:
raise ValueError(f"Provider '{self.provider}' is not supported.")
Expand Down Expand Up @@ -171,6 +174,16 @@ def rerank(
):
pass

@abstractmethod
async def arerank(
self,
query: str,
results: list[VectorSearchResult],
stage: PipeStage = PipeStage.RERANK,
limit: int = 10,
):
pass

def set_prefixes(self, config_prefixes: dict[str, str], base_model: str):
self.prefixes = {}

Expand Down
125 changes: 121 additions & 4 deletions py/core/providers/embeddings/litellm.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import logging
import os
from copy import copy
from typing import Any

import litellm
import requests
from aiohttp import ClientError, ClientSession
from litellm import AuthenticationError, aembedding, embedding

from core.base import (
Expand Down Expand Up @@ -36,10 +40,21 @@ def __init__(
raise ValueError(
"LiteLLMEmbeddingProvider must be initialized with provider `litellm`."
)

self.rerank_url = None
if config.rerank_model:
raise ValueError(
"LiteLLMEmbeddingProvider does not support separate reranking."
)

if "huggingface" not in config.rerank_model:
raise ValueError(
"LiteLLMEmbeddingProvider only supports re-ranking via the HuggingFace text-embeddings-inference API"
)

url = os.getenv("HUGGINGFACE_API_BASE") or config.rerank_url
if not url:
raise ValueError(
"LiteLLMEmbeddingProvider requires a valid reranking API url to be set via `embedding.rerank_url` in the r2r.toml, or via the environment variable `HUGGINGFACE_API_BASE`."
)
self.rerank_url = url

self.base_model = config.base_model
if "amazon" in self.base_model:
Expand Down Expand Up @@ -182,4 +197,106 @@ def rerank(
stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.RERANK,
limit: int = 10,
):
return results[:limit]
if self.config.rerank_model is not None:
if not self.rerank_url:
raise ValueError(
"Error, `rerank_url` was expected to be set inside LiteLLMEmbeddingProvider"
)

texts = [result.text for result in results]

payload = {
"query": query,
"texts": texts,
"model-id": self.config.rerank_model.split("huggingface/")[1],
}

headers = {"Content-Type": "application/json"}

try:
response = requests.post(
self.rerank_url, json=payload, headers=headers
)
response.raise_for_status()
reranked_results = response.json()

# Copy reranked results into new array
scored_results = []
for rank_info in reranked_results:
original_result = results[rank_info["index"]]
copied_result = copy(original_result)
# Inject the reranking score into the result object
copied_result.score = rank_info["score"]
scored_results.append(copied_result)

# Return only the VectorSearchResult objects, limited to specified count
return scored_results[:limit]

except requests.RequestException as e:
logger.error(f"Error during reranking: {str(e)}")
# Fall back to returning the original results if reranking fails
return results[:limit]
else:
return results[:limit]

async def arerank(
self,
query: str,
results: list[VectorSearchResult],
stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.RERANK,
limit: int = 10,
) -> list[VectorSearchResult]:
"""
Asynchronously rerank search results using the configured rerank model.
Args:
query: The search query string
results: List of VectorSearchResult objects to rerank
stage: The pipeline stage (must be RERANK)
limit: Maximum number of results to return
Returns:
List of reranked VectorSearchResult objects, limited to specified count
"""
if self.config.rerank_model is not None:
if not self.rerank_url:
raise ValueError(
"Error, `rerank_url` was expected to be set inside LiteLLMEmbeddingProvider"
)

texts = [result.text for result in results]

payload = {
"query": query,
"texts": texts,
"model-id": self.config.rerank_model.split("huggingface/")[1],
}

headers = {"Content-Type": "application/json"}

try:
async with ClientSession() as session:
async with session.post(
self.rerank_url, json=payload, headers=headers
) as response:
response.raise_for_status()
reranked_results = await response.json()

# Copy reranked results into new array
scored_results = []
for rank_info in reranked_results:
original_result = results[rank_info["index"]]
copied_result = copy(original_result)
# Inject the reranking score into the result object
copied_result.score = rank_info["score"]
scored_results.append(copied_result)

# Return only the VectorSearchResult objects, limited to specified count
return scored_results[:limit]

except (ClientError, Exception) as e:
logger.error(f"Error during async reranking: {str(e)}")
# Fall back to returning the original results if reranking fails
return results[:limit]
else:
return results[:limit]
9 changes: 9 additions & 0 deletions py/core/providers/embeddings/ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,3 +183,12 @@ def rerank(
limit: int = 10,
) -> list[VectorSearchResult]:
return results[:limit]

async def arerank(
self,
query: str,
results: list[VectorSearchResult],
stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.RERANK,
limit: int = 10,
):
return results[:limit]
9 changes: 9 additions & 0 deletions py/core/providers/embeddings/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,15 @@ def rerank(
):
return results[:limit]

async def arerank(
self,
query: str,
results: list[VectorSearchResult],
stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.RERANK,
limit: int = 10,
):
return results[:limit]

def tokenize_string(self, text: str, model: str) -> list[int]:
try:
import tiktoken
Expand Down
Loading

0 comments on commit 6ee6d9e

Please sign in to comment.