Skip to content

Commit

Permalink
Decouple search logic from respond
Browse files Browse the repository at this point in the history
  • Loading branch information
jacobvm04 committed Nov 1, 2023
1 parent 33478d6 commit a3c045b
Showing 1 changed file with 22 additions and 12 deletions.
34 changes: 22 additions & 12 deletions agent/chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,19 +81,36 @@ def think(cls, cache: Conversation, input: str):
chain.astream({}, {"tags": ["thought"], "metadata": {"conversation_id": cache.conversation_id, "user_id": cache.user_id}}),
lambda thought: cache.add_message("thought", AIMessage(content=thought))
)


@classmethod
@sentry_sdk.trace
def respond(cls, cache: Conversation, thought: str, input: str):
"""Generate Bloom's response to the user."""

messages = [
response_prompt = ChatPromptTemplate.from_messages([
cls.system_response,
*cache.messages("response"),
HumanMessage(content=input)
]
])

# apply search step
response_prompt = cls.search_step(response_prompt.format_messages(thought=thought))

chain = response_prompt | cls.llm

cache.add_message("response", HumanMessage(content=input))

return Streamable(
chain.astream({"thought": thought}, {"tags": ["response"], "metadata": {"conversation_id": cache.conversation_id, "user_id": cache.user_id}}),
lambda response: cache.add_message("response", AIMessage(content=response))
)


search_messages = ChatPromptTemplate.from_messages(messages).format_messages(thought=thought).copy()
@classmethod
@sentry_sdk.trace
def search_step(cls, messages: list[BaseMessage]):
search_messages = messages.copy()
search_messages.append(SystemMessage(content=f"Reason about whether or not a google search would be benificial to answer the question. Always use it if you are unsure about your knowledge.\n\nf{search_ready_output_parser.get_format_instructions()}"))

search_ready_message = cls.llm.predict_messages(search_messages)
Expand All @@ -107,16 +124,9 @@ def respond(cls, cache: Conversation, thought: str, input: str):
search_result_summary = cls.search_tool.run(search_query_message.content)

messages.append(SystemMessage(content=f"Use the information from these searchs to help answer your question.\nMake sure to not just repeat answers from sources, provide the sources justifications when possible. More detail is better.\n\nRelevant Google Search: {search_query_message.content}\n\n{search_result_summary}\n\nCite your sources via bracket notation with numbers (don't use any other special characters), and include the full links at the end."))

return ChatPromptTemplate.from_messages(messages)

response_prompt = ChatPromptTemplate.from_messages(messages)
chain = response_prompt | cls.llm

cache.add_message("response", HumanMessage(content=input))

return Streamable(
chain.astream({ "thought": thought }, {"tags": ["response"], "metadata": {"conversation_id": cache.conversation_id, "user_id": cache.user_id}}),
lambda response: cache.add_message("response", AIMessage(content=response))
)

@classmethod
@sentry_sdk.trace
Expand Down

0 comments on commit a3c045b

Please sign in to comment.