diff --git a/README.md b/README.md
index 2cf9e4c..4b1be88 100644
--- a/README.md
+++ b/README.md
@@ -23,6 +23,8 @@ RAGLite is a Python toolkit for Retrieval-Augmented Generation (RAG) with Postgr
- 🧬 Multi-vector chunk embedding with [late chunking](https://weaviate.io/blog/late-chunking) and [contextual chunk headings](https://d-star.ai/solving-the-out-of-context-chunk-problem-for-rag)
- ✂️ Optimal [level 4 semantic chunking](https://medium.com/@anuragmishra_27746/five-levels-of-chunking-strategies-in-rag-notes-from-gregs-video-7b735895694d) by solving a [binary integer programming problem](https://en.wikipedia.org/wiki/Integer_programming)
- 🔍 [Hybrid search](https://plg.uwaterloo.ca/~gvcormac/cormacksigir09-rrf.pdf) with the database's native keyword & vector search ([tsvector](https://www.postgresql.org/docs/current/datatype-textsearch.html)+[pgvector](https://github.com/pgvector/pgvector), [FTS5](https://www.sqlite.org/fts5.html)+[sqlite-vec](https://github.com/asg017/sqlite-vec)[^1])
+- 💰 Improved cost and latency with a [prompt caching-aware message array structure](https://platform.openai.com/docs/guides/prompt-caching)
+- 🍰 Improved output quality with [Anthropic's long-context prompt format](https://docs.anthropic.com/en/docs/build-with-claude/prompt-engineering/long-context-tips)
- 🌀 Optimal [closed-form linear query adapter](src/raglite/_query_adapter.py) by solving an [orthogonal Procrustes problem](https://en.wikipedia.org/wiki/Orthogonal_Procrustes_problem)
##### Extensible
@@ -157,38 +159,85 @@ insert_document(Path("Special Relativity.pdf"), config=my_config)
### 3. Searching and Retrieval-Augmented Generation (RAG)
-Now, you can search for chunks with vector search, keyword search, or a hybrid of the two. You can also rerank the search results with the configured reranker. And you can use any search method of your choice (`hybrid_search` is the default) together with reranking to answer questions with RAG:
+#### 3.1 Simple RAG pipeline
+
+Now you can run a simple but powerful RAG pipeline that consists of retrieving the most relevant chunk spans (each of which is a list of consecutive chunks) with hybrid search and reranking, converting the user prompt to a RAG instruction and appending it to the message history, and finally generating the RAG response:
+
+```python
+from raglite import create_rag_instruction, rag, retrieve_rag_context
+
+# Retrieve relevant chunk spans with hybrid search and reranking:
+user_prompt = "How is intelligence measured?"
+chunk_spans = retrieve_rag_context(query=user_prompt, num_chunks=5, config=my_config)
+
+# Append a RAG instruction based on the user prompt and context to the message history:
+messages = [] # Or start with an existing message history.
+messages.append(create_rag_instruction(user_prompt=user_prompt, context=chunk_spans))
+
+# Stream the RAG response:
+stream = rag(messages, config=my_config)
+for update in stream:
+ print(update, end="")
+
+# Access the documents cited in the RAG response:
+documents = [chunk_span.document for chunk_span in chunk_spans]
+```
+
+#### 3.2 Advanced RAG pipeline
+
+> [!TIP]
+> 🥇 Reranking can significantly improve the output quality of a RAG application. To add reranking to your application: first search for a larger set of 20 relevant chunks, then rerank them with a [rerankers](https://github.com/AnswerDotAI/rerankers) reranker, and finally keep the top 5 chunks.
+
+In addition to the simple RAG pipeline, RAGLite also offers more advanced control over the individual steps of the pipeline. A full pipeline consists of several steps:
+
+1. Searching for relevant chunks with keyword, vector, or hybrid search
+2. Retrieving the chunks from the database
+3. Reranking the chunks and selecting the top 5 results
+4. Extending the chunks with their neighbors and grouping them into chunk spans
+5. Converting the user prompt to a RAG instruction and appending it to the message history
+6. Streaming an LLM response to the message history
+7. Accessing the cited documents from the chunk spans
```python
# Search for chunks:
from raglite import hybrid_search, keyword_search, vector_search
-prompt = "How is intelligence measured?"
-chunk_ids_vector, _ = vector_search(prompt, num_results=20, config=my_config)
-chunk_ids_keyword, _ = keyword_search(prompt, num_results=20, config=my_config)
-chunk_ids_hybrid, _ = hybrid_search(prompt, num_results=20, config=my_config)
+user_prompt = "How is intelligence measured?"
+chunk_ids_vector, _ = vector_search(user_prompt, num_results=20, config=my_config)
+chunk_ids_keyword, _ = keyword_search(user_prompt, num_results=20, config=my_config)
+chunk_ids_hybrid, _ = hybrid_search(user_prompt, num_results=20, config=my_config)
# Retrieve chunks:
from raglite import retrieve_chunks
chunks_hybrid = retrieve_chunks(chunk_ids_hybrid, config=my_config)
-# Rerank chunks:
+# Rerank chunks and keep the top 5 (optional, but recommended):
from raglite import rerank_chunks
-chunks_reranked = rerank_chunks(prompt, chunks_hybrid, config=my_config)
+chunks_reranked = rerank_chunks(user_prompt, chunks_hybrid, config=my_config)
+chunks_reranked = chunks_reranked[:5]
+
+# Extend chunks with their neighbors and group them into chunk spans:
+from raglite import retrieve_chunk_spans
+
+chunk_spans = retrieve_chunk_spans(chunks_reranked, config=my_config)
+
+# Append a RAG instruction based on the user prompt and context to the message history:
+from raglite import create_rag_instruction
+
+messages = [] # Or start with an existing message history.
+messages.append(create_rag_instruction(user_prompt=user_prompt, context=chunk_spans))
-# Answer questions with RAG:
+# Stream the RAG response:
from raglite import rag
-prompt = "What does it mean for two events to be simultaneous?"
-stream = rag(prompt, config=my_config)
+stream = rag(messages, config=my_config)
for update in stream:
print(update, end="")
-# You can also pass a search method or search results directly:
-stream = rag(prompt, search=hybrid_search, config=my_config)
-stream = rag(prompt, search=chunks_reranked, config=my_config)
+# Access the documents cited in the RAG response:
+documents = [chunk_span.document for chunk_span in chunk_spans]
```
### 4. Computing and using an optimal query adapter
@@ -200,7 +249,7 @@ RAGLite can compute and apply an [optimal closed-form query adapter](src/raglite
from raglite import insert_evals, update_query_adapter
insert_evals(num_evals=100, config=my_config)
-update_query_adapter(config=my_config) # From here, simply call vector_search to use the query adapter.
+update_query_adapter(config=my_config) # From here, every vector search will use the query adapter.
```
### 5. Evaluation of retrieval and generation
diff --git a/src/raglite/__init__.py b/src/raglite/__init__.py
index c1c9dc9..8ef7a26 100644
--- a/src/raglite/__init__.py
+++ b/src/raglite/__init__.py
@@ -5,13 +5,13 @@
from raglite._eval import answer_evals, evaluate, insert_evals
from raglite._insert import insert_document
from raglite._query_adapter import update_query_adapter
-from raglite._rag import async_rag, rag
+from raglite._rag import async_rag, create_rag_instruction, rag, retrieve_rag_context
from raglite._search import (
hybrid_search,
keyword_search,
rerank_chunks,
+ retrieve_chunk_spans,
retrieve_chunks,
- retrieve_segments,
vector_search,
)
@@ -25,9 +25,11 @@
"keyword_search",
"vector_search",
"retrieve_chunks",
- "retrieve_segments",
+ "retrieve_chunk_spans",
"rerank_chunks",
# RAG
+ "retrieve_rag_context",
+ "create_rag_instruction",
"async_rag",
"rag",
# Query adapter
diff --git a/src/raglite/_chainlit.py b/src/raglite/_chainlit.py
index 9499baf..1f3eeeb 100644
--- a/src/raglite/_chainlit.py
+++ b/src/raglite/_chainlit.py
@@ -9,9 +9,11 @@
from raglite import (
RAGLiteConfig,
async_rag,
+ create_rag_instruction,
hybrid_search,
insert_document,
rerank_chunks,
+ retrieve_chunk_spans,
retrieve_chunks,
)
from raglite._markdown import document_to_markdown
@@ -19,6 +21,7 @@
async_insert_document = cl.make_async(insert_document)
async_hybrid_search = cl.make_async(hybrid_search)
async_retrieve_chunks = cl.make_async(retrieve_chunks)
+async_retrieve_chunk_spans = cl.make_async(retrieve_chunk_spans)
async_rerank_chunks = cl.make_async(rerank_chunks)
@@ -84,9 +87,12 @@ async def handle_message(user_message: cl.Message) -> None:
step.input = Path(file.path).name
await async_insert_document(Path(file.path), config=config)
# Append any inline attachments to the user prompt.
- user_prompt = f"{user_message.content}\n\n" + "\n\n".join(
- f'\n{attachment.strip()}\n'
- for i, attachment in enumerate(inline_attachments)
+ user_prompt = (
+ "\n\n".join(
+ f'\n{attachment.strip()}\n'
+ for i, attachment in enumerate(inline_attachments)
+ )
+ + f"\n\n{user_message.content}"
)
# Search for relevant contexts for RAG.
async with cl.Step(name="search", type="retrieval") as step:
@@ -94,24 +100,24 @@ async def handle_message(user_message: cl.Message) -> None:
chunk_ids, _ = await async_hybrid_search(query=user_prompt, num_results=10, config=config)
chunks = await async_retrieve_chunks(chunk_ids=chunk_ids, config=config)
step.output = chunks
- step.elements = [ # Show the top 3 chunks inline.
- cl.Text(content=str(chunk), display="inline") for chunk in chunks[:3]
+ step.elements = [ # Show the top chunks inline.
+ cl.Text(content=str(chunk), display="inline") for chunk in chunks[:5]
]
- # Rerank the chunks.
+ await step.update() # TODO: Workaround for https://github.com/Chainlit/chainlit/issues/602.
+ # Rerank the chunks and group them into chunk spans.
async with cl.Step(name="rerank", type="rerank") as step:
step.input = chunks
chunks = await async_rerank_chunks(query=user_prompt, chunk_ids=chunks, config=config)
- step.output = chunks
- step.elements = [ # Show the top 3 chunks inline.
- cl.Text(content=str(chunk), display="inline") for chunk in chunks[:3]
+ chunk_spans = await async_retrieve_chunk_spans(chunks[:5], config=config)
+ step.output = chunk_spans
+ step.elements = [ # Show the top chunk spans inline.
+ cl.Text(content=str(chunk_span), display="inline") for chunk_span in chunk_spans
]
+ await step.update() # TODO: Workaround for https://github.com/Chainlit/chainlit/issues/602.
# Stream the LLM response.
assistant_message = cl.Message(content="")
- async for token in async_rag(
- prompt=user_prompt,
- search=chunks,
- messages=cl.chat_context.to_openai()[-5:], # type: ignore[no-untyped-call]
- config=config,
- ):
+ messages: list[dict[str, str]] = cl.chat_context.to_openai()[:-1] # type: ignore[no-untyped-call]
+ messages.append(create_rag_instruction(user_prompt=user_prompt, context=chunk_spans))
+ async for token in async_rag(messages, config=config):
await assistant_message.stream_token(token)
await assistant_message.update() # type: ignore[no-untyped-call]
diff --git a/src/raglite/_database.py b/src/raglite/_database.py
index 36a7fe4..510a3bb 100644
--- a/src/raglite/_database.py
+++ b/src/raglite/_database.py
@@ -2,29 +2,31 @@
import datetime
import json
+from dataclasses import dataclass, field
from functools import lru_cache
from hashlib import sha256
from pathlib import Path
from typing import Any
+from xml.sax.saxutils import escape
import numpy as np
from markdown_it import MarkdownIt
from pydantic import ConfigDict
from sqlalchemy.engine import Engine, make_url
-from sqlmodel import (
- JSON,
- Column,
- Field,
- Relationship,
- Session,
- SQLModel,
- create_engine,
- text,
-)
+from sqlmodel import JSON, Column, Field, Relationship, Session, SQLModel, create_engine, text
from raglite._config import RAGLiteConfig
from raglite._litellm import get_embedding_dim
-from raglite._typing import Embedding, FloatMatrix, FloatVector, PickledObject
+from raglite._typing import (
+ ChunkId,
+ DocumentId,
+ Embedding,
+ EvalId,
+ FloatMatrix,
+ FloatVector,
+ IndexId,
+ PickledObject,
+)
def hash_bytes(data: bytes, max_len: int = 16) -> str:
@@ -39,7 +41,7 @@ class Document(SQLModel, table=True):
model_config = ConfigDict(arbitrary_types_allowed=True) # type: ignore[assignment]
# Table columns.
- id: str = Field(..., primary_key=True)
+ id: DocumentId = Field(..., primary_key=True)
filename: str
url: str | None = Field(default=None)
metadata_: dict[str, Any] = Field(default_factory=dict, sa_column=Column("metadata", JSON))
@@ -70,8 +72,8 @@ class Chunk(SQLModel, table=True):
model_config = ConfigDict(arbitrary_types_allowed=True) # type: ignore[assignment]
# Table columns.
- id: str = Field(..., primary_key=True)
- document_id: str = Field(..., foreign_key="document.id", index=True)
+ id: ChunkId = Field(..., primary_key=True)
+ document_id: DocumentId = Field(..., foreign_key="document.id", index=True)
index: int = Field(..., index=True)
headings: str
body: str
@@ -83,11 +85,7 @@ class Chunk(SQLModel, table=True):
@staticmethod
def from_body(
- document_id: str,
- index: int,
- body: str,
- headings: str = "",
- **kwargs: Any,
+ document_id: DocumentId, index: int, body: str, headings: str = "", **kwargs: Any
) -> "Chunk":
"""Create a chunk from Markdown."""
return Chunk(
@@ -139,10 +137,62 @@ def __repr__(self) -> str:
indent=4,
)
- def __str__(self) -> str:
- """Context representation of this chunk."""
+ @property
+ def content(self) -> str:
+ """Return this chunk's contextual heading and body."""
return f"{self.headings.strip()}\n\n{self.body.strip()}".strip()
+ def __str__(self) -> str:
+ """Return this chunk's content."""
+ return self.content
+
+
+@dataclass
+class ChunkSpan:
+ """A consecutive sequence of chunks from a single document."""
+
+ chunks: list[Chunk]
+ document: Document = field(init=False)
+
+ def __post_init__(self) -> None:
+ """Set the document field."""
+ if self.chunks:
+ self.document = self.chunks[0].document
+
+ def to_xml(self, index: int | None = None) -> str:
+ """Convert this chunk span to an XML representation.
+
+ The XML representation follows Anthropic's best practices [1].
+
+ [1] https://docs.anthropic.com/en/docs/build-with-claude/prompt-engineering/long-context-tips
+ """
+ if not self.chunks:
+ return ""
+ index_attribute = f' index="{index}"' if index is not None else ""
+ xml = "\n".join(
+ [
+ f'',
+ f"",
+ f'',
+ f"\n{escape(self.chunks[0].headings.strip())}\n",
+ f"\n{escape(''.join(chunk.body for chunk in self.chunks).strip())}\n",
+ "",
+ "",
+ ]
+ )
+ return xml
+
+ @property
+ def content(self) -> str:
+ """Return this chunk span's contextual heading and chunk bodies."""
+ heading = self.chunks[0].headings.strip() if self.chunks else ""
+ bodies = "".join(chunk.body for chunk in self.chunks)
+ return f"{heading}\n\n{bodies}".strip()
+
+ def __str__(self) -> str:
+ """Return this chunk span's content."""
+ return self.content
+
class ChunkEmbedding(SQLModel, table=True):
"""A (sub-)chunk embedding."""
@@ -154,7 +204,7 @@ class ChunkEmbedding(SQLModel, table=True):
# Table columns.
id: int = Field(..., primary_key=True)
- chunk_id: str = Field(..., foreign_key="chunk.id", index=True)
+ chunk_id: ChunkId = Field(..., foreign_key="chunk.id", index=True)
embedding: FloatVector = Field(..., sa_column=Column(Embedding(dim=-1)))
# Add relationship so we can access embedding.chunk.
@@ -175,7 +225,7 @@ class IndexMetadata(SQLModel, table=True):
model_config = ConfigDict(arbitrary_types_allowed=True) # type: ignore[assignment]
# Table columns.
- id: str = Field(..., primary_key=True)
+ id: IndexId = Field(..., primary_key=True)
version: datetime.datetime = Field(
default_factory=lambda: datetime.datetime.now(datetime.timezone.utc)
)
@@ -208,9 +258,9 @@ class Eval(SQLModel, table=True):
model_config = ConfigDict(arbitrary_types_allowed=True) # type: ignore[assignment]
# Table columns.
- id: str = Field(..., primary_key=True)
- document_id: str = Field(..., foreign_key="document.id", index=True)
- chunk_ids: list[str] = Field(default_factory=list, sa_column=Column(JSON))
+ id: EvalId = Field(..., primary_key=True)
+ document_id: DocumentId = Field(..., foreign_key="document.id", index=True)
+ chunk_ids: list[ChunkId] = Field(default_factory=list, sa_column=Column(JSON))
question: str
contexts: list[str] = Field(default_factory=list, sa_column=Column(JSON))
ground_truth: str
@@ -221,10 +271,7 @@ class Eval(SQLModel, table=True):
@staticmethod
def from_chunks(
- question: str,
- contexts: list[Chunk],
- ground_truth: str,
- **kwargs: Any,
+ question: str, contexts: list[Chunk], ground_truth: str, **kwargs: Any
) -> "Eval":
"""Create a chunk from Markdown."""
document_id = contexts[0].document_id
@@ -284,18 +331,22 @@ def create_database_engine(config: RAGLiteConfig | None = None) -> Engine:
with Session(engine) as session:
metrics = {"cosine": "cosine", "dot": "ip", "euclidean": "l2", "l1": "l1", "l2": "l2"}
session.execute(
- text("""
+ text(
+ """
CREATE INDEX IF NOT EXISTS keyword_search_chunk_index ON chunk USING GIN (to_tsvector('simple', body));
- """)
+ """
+ )
)
session.execute(
- text(f"""
+ text(
+ f"""
CREATE INDEX IF NOT EXISTS vector_search_chunk_index ON chunk_embedding
USING hnsw (
(embedding::halfvec({embedding_dim}))
halfvec_{metrics[config.vector_search_index_metric]}_ops
);
- """)
+ """
+ )
)
session.commit()
elif db_backend == "sqlite":
@@ -304,31 +355,39 @@ def create_database_engine(config: RAGLiteConfig | None = None) -> Engine:
# [1] https://www.sqlite.org/fts5.html#external_content_tables
with Session(engine) as session:
session.execute(
- text("""
+ text(
+ """
CREATE VIRTUAL TABLE IF NOT EXISTS keyword_search_chunk_index USING fts5(body, content='chunk', content_rowid='rowid');
- """)
+ """
+ )
)
session.execute(
- text("""
+ text(
+ """
CREATE TRIGGER IF NOT EXISTS keyword_search_chunk_index_auto_insert AFTER INSERT ON chunk BEGIN
INSERT INTO keyword_search_chunk_index(rowid, body) VALUES (new.rowid, new.body);
END;
- """)
+ """
+ )
)
session.execute(
- text("""
+ text(
+ """
CREATE TRIGGER IF NOT EXISTS keyword_search_chunk_index_auto_delete AFTER DELETE ON chunk BEGIN
INSERT INTO keyword_search_chunk_index(keyword_search_chunk_index, rowid, body) VALUES('delete', old.rowid, old.body);
END;
- """)
+ """
+ )
)
session.execute(
- text("""
+ text(
+ """
CREATE TRIGGER IF NOT EXISTS keyword_search_chunk_index_auto_update AFTER UPDATE ON chunk BEGIN
INSERT INTO keyword_search_chunk_index(keyword_search_chunk_index, rowid, body) VALUES('delete', old.rowid, old.body);
INSERT INTO keyword_search_chunk_index(rowid, body) VALUES (new.rowid, new.body);
END;
- """)
+ """
+ )
)
session.commit()
return engine
diff --git a/src/raglite/_eval.py b/src/raglite/_eval.py
index bfce7c4..f26789c 100644
--- a/src/raglite/_eval.py
+++ b/src/raglite/_eval.py
@@ -12,8 +12,8 @@
from raglite._config import RAGLiteConfig
from raglite._database import Chunk, Document, Eval, create_database_engine
from raglite._extract import extract_with_llm
-from raglite._rag import rag
-from raglite._search import hybrid_search, retrieve_segments, vector_search
+from raglite._rag import create_rag_instruction, rag, retrieve_rag_context
+from raglite._search import hybrid_search, retrieve_chunk_spans, vector_search
from raglite._typing import SearchMethod
@@ -74,11 +74,14 @@ def validate_question(cls, value: str) -> str:
continue
# Expand the seed chunk into a set of related chunks.
related_chunk_ids, _ = vector_search(
- np.mean(seed_chunk.embedding_matrix, axis=0, keepdims=True),
+ query=np.mean(seed_chunk.embedding_matrix, axis=0, keepdims=True),
num_results=randint(2, max_contexts_per_eval // 2), # noqa: S311
config=config,
)
- related_chunks = retrieve_segments(related_chunk_ids, config=config)
+ related_chunks = [
+ str(chunk_spans)
+ for chunk_spans in retrieve_chunk_spans(related_chunk_ids, config=config)
+ ]
# Extract a question from the seed chunk's related chunks.
try:
question_response = extract_with_llm(
@@ -90,7 +93,7 @@ def validate_question(cls, value: str) -> str:
question = question_response.question
# Search for candidate chunks to answer the generated question.
candidate_chunk_ids, _ = hybrid_search(
- question, num_results=max_contexts_per_eval, config=config
+ query=question, num_results=max_contexts_per_eval, config=config
)
candidate_chunks = [session.get(Chunk, chunk_id) for chunk_id in candidate_chunk_ids]
@@ -157,9 +160,7 @@ class AnswerResponse(BaseModel):
answer = answer_response.answer
# Store the eval in the database.
eval_ = Eval.from_chunks(
- question=question,
- contexts=relevant_chunks,
- ground_truth=answer,
+ question=question, contexts=relevant_chunks, ground_truth=answer
)
session.add(eval_)
session.commit()
@@ -181,11 +182,15 @@ def answer_evals(
answers: list[str] = []
contexts: list[list[str]] = []
for eval_ in tqdm(evals, desc="Answering evals", unit="eval", dynamic_ncols=True):
- response = rag(eval_.question, search=search, config=config)
+ chunk_spans = retrieve_rag_context(query=eval_.question, search=search, config=config)
+ messages = [create_rag_instruction(user_prompt=eval_.question, context=chunk_spans)]
+ response = rag(messages, config=config)
answer = "".join(response)
answers.append(answer)
- chunk_ids, _ = search(eval_.question, config=config)
- contexts.append(retrieve_segments(chunk_ids))
+ chunk_ids, _ = search(query=eval_.question, config=config)
+ contexts.append(
+ [str(chunk_span) for chunk_span in retrieve_chunk_spans(chunk_ids, config=config)]
+ )
# Collect the answered evals.
answered_evals: dict[str, list[str] | list[list[str]]] = {
"question": [eval_.question for eval_ in evals],
@@ -199,8 +204,7 @@ def answer_evals(
def evaluate(
- answered_evals: pd.DataFrame | int = 100,
- config: RAGLiteConfig | None = None,
+ answered_evals: pd.DataFrame | int = 100, config: RAGLiteConfig | None = None
) -> pd.DataFrame:
"""Evaluate the performance of a set of answered evals with Ragas."""
try:
diff --git a/src/raglite/_extract.py b/src/raglite/_extract.py
index 95e46e3..f3d73ff 100644
--- a/src/raglite/_extract.py
+++ b/src/raglite/_extract.py
@@ -61,7 +61,7 @@ class MyNameResponse(BaseModel):
# Concatenate the user prompt if it is a list of strings.
if isinstance(user_prompt, list):
user_prompt = "\n\n".join(
- f'\n{chunk.strip()}\n'
+ f'\n{chunk.strip()}\n'
for i, chunk in enumerate(user_prompt)
)
# Enable JSON schema validation.
diff --git a/src/raglite/_insert.py b/src/raglite/_insert.py
index 804d0b7..42061a6 100644
--- a/src/raglite/_insert.py
+++ b/src/raglite/_insert.py
@@ -13,11 +13,11 @@
from raglite._markdown import document_to_markdown
from raglite._split_chunks import split_chunks
from raglite._split_sentences import split_sentences
-from raglite._typing import FloatMatrix
+from raglite._typing import DocumentId, FloatMatrix
def _create_chunk_records(
- document_id: str,
+ document_id: DocumentId,
chunks: list[str],
chunk_embeddings: list[FloatMatrix],
config: RAGLiteConfig,
diff --git a/src/raglite/_rag.py b/src/raglite/_rag.py
index 81ffda2..8fb1a0c 100644
--- a/src/raglite/_rag.py
+++ b/src/raglite/_rag.py
@@ -2,159 +2,93 @@
from collections.abc import AsyncIterator, Iterator
+import numpy as np
from litellm import acompletion, completion
from raglite._config import RAGLiteConfig
-from raglite._database import Chunk
+from raglite._database import ChunkSpan
from raglite._litellm import get_context_size
-from raglite._search import hybrid_search, rerank_chunks, retrieve_segments
+from raglite._search import hybrid_search, rerank_chunks, retrieve_chunk_spans
from raglite._typing import SearchMethod
-RAG_SYSTEM_PROMPT = """
+# The default RAG instruction template follows Anthropic's best practices [1].
+# [1] https://docs.anthropic.com/en/docs/build-with-claude/prompt-engineering/long-context-tips
+RAG_INSTRUCTION_TEMPLATE = """
You are a friendly and knowledgeable assistant that provides complete and insightful answers.
-Answer the user's question using only the context below.
+Whenever possible, use only the provided context to respond to the question at the end.
When responding, you MUST NOT reference the existence of the context, directly or indirectly.
Instead, you MUST treat the context as if its contents are entirely part of your working memory.
+
+{context}
+
+{user_prompt}
""".strip()
-def _max_contexts(
- prompt: str,
+def retrieve_rag_context(
+ query: str,
*,
- max_contexts: int = 5,
- context_neighbors: tuple[int, ...] | None = (-1, 1),
- messages: list[dict[str, str]] | None = None,
+ num_chunks: int = 5,
+ chunk_neighbors: tuple[int, ...] | None = (-1, 1),
+ search: SearchMethod = hybrid_search,
config: RAGLiteConfig | None = None,
-) -> int:
- """Determine the maximum number of contexts for RAG."""
- # Get the model's context size.
+) -> list[ChunkSpan]:
+ """Retrieve context for RAG."""
+ # If the user has configured a reranker, we retrieve extra contexts to rerank.
config = config or RAGLiteConfig()
- max_tokens = get_context_size(config)
- # Reduce the maximum number of contexts to take into account the LLM's context size.
- max_context_tokens = (
- max_tokens
- - sum(len(message["content"]) // 3 for message in messages or []) # Previous messages.
- - len(RAG_SYSTEM_PROMPT) // 3 # System prompt.
- - len(prompt) // 3 # User prompt.
- )
- max_tokens_per_context = config.chunk_max_size // 3
- max_tokens_per_context *= 1 + len(context_neighbors or [])
- max_contexts = min(max_contexts, max_context_tokens // max_tokens_per_context)
- if max_contexts <= 0:
- error_message = "Not enough context tokens available for RAG."
- raise ValueError(error_message)
- return max_contexts
-
-
-def _contexts( # noqa: PLR0913
- prompt: str,
- *,
- max_contexts: int = 5,
- context_neighbors: tuple[int, ...] | None = (-1, 1),
- search: SearchMethod | list[str] | list[Chunk] = hybrid_search,
- messages: list[dict[str, str]] | None = None,
- config: RAGLiteConfig | None = None,
-) -> list[str]:
- """Retrieve contexts for RAG."""
- # Determine the maximum number of contexts.
- max_contexts = _max_contexts(
- prompt,
- max_contexts=max_contexts,
- context_neighbors=context_neighbors,
- messages=messages,
- config=config,
- )
- # Retrieve the top chunks.
- config = config or RAGLiteConfig()
- chunks: list[str] | list[Chunk]
- if callable(search):
- # If the user has configured a reranker, we retrieve extra contexts to rerank.
- extra_contexts = 3 * max_contexts if config.reranker else 0
- # Retrieve relevant contexts.
- chunk_ids, _ = search(prompt, num_results=max_contexts + extra_contexts, config=config)
- # Rerank the relevant contexts.
- chunks = rerank_chunks(query=prompt, chunk_ids=chunk_ids, config=config)
- else:
- # The user has passed a list of chunk_ids or chunks directly.
- chunks = search
+ extra_chunks = 3 * num_chunks if config.reranker else 0
+ # Search for relevant chunks.
+ chunk_ids, _ = search(query, num_results=num_chunks + extra_chunks, config=config)
+ # Rerank the chunks from most to least relevant.
+ chunks = rerank_chunks(query, chunk_ids=chunk_ids, config=config)
# Extend the top contexts with their neighbors and group chunks into contiguous segments.
- segments = retrieve_segments(chunks[:max_contexts], neighbors=context_neighbors, config=config)
- return segments
+ context = retrieve_chunk_spans(chunks[:num_chunks], neighbors=chunk_neighbors, config=config)
+ return context
-def rag( # noqa: PLR0913
- prompt: str,
+def create_rag_instruction(
+ user_prompt: str,
+ context: list[ChunkSpan],
*,
- max_contexts: int = 5,
- context_neighbors: tuple[int, ...] | None = (-1, 1),
- search: SearchMethod | list[str] | list[Chunk] = hybrid_search,
- messages: list[dict[str, str]] | None = None,
- system_prompt: str = RAG_SYSTEM_PROMPT,
- config: RAGLiteConfig | None = None,
-) -> Iterator[str]:
- """Retrieval-augmented generation."""
- # Get the contexts for RAG as contiguous segments of chunks.
- config = config or RAGLiteConfig()
- segments = _contexts(
- prompt,
- max_contexts=max_contexts,
- context_neighbors=context_neighbors,
- search=search,
- config=config,
- )
- system_prompt = f"{system_prompt}\n\n" + "\n\n".join(
- f'\n{segment.strip()}\n'
- for i, segment in enumerate(segments)
- )
+ rag_instruction_template: str = RAG_INSTRUCTION_TEMPLATE,
+) -> dict[str, str]:
+ """Convert a user prompt to a RAG instruction.
+
+ The RAG instruction's format follows Anthropic's best practices [1].
+
+ [1] https://docs.anthropic.com/en/docs/build-with-claude/prompt-engineering/long-context-tips
+ """
+ message = {
+ "role": "user",
+ "content": rag_instruction_template.format(
+ user_prompt=user_prompt.strip(),
+ context="\n".join(
+ chunk_span.to_xml(index=i + 1) for i, chunk_span in enumerate(context)
+ ),
+ ),
+ }
+ return message
+
+
+def rag(messages: list[dict[str, str]], *, config: RAGLiteConfig) -> Iterator[str]:
+ # Truncate the oldest messages so we don't hit the context limit.
+ max_tokens = get_context_size(config)
+ cum_tokens = np.cumsum([len(message.get("content", "")) // 3 for message in messages][::-1])
+ messages = messages[-np.searchsorted(cum_tokens, max_tokens) :]
# Stream the LLM response.
- stream = completion(
- model=config.llm,
- messages=[
- *(messages or []),
- {"role": "system", "content": system_prompt},
- {"role": "user", "content": prompt},
- ],
- stream=True,
- )
+ stream = completion(model=config.llm, messages=messages, stream=True)
for output in stream:
token: str = output["choices"][0]["delta"].get("content") or ""
yield token
-async def async_rag( # noqa: PLR0913
- prompt: str,
- *,
- max_contexts: int = 5,
- context_neighbors: tuple[int, ...] | None = (-1, 1),
- search: SearchMethod | list[str] | list[Chunk] = hybrid_search,
- messages: list[dict[str, str]] | None = None,
- system_prompt: str = RAG_SYSTEM_PROMPT,
- config: RAGLiteConfig | None = None,
-) -> AsyncIterator[str]:
- """Retrieval-augmented generation."""
- # Get the contexts for RAG as contiguous segments of chunks.
- config = config or RAGLiteConfig()
- segments = _contexts(
- prompt,
- max_contexts=max_contexts,
- context_neighbors=context_neighbors,
- search=search,
- config=config,
- )
- system_prompt = f"{system_prompt}\n\n" + "\n\n".join(
- f'\n{segment.strip()}\n'
- for i, segment in enumerate(segments)
- )
- # Stream the LLM response.
- async_stream = await acompletion(
- model=config.llm,
- messages=[
- *(messages or []),
- {"role": "system", "content": system_prompt},
- {"role": "user", "content": prompt},
- ],
- stream=True,
- )
+async def async_rag(messages: list[dict[str, str]], *, config: RAGLiteConfig) -> AsyncIterator[str]:
+ # Truncate the oldest messages so we don't hit the context limit.
+ max_tokens = get_context_size(config)
+ cum_tokens = np.cumsum([len(message.get("content", "")) // 3 for message in messages][::-1])
+ messages = messages[-np.searchsorted(cum_tokens, max_tokens) :]
+ # Asynchronously stream the LLM response.
+ async_stream = await acompletion(model=config.llm, messages=messages, stream=True)
async for output in async_stream:
token: str = output["choices"][0]["delta"].get("content") or ""
yield token
diff --git a/src/raglite/_search.py b/src/raglite/_search.py
index 30c7982..b7976cb 100644
--- a/src/raglite/_search.py
+++ b/src/raglite/_search.py
@@ -1,4 +1,4 @@
-"""Query documents."""
+"""Search and retrieve chunks."""
import re
import string
@@ -10,20 +10,24 @@
import numpy as np
from langdetect import detect
from sqlalchemy.engine import make_url
+from sqlalchemy.orm import joinedload
from sqlmodel import Session, and_, col, or_, select, text
from raglite._config import RAGLiteConfig
-from raglite._database import Chunk, ChunkEmbedding, IndexMetadata, create_database_engine
+from raglite._database import (
+ Chunk,
+ ChunkEmbedding,
+ ChunkSpan,
+ IndexMetadata,
+ create_database_engine,
+)
from raglite._embed import embed_sentences
-from raglite._typing import FloatMatrix
+from raglite._typing import ChunkId, FloatMatrix
def vector_search(
- query: str | FloatMatrix,
- *,
- num_results: int = 3,
- config: RAGLiteConfig | None = None,
-) -> tuple[list[str], list[float]]:
+ query: str | FloatMatrix, *, num_results: int = 3, config: RAGLiteConfig | None = None
+) -> tuple[list[ChunkId], list[float]]:
"""Search chunks using ANN vector search."""
# Read the config.
config = config or RAGLiteConfig()
@@ -90,7 +94,7 @@ def vector_search(
def keyword_search(
query: str, *, num_results: int = 3, config: RAGLiteConfig | None = None
-) -> tuple[list[str], list[float]]:
+) -> tuple[list[ChunkId], list[float]]:
"""Search chunks using BM25 keyword search."""
# Read the config.
config = config or RAGLiteConfig()
@@ -104,13 +108,15 @@ def keyword_search(
query_escaped = re.sub(r"[&|!():<>\"]", " ", query)
tsv_query = " | ".join(query_escaped.split())
# Perform keyword search with tsvector.
- statement = text("""
+ statement = text(
+ """
SELECT id as chunk_id, ts_rank(to_tsvector('simple', body), to_tsquery('simple', :query)) AS score
FROM chunk
WHERE to_tsvector('simple', body) @@ to_tsquery('simple', :query)
ORDER BY score DESC
LIMIT :limit;
- """)
+ """
+ )
results = session.execute(statement, params={"query": tsv_query, "limit": num_results})
elif db_backend == "sqlite":
# Convert the query to an FTS5 query [1].
@@ -120,13 +126,15 @@ def keyword_search(
# Perform keyword search with FTS5. In FTS5, BM25 scores are negative [1], so we
# negate them to make them positive.
# [1] https://www.sqlite.org/fts5.html#the_bm25_function
- statement = text("""
+ statement = text(
+ """
SELECT chunk.id as chunk_id, -bm25(keyword_search_chunk_index) as score
FROM chunk JOIN keyword_search_chunk_index ON chunk.rowid = keyword_search_chunk_index.rowid
WHERE keyword_search_chunk_index MATCH :match
ORDER BY score DESC
LIMIT :limit;
- """)
+ """
+ )
results = session.execute(statement, params={"match": fts5_query, "limit": num_results})
# Unpack the results.
results = list(results) # type: ignore[assignment]
@@ -136,8 +144,8 @@ def keyword_search(
def reciprocal_rank_fusion(
- rankings: list[list[str]], *, k: int = 60
-) -> tuple[list[str], list[float]]:
+ rankings: list[list[ChunkId]], *, k: int = 60
+) -> tuple[list[ChunkId], list[float]]:
"""Reciprocal Rank Fusion."""
# Compute the RRF score.
chunk_ids = {chunk_id for ranking in rankings for chunk_id in ranking}
@@ -155,7 +163,7 @@ def reciprocal_rank_fusion(
def hybrid_search(
query: str, *, num_results: int = 3, num_rerank: int = 100, config: RAGLiteConfig | None = None
-) -> tuple[list[str], list[float]]:
+) -> tuple[list[ChunkId], list[float]]:
"""Search chunks by combining ANN vector search with BM25 keyword search."""
# Run both searches.
vs_chunk_ids, _ = vector_search(query, num_results=num_rerank, config=config)
@@ -167,84 +175,33 @@ def hybrid_search(
def retrieve_chunks(
- chunk_ids: list[str],
- *,
- config: RAGLiteConfig | None = None,
+ chunk_ids: list[ChunkId], *, config: RAGLiteConfig | None = None
) -> list[Chunk]:
"""Retrieve chunks by their ids."""
config = config or RAGLiteConfig()
engine = create_database_engine(config)
with Session(engine) as session:
- chunks = list(session.exec(select(Chunk).where(col(Chunk.id).in_(chunk_ids))).all())
+ chunks = list(
+ session.exec(
+ select(Chunk)
+ .where(col(Chunk.id).in_(chunk_ids))
+ # Eagerly load chunk.document.
+ .options(joinedload(Chunk.document)) # type: ignore[arg-type]
+ ).all()
+ )
chunks = sorted(chunks, key=lambda chunk: chunk_ids.index(chunk.id))
return chunks
-def retrieve_segments(
- chunk_ids: list[str] | list[Chunk],
- *,
- neighbors: tuple[int, ...] | None = (-1, 1),
- config: RAGLiteConfig | None = None,
-) -> list[str]:
- """Group chunks into contiguous segments and retrieve them."""
- # Retrieve the chunks.
- config = config or RAGLiteConfig()
- chunks: list[Chunk] = (
- retrieve_chunks(chunk_ids, config=config) # type: ignore[arg-type,assignment]
- if all(isinstance(chunk_id, str) for chunk_id in chunk_ids)
- else chunk_ids
- )
- # Extend the chunks with their neighbouring chunks.
- if neighbors:
- engine = create_database_engine(config)
- with Session(engine) as session:
- neighbor_conditions = [
- and_(Chunk.document_id == chunk.document_id, Chunk.index == chunk.index + offset)
- for chunk in chunks
- for offset in neighbors
- ]
- chunks += list(session.exec(select(Chunk).where(or_(*neighbor_conditions))).all())
- # Keep only the unique chunks.
- chunks = list(set(chunks))
- # Sort the chunks by document_id and index (needed for groupby).
- chunks = sorted(chunks, key=lambda chunk: (chunk.document_id, chunk.index))
- # Group the chunks into contiguous segments.
- segments: list[list[Chunk]] = []
- for _, group in groupby(chunks, key=lambda chunk: chunk.document_id):
- segment: list[Chunk] = []
- for chunk in group:
- if not segment or chunk.index == segment[-1].index + 1:
- segment.append(chunk)
- else:
- segments.append(segment)
- segment = [chunk]
- segments.append(segment)
- # Rank segments according to the aggregate relevance of their chunks.
- chunk_id_to_score = {chunk.id: 1 / (i + 1) for i, chunk in enumerate(chunks)}
- segments.sort(
- key=lambda segment: sum(chunk_id_to_score.get(chunk.id, 0.0) for chunk in segment),
- reverse=True,
- )
- # Convert the segments into strings.
- segments = [
- segment[0].headings.strip() + "\n\n" + "".join(chunk.body for chunk in segment).strip() # type: ignore[misc]
- for segment in segments
- ]
- return segments # type: ignore[return-value]
-
-
def rerank_chunks(
- query: str,
- chunk_ids: list[str] | list[Chunk],
- *,
- config: RAGLiteConfig | None = None,
+ query: str, chunk_ids: list[ChunkId] | list[Chunk], *, config: RAGLiteConfig | None = None
) -> list[Chunk]:
"""Rerank chunks according to their relevance to a given query."""
# Retrieve the chunks.
config = config or RAGLiteConfig()
chunks: list[Chunk] = (
retrieve_chunks(chunk_ids, config=config) # type: ignore[arg-type,assignment]
- if all(isinstance(chunk_id, str) for chunk_id in chunk_ids)
+ if all(isinstance(chunk_id, ChunkId) for chunk_id in chunk_ids)
else chunk_ids
)
# Early exit if no reranker is configured.
@@ -269,3 +226,63 @@ def rerank_chunks(
results = reranker.rank(query=query, docs=[str(chunk) for chunk in chunks])
chunks = [chunks[result.doc_id] for result in results.results]
return chunks
+
+
+def retrieve_chunk_spans(
+ chunk_ids: list[ChunkId] | list[Chunk],
+ *,
+ neighbors: tuple[int, ...] | None = (-1, 1),
+ config: RAGLiteConfig | None = None,
+) -> list[ChunkSpan]:
+ """Group chunks into contiguous chunk spans and retrieve them.
+
+ Chunk spans are ordered according to the aggregate relevance of their underlying chunks, as
+ determined by the order in which they are provided to this function.
+ """
+ # Retrieve the chunks.
+ config = config or RAGLiteConfig()
+ chunks: list[Chunk] = (
+ retrieve_chunks(chunk_ids, config=config) # type: ignore[arg-type,assignment]
+ if all(isinstance(chunk_id, ChunkId) for chunk_id in chunk_ids)
+ else chunk_ids
+ )
+ # Assign a reciprocal ranking score to each chunk based on its position in the original list.
+ chunk_id_to_score = {chunk.id: 1 / (i + 1) for i, chunk in enumerate(chunks)}
+ # Extend the chunks with their neighbouring chunks.
+ engine = create_database_engine(config)
+ with Session(engine) as session:
+ if neighbors:
+ neighbor_conditions = [
+ and_(Chunk.document_id == chunk.document_id, Chunk.index == chunk.index + offset)
+ for chunk in chunks
+ for offset in neighbors
+ ]
+ chunks += list(
+ session.exec(
+ select(Chunk)
+ .where(or_(*neighbor_conditions))
+ # Eagerly load chunk.document.
+ .options(joinedload(Chunk.document)) # type: ignore[arg-type]
+ ).all()
+ )
+ # Deduplicate and sort the chunks by document_id and index (needed for groupby).
+ unique_chunks = sorted(set(chunks), key=lambda chunk: (chunk.document_id, chunk.index))
+ # Group the chunks into contiguous segments.
+ chunk_spans: list[ChunkSpan] = []
+ for _, group in groupby(unique_chunks, key=lambda chunk: chunk.document_id):
+ chunk_sequence: list[Chunk] = []
+ for chunk in group:
+ if not chunk_sequence or chunk.index == chunk_sequence[-1].index + 1:
+ chunk_sequence.append(chunk)
+ else:
+ chunk_spans.append(ChunkSpan(chunks=chunk_sequence))
+ chunk_sequence = [chunk]
+ chunk_spans.append(ChunkSpan(chunks=chunk_sequence))
+ # Rank segments according to the aggregate relevance of their chunks.
+ chunk_spans.sort(
+ key=lambda chunk_span: sum(
+ chunk_id_to_score.get(chunk.id, 0.0) for chunk in chunk_span.chunks
+ ),
+ reverse=True,
+ )
+ return chunk_spans
diff --git a/src/raglite/_typing.py b/src/raglite/_typing.py
index 07a6904..9846ecc 100644
--- a/src/raglite/_typing.py
+++ b/src/raglite/_typing.py
@@ -12,6 +12,11 @@
from raglite._config import RAGLiteConfig
+ChunkId = str
+DocumentId = str
+EvalId = str
+IndexId = str
+
FloatMatrix = np.ndarray[tuple[int, int], np.dtype[np.floating[Any]]]
FloatVector = np.ndarray[tuple[int], np.dtype[np.floating[Any]]]
IntVector = np.ndarray[tuple[int], np.dtype[np.intp]]
diff --git a/tests/test_rag.py b/tests/test_rag.py
index 150a31b..7643bcf 100644
--- a/tests/test_rag.py
+++ b/tests/test_rag.py
@@ -1,16 +1,16 @@
"""Test RAGLite's RAG functionality."""
import os
-from typing import TYPE_CHECKING
import pytest
from llama_cpp import llama_supports_gpu_offload
-from raglite import RAGLiteConfig, hybrid_search, rag, retrieve_chunks
-
-if TYPE_CHECKING:
- from raglite._database import Chunk
- from raglite._typing import SearchMethod
+from raglite import (
+ RAGLiteConfig,
+ create_rag_instruction,
+ retrieve_rag_context,
+)
+from raglite._rag import rag
def is_accelerator_available() -> bool:
@@ -21,20 +21,13 @@ def is_accelerator_available() -> bool:
@pytest.mark.skipif(not is_accelerator_available(), reason="No accelerator available")
def test_rag(raglite_test_config: RAGLiteConfig) -> None:
"""Test Retrieval-Augmented Generation."""
- # Assemble different types of search inputs for RAG.
- prompt = "What does it mean for two events to be simultaneous?"
- search_inputs: list[SearchMethod | list[str] | list[Chunk]] = [
- hybrid_search, # A search method as input.
- hybrid_search(prompt, config=raglite_test_config)[0], # Chunk ids as input.
- retrieve_chunks( # Chunks as input.
- hybrid_search(prompt, config=raglite_test_config)[0], config=raglite_test_config
- ),
- ]
# Answer a question with RAG.
- for search_input in search_inputs:
- stream = rag(prompt, search=search_input, config=raglite_test_config)
- answer = ""
- for update in stream:
- assert isinstance(update, str)
- answer += update
- assert "simultaneous" in answer.lower()
+ user_prompt = "What does it mean for two events to be simultaneous?"
+ chunk_spans = retrieve_rag_context(query=user_prompt, config=raglite_test_config)
+ messages = [create_rag_instruction(user_prompt, context=chunk_spans)]
+ stream = rag(messages, config=raglite_test_config)
+ answer = ""
+ for update in stream:
+ assert isinstance(update, str)
+ answer += update
+ assert "simultaneous" in answer.lower()
diff --git a/tests/test_search.py b/tests/test_search.py
index 9cea9d1..e677465 100644
--- a/tests/test_search.py
+++ b/tests/test_search.py
@@ -6,11 +6,11 @@
RAGLiteConfig,
hybrid_search,
keyword_search,
+ retrieve_chunk_spans,
retrieve_chunks,
- retrieve_segments,
vector_search,
)
-from raglite._database import Chunk
+from raglite._database import Chunk, ChunkSpan, Document
from raglite._typing import SearchMethod
@@ -43,9 +43,14 @@ def test_search(raglite_test_config: RAGLiteConfig, search_method: SearchMethod)
assert all(isinstance(chunk, Chunk) for chunk in chunks)
assert all(chunk_id == chunk.id for chunk_id, chunk in zip(chunk_ids, chunks, strict=True))
assert any("Definition of Simultaneity" in str(chunk) for chunk in chunks)
+ assert all(isinstance(chunk.document, Document) for chunk in chunks)
# Extend the chunks with their neighbours and group them into contiguous segments.
- segments = retrieve_segments(chunk_ids, neighbors=(-1, 1), config=raglite_test_config)
- assert all(isinstance(segment, str) for segment in segments)
+ chunk_spans = retrieve_chunk_spans(chunk_ids, neighbors=(-1, 1), config=raglite_test_config)
+ assert all(isinstance(chunk_span, ChunkSpan) for chunk_span in chunk_spans)
+ assert all(isinstance(chunk_span.document, Document) for chunk_span in chunk_spans)
+ chunk_spans = retrieve_chunk_spans(chunks, neighbors=(-1, 1), config=raglite_test_config)
+ assert all(isinstance(chunk_span, ChunkSpan) for chunk_span in chunk_spans)
+ assert all(isinstance(chunk_span.document, Document) for chunk_span in chunk_spans)
def test_search_no_results(raglite_test_config: RAGLiteConfig, search_method: SearchMethod) -> None: