Skip to content

Commit

Permalink
optimizers working, but appending any example
Browse files Browse the repository at this point in the history
  • Loading branch information
vintrocode authored and VVoruganti committed Feb 22, 2024
1 parent 8c5845f commit 07651cb
Show file tree
Hide file tree
Showing 6 changed files with 84 additions and 22 deletions.
26 changes: 26 additions & 0 deletions example/discord/honcho-dspy-personas/bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
intents.messages = True
intents.message_content = True
intents.members = True
intents.reactions = True # Enable reactions intent

# app_id = str(uuid1())
app_id = "vince-dspy-personas"
Expand All @@ -18,6 +19,9 @@

bot = discord.Bot(intents=intents)

thumbs_up_messages = []
thumbs_down_messages = []

@bot.event
async def on_ready():
print(f'We have logged in as {bot.user}')
Expand Down Expand Up @@ -67,6 +71,28 @@ async def on_message(message):

session.create_message(is_user=False, content=response)

@bot.event
async def on_reaction_add(reaction, user):
# Ensure the bot does not react to its own reactions
if user == bot.user:
return

user_id = f"discord_{str(reaction.message.author.id)}"
location_id = str(reaction.message.channel.id)

# Check if the reaction is a thumbs up
if str(reaction.emoji) == '👍':
thumbs_up_messages.append(reaction.message.content)
print(f"Added to thumbs up: {reaction.message.content}")
# Check if the reaction is a thumbs down
elif str(reaction.emoji) == '👎':
thumbs_down_messages.append(reaction.message.content)
print(f"Added to thumbs down: {reaction.message.content}")

# TODO: we need to append these to the examples list within the user state json object



@bot.slash_command(name = "restart", description = "Restart the Conversation")
async def restart(ctx):
user_id=f"discord_{str(ctx.author.id)}"
Expand Down
20 changes: 13 additions & 7 deletions example/discord/honcho-dspy-personas/chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,10 @@ def __init__(self) -> None:
pass

@classmethod
async def generate_state_commentary(cls, chat_history: List[Message], input: str) -> str:
async def generate_state_commentary(cls, existing_states: List[str], chat_history: List[Message], input: str) -> str:
"""Generate a commentary on the current state of the user"""
# format existing states
existing_states = "\n".join(existing_states)
# format prompt
state_commentary = ChatPromptTemplate.from_messages([
cls.system_state_commentary
Expand All @@ -55,23 +57,27 @@ async def generate_state_commentary(cls, chat_history: List[Message], input: str
# inference
response = await chain.ainvoke({
"chat_history": chat_history,
"user_input": input
"user_input": input,
"existing_states": existing_states,
})
# return output
return response.content

@classmethod
async def generate_state_label(cls, state_commentary: str) -> str:
async def generate_state_label(cls, existing_states: List[str], state_commentary: str) -> str:
"""Generate a state label from a commetary on the user's state"""
# format existing states
existing_states = "\n".join(existing_states)
# format prompt
state_labeling = ChatPromptTemplate.from_messages([
cls.system_state_labeling
cls.system_state_labeling,
])
# LCEL
chain = state_labeling | cls.lc_gpt_4
# inference
response = await chain.ainvoke({
"state_commentary": state_commentary
"state_commentary": state_commentary,
"existing_states": existing_states,
})
# return output
return response.content
Expand Down Expand Up @@ -102,8 +108,8 @@ async def generate_state(cls, existing_states: List[str], chat_history: List[Mes
""""Determine the user's state from the current conversation state"""

# Generate label
state_commentary = await cls.generate_state_commentary(chat_history, input)
state_label = await cls.generate_state_label(state_commentary)
state_commentary = await cls.generate_state_commentary(existing_states, chat_history, input)
state_label = await cls.generate_state_label(existing_states, state_commentary)

# Determine if state is new
# if True, it doesn't exist, state is new
Expand Down
30 changes: 23 additions & 7 deletions example/discord/honcho-dspy-personas/graph.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import dspy
from typing import List
from dspy import Example
from typing import List, Optional
from dspy.teleprompt import BootstrapFewShot
from dotenv import load_dotenv
from chain import StateExtractor, format_chat_history
Expand All @@ -11,7 +12,7 @@
load_dotenv()

# Configure DSPy
dspy_gpt4 = dspy.OpenAI(model="gpt-4")
dspy_gpt4 = dspy.OpenAI(model="gpt-4", max_tokens=1000)
dspy.settings.configure(lm=dspy_gpt4)


Expand All @@ -33,18 +34,23 @@ class ChatWithThought(dspy.Module):
generate_thought = dspy.Predict(Thought)
generate_response = dspy.Predict(Response)

def forward(self, user_message: Message, session: Session, chat_input: str):
def forward(self, chat_input: str, user_message: Optional[Message] = None, session: Optional[Session] = None):
# call the thought predictor
thought = self.generate_thought(user_input=chat_input)
session.create_metamessage(user_message, metamessage_type="thought", content=thought.thought)

if session and user_message:
session.create_metamessage(user_message, metamessage_type="thought", content=thought.thought)

# call the response predictor
response = self.generate_response(user_input=chat_input, thought=thought.thought)

return response.response
# remove ai prefix
response = response.response.replace("ai:", "").strip()

return response

user_state_storage = {}
async def chat(user_message: Message, session: Session, chat_history: List[Message], input: str, optimization_threshold=5):
async def chat(user_message: Message, session: Session, chat_history: List[Message], input: str, optimization_threshold=3):
# first we need to see if the user has any existing states
existing_states = list(user_state_storage.keys())

Expand All @@ -66,6 +72,8 @@ async def chat(user_message: Message, session: Session, chat_history: List[Messa

# Optimize the state's chat module if we've reached the optimization threshold
examples = user_state_data["examples"]
print(f"Num examples: {len(examples)}")

if len(examples) >= optimization_threshold:
# Optimize chat module
optimizer = BootstrapFewShot(metric=metric)
Expand All @@ -74,10 +82,18 @@ async def chat(user_message: Message, session: Session, chat_history: List[Messa
user_state_data["chat_module"] = compiled_chat_module.dump_state()
user_chat_module = compiled_chat_module

# save to file for debugging purposes
# compiled_chat_module.save("module.json")


# use that pipeline to generate a response
chat_input = format_chat_history(chat_history, user_input=input)

response = user_chat_module(user_message=user_message, session=session, chat_input=chat_input)
dspy_gpt4.inspect_history(n=2)

# append example
example = Example(chat_input=chat_input, assessment_dimension=user_state, response=response).with_inputs('chat_input')
examples.append(example)
user_state_storage[user_state]["examples"] = examples

return response
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
_type: prompt
input_variables:
["chat_history", "user_input"]
["existing_states", "chat_history", "user_input"]
template: >
Your job is to make a prediction about the task the user might be engaging in. Some people might be researching, exploring curiosities, or just asking questions for general inquiry. Provide commentary that would shed light on the "mode" the user might be in.
existing states: ```{existing_states}```
chat history: ```{chat_history}```
user input: ```{user_input}```
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
_type: prompt
input_variables:
["state_commentary"]
["state_commentary", "existing_states"]
template: >
Your job is to label the task the user might be engaging in. Some people might be conducting research, exploring a interest, or just asking questions for general inquiry.
Your job is to label the state the user might be in. Some people might be conducting research, exploring a interest, or just asking questions for general inquiry.
commentary: ```{state_commentary}```
Prior states, from oldest to most recent: ```
{existing_states}
````
Take into account the user's prior states when making your prediction. Output your prediction as a concise, single word label.
Output your prediction as a concise, single word label.
17 changes: 13 additions & 4 deletions example/discord/honcho-dspy-personas/response_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,30 @@

class MessageResponseAssess(dspy.Signature):
"""Assess the quality of a response along the specified dimension."""
user_message = dspy.InputField()
chat_input = dspy.InputField()
ai_response = dspy.InputField()
gold_response = dspy.InputField()
assessment_dimension = dspy.InputField()
assessment_answer = dspy.OutputField(desc="Good or not")


def metric(user_message, ai_response, assessment_dimension):
def metric(example, ai_response, trace=None):
"""Assess the quality of a response along the specified dimension."""

assessment_dimension = example.assessment_dimension
chat_input = example.chat_input
gold_response = example.response

with dspy.context(lm=gpt4T):
assessment_result = dspy.Predict(MessageResponseAssess)(
user_message=user_message,
chat_input=chat_input,
ai_response=ai_response,
gold_response=gold_response,
assessment_dimension=assessment_dimension
)

is_positive = assessment_result.assessment_answer.lower() == 'good'

gpt4T.inspect_history(n=3)

return is_positive
return is_positive

0 comments on commit 07651cb

Please sign in to comment.