From e3f5893541f4601c885cde273781c8b0028bcd77 Mon Sep 17 00:00:00 2001 From: Laurent Sorber Date: Mon, 7 Oct 2024 23:07:34 +0200 Subject: [PATCH] fix: patch rerankers flashrank issue (#22) --- README.md | 4 +-- src/raglite/_config.py | 3 ++- src/raglite/_flashrank.py | 41 +++++++++++++++++++++++++++++ tests/conftest.py | 34 +++++++++++++++++++----- tests/test_rag.py | 51 ++++++++++++++++-------------------- tests/test_rerank.py | 54 +++++++++++++++++++++++++++++++++++++++ tests/test_search.py | 47 ++++++++++++++++++++++++++++++++++ 7 files changed, 195 insertions(+), 39 deletions(-) create mode 100644 src/raglite/_flashrank.py create mode 100644 tests/test_rerank.py create mode 100644 tests/test_search.py diff --git a/README.md b/README.md index e1adb2a..7005307 100644 --- a/README.md +++ b/README.md @@ -34,10 +34,10 @@ RAGLite is a Python package for Retrieval-Augmented Generation (RAG) with Postgr ## Installing -First, begin by installing SpaCy's multilingual sentence model: +First, begin by installing spaCy's multilingual sentence model: ```sh -# Install SpaCy's xx_sent_ud_sm: +# Install spaCy's xx_sent_ud_sm: pip install https://github.com/explosion/spacy-models/releases/download/xx_sent_ud_sm-3.7.0/xx_sent_ud_sm-3.7.0-py3-none-any.whl ``` diff --git a/src/raglite/_config.py b/src/raglite/_config.py index 38a9248..87e20b1 100644 --- a/src/raglite/_config.py +++ b/src/raglite/_config.py @@ -8,10 +8,11 @@ from llama_cpp import llama_supports_gpu_offload from sqlalchemy.engine import URL +from raglite._flashrank import PatchedFlashRankRanker as FlashRankRanker + # Suppress rerankers output on import until [1] is fixed. # [1] https://github.com/AnswerDotAI/rerankers/issues/36 with contextlib.redirect_stdout(StringIO()): - from rerankers.models.flashrank_ranker import FlashRankRanker from rerankers.models.ranker import BaseRanker diff --git a/src/raglite/_flashrank.py b/src/raglite/_flashrank.py new file mode 100644 index 0000000..f69d655 --- /dev/null +++ b/src/raglite/_flashrank.py @@ -0,0 +1,41 @@ +"""Patched version of FlashRankRanker that fixes incorrect reranking [1]. + +[1] https://github.com/AnswerDotAI/rerankers/issues/39 +""" + +import contextlib +from io import StringIO +from typing import Any + +from flashrank import RerankRequest + +# Suppress rerankers output on import until [1] is fixed. +# [1] https://github.com/AnswerDotAI/rerankers/issues/36 +with contextlib.redirect_stdout(StringIO()): + from rerankers.documents import Document + from rerankers.models.flashrank_ranker import FlashRankRanker + from rerankers.results import RankedResults, Result + from rerankers.utils import prep_docs + + +class PatchedFlashRankRanker(FlashRankRanker): + def rank( + self, + query: str, + docs: str | list[str] | Document | list[Document], + doc_ids: list[str] | list[int] | None = None, + metadata: list[dict[str, Any]] | None = None, + ) -> RankedResults: + docs = prep_docs(docs, doc_ids, metadata) + passages = [{"id": doc_idx, "text": doc.text} for doc_idx, doc in enumerate(docs)] + rerank_request = RerankRequest(query=query, passages=passages) + flashrank_results = self.model.rerank(rerank_request) + ranked_results = [ + Result( + document=docs[result["id"]], # This patches the incorrect ranking in the original. + score=result["score"], + rank=idx + 1, + ) + for idx, result in enumerate(flashrank_results) + ] + return RankedResults(results=ranked_results, query=query, has_scores=True) diff --git a/tests/conftest.py b/tests/conftest.py index 96121b4..256bad8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,11 +2,14 @@ import os import socket +import tempfile +from collections.abc import Generator +from pathlib import Path import pytest from sqlalchemy import create_engine, text -from raglite import RAGLiteConfig +from raglite import RAGLiteConfig, insert_document POSTGRES_URL = "postgresql+pg8000://raglite_user:raglite_password@postgres:5432/postgres" @@ -26,7 +29,7 @@ def is_openai_available() -> bool: def pytest_sessionstart(session: pytest.Session) -> None: - """Reset the PostgreSQL database.""" + """Reset the PostgreSQL and SQLite databases.""" if is_postgres_running(): engine = create_engine(POSTGRES_URL, isolation_level="AUTOCOMMIT") with engine.connect() as conn: @@ -35,9 +38,18 @@ def pytest_sessionstart(session: pytest.Session) -> None: conn.execute(text(f"CREATE DATABASE raglite_test_{variant}")) +@pytest.fixture(scope="session") +def sqlite_url() -> Generator[str, None, None]: + """Create a temporary SQLite database file and return the database URL.""" + with tempfile.TemporaryDirectory() as temp_dir: + db_file = Path(temp_dir) / "raglite_test.sqlite" + yield f"sqlite:///{db_file}" + + @pytest.fixture( + scope="session", params=[ - pytest.param("sqlite:///:memory:", id="sqlite"), + pytest.param("sqlite", id="sqlite"), pytest.param( POSTGRES_URL, id="postgres", @@ -47,11 +59,14 @@ def pytest_sessionstart(session: pytest.Session) -> None: ) def database(request: pytest.FixtureRequest) -> str: """Get a database URL to test RAGLite with.""" - db_url: str = request.param + db_url: str = ( + request.getfixturevalue("sqlite_url") if request.param == "sqlite" else request.param + ) return db_url @pytest.fixture( + scope="session", params=[ pytest.param( "llama-cpp-python/lm-kit/bge-m3-gguf/*Q4_K_M.gguf", @@ -70,13 +85,18 @@ def embedder(request: pytest.FixtureRequest) -> str: return embedder -@pytest.fixture +@pytest.fixture(scope="session") def raglite_test_config(database: str, embedder: str) -> RAGLiteConfig: """Create a lightweight in-memory config for testing SQLite and PostgreSQL.""" - # Select the PostgreSQL database based on the embedder. + # Select the database based on the embedder. + variant = "local" if embedder.startswith("llama-cpp-python") else "remote" if "postgres" in database: - variant = "local" if embedder.startswith("llama-cpp-python") else "remote" database = database.replace("/postgres", f"/raglite_test_{variant}") + elif "sqlite" in database: + database = database.replace(".sqlite", f"_{variant}.sqlite") # Create a RAGLite config for the given database and embedder. db_config = RAGLiteConfig(db_url=database, embedder=embedder) + # Insert a document and update the index. + doc_path = Path(__file__).parent / "specrel.pdf" # Einstein's special relativity paper. + insert_document(doc_path, config=db_config) return db_config diff --git a/tests/test_rag.py b/tests/test_rag.py index b265e21..150a31b 100644 --- a/tests/test_rag.py +++ b/tests/test_rag.py @@ -1,12 +1,16 @@ """Test RAGLite's RAG functionality.""" import os -from pathlib import Path +from typing import TYPE_CHECKING import pytest from llama_cpp import llama_supports_gpu_offload -from raglite import RAGLiteConfig, hybrid_search, insert_document, rag, retrieve_segments +from raglite import RAGLiteConfig, hybrid_search, rag, retrieve_chunks + +if TYPE_CHECKING: + from raglite._database import Chunk + from raglite._typing import SearchMethod def is_accelerator_available() -> bool: @@ -14,34 +18,23 @@ def is_accelerator_available() -> bool: return llama_supports_gpu_offload() or (os.cpu_count() or 1) >= 8 # noqa: PLR2004 -def test_insert_index_search(raglite_test_config: RAGLiteConfig) -> None: - """Test inserting a document, updating the indexes, and searching for a query.""" - # Insert a document and update the index. - doc_path = Path(__file__).parent / "specrel.pdf" # Einstein's special relativity paper. - insert_document(doc_path, config=raglite_test_config) - # Search for a query. - query = "What does it mean for two events to be simultaneous?" - chunk_ids, scores = hybrid_search(query, config=raglite_test_config) - assert len(chunk_ids) == len(scores) - assert all(isinstance(chunk_id, str) for chunk_id in chunk_ids) - assert all(isinstance(score, float) for score in scores) - # Group the chunks into segments and retrieve them. - segments = retrieve_segments(chunk_ids, neighbors=None, config=raglite_test_config) - assert all(isinstance(segment, str) for segment in segments) - assert "Definition of Simultaneity" in "".join(segments[:2]) - - @pytest.mark.skipif(not is_accelerator_available(), reason="No accelerator available") def test_rag(raglite_test_config: RAGLiteConfig) -> None: """Test Retrieval-Augmented Generation.""" - # Insert a document and update the index. - doc_path = Path(__file__).parent / "specrel.pdf" # Einstein's special relativity paper. - insert_document(doc_path, config=raglite_test_config) - # Answer a question with RAG. + # Assemble different types of search inputs for RAG. prompt = "What does it mean for two events to be simultaneous?" - stream = rag(prompt, search=hybrid_search, config=raglite_test_config) - answer = "" - for update in stream: - assert isinstance(update, str) - answer += update - assert "simultaneous" in answer.lower() + search_inputs: list[SearchMethod | list[str] | list[Chunk]] = [ + hybrid_search, # A search method as input. + hybrid_search(prompt, config=raglite_test_config)[0], # Chunk ids as input. + retrieve_chunks( # Chunks as input. + hybrid_search(prompt, config=raglite_test_config)[0], config=raglite_test_config + ), + ] + # Answer a question with RAG. + for search_input in search_inputs: + stream = rag(prompt, search=search_input, config=raglite_test_config) + answer = "" + for update in stream: + assert isinstance(update, str) + answer += update + assert "simultaneous" in answer.lower() diff --git a/tests/test_rerank.py b/tests/test_rerank.py new file mode 100644 index 0000000..901f3b0 --- /dev/null +++ b/tests/test_rerank.py @@ -0,0 +1,54 @@ +"""Test RAGLite's reranking functionality.""" + +import pytest +from rerankers.models.ranker import BaseRanker + +from raglite import RAGLiteConfig, hybrid_search, rerank, retrieve_chunks +from raglite._database import Chunk +from raglite._flashrank import PatchedFlashRankRanker as FlashRankRanker + + +@pytest.fixture( + params=[ + pytest.param(None, id="no_reranker"), + pytest.param(FlashRankRanker("ms-marco-MiniLM-L-12-v2", verbose=0), id="flashrank_english"), + pytest.param( + ( + ("en", FlashRankRanker("ms-marco-MiniLM-L-12-v2", verbose=0)), + ("other", FlashRankRanker("ms-marco-MultiBERT-L-12", verbose=0)), + ), + id="flashrank_multilingual", + ), + ], +) +def reranker( + request: pytest.FixtureRequest, +) -> BaseRanker | tuple[tuple[str, BaseRanker], ...] | None: + """Get a reranker to test RAGLite with.""" + reranker: BaseRanker | tuple[tuple[str, BaseRanker], ...] | None = request.param + return reranker + + +def test_reranker( + raglite_test_config: RAGLiteConfig, + reranker: BaseRanker | tuple[tuple[str, BaseRanker], ...] | None, +) -> None: + """Test inserting a document, updating the indexes, and searching for a query.""" + # Update the config with the reranker. + raglite_test_config = RAGLiteConfig( + db_url=raglite_test_config.db_url, embedder=raglite_test_config.embedder, reranker=reranker + ) + # Search for a query. + query = "What does it mean for two events to be simultaneous?" + chunk_ids, _ = hybrid_search(query, num_results=3, config=raglite_test_config) + # Retrieve the chunks. + chunks = retrieve_chunks(chunk_ids, config=raglite_test_config) + assert all(isinstance(chunk, Chunk) for chunk in chunks) + # Rerank the chunks given an inverted chunk order. + reranked_chunks = rerank(query, chunks[::-1], config=raglite_test_config) + if reranker is not None and "text-embedding-3-small" not in raglite_test_config.embedder: + assert reranked_chunks[:3] == chunks[:3] + # Test that we can also rerank given the chunk_ids only. + reranked_chunks = rerank(query, chunk_ids[::-1], config=raglite_test_config) + if reranker is not None and "text-embedding-3-small" not in raglite_test_config.embedder: + assert reranked_chunks[:3] == chunks[:3] diff --git a/tests/test_search.py b/tests/test_search.py new file mode 100644 index 0000000..5249b5e --- /dev/null +++ b/tests/test_search.py @@ -0,0 +1,47 @@ +"""Test RAGLite's search functionality.""" + +import pytest + +from raglite import ( + RAGLiteConfig, + hybrid_search, + keyword_search, + retrieve_chunks, + retrieve_segments, + vector_search, +) +from raglite._database import Chunk +from raglite._typing import SearchMethod + + +@pytest.fixture( + params=[ + pytest.param(keyword_search, id="keyword_search"), + pytest.param(vector_search, id="vector_search"), + pytest.param(hybrid_search, id="hybrid_search"), + ], +) +def search_method( + request: pytest.FixtureRequest, +) -> SearchMethod: + """Get a search method to test RAGLite with.""" + search_method: SearchMethod = request.param + return search_method + + +def test_search(raglite_test_config: RAGLiteConfig, search_method: SearchMethod) -> None: + """Test searching for a query.""" + # Search for a query. + query = "What does it mean for two events to be simultaneous?" + num_results = 5 + chunk_ids, scores = search_method(query, num_results=num_results, config=raglite_test_config) + assert len(chunk_ids) == len(scores) == num_results + assert all(isinstance(chunk_id, str) for chunk_id in chunk_ids) + assert all(isinstance(score, float) for score in scores) + # Retrieve the chunks. + chunks = retrieve_chunks(chunk_ids, config=raglite_test_config) + assert all(isinstance(chunk, Chunk) for chunk in chunks) + assert any("Definition of Simultaneity" in str(chunk) for chunk in chunks) + # Extend the chunks with their neighbours and group them into contiguous segments. + segments = retrieve_segments(chunk_ids, neighbors=(-1, 1), config=raglite_test_config) + assert all(isinstance(segment, str) for segment in segments)