Skip to content

Commit

Permalink
hotfix(agents-api): Fix session.situation not being rendered
Browse files Browse the repository at this point in the history
Signed-off-by: Diwank Singh Tomer <[email protected]>
  • Loading branch information
creatorrr committed Sep 16, 2024
1 parent 97cb972 commit d50efaf
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 10 deletions.
1 change: 1 addition & 0 deletions agents-api/agents_api/common/protocol/sessions.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ def get_chat_environment(self) -> dict[str, dict | list[dict]]:
"session": self.session.model_dump(),
"agents": [agent.model_dump() for agent in self.agents],
"current_agent": current_agent.model_dump(),
"agent": current_agent.model_dump(),
"users": [user.model_dump() for user in self.users],
"settings": settings,
"tools": [tool.model_dump() for tool in tools],
Expand Down
44 changes: 34 additions & 10 deletions agents-api/agents_api/routers/sessions/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,21 +46,41 @@ async def chat(
# Merge the settings and prepare environment
chat_context.merge_settings(chat_input)
settings: dict = chat_context.settings.model_dump()
env: dict = chat_context.get_chat_environment()
settings["model"] = f"openai/{settings['model']}" # litellm proxy idiosyncracy

# Render the messages
new_raw_messages = [msg.model_dump() for msg in chat_input.messages]

# Get the past messages and doc references
past_messages, doc_references = await gather_messages(
developer=developer,
session_id=session_id,
chat_context=chat_context,
chat_input=chat_input,
)

# Prepare the environment
env: dict = chat_context.get_chat_environment()
env["docs"] = doc_references
new_messages = await render_template(new_raw_messages, variables=env)

# Render the system message
if situation := chat_context.session.situation:
system_message = dict(
role="system",
content=situation,
)

system_messages: list[dict] = await render_template(
[system_message], variables=env
)
past_messages = system_messages + past_messages

# Render the incoming messages
new_raw_messages = [msg.model_dump() for msg in chat_input.messages]

if chat_context.session.render_templates:
new_messages = await render_template(new_raw_messages, variables=env)
else:
new_messages = new_raw_messages

# Combine the past messages with the new messages
messages = past_messages + new_messages

# Get the tools
Expand All @@ -74,15 +94,17 @@ async def chat(

# FIXME: Hotfix for datetime not serializable. Needs investigation
messages = [
msg.model_dump() if hasattr(msg, "model_dump") else msg
for msg in messages
msg.model_dump() if hasattr(msg, "model_dump") else msg for msg in messages
]

messages = [
dict(role=m["role"], content=m["content"], user=m.get("user"))
for m in messages
dict(role=m["role"], content=m["content"], user=m.get("user")) for m in messages
]

from pprint import pprint

pprint(messages)

# Get the response from the model
model_response = await litellm.acompletion(
messages=messages,
Expand All @@ -104,7 +126,9 @@ async def chat(
# Add the response to the new entries
new_entries.append(
CreateEntryRequest.from_model_input(
model=settings["model"], **model_response.choices[0].model_dump()['message'], source="api_response"
model=settings["model"],
**model_response.choices[0].model_dump()["message"],
source="api_response",
)
)
background_tasks.add_task(
Expand Down

0 comments on commit d50efaf

Please sign in to comment.