-
Notifications
You must be signed in to change notification settings - Fork 75
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #345 from DAGWorks-Inc/examples/rag-otel
examples: RAG chatbot with lancedb, dlt, opentelemetry
- Loading branch information
Showing
9 changed files
with
1,004 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
# Burr RAG with LanceDB and dlt document ingestion | ||
|
||
This example shows how to build a chatbot with RAG over Substack blogs (or any RSS feed) stored into LanceDB. | ||
|
||
![burr ui](burr-ui.gif) | ||
|
||
> Burr UI brings a new level of observability to your RAG application via OpenTelemetry | ||
Burr + [LanceDB](https://lancedb.github.io/lancedb/) constitute a powerful, but lightweight combo to build retrieval-augmented generative (RAG) agents. Burr allows to define complex agents in an easy-to-understand and debug manner. It also provides all the right features to help you productionize agents including: monitoring, storing interactions, streaming, and a fully-featured open-source UI. | ||
|
||
LanceDB makes it easy to swap embedding providers, and hides this concern from the Burr application layer. For this example, we'll be using [OpenAI](https://github.com/openai/openai-python) for embedding and response generation. | ||
|
||
By leveraging the [Burr integration with OpenTelemetry](https://blog.dagworks.io/p/building-generative-ai-agent-based), we get full visibility into the OpenAI API requests/responses and the LanceDB operations for free. | ||
|
||
To ingest data, we use [dlt and its LanceDB integration](https://dlthub.com/devel/dlt-ecosystem/destinations/lancedb), which makes it very simple to query, embed, and store blogs from the web into LanceDB tables. | ||
|
||
## Content | ||
|
||
- `notebook.ipynb` contains a tutorial | ||
- `application.py` has the `burr` code for the chatbot | ||
- `ingestion.py` has the `dlt` code for document ingestion | ||
- `utils.py` contains functions utility functions to setup `OpenTelemetry` instrumentation and environment variables |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
import os | ||
import textwrap | ||
|
||
import lancedb | ||
import openai | ||
|
||
from burr.core import Application, ApplicationBuilder, State, action | ||
from burr.lifecycle import PostRunStepHook | ||
|
||
|
||
@action(reads=[], writes=["relevant_chunks", "chat_history"]) | ||
def relevant_chunk_retrieval( | ||
state: State, | ||
user_query: str, | ||
lancedb_con: lancedb.DBConnection, | ||
) -> State: | ||
"""Search LanceDB with the user query and return the top 4 results""" | ||
text_chunks_table = lancedb_con.open_table("dagworks___contexts") | ||
search_results = ( | ||
text_chunks_table.search(user_query).select(["text", "id__"]).limit(4).to_list() | ||
) | ||
|
||
return state.update(relevant_chunks=search_results).append(chat_history=user_query) | ||
|
||
|
||
@action(reads=["chat_history", "relevant_chunks"], writes=["chat_history"]) | ||
def bot_turn(state: State, llm_client: openai.OpenAI) -> State: | ||
"""Collect relevant chunks and produce a response to the user query""" | ||
user_query = state["chat_history"][-1] | ||
relevant_chunks = state["relevant_chunks"] | ||
|
||
system_prompt = textwrap.dedent( | ||
"""You are a conversational agent designed to discuss and provide \ | ||
insights about various blog posts. Your task is to engage users in \ | ||
meaningful conversations based on the content of the blog articles they mention. | ||
""" | ||
) | ||
joined_chunks = " ".join([c["text"] for c in relevant_chunks]) | ||
user_prompt = "BLOGS CONTENT\n" + joined_chunks + "\nUSER QUERY\n" + user_query | ||
|
||
response = llm_client.chat.completions.create( | ||
model="gpt-4o-mini", | ||
messages=[ | ||
{"role": "system", "content": system_prompt}, | ||
{"role": "user", "content": user_prompt}, | ||
], | ||
) | ||
bot_answer = response.choices[0].message.content | ||
|
||
return state.append(chat_history=bot_answer) | ||
|
||
|
||
class PrintBotAnswer(PostRunStepHook): | ||
"""Hook to print the bot's answer""" | ||
|
||
def post_run_step(self, *, state, action, **future_kwargs): | ||
if action.name == "bot_turn": | ||
print("\n🤖: ", state["chat_history"][-1]) | ||
|
||
|
||
def build_application() -> Application: | ||
"""Create the Burr `Application`. This is responsible for instantiating the | ||
OpenAI client and the LanceDB connection | ||
""" | ||
llm_client = openai.OpenAI() | ||
lancedb_con = lancedb.connect(os.environ["DESTINATION__LANCEDB__CREDENTIALS__URI"]) | ||
|
||
return ( | ||
ApplicationBuilder() | ||
.with_actions( | ||
relevant_chunk_retrieval.bind(lancedb_con=lancedb_con), | ||
bot_turn.bind(llm_client=llm_client), | ||
) | ||
.with_transitions( | ||
("relevant_chunk_retrieval", "bot_turn"), | ||
("bot_turn", "relevant_chunk_retrieval"), | ||
) | ||
.with_entrypoint("relevant_chunk_retrieval") | ||
.with_tracker("local", project="substack-rag", use_otel_tracing=True) | ||
.with_hooks(PrintBotAnswer()) | ||
.build() | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
import utils | ||
|
||
utils.set_environment_variables() # set environment variables for LanceDB | ||
utils.instrument() # register the OpenTelemetry instrumentation | ||
|
||
# build the Burr `Application` | ||
app = build_application() | ||
app.visualize("statemachine.png") | ||
|
||
# Launch the Burr application in a `while` loop | ||
print("\n## Lauching RAG application ##") | ||
while True: | ||
user_query = input("\nAsk something or type `quit/q` to exit: ") | ||
if user_query.lower() in ["quit", "q"]: | ||
break | ||
|
||
_, _, _ = app.run( | ||
halt_after=["bot_turn"], | ||
inputs={"user_query": user_query}, | ||
) |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
import re | ||
from typing import Generator | ||
|
||
import dlt | ||
import feedparser | ||
import requests | ||
import utils | ||
from bs4 import BeautifulSoup | ||
|
||
|
||
def split_text(text): | ||
"""Split text on punction (., !, ?).""" | ||
sentence_endings = r"[.!?]+" | ||
for sentence in re.split(sentence_endings, text): | ||
sentence = sentence.strip() | ||
if sentence: | ||
yield sentence | ||
|
||
|
||
def contextualize(chunks: list[str], window=5, stride=3, min_window_size=2): | ||
"""Rolling window operation to join consecutive sentences into larger chunks.""" | ||
n_chunks = len(chunks) | ||
for start_i in range(0, n_chunks, stride): | ||
if (start_i + window <= n_chunks) or (n_chunks - start_i >= min_window_size): | ||
yield " ".join(chunks[start_i : min(start_i + window, n_chunks)]) | ||
|
||
|
||
@dlt.resource(name="substack", write_disposition="merge", primary_key="id") | ||
def rss_entries(substack_url: str) -> Generator: | ||
"""Substack blog entries retrieved from a RSS feed""" | ||
FIELDS_TO_EXCLUDE = [ | ||
"published_parsed", | ||
"title_detail", | ||
"summary_detail", | ||
"author_detail", | ||
"guidislink", | ||
"authors", | ||
"links", | ||
] | ||
|
||
r = requests.get(f"{substack_url}/feed") | ||
rss_feed = feedparser.parse(r.content) | ||
for entry in rss_feed["entries"]: | ||
for field in FIELDS_TO_EXCLUDE: | ||
entry.pop(field) | ||
|
||
yield entry | ||
|
||
|
||
@dlt.transformer(primary_key="id") | ||
def parsed_html(rss_entry: dict): | ||
"""Parse the HTML from the RSS entry""" | ||
soup = BeautifulSoup(rss_entry["content"][0]["value"], "html.parser") | ||
parsed_text = soup.get_text(separator=" ", strip=True) | ||
yield {"id": rss_entry["id"], "text": parsed_text} | ||
|
||
|
||
@dlt.transformer(primary_key="chunk_id") | ||
def chunks(parsed_html: dict) -> list[dict]: | ||
"""Chunk text""" | ||
return [ | ||
dict( | ||
document_id=parsed_html["id"], | ||
chunk_id=idx, | ||
text=text, | ||
) | ||
for idx, text in enumerate(split_text(parsed_html["text"])) | ||
] | ||
|
||
|
||
# order is important for reduce / rolling step | ||
# default to order of the batch or specifying sorting key | ||
@dlt.transformer(primary_key="context_id") | ||
def contexts(chunks: list[dict]) -> Generator: | ||
"""Assemble consecutive chunks into larger context windows""" | ||
# first handle the m-to-n relationship | ||
# set of foreign keys (i.e., "chunk_id") | ||
chunk_id_set = set(chunk["chunk_id"] for chunk in chunks) | ||
context_id = utils.hash_set(chunk_id_set) | ||
|
||
# create a table only containing the keys | ||
for chunk_id in chunk_id_set: | ||
yield dlt.mark.with_table_name( | ||
{"chunk_id": chunk_id, "context_id": context_id}, | ||
"chunks_to_contexts_keys", | ||
) | ||
|
||
# main transformation logic | ||
for contextualized in contextualize([chunk["text"] for chunk in chunks]): | ||
yield dlt.mark.with_table_name( | ||
{"context_id": context_id, "text": contextualized}, "contexts" | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
import dlt | ||
from dlt.destinations.adapters import lancedb_adapter | ||
|
||
utils.set_environment_variables() | ||
|
||
pipeline = dlt.pipeline( | ||
pipeline_name="substack-blog", destination="lancedb", dataset_name="dagworks" | ||
) | ||
|
||
blog_url = "https://blog.dagworks.io/" | ||
|
||
full_entries = lancedb_adapter(rss_entries(blog_url), embed="summary") | ||
chunked_entries = rss_entries(blog_url) | parsed_html | chunks | ||
contextualized_chunks = lancedb_adapter(chunked_entries | contexts, embed="text") | ||
|
||
load_info = pipeline.run([full_entries, chunked_entries, contextualized_chunks]) | ||
print(load_info) |
Oops, something went wrong.