Skip to content

Commit

Permalink
fix: patch rerankers flashrank issue (#22)
Browse files Browse the repository at this point in the history
  • Loading branch information
lsorber authored Oct 7, 2024
1 parent 8066ebe commit e3f5893
Show file tree
Hide file tree
Showing 7 changed files with 195 additions and 39 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,10 @@ RAGLite is a Python package for Retrieval-Augmented Generation (RAG) with Postgr

## Installing

First, begin by installing SpaCy's multilingual sentence model:
First, begin by installing spaCy's multilingual sentence model:

```sh
# Install SpaCy's xx_sent_ud_sm:
# Install spaCy's xx_sent_ud_sm:
pip install https://github.com/explosion/spacy-models/releases/download/xx_sent_ud_sm-3.7.0/xx_sent_ud_sm-3.7.0-py3-none-any.whl
```

Expand Down
3 changes: 2 additions & 1 deletion src/raglite/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@
from llama_cpp import llama_supports_gpu_offload
from sqlalchemy.engine import URL

from raglite._flashrank import PatchedFlashRankRanker as FlashRankRanker

# Suppress rerankers output on import until [1] is fixed.
# [1] https://github.com/AnswerDotAI/rerankers/issues/36
with contextlib.redirect_stdout(StringIO()):
from rerankers.models.flashrank_ranker import FlashRankRanker
from rerankers.models.ranker import BaseRanker


Expand Down
41 changes: 41 additions & 0 deletions src/raglite/_flashrank.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
"""Patched version of FlashRankRanker that fixes incorrect reranking [1].
[1] https://github.com/AnswerDotAI/rerankers/issues/39
"""

import contextlib
from io import StringIO
from typing import Any

from flashrank import RerankRequest

# Suppress rerankers output on import until [1] is fixed.
# [1] https://github.com/AnswerDotAI/rerankers/issues/36
with contextlib.redirect_stdout(StringIO()):
from rerankers.documents import Document
from rerankers.models.flashrank_ranker import FlashRankRanker
from rerankers.results import RankedResults, Result
from rerankers.utils import prep_docs


class PatchedFlashRankRanker(FlashRankRanker):
def rank(
self,
query: str,
docs: str | list[str] | Document | list[Document],
doc_ids: list[str] | list[int] | None = None,
metadata: list[dict[str, Any]] | None = None,
) -> RankedResults:
docs = prep_docs(docs, doc_ids, metadata)
passages = [{"id": doc_idx, "text": doc.text} for doc_idx, doc in enumerate(docs)]
rerank_request = RerankRequest(query=query, passages=passages)
flashrank_results = self.model.rerank(rerank_request)
ranked_results = [
Result(
document=docs[result["id"]], # This patches the incorrect ranking in the original.
score=result["score"],
rank=idx + 1,
)
for idx, result in enumerate(flashrank_results)
]
return RankedResults(results=ranked_results, query=query, has_scores=True)
34 changes: 27 additions & 7 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@

import os
import socket
import tempfile
from collections.abc import Generator
from pathlib import Path

import pytest
from sqlalchemy import create_engine, text

from raglite import RAGLiteConfig
from raglite import RAGLiteConfig, insert_document

POSTGRES_URL = "postgresql+pg8000://raglite_user:raglite_password@postgres:5432/postgres"

Expand All @@ -26,7 +29,7 @@ def is_openai_available() -> bool:


def pytest_sessionstart(session: pytest.Session) -> None:
"""Reset the PostgreSQL database."""
"""Reset the PostgreSQL and SQLite databases."""
if is_postgres_running():
engine = create_engine(POSTGRES_URL, isolation_level="AUTOCOMMIT")
with engine.connect() as conn:
Expand All @@ -35,9 +38,18 @@ def pytest_sessionstart(session: pytest.Session) -> None:
conn.execute(text(f"CREATE DATABASE raglite_test_{variant}"))


@pytest.fixture(scope="session")
def sqlite_url() -> Generator[str, None, None]:
"""Create a temporary SQLite database file and return the database URL."""
with tempfile.TemporaryDirectory() as temp_dir:
db_file = Path(temp_dir) / "raglite_test.sqlite"
yield f"sqlite:///{db_file}"


@pytest.fixture(
scope="session",
params=[
pytest.param("sqlite:///:memory:", id="sqlite"),
pytest.param("sqlite", id="sqlite"),
pytest.param(
POSTGRES_URL,
id="postgres",
Expand All @@ -47,11 +59,14 @@ def pytest_sessionstart(session: pytest.Session) -> None:
)
def database(request: pytest.FixtureRequest) -> str:
"""Get a database URL to test RAGLite with."""
db_url: str = request.param
db_url: str = (
request.getfixturevalue("sqlite_url") if request.param == "sqlite" else request.param
)
return db_url


@pytest.fixture(
scope="session",
params=[
pytest.param(
"llama-cpp-python/lm-kit/bge-m3-gguf/*Q4_K_M.gguf",
Expand All @@ -70,13 +85,18 @@ def embedder(request: pytest.FixtureRequest) -> str:
return embedder


@pytest.fixture
@pytest.fixture(scope="session")
def raglite_test_config(database: str, embedder: str) -> RAGLiteConfig:
"""Create a lightweight in-memory config for testing SQLite and PostgreSQL."""
# Select the PostgreSQL database based on the embedder.
# Select the database based on the embedder.
variant = "local" if embedder.startswith("llama-cpp-python") else "remote"
if "postgres" in database:
variant = "local" if embedder.startswith("llama-cpp-python") else "remote"
database = database.replace("/postgres", f"/raglite_test_{variant}")
elif "sqlite" in database:
database = database.replace(".sqlite", f"_{variant}.sqlite")
# Create a RAGLite config for the given database and embedder.
db_config = RAGLiteConfig(db_url=database, embedder=embedder)
# Insert a document and update the index.
doc_path = Path(__file__).parent / "specrel.pdf" # Einstein's special relativity paper.
insert_document(doc_path, config=db_config)
return db_config
51 changes: 22 additions & 29 deletions tests/test_rag.py
Original file line number Diff line number Diff line change
@@ -1,47 +1,40 @@
"""Test RAGLite's RAG functionality."""

import os
from pathlib import Path
from typing import TYPE_CHECKING

import pytest
from llama_cpp import llama_supports_gpu_offload

from raglite import RAGLiteConfig, hybrid_search, insert_document, rag, retrieve_segments
from raglite import RAGLiteConfig, hybrid_search, rag, retrieve_chunks

if TYPE_CHECKING:
from raglite._database import Chunk
from raglite._typing import SearchMethod


def is_accelerator_available() -> bool:
"""Check if an accelerator is available."""
return llama_supports_gpu_offload() or (os.cpu_count() or 1) >= 8 # noqa: PLR2004


def test_insert_index_search(raglite_test_config: RAGLiteConfig) -> None:
"""Test inserting a document, updating the indexes, and searching for a query."""
# Insert a document and update the index.
doc_path = Path(__file__).parent / "specrel.pdf" # Einstein's special relativity paper.
insert_document(doc_path, config=raglite_test_config)
# Search for a query.
query = "What does it mean for two events to be simultaneous?"
chunk_ids, scores = hybrid_search(query, config=raglite_test_config)
assert len(chunk_ids) == len(scores)
assert all(isinstance(chunk_id, str) for chunk_id in chunk_ids)
assert all(isinstance(score, float) for score in scores)
# Group the chunks into segments and retrieve them.
segments = retrieve_segments(chunk_ids, neighbors=None, config=raglite_test_config)
assert all(isinstance(segment, str) for segment in segments)
assert "Definition of Simultaneity" in "".join(segments[:2])


@pytest.mark.skipif(not is_accelerator_available(), reason="No accelerator available")
def test_rag(raglite_test_config: RAGLiteConfig) -> None:
"""Test Retrieval-Augmented Generation."""
# Insert a document and update the index.
doc_path = Path(__file__).parent / "specrel.pdf" # Einstein's special relativity paper.
insert_document(doc_path, config=raglite_test_config)
# Answer a question with RAG.
# Assemble different types of search inputs for RAG.
prompt = "What does it mean for two events to be simultaneous?"
stream = rag(prompt, search=hybrid_search, config=raglite_test_config)
answer = ""
for update in stream:
assert isinstance(update, str)
answer += update
assert "simultaneous" in answer.lower()
search_inputs: list[SearchMethod | list[str] | list[Chunk]] = [
hybrid_search, # A search method as input.
hybrid_search(prompt, config=raglite_test_config)[0], # Chunk ids as input.
retrieve_chunks( # Chunks as input.
hybrid_search(prompt, config=raglite_test_config)[0], config=raglite_test_config
),
]
# Answer a question with RAG.
for search_input in search_inputs:
stream = rag(prompt, search=search_input, config=raglite_test_config)
answer = ""
for update in stream:
assert isinstance(update, str)
answer += update
assert "simultaneous" in answer.lower()
54 changes: 54 additions & 0 deletions tests/test_rerank.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
"""Test RAGLite's reranking functionality."""

import pytest
from rerankers.models.ranker import BaseRanker

from raglite import RAGLiteConfig, hybrid_search, rerank, retrieve_chunks
from raglite._database import Chunk
from raglite._flashrank import PatchedFlashRankRanker as FlashRankRanker


@pytest.fixture(
params=[
pytest.param(None, id="no_reranker"),
pytest.param(FlashRankRanker("ms-marco-MiniLM-L-12-v2", verbose=0), id="flashrank_english"),
pytest.param(
(
("en", FlashRankRanker("ms-marco-MiniLM-L-12-v2", verbose=0)),
("other", FlashRankRanker("ms-marco-MultiBERT-L-12", verbose=0)),
),
id="flashrank_multilingual",
),
],
)
def reranker(
request: pytest.FixtureRequest,
) -> BaseRanker | tuple[tuple[str, BaseRanker], ...] | None:
"""Get a reranker to test RAGLite with."""
reranker: BaseRanker | tuple[tuple[str, BaseRanker], ...] | None = request.param
return reranker


def test_reranker(
raglite_test_config: RAGLiteConfig,
reranker: BaseRanker | tuple[tuple[str, BaseRanker], ...] | None,
) -> None:
"""Test inserting a document, updating the indexes, and searching for a query."""
# Update the config with the reranker.
raglite_test_config = RAGLiteConfig(
db_url=raglite_test_config.db_url, embedder=raglite_test_config.embedder, reranker=reranker
)
# Search for a query.
query = "What does it mean for two events to be simultaneous?"
chunk_ids, _ = hybrid_search(query, num_results=3, config=raglite_test_config)
# Retrieve the chunks.
chunks = retrieve_chunks(chunk_ids, config=raglite_test_config)
assert all(isinstance(chunk, Chunk) for chunk in chunks)
# Rerank the chunks given an inverted chunk order.
reranked_chunks = rerank(query, chunks[::-1], config=raglite_test_config)
if reranker is not None and "text-embedding-3-small" not in raglite_test_config.embedder:
assert reranked_chunks[:3] == chunks[:3]
# Test that we can also rerank given the chunk_ids only.
reranked_chunks = rerank(query, chunk_ids[::-1], config=raglite_test_config)
if reranker is not None and "text-embedding-3-small" not in raglite_test_config.embedder:
assert reranked_chunks[:3] == chunks[:3]
47 changes: 47 additions & 0 deletions tests/test_search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
"""Test RAGLite's search functionality."""

import pytest

from raglite import (
RAGLiteConfig,
hybrid_search,
keyword_search,
retrieve_chunks,
retrieve_segments,
vector_search,
)
from raglite._database import Chunk
from raglite._typing import SearchMethod


@pytest.fixture(
params=[
pytest.param(keyword_search, id="keyword_search"),
pytest.param(vector_search, id="vector_search"),
pytest.param(hybrid_search, id="hybrid_search"),
],
)
def search_method(
request: pytest.FixtureRequest,
) -> SearchMethod:
"""Get a search method to test RAGLite with."""
search_method: SearchMethod = request.param
return search_method


def test_search(raglite_test_config: RAGLiteConfig, search_method: SearchMethod) -> None:
"""Test searching for a query."""
# Search for a query.
query = "What does it mean for two events to be simultaneous?"
num_results = 5
chunk_ids, scores = search_method(query, num_results=num_results, config=raglite_test_config)
assert len(chunk_ids) == len(scores) == num_results
assert all(isinstance(chunk_id, str) for chunk_id in chunk_ids)
assert all(isinstance(score, float) for score in scores)
# Retrieve the chunks.
chunks = retrieve_chunks(chunk_ids, config=raglite_test_config)
assert all(isinstance(chunk, Chunk) for chunk in chunks)
assert any("Definition of Simultaneity" in str(chunk) for chunk in chunks)
# Extend the chunks with their neighbours and group them into contiguous segments.
segments = retrieve_segments(chunk_ids, neighbors=(-1, 1), config=raglite_test_config)
assert all(isinstance(segment, str) for segment in segments)

0 comments on commit e3f5893

Please sign in to comment.