-
Notifications
You must be signed in to change notification settings - Fork 39
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
fix: patch rerankers flashrank issue (#22)
- Loading branch information
Showing
7 changed files
with
195 additions
and
39 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
"""Patched version of FlashRankRanker that fixes incorrect reranking [1]. | ||
[1] https://github.com/AnswerDotAI/rerankers/issues/39 | ||
""" | ||
|
||
import contextlib | ||
from io import StringIO | ||
from typing import Any | ||
|
||
from flashrank import RerankRequest | ||
|
||
# Suppress rerankers output on import until [1] is fixed. | ||
# [1] https://github.com/AnswerDotAI/rerankers/issues/36 | ||
with contextlib.redirect_stdout(StringIO()): | ||
from rerankers.documents import Document | ||
from rerankers.models.flashrank_ranker import FlashRankRanker | ||
from rerankers.results import RankedResults, Result | ||
from rerankers.utils import prep_docs | ||
|
||
|
||
class PatchedFlashRankRanker(FlashRankRanker): | ||
def rank( | ||
self, | ||
query: str, | ||
docs: str | list[str] | Document | list[Document], | ||
doc_ids: list[str] | list[int] | None = None, | ||
metadata: list[dict[str, Any]] | None = None, | ||
) -> RankedResults: | ||
docs = prep_docs(docs, doc_ids, metadata) | ||
passages = [{"id": doc_idx, "text": doc.text} for doc_idx, doc in enumerate(docs)] | ||
rerank_request = RerankRequest(query=query, passages=passages) | ||
flashrank_results = self.model.rerank(rerank_request) | ||
ranked_results = [ | ||
Result( | ||
document=docs[result["id"]], # This patches the incorrect ranking in the original. | ||
score=result["score"], | ||
rank=idx + 1, | ||
) | ||
for idx, result in enumerate(flashrank_results) | ||
] | ||
return RankedResults(results=ranked_results, query=query, has_scores=True) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,47 +1,40 @@ | ||
"""Test RAGLite's RAG functionality.""" | ||
|
||
import os | ||
from pathlib import Path | ||
from typing import TYPE_CHECKING | ||
|
||
import pytest | ||
from llama_cpp import llama_supports_gpu_offload | ||
|
||
from raglite import RAGLiteConfig, hybrid_search, insert_document, rag, retrieve_segments | ||
from raglite import RAGLiteConfig, hybrid_search, rag, retrieve_chunks | ||
|
||
if TYPE_CHECKING: | ||
from raglite._database import Chunk | ||
from raglite._typing import SearchMethod | ||
|
||
|
||
def is_accelerator_available() -> bool: | ||
"""Check if an accelerator is available.""" | ||
return llama_supports_gpu_offload() or (os.cpu_count() or 1) >= 8 # noqa: PLR2004 | ||
|
||
|
||
def test_insert_index_search(raglite_test_config: RAGLiteConfig) -> None: | ||
"""Test inserting a document, updating the indexes, and searching for a query.""" | ||
# Insert a document and update the index. | ||
doc_path = Path(__file__).parent / "specrel.pdf" # Einstein's special relativity paper. | ||
insert_document(doc_path, config=raglite_test_config) | ||
# Search for a query. | ||
query = "What does it mean for two events to be simultaneous?" | ||
chunk_ids, scores = hybrid_search(query, config=raglite_test_config) | ||
assert len(chunk_ids) == len(scores) | ||
assert all(isinstance(chunk_id, str) for chunk_id in chunk_ids) | ||
assert all(isinstance(score, float) for score in scores) | ||
# Group the chunks into segments and retrieve them. | ||
segments = retrieve_segments(chunk_ids, neighbors=None, config=raglite_test_config) | ||
assert all(isinstance(segment, str) for segment in segments) | ||
assert "Definition of Simultaneity" in "".join(segments[:2]) | ||
|
||
|
||
@pytest.mark.skipif(not is_accelerator_available(), reason="No accelerator available") | ||
def test_rag(raglite_test_config: RAGLiteConfig) -> None: | ||
"""Test Retrieval-Augmented Generation.""" | ||
# Insert a document and update the index. | ||
doc_path = Path(__file__).parent / "specrel.pdf" # Einstein's special relativity paper. | ||
insert_document(doc_path, config=raglite_test_config) | ||
# Answer a question with RAG. | ||
# Assemble different types of search inputs for RAG. | ||
prompt = "What does it mean for two events to be simultaneous?" | ||
stream = rag(prompt, search=hybrid_search, config=raglite_test_config) | ||
answer = "" | ||
for update in stream: | ||
assert isinstance(update, str) | ||
answer += update | ||
assert "simultaneous" in answer.lower() | ||
search_inputs: list[SearchMethod | list[str] | list[Chunk]] = [ | ||
hybrid_search, # A search method as input. | ||
hybrid_search(prompt, config=raglite_test_config)[0], # Chunk ids as input. | ||
retrieve_chunks( # Chunks as input. | ||
hybrid_search(prompt, config=raglite_test_config)[0], config=raglite_test_config | ||
), | ||
] | ||
# Answer a question with RAG. | ||
for search_input in search_inputs: | ||
stream = rag(prompt, search=search_input, config=raglite_test_config) | ||
answer = "" | ||
for update in stream: | ||
assert isinstance(update, str) | ||
answer += update | ||
assert "simultaneous" in answer.lower() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
"""Test RAGLite's reranking functionality.""" | ||
|
||
import pytest | ||
from rerankers.models.ranker import BaseRanker | ||
|
||
from raglite import RAGLiteConfig, hybrid_search, rerank, retrieve_chunks | ||
from raglite._database import Chunk | ||
from raglite._flashrank import PatchedFlashRankRanker as FlashRankRanker | ||
|
||
|
||
@pytest.fixture( | ||
params=[ | ||
pytest.param(None, id="no_reranker"), | ||
pytest.param(FlashRankRanker("ms-marco-MiniLM-L-12-v2", verbose=0), id="flashrank_english"), | ||
pytest.param( | ||
( | ||
("en", FlashRankRanker("ms-marco-MiniLM-L-12-v2", verbose=0)), | ||
("other", FlashRankRanker("ms-marco-MultiBERT-L-12", verbose=0)), | ||
), | ||
id="flashrank_multilingual", | ||
), | ||
], | ||
) | ||
def reranker( | ||
request: pytest.FixtureRequest, | ||
) -> BaseRanker | tuple[tuple[str, BaseRanker], ...] | None: | ||
"""Get a reranker to test RAGLite with.""" | ||
reranker: BaseRanker | tuple[tuple[str, BaseRanker], ...] | None = request.param | ||
return reranker | ||
|
||
|
||
def test_reranker( | ||
raglite_test_config: RAGLiteConfig, | ||
reranker: BaseRanker | tuple[tuple[str, BaseRanker], ...] | None, | ||
) -> None: | ||
"""Test inserting a document, updating the indexes, and searching for a query.""" | ||
# Update the config with the reranker. | ||
raglite_test_config = RAGLiteConfig( | ||
db_url=raglite_test_config.db_url, embedder=raglite_test_config.embedder, reranker=reranker | ||
) | ||
# Search for a query. | ||
query = "What does it mean for two events to be simultaneous?" | ||
chunk_ids, _ = hybrid_search(query, num_results=3, config=raglite_test_config) | ||
# Retrieve the chunks. | ||
chunks = retrieve_chunks(chunk_ids, config=raglite_test_config) | ||
assert all(isinstance(chunk, Chunk) for chunk in chunks) | ||
# Rerank the chunks given an inverted chunk order. | ||
reranked_chunks = rerank(query, chunks[::-1], config=raglite_test_config) | ||
if reranker is not None and "text-embedding-3-small" not in raglite_test_config.embedder: | ||
assert reranked_chunks[:3] == chunks[:3] | ||
# Test that we can also rerank given the chunk_ids only. | ||
reranked_chunks = rerank(query, chunk_ids[::-1], config=raglite_test_config) | ||
if reranker is not None and "text-embedding-3-small" not in raglite_test_config.embedder: | ||
assert reranked_chunks[:3] == chunks[:3] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
"""Test RAGLite's search functionality.""" | ||
|
||
import pytest | ||
|
||
from raglite import ( | ||
RAGLiteConfig, | ||
hybrid_search, | ||
keyword_search, | ||
retrieve_chunks, | ||
retrieve_segments, | ||
vector_search, | ||
) | ||
from raglite._database import Chunk | ||
from raglite._typing import SearchMethod | ||
|
||
|
||
@pytest.fixture( | ||
params=[ | ||
pytest.param(keyword_search, id="keyword_search"), | ||
pytest.param(vector_search, id="vector_search"), | ||
pytest.param(hybrid_search, id="hybrid_search"), | ||
], | ||
) | ||
def search_method( | ||
request: pytest.FixtureRequest, | ||
) -> SearchMethod: | ||
"""Get a search method to test RAGLite with.""" | ||
search_method: SearchMethod = request.param | ||
return search_method | ||
|
||
|
||
def test_search(raglite_test_config: RAGLiteConfig, search_method: SearchMethod) -> None: | ||
"""Test searching for a query.""" | ||
# Search for a query. | ||
query = "What does it mean for two events to be simultaneous?" | ||
num_results = 5 | ||
chunk_ids, scores = search_method(query, num_results=num_results, config=raglite_test_config) | ||
assert len(chunk_ids) == len(scores) == num_results | ||
assert all(isinstance(chunk_id, str) for chunk_id in chunk_ids) | ||
assert all(isinstance(score, float) for score in scores) | ||
# Retrieve the chunks. | ||
chunks = retrieve_chunks(chunk_ids, config=raglite_test_config) | ||
assert all(isinstance(chunk, Chunk) for chunk in chunks) | ||
assert any("Definition of Simultaneity" in str(chunk) for chunk in chunks) | ||
# Extend the chunks with their neighbours and group them into contiguous segments. | ||
segments = retrieve_segments(chunk_ids, neighbors=(-1, 1), config=raglite_test_config) | ||
assert all(isinstance(segment, str) for segment in segments) |