Skip to content

Commit

Permalink
feat: support prompt caching and apply Anthropic's long-context promp…
Browse files Browse the repository at this point in the history
…t format (#52)
  • Loading branch information
undo76 authored Dec 3, 2024
1 parent 003967b commit 0a08bf5
Show file tree
Hide file tree
Showing 12 changed files with 396 additions and 322 deletions.
77 changes: 63 additions & 14 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ RAGLite is a Python toolkit for Retrieval-Augmented Generation (RAG) with Postgr
- 🧬 Multi-vector chunk embedding with [late chunking](https://weaviate.io/blog/late-chunking) and [contextual chunk headings](https://d-star.ai/solving-the-out-of-context-chunk-problem-for-rag)
- ✂️ Optimal [level 4 semantic chunking](https://medium.com/@anuragmishra_27746/five-levels-of-chunking-strategies-in-rag-notes-from-gregs-video-7b735895694d) by solving a [binary integer programming problem](https://en.wikipedia.org/wiki/Integer_programming)
- 🔍 [Hybrid search](https://plg.uwaterloo.ca/~gvcormac/cormacksigir09-rrf.pdf) with the database's native keyword & vector search ([tsvector](https://www.postgresql.org/docs/current/datatype-textsearch.html)+[pgvector](https://github.com/pgvector/pgvector), [FTS5](https://www.sqlite.org/fts5.html)+[sqlite-vec](https://github.com/asg017/sqlite-vec)[^1])
- 💰 Improved cost and latency with a [prompt caching-aware message array structure](https://platform.openai.com/docs/guides/prompt-caching)
- 🍰 Improved output quality with [Anthropic's long-context prompt format](https://docs.anthropic.com/en/docs/build-with-claude/prompt-engineering/long-context-tips)
- 🌀 Optimal [closed-form linear query adapter](src/raglite/_query_adapter.py) by solving an [orthogonal Procrustes problem](https://en.wikipedia.org/wiki/Orthogonal_Procrustes_problem)

##### Extensible
Expand Down Expand Up @@ -157,38 +159,85 @@ insert_document(Path("Special Relativity.pdf"), config=my_config)

### 3. Searching and Retrieval-Augmented Generation (RAG)

Now, you can search for chunks with vector search, keyword search, or a hybrid of the two. You can also rerank the search results with the configured reranker. And you can use any search method of your choice (`hybrid_search` is the default) together with reranking to answer questions with RAG:
#### 3.1 Simple RAG pipeline

Now you can run a simple but powerful RAG pipeline that consists of retrieving the most relevant chunk spans (each of which is a list of consecutive chunks) with hybrid search and reranking, converting the user prompt to a RAG instruction and appending it to the message history, and finally generating the RAG response:

```python
from raglite import create_rag_instruction, rag, retrieve_rag_context

# Retrieve relevant chunk spans with hybrid search and reranking:
user_prompt = "How is intelligence measured?"
chunk_spans = retrieve_rag_context(query=user_prompt, num_chunks=5, config=my_config)

# Append a RAG instruction based on the user prompt and context to the message history:
messages = [] # Or start with an existing message history.
messages.append(create_rag_instruction(user_prompt=user_prompt, context=chunk_spans))

# Stream the RAG response:
stream = rag(messages, config=my_config)
for update in stream:
print(update, end="")

# Access the documents cited in the RAG response:
documents = [chunk_span.document for chunk_span in chunk_spans]
```

#### 3.2 Advanced RAG pipeline

> [!TIP]
> 🥇 Reranking can significantly improve the output quality of a RAG application. To add reranking to your application: first search for a larger set of 20 relevant chunks, then rerank them with a [rerankers](https://github.com/AnswerDotAI/rerankers) reranker, and finally keep the top 5 chunks.
In addition to the simple RAG pipeline, RAGLite also offers more advanced control over the individual steps of the pipeline. A full pipeline consists of several steps:

1. Searching for relevant chunks with keyword, vector, or hybrid search
2. Retrieving the chunks from the database
3. Reranking the chunks and selecting the top 5 results
4. Extending the chunks with their neighbors and grouping them into chunk spans
5. Converting the user prompt to a RAG instruction and appending it to the message history
6. Streaming an LLM response to the message history
7. Accessing the cited documents from the chunk spans

```python
# Search for chunks:
from raglite import hybrid_search, keyword_search, vector_search

prompt = "How is intelligence measured?"
chunk_ids_vector, _ = vector_search(prompt, num_results=20, config=my_config)
chunk_ids_keyword, _ = keyword_search(prompt, num_results=20, config=my_config)
chunk_ids_hybrid, _ = hybrid_search(prompt, num_results=20, config=my_config)
user_prompt = "How is intelligence measured?"
chunk_ids_vector, _ = vector_search(user_prompt, num_results=20, config=my_config)
chunk_ids_keyword, _ = keyword_search(user_prompt, num_results=20, config=my_config)
chunk_ids_hybrid, _ = hybrid_search(user_prompt, num_results=20, config=my_config)

# Retrieve chunks:
from raglite import retrieve_chunks

chunks_hybrid = retrieve_chunks(chunk_ids_hybrid, config=my_config)

# Rerank chunks:
# Rerank chunks and keep the top 5 (optional, but recommended):
from raglite import rerank_chunks

chunks_reranked = rerank_chunks(prompt, chunks_hybrid, config=my_config)
chunks_reranked = rerank_chunks(user_prompt, chunks_hybrid, config=my_config)
chunks_reranked = chunks_reranked[:5]

# Extend chunks with their neighbors and group them into chunk spans:
from raglite import retrieve_chunk_spans

chunk_spans = retrieve_chunk_spans(chunks_reranked, config=my_config)

# Append a RAG instruction based on the user prompt and context to the message history:
from raglite import create_rag_instruction

messages = [] # Or start with an existing message history.
messages.append(create_rag_instruction(user_prompt=user_prompt, context=chunk_spans))

# Answer questions with RAG:
# Stream the RAG response:
from raglite import rag

prompt = "What does it mean for two events to be simultaneous?"
stream = rag(prompt, config=my_config)
stream = rag(messages, config=my_config)
for update in stream:
print(update, end="")

# You can also pass a search method or search results directly:
stream = rag(prompt, search=hybrid_search, config=my_config)
stream = rag(prompt, search=chunks_reranked, config=my_config)
# Access the documents cited in the RAG response:
documents = [chunk_span.document for chunk_span in chunk_spans]
```

### 4. Computing and using an optimal query adapter
Expand All @@ -200,7 +249,7 @@ RAGLite can compute and apply an [optimal closed-form query adapter](src/raglite
from raglite import insert_evals, update_query_adapter

insert_evals(num_evals=100, config=my_config)
update_query_adapter(config=my_config) # From here, simply call vector_search to use the query adapter.
update_query_adapter(config=my_config) # From here, every vector search will use the query adapter.
```

### 5. Evaluation of retrieval and generation
Expand Down
8 changes: 5 additions & 3 deletions src/raglite/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@
from raglite._eval import answer_evals, evaluate, insert_evals
from raglite._insert import insert_document
from raglite._query_adapter import update_query_adapter
from raglite._rag import async_rag, rag
from raglite._rag import async_rag, create_rag_instruction, rag, retrieve_rag_context
from raglite._search import (
hybrid_search,
keyword_search,
rerank_chunks,
retrieve_chunk_spans,
retrieve_chunks,
retrieve_segments,
vector_search,
)

Expand All @@ -25,9 +25,11 @@
"keyword_search",
"vector_search",
"retrieve_chunks",
"retrieve_segments",
"retrieve_chunk_spans",
"rerank_chunks",
# RAG
"retrieve_rag_context",
"create_rag_instruction",
"async_rag",
"rag",
# Query adapter
Expand Down
36 changes: 21 additions & 15 deletions src/raglite/_chainlit.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,19 @@
from raglite import (
RAGLiteConfig,
async_rag,
create_rag_instruction,
hybrid_search,
insert_document,
rerank_chunks,
retrieve_chunk_spans,
retrieve_chunks,
)
from raglite._markdown import document_to_markdown

async_insert_document = cl.make_async(insert_document)
async_hybrid_search = cl.make_async(hybrid_search)
async_retrieve_chunks = cl.make_async(retrieve_chunks)
async_retrieve_chunk_spans = cl.make_async(retrieve_chunk_spans)
async_rerank_chunks = cl.make_async(rerank_chunks)


Expand Down Expand Up @@ -84,34 +87,37 @@ async def handle_message(user_message: cl.Message) -> None:
step.input = Path(file.path).name
await async_insert_document(Path(file.path), config=config)
# Append any inline attachments to the user prompt.
user_prompt = f"{user_message.content}\n\n" + "\n\n".join(
f'<attachment index="{i}">\n{attachment.strip()}\n</attachment>'
for i, attachment in enumerate(inline_attachments)
user_prompt = (
"\n\n".join(
f'<attachment index="{i}">\n{attachment.strip()}\n</attachment>'
for i, attachment in enumerate(inline_attachments)
)
+ f"\n\n{user_message.content}"
)
# Search for relevant contexts for RAG.
async with cl.Step(name="search", type="retrieval") as step:
step.input = user_message.content
chunk_ids, _ = await async_hybrid_search(query=user_prompt, num_results=10, config=config)
chunks = await async_retrieve_chunks(chunk_ids=chunk_ids, config=config)
step.output = chunks
step.elements = [ # Show the top 3 chunks inline.
cl.Text(content=str(chunk), display="inline") for chunk in chunks[:3]
step.elements = [ # Show the top chunks inline.
cl.Text(content=str(chunk), display="inline") for chunk in chunks[:5]
]
# Rerank the chunks.
await step.update() # TODO: Workaround for https://github.com/Chainlit/chainlit/issues/602.
# Rerank the chunks and group them into chunk spans.
async with cl.Step(name="rerank", type="rerank") as step:
step.input = chunks
chunks = await async_rerank_chunks(query=user_prompt, chunk_ids=chunks, config=config)
step.output = chunks
step.elements = [ # Show the top 3 chunks inline.
cl.Text(content=str(chunk), display="inline") for chunk in chunks[:3]
chunk_spans = await async_retrieve_chunk_spans(chunks[:5], config=config)
step.output = chunk_spans
step.elements = [ # Show the top chunk spans inline.
cl.Text(content=str(chunk_span), display="inline") for chunk_span in chunk_spans
]
await step.update() # TODO: Workaround for https://github.com/Chainlit/chainlit/issues/602.
# Stream the LLM response.
assistant_message = cl.Message(content="")
async for token in async_rag(
prompt=user_prompt,
search=chunks,
messages=cl.chat_context.to_openai()[-5:], # type: ignore[no-untyped-call]
config=config,
):
messages: list[dict[str, str]] = cl.chat_context.to_openai()[:-1] # type: ignore[no-untyped-call]
messages.append(create_rag_instruction(user_prompt=user_prompt, context=chunk_spans))
async for token in async_rag(messages, config=config):
await assistant_message.stream_token(token)
await assistant_message.update() # type: ignore[no-untyped-call]
Loading

0 comments on commit 0a08bf5

Please sign in to comment.