Skip to content

Commit

Permalink
fix(agents-api): Misc fixes for chat endpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
HamadaSalhab committed Dec 27, 2024
1 parent 718d612 commit 23de839
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 16 deletions.
1 change: 1 addition & 0 deletions agents-api/agents_api/common/protocol/sessions.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def get_chat_environment(self) -> dict[str, dict | list[dict] | None]:
"agents": [agent.model_dump() for agent in self.agents],
"current_agent": current_agent.model_dump(),
"agent": current_agent.model_dump(),
"user": self.users[0].model_dump() if len(self.users) > 0 else None,
"users": [user.model_dump() for user in self.users],
"settings": settings,
"tools": [tool.model_dump() for tool in tools],
Expand Down
12 changes: 6 additions & 6 deletions agents-api/agents_api/queries/chat/gather_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from fastapi import HTTPException
from pydantic import ValidationError

from ...autogen.openapi_model import ChatInput, DocReference, History
from ...autogen.openapi_model import ChatInput, DocReference, History, Session
from ...clients import litellm
from ...common.protocol.developers import Developer
from ...common.protocol.sessions import ChatContext
Expand Down Expand Up @@ -42,7 +42,7 @@ async def gather_messages(
assert len(new_raw_messages) > 0

# Get the session history
history: History = get_history(
history: History = await get_history(
developer_id=developer.id,
session_id=session_id,
allowed_sources=["api_request", "api_response", "tool_response", "summarizer"],
Expand All @@ -69,7 +69,7 @@ async def gather_messages(
return past_messages, []

# Get recall options
session = get_session(
session: Session = await get_session(
developer_id=developer.id,
session_id=session_id,
)
Expand Down Expand Up @@ -117,20 +117,20 @@ async def gather_messages(
doc_references: list[DocReference] = []
match recall_options.mode:
case "vector":
doc_references: list[DocReference] = search_docs_by_embedding(
doc_references: list[DocReference] = await search_docs_by_embedding(
developer_id=developer.id,
owners=owners,
query_embedding=query_embedding,
)
case "hybrid":
doc_references: list[DocReference] = search_docs_hybrid(
doc_references: list[DocReference] = await search_docs_hybrid(
developer_id=developer.id,
owners=owners,
query=query_text,
query_embedding=query_embedding,
)
case "text":
doc_references: list[DocReference] = search_docs_by_text(
doc_references: list[DocReference] = await search_docs_by_text(
developer_id=developer.id,
owners=owners,
query=query_text,
Expand Down
27 changes: 17 additions & 10 deletions agents-api/agents_api/queries/chat/prepare_chat_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from uuid import UUID

from beartype import beartype

from sqlglot import parse_one
from ...common.protocol.sessions import ChatContext, make_session
from ..utils import (
pg_query,
Expand All @@ -13,7 +13,7 @@
T = TypeVar("T")


sql_query = """
sql_query = parse_one("""
SELECT * FROM
(
SELECT jsonb_agg(u) AS users FROM (
Expand Down Expand Up @@ -65,6 +65,7 @@
sessions.situation,
sessions.system_template,
sessions.created_at,
sessions.updated_at,
sessions.metadata,
sessions.render_templates,
sessions.token_budget,
Expand All @@ -86,7 +87,6 @@
tools.developer_id,
tools.agent_id,
tools.task_id,
tools.task_version,
tools.type,
tools.name,
tools.description,
Expand All @@ -100,23 +100,28 @@
session_id = $2 AND
session_lookup.participant_type = 'agent'
) r
) AS toolsets"""
) AS toolsets""").sql(pretty=True)


def _transform(d):
toolsets = {}
for tool in d["toolsets"]:

# Default to empty lists when users/agents are not present
d["users"] = d.get("users") or []
d["agents"] = d.get("agents") or []

for tool in d.get("toolsets") or []:
agent_id = tool["agent_id"]
if agent_id in toolsets:
toolsets[agent_id].append(tool)
else:
toolsets[agent_id] = [tool]

return {
transformed_data = {
**d,
"session": make_session(
agents=[a["id"] for a in d["agents"]],
users=[u["id"] for u in d["users"]],
agents=[a["id"] for a in d.get("agents") or []],
users=[u["id"] for u in d.get("users") or []],
**d["session"],
),
"toolsets": [
Expand All @@ -134,6 +139,8 @@ def _transform(d):
],
}

return transformed_data


# TODO: implement this part
# @rewrap_exceptions(
Expand All @@ -153,12 +160,12 @@ async def prepare_chat_context(
*,
developer_id: UUID,
session_id: UUID,
) -> tuple[list[str], list]:
) -> tuple[str, list]:
"""
Executes a complex query to retrieve memory context based on session ID.
"""

return (
[sql_query.format()],
sql_query,
[developer_id, session_id],
)

0 comments on commit 23de839

Please sign in to comment.