Skip to content

Commit

Permalink
feat: Add summarizing and quoting capabilities to 05-assistive-chatbot (
Browse files Browse the repository at this point in the history
  • Loading branch information
yoomlam authored Jun 3, 2024
1 parent c650128 commit f1464cd
Show file tree
Hide file tree
Showing 17 changed files with 595 additions and 183 deletions.
81 changes: 66 additions & 15 deletions 05-assistive-chatbot/chatbot-chainlit.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import chatbot
from chatbot import engines, llms, utils
from chatbot.engines import v2_household_engine

logger = logging.getLogger(f"chatbot.{__name__}")

Expand Down Expand Up @@ -41,18 +42,40 @@ async def init_chat():
),
Select(
id="model",
label="LLM Model",
label="Primary LLM Model",
values=llms.available_llms(),
initial_value=chatbot.initial_settings["model"],
),
Slider(
id="temperature",
label="LLM Temperature",
label="Temperature for primary LLM",
initial=chatbot.initial_settings["temperature"],
min=0,
max=2,
step=0.1,
),
Slider(
id="retrieve_k",
label="Guru cards to retrieve",
initial=chatbot.initial_settings["retrieve_k"],
min=1,
max=10,
step=1,
),
Select(
id="model2",
label="LLM Model for summarizer",
values=llms.available_llms(),
initial_value=chatbot.initial_settings["model2"],
),
Slider(
id="temperature2",
label="Temperature for summarizer",
initial=chatbot.initial_settings["temperature2"],
min=0,
max=2,
step=0.1,
),
# TODO: Add LLM response streaming capability
# Switch(id="streaming", label="Stream response tokens", initial=True),
]
Expand All @@ -74,7 +97,7 @@ async def update_settings(settings):
@utils.timer
async def apply_settings():
settings = cl.user_session.get("settings")
await create_llm_client(settings)
await create_chat_engine(settings)

# PLACEHOLDER: Apply other settings

Expand All @@ -86,11 +109,11 @@ async def apply_settings():
return settings


async def create_llm_client(settings):
msg = cl.Message(author="backend", content=f"Setting up LLM: {settings['model']} ...\n")
async def create_chat_engine(settings):
msg = cl.Message(author="backend", content=f"Setting up chat engine: {settings['chat_engine']} ...\n")

cl.user_session.set("chat_engine", chatbot.create_chat_engine(settings))
await msg.stream_token("Done setting up LLM")
await msg.stream_token("Done setting up chat engine")
await msg.send()


Expand All @@ -101,18 +124,46 @@ async def message_submitted(message: cl.Message):
if not cl.user_session.get("settings_applied", False):
return

# settings = cl.user_session.get("settings")

chat_engine = cl.user_session.get("chat_engine")
response = chat_engine.gen_response(message.content)

await cl.Message(content=f"*Response*: {response}").send()

# generated_results = on_question(message.content)
# print(json.dumps(dataclasses.asdict(generated_results), indent=2))

# message_args = format_as_markdown(generated_results)
await cl.Message(message.content).send()
if isinstance(response, v2_household_engine.GenerationResults):
message_args = format_v2_results_as_markdown(response)
await cl.Message(content=message_args["content"], elements=message_args["elements"]).send()
else:
await cl.Message(content=f"*Response*: {response}").send()


def format_v2_results_as_markdown(gen_results):
resp = ["", f"## Q: {gen_results.question}"]

dq_resp = ["<details><summary>Derived Questions</summary>", ""]
for dq in gen_results.derived_questions:
dq_resp.append(f"- {dq.derived_question}")
dq_resp += ["</details>", ""]

cards_resp = []
for i, card in enumerate(gen_results.cards, 1):
if card.summary:
cards_resp += [
f"<details><summary>{i}. <a href='https://link/to/guru_card'>{card.card_title}</a></summary>",
"",
f" Summary: {card.summary}",
"",
]
indented_quotes = [q.strip().replace("\n", "\n ") for q in card.quotes]
cards_resp += [f"\n Quote:\n ```\n {q}\n ```" for q in indented_quotes]
cards_resp += ["</details>", ""]

return {
"content": "\n".join(resp + dq_resp + cards_resp),
"elements": [
# Example of how to use cl.Text with different display parameters -- it's not intuitive
# The name argument must exist in the message content so that a link can be created.
# cl.Text(name="Derived Questions", content="\n".join(dq_resp), display="side"),
# cl.Text(name="Guru Cards", content="\n".join(cards_resp), display="inline")
],
}


@cl.on_stop
Expand Down
28 changes: 24 additions & 4 deletions 05-assistive-chatbot/chatbot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,31 @@ def configure_logging():

## Initialize settings

# Opt out of telemetry -- https://docs.trychroma.com/telemetry
os.environ.setdefault("ANONYMIZED_TELEMETRY", "False")

# Used by SentenceTransformerEmbeddings and HuggingFaceEmbeddings
os.environ.setdefault("SENTENCE_TRANSFORMERS_HOME", "./.sentence-transformers-cache")

# Disable DSPy cache to get different responses for retry attempts
# Set to true to enable caching for faster responses and optimizing prompts using DSPy
os.environ.setdefault("DSP_CACHEBOOL", "false")


@utils.verbose_timer(logger)
def _init_settings():
# Remember to update ChatSettings in chatbot-chainlit.py when adding new settings
# and update chatbot/engines/__init.py:CHATBOT_SETTING_KEYS
return {
"env": os.environ.get("ENV", "DEV"),
"enable_api": is_true(os.environ.get("ENABLE_CHATBOT_API", "False")),
"chat_engine": os.environ.get("CHAT_ENGINE", "Direct"),
"model": os.environ.get("LLM_MODEL_NAME", "mock :: llm"),
"temperature": float(os.environ.get("LLM_TEMPERATURE", 0.1)),
"retrieve_k": int(os.environ.get("RETRIEVE_K", 4)),
# Used by SummariesChatEngine
"model2": os.environ.get("LLM_MODEL_NAME_2", os.environ.get("LLM_MODEL_NAME", "mock :: llm")),
"temperature2": float(os.environ.get("LLM_TEMPERATURE2", 0.1)),
}


Expand All @@ -55,9 +71,13 @@ def validate_settings(settings):
if chat_engine not in engines._discover_chat_engines():
return f"Unknown chat_engine: '{chat_engine}'"

model_name = settings["model"]
if model_name not in llms._discover_llms():
return f"Unknown model: '{model_name}'"
for setting_name in ["model", "model2"]:
model_name = settings[setting_name]
if model_name not in llms._discover_llms():
return f"Unknown {setting_name}: '{model_name}'"

if chat_engine.startswith("Summaries") and "instruct" not in model_name:
logger.warning("For the %s chat engine, an `*instruct` model is recommended", chat_engine)

# PLACEHOLDER: Validate other settings

Expand All @@ -69,4 +89,4 @@ def validate_settings(settings):

@utils.timer
def create_chat_engine(settings):
return engines.create(settings["chat_engine"], settings)
return engines.create_engine(settings["chat_engine"], settings)
6 changes: 3 additions & 3 deletions 05-assistive-chatbot/chatbot/engines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,20 +37,20 @@ def _discover_chat_engines(force=False):


@utils.timer
def create(engine_name, settings=None):
def create_engine(engine_name, settings=None):
_discover_chat_engines()
return _engines[engine_name].init_engine(settings)


## Utility functions

CHATBOT_SETTING_KEYS = ["enable_api", "chat_engine", "model"]
# Settings that are specific to our chatbot and shouldn't be passed onto the LLM client
CHATBOT_SETTING_KEYS = ["env", "enable_api", "chat_engine", "model", "model2", "temperature2", "retrieve_k"]


@utils.timer
def create_llm_client(settings):
llm_name = settings["model"]
# llm_settings = dict((k, settings[k]) for k in ["temperature"] if k in settings)
remaining_settings = {k: settings[k] for k in settings if k not in CHATBOT_SETTING_KEYS}
client = llms.init_client(llm_name, remaining_settings)
return client
145 changes: 145 additions & 0 deletions 05-assistive-chatbot/chatbot/engines/v2_household_dspy_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
# DSPy version of v2_household_engine.py

import json
import logging
import os
from functools import cached_property

import dspy # type: ignore[import-untyped]
from dspy.signatures.signature import signature_to_template # type: ignore[import-untyped]

from chatbot import engines
from chatbot.engines.v2_household_engine import SummariesChatEngine

logger = logging.getLogger(__name__)

ENGINE_NAME = "Summaries-DSPy"


def init_engine(settings):
return SummariesDspyChatEngine(settings)


## LLM Client classes for (Question) Decomposer and (Guru Card) Summarizer used by the chat engine.
## DSPy LLM clients require different handling than non-DSPy clients.
## Also non-DSPy clients are using the prompt generated (but not yet optimized) by DSPy.


class DecomposerDspyClient:
def __init__(self, prompts, settings):
self.prompts = prompts

if os.environ.get("DSP_CACHEBOOL").lower() != "false":
logger.warning("DSP_CACHEBOOL should be set to True to get different responses for retry attempts")

if "predictor" not in settings:
settings["predictor"] = self.decomposer_predictor
self.decomposer_client = engines.create_llm_client(settings)

def decomposer_predictor(self, message):
prediction = self.prompts.decomposer(question=message)
derived_questions = json.loads(prediction.answer)
if "Answer" in derived_questions:
# For OpenAI 'gpt-4-turbo' in json mode
derived_questions = derived_questions["Answer"]
return derived_questions

def generate_derived_questions(self, query):
# generate_reponse() indirectly calls decomposer_predictor()
return self.decomposer_client.generate_reponse(query)


class DecomposerLlmClient:
def __init__(self, prompts, settings):
self.prompts = prompts
self.decomposer_client = engines.create_llm_client(settings)

def generate_derived_questions(self, query):
response = call_llm_with_dspy_prompt(self.decomposer_client, self.prompts.decomposer, question=query)
return json.loads(response)


class SummarizerDspyClient:
def __init__(self, prompts, settings):
self.prompts = prompts

settings["predictor"] = None
self.summarizer_client = engines.create_llm_client(settings)

def summarizer(self, **kwargs):
with dspy.context(lm=self.summarizer_client.llm):
return self.prompts.summarizer(**kwargs).answer


class SummarizerLlmClient:
def __init__(self, prompts, settings):
self.prompts = prompts

self.summarizer_client = engines.create_llm_client(settings)

def summarizer(self, **kwargs):
return call_llm_with_dspy_prompt(self.summarizer_client, self.prompts.summarizer, **kwargs)


def call_llm_with_dspy_prompt(llm_client, dspy_predict_obj: dspy.Predict, **template_inputs):
template = signature_to_template(dspy_predict_obj.signature)
# demos are for in-context learning
dspy_prompt = template({"demos": []} | template_inputs)
logger.debug("Prompt: %s", dspy_prompt)
response = llm_client.generate_reponse(dspy_prompt)
logger.debug("Response object: %s", response)
return response


class DSPyPrompts:
@cached_property
def decomposer(self):
class DecomposeQuestion(dspy.Signature):
"""Rephrase and decompose into multiple questions so that we can search for relevant public benefits eligibility requirements. \
Be concise -- only respond with JSON. Only output the questions as a JSON list: ["question1", "question2", ...]. \
The question is: {question}"""

# TODO: Incorporate https://gist.github.com/hugodutka/6ef19e197feec9e4ce42c3b6994a919d

question = dspy.InputField()
answer = dspy.OutputField(desc='["question1", "question2", ...]')

return dspy.Predict(DecomposeQuestion)

@cached_property
def summarizer(self):
class SummarizeCardGivenQuestion(dspy.Signature):
"""Summarize the following context into 1 sentence without explicitly answering the question(s): {context_question}
Context: {context}
"""

context_question = dspy.InputField()
context = dspy.InputField()
answer = dspy.OutputField()

return dspy.Predict(SummarizeCardGivenQuestion)


## Summaries (using DSPy) Chat Engine


class SummariesDspyChatEngine(SummariesChatEngine):
def _init_llms(self, settings):
prompts = DSPyPrompts()

if settings["model"].startswith("dspy ::"):
decomposer_client = DecomposerDspyClient(prompts, settings.copy())
else:
decomposer_client = DecomposerLlmClient(prompts, settings.copy())
self.decomposer = decomposer_client.generate_derived_questions

if "model2" in settings:
settings["model"] = settings.pop("model2")
if "temperature2" in settings:
settings["temperature"] = settings.pop("temperature2")
if settings["model"].startswith("dspy ::"):
summarizer_client = SummarizerDspyClient(prompts, settings.copy())
else:
summarizer_client = SummarizerLlmClient(prompts, settings.copy())
self.summarizer = summarizer_client.summarizer
Loading

0 comments on commit f1464cd

Please sign in to comment.