diff --git a/agents-api/agents_api/common/protocol/sessions.py b/agents-api/agents_api/common/protocol/sessions.py index f12ff8108..121afe702 100644 --- a/agents-api/agents_api/common/protocol/sessions.py +++ b/agents-api/agents_api/common/protocol/sessions.py @@ -116,6 +116,7 @@ def get_chat_environment(self) -> dict[str, dict | list[dict]]: "session": self.session.model_dump(), "agents": [agent.model_dump() for agent in self.agents], "current_agent": current_agent.model_dump(), + "agent": current_agent.model_dump(), "users": [user.model_dump() for user in self.users], "settings": settings, "tools": [tool.model_dump() for tool in tools], diff --git a/agents-api/agents_api/routers/sessions/chat.py b/agents-api/agents_api/routers/sessions/chat.py index e391fc716..90e397a9d 100644 --- a/agents-api/agents_api/routers/sessions/chat.py +++ b/agents-api/agents_api/routers/sessions/chat.py @@ -46,12 +46,9 @@ async def chat( # Merge the settings and prepare environment chat_context.merge_settings(chat_input) settings: dict = chat_context.settings.model_dump() - env: dict = chat_context.get_chat_environment() settings["model"] = f"openai/{settings['model']}" # litellm proxy idiosyncracy - - # Render the messages - new_raw_messages = [msg.model_dump() for msg in chat_input.messages] + # Get the past messages and doc references past_messages, doc_references = await gather_messages( developer=developer, session_id=session_id, @@ -59,8 +56,31 @@ async def chat( chat_input=chat_input, ) + # Prepare the environment + env: dict = chat_context.get_chat_environment() env["docs"] = doc_references - new_messages = await render_template(new_raw_messages, variables=env) + + # Render the system message + if situation := chat_context.session.situation: + system_message = dict( + role="system", + content=situation, + ) + + system_messages: list[dict] = await render_template( + [system_message], variables=env + ) + past_messages = system_messages + past_messages + + # Render the incoming messages + new_raw_messages = [msg.model_dump() for msg in chat_input.messages] + + if chat_context.session.render_templates: + new_messages = await render_template(new_raw_messages, variables=env) + else: + new_messages = new_raw_messages + + # Combine the past messages with the new messages messages = past_messages + new_messages # Get the tools @@ -74,15 +94,17 @@ async def chat( # FIXME: Hotfix for datetime not serializable. Needs investigation messages = [ - msg.model_dump() if hasattr(msg, "model_dump") else msg - for msg in messages + msg.model_dump() if hasattr(msg, "model_dump") else msg for msg in messages ] messages = [ - dict(role=m["role"], content=m["content"], user=m.get("user")) - for m in messages + dict(role=m["role"], content=m["content"], user=m.get("user")) for m in messages ] + from pprint import pprint + + pprint(messages) + # Get the response from the model model_response = await litellm.acompletion( messages=messages, @@ -104,7 +126,9 @@ async def chat( # Add the response to the new entries new_entries.append( CreateEntryRequest.from_model_input( - model=settings["model"], **model_response.choices[0].model_dump()['message'], source="api_response" + model=settings["model"], + **model_response.choices[0].model_dump()["message"], + source="api_response", ) ) background_tasks.add_task(