Skip to content

Commit

Permalink
Add baseline DSPy implementation (#7)
Browse files Browse the repository at this point in the history
  • Loading branch information
yoomlam authored Apr 9, 2024
1 parent 817addc commit 4a66abe
Show file tree
Hide file tree
Showing 5 changed files with 444 additions and 40 deletions.
101 changes: 101 additions & 0 deletions 02-household-queries/debugging.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import functools
import time
import code
import traceback
from typing import List, Dict, Any
import chainlit as cl
from chainlit.types import ThreadDict

# import readline # enables Up/Down/History in the console
from langchain_core.runnables import RunnableLambda
from langchain.callbacks.base import BaseCallbackHandler


def timer(func):
@functools.wraps(func)
def wrapper_timer(*args, **kwargs):
tic = time.perf_counter()
value = func(*args, **kwargs)
toc = time.perf_counter()
elapsed_time = toc - tic
print(f"(Elapsed time of {func.__name__}: {elapsed_time:0.4f} seconds)")
return value

return wrapper_timer


def stacktrace():
traceback.print_stack()


def debug_here(local_vars):
"""Usage: debug_here(locals())"""
variables = globals().copy()
variables.update(local_vars)
shell = code.InteractiveConsole(variables)
shell.interact()


def debug_runnable(prefix: str):
"""Useful to see output/input between Runnables in a LangChain"""

def debug_chainlink(x):
print(f"DEBUG_CHAINLINK {prefix}", x)
return x

return RunnableLambda(debug_chainlink)


def print_prompt_templates(chain):
print("RUNNABLE", chain) # .json(indent=2))
if chain.middle:
print(
"combine_documents_chain.llm_chain\n",
chain.middle[0].combine_documents_chain.llm_chain.prompt.template,
)
print(
"combine_documents_chain.document_prompt\n",
chain.middle[0].combine_documents_chain.document_prompt.template,
)


class CaptureLlmPromptHandler(BaseCallbackHandler):
"""Prints prompt being sent to an LLM"""

def __init__(self, printToStdOut=True):
self.toStdout = printToStdOut

async def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> Any:
formatted_prompts = "\n".join(prompts).replace("```", "``")
if self.toStdout:
print(f"\nPROMPT:\n{formatted_prompts}")
await cl.Message(
author="prompt debug",
content=f"Prompt sent to LLM:\n```\n{formatted_prompts}\n```",
).send()


@cl.on_chat_start
async def print_user_sesion():
# https://docs.chainlit.io/concepts/user-session
for key in ["id", "env", "chat_settings", "user", "chat_profile", "root_message"]:
print(key, cl.user_session.get(key))


@cl.on_stop
def on_stop():
print("The user wants to stop the task!")


# When a user resumes a chat session that was previously disconnected.
# This can only happen if authentication and data persistence are enabled.
@cl.on_chat_resume
async def on_chat_resume(thread: ThreadDict):
print("The user resumed a previous chat session!", thread.keys())


@cl.on_chat_end
def on_chat_end():
print("The user disconnected!")
179 changes: 179 additions & 0 deletions 02-household-queries/dspy_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
import os
import json
from typing import Optional

import dotenv
import dspy
from dsp.utils import dotdict

from langchain_community.vectorstores import Chroma
from langchain_community.embeddings import HuggingFaceEmbeddings

import debugging


dotenv.load_dotenv()


class BasicQA(dspy.Signature):
"""Answer questions with short factoid answers."""

question = dspy.InputField()
answer = dspy.OutputField(desc="often between 1 and 5 words")


def run_basic_predictor(query):
# Define the predictor.
generate_answer = dspy.Predict(BasicQA)

# Call the predictor on a particular input.
pred = generate_answer(question=query)

# Print the input and the prediction.
print(f"Query: {query}")
print(f"Answer: {pred.answer}")
return pred


def run_cot_predictor(query):
generate_answer_with_chain_of_thought = dspy.ChainOfThought(BasicQA)

# Call the predictor on the same input.
pred = generate_answer_with_chain_of_thought(question=query)
print(f"\nQUERY : {query}")
print(f"\nRATIONALE: {pred.rationale.split(':', 1)[1].strip()}")
print(f"\nANSWER : {pred.answer}")
# debugging.debug_here(locals())


class GenerateAnswer(dspy.Signature):
"""Answer the question with a short factoid answer."""

context = dspy.InputField(
desc="may contain relevant facts used to answer the question"
)
question = dspy.InputField()
answer = dspy.OutputField(
desc="Start with one of these words: Yes, No, Maybe. Between 1 and 5 words"
)


class RAG(dspy.Module):
def __init__(self, num_passages):
super().__init__()

self.retrieve = dspy.Retrieve(k=num_passages)
self.generate_answer = dspy.ChainOfThought(GenerateAnswer)

def forward(self, query):
context = self.retrieve(query).passages
prediction = self.generate_answer(context=context, question=query)
return dspy.Prediction(context=context, answer=prediction.answer)


@debugging.timer
def run_retrieval(query, retrieve_k):
retrieve = dspy.Retrieve(k=retrieve_k)
retrieval = retrieve(query)
topK_passages = retrieval.passages

print(f"Top {retrieve.k} passages for query: {query} \n", "-" * 30, "\n")
for i, passage in enumerate(topK_passages):
print(f"[{i+1}]", passage, "\n")
return retrieval


def run_rag(query, retrieve_k):
rag = RAG(retrieve_k)
pred = rag(query=query)
print(f"\nRATIONALE: {pred.get('rationale')}")
print(f"\nANSWER : {pred.answer}")
print(f"\nCONTEXT: {len(pred.context)}")
for i, d in enumerate(pred.context):
print(i + 1, d, "\n")
# debugging.debug_here(locals())


# https://dspy-docs.vercel.app/docs/deep-dive/retrieval_models_clients/custom-rm-client
class RetrievalModelWrapper(dspy.Retrieve):
def __init__(self, vectordb):
super().__init__()
self.vectordb = vectordb

def forward(self, query: str, k: Optional[int]) -> dspy.Prediction:
k = self.k if k is None else k
# print("k=", k)
# k parameter is specific to Chroma retriever
# See other parameters in .../site-packages/langchain_core/vectorstores.py
retriever = self.vectordb.as_retriever(search_kwargs={"k": k})
retrievals = retriever.invoke(query)
# print("Retrieved")
# for d in retrievals:
# print(d)
# print()

# DSPy expects a `long_text` attribute for each retrieved item
retrievals_as_text = [
dotdict({"long_text": doc.page_content}) for doc in retrievals
]
return retrievals_as_text


@debugging.timer
def create_retriever_model():
# "The all-mpnet-base-v2 model provides the best quality, while all-MiniLM-L6-v2 is 5 times faster and still offers good quality."
_embeddings_model_name = os.environ.get("EMBEDDINGS_MODEL_NAME", "all-MiniLM-L6-v2")
embeddings = HuggingFaceEmbeddings(model_name=_embeddings_model_name)
vectordb = Chroma(
embedding_function=embeddings,
collection_name="resources",
persist_directory="./chroma_db",
)

# https://dspy-docs.vercel.app/docs/deep-dive/retrieval_models_clients/ChromadbRM
# return ChromadbRM(collection_name="resources", persist_directory="./chroma_db", embedding_function=embedding_function)

return RetrievalModelWrapper(vectordb)


@debugging.timer
def create_llm_model():
llm_name = "openhermes" # "openhermes", "llama2", "mistral"
return dspy.OllamaLocal(model=llm_name, temperature=0.1)


def load_training_json():
with open("question_answer_citations.json", encoding="utf-8") as data_file:
json_data = json.load(data_file)
# print(json.dumps(json_data, indent=2))
return json_data


def main(query):
retrieve_k = int(os.environ.get("RETRIEVE_K", "2"))

# run_basic_predictor(query)
# run_cot_predictor(query)
# run_retrieval(query, retrieve_k)
run_rag(query, retrieve_k)


if __name__ == "__main__":
qa = load_training_json()
for qa_dict in qa:
orig_question = qa_dict["orig_question"]
question = qa_dict.get("question", orig_question)
print(f"\nQUESTION {qa_dict['id']}: {question}")
answer = qa_dict["answer"]
short_answer = qa_dict.get("short_answer", answer)
print(f" SHORT ANSWER : {short_answer}")
print(f" Desired ANSWER : {answer}")
print()

llm_model = create_llm_model()
dspy.settings.configure(lm=llm_model, rm=create_retriever_model())

main(question)

print("----- llm_model.inspect_history ------------------")
llm_model.inspect_history(n=10)
2 changes: 1 addition & 1 deletion 02-household-queries/question_answer_citations.json
Original file line number Diff line number Diff line change
Expand Up @@ -166,4 +166,4 @@
"Who are mandatory HH members for food stamps?"
]
}
]
]
4 changes: 3 additions & 1 deletion 02-household-queries/requirements.in
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
beautifulsoup4
chainlit
chromadb
dspy-ai
jinja2
jq
langchain
langchain_community
Expand All @@ -13,4 +15,4 @@ langchain-text-splitters
# Needed by langchain_community/document_loaders/pdf.py
pdfminer.six
rapidocr-onnxruntime
# sentence-transformers
sentence-transformers
Loading

0 comments on commit 4a66abe

Please sign in to comment.