Skip to content

Commit

Permalink
feat: Add prepare chat context query
Browse files Browse the repository at this point in the history
  • Loading branch information
whiterabbit1983 committed Dec 19, 2024
1 parent a7d10da commit 63dc1bc
Show file tree
Hide file tree
Showing 2 changed files with 129 additions and 108 deletions.
12 changes: 5 additions & 7 deletions agents-api/agents_api/queries/chat/gather_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,17 @@

from beartype import beartype
from fastapi import HTTPException
from pycozo.client import QueryException
from pydantic import ValidationError

from ...autogen.openapi_model import ChatInput, DocReference, History
from ...clients import litellm
from ...common.protocol.developers import Developer
from ...common.protocol.sessions import ChatContext
from ..docs.search_docs_by_embedding import search_docs_by_embedding
from ..docs.search_docs_by_text import search_docs_by_text
from ..docs.search_docs_hybrid import search_docs_hybrid
from ..entry.get_history import get_history
from ..session.get_session import get_session
# from ..docs.search_docs_by_embedding import search_docs_by_embedding
# from ..docs.search_docs_by_text import search_docs_by_text
# from ..docs.search_docs_hybrid import search_docs_hybrid
# from ..entry.get_history import get_history
from ..sessions.get_session import get_session
from ..utils import (
partialclass,
rewrap_exceptions,
Expand All @@ -25,7 +24,6 @@

@rewrap_exceptions(
{
QueryException: partialclass(HTTPException, status_code=400),
ValidationError: partialclass(HTTPException, status_code=400),
TypeError: partialclass(HTTPException, status_code=400),
}
Expand Down
225 changes: 124 additions & 101 deletions agents-api/agents_api/queries/chat/prepare_chat_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,36 +2,118 @@
from uuid import UUID

from beartype import beartype
from fastapi import HTTPException
from pydantic import ValidationError

from ...common.protocol.sessions import ChatContext, make_session
from ..session.prepare_session_data import prepare_session_data
from ..utils import (
cozo_query,
fix_uuid_if_present,
partialclass,
rewrap_exceptions,
verify_developer_id_query,
verify_developer_owns_resource_query,
pg_query,
wrap_in_class,
)

ModelT = TypeVar("ModelT", bound=Any)
T = TypeVar("T")


# TODO: implement this part
# @rewrap_exceptions(
# {
# ValidationError: partialclass(HTTPException, status_code=400),
# TypeError: partialclass(HTTPException, status_code=400),
# }
# )
@wrap_in_class(
ChatContext,
one=True,
transform=lambda d: {
query = """
SELECT * FROM
(
SELECT jsonb_agg(u) AS users FROM (
SELECT
session_lookup.participant_id,
users.user_id AS id,
users.developer_id,
users.name,
users.about,
users.created_at,
users.updated_at,
users.metadata
FROM session_lookup
INNER JOIN users ON session_lookup.participant_id = users.user_id
WHERE
session_lookup.developer_id = $1 AND
session_id = $2 AND
session_lookup.participant_type = 'user'
) u
) AS users,
(
SELECT jsonb_agg(a) AS agents FROM (
SELECT
session_lookup.participant_id,
agents.agent_id AS id,
agents.developer_id,
agents.canonical_name,
agents.name,
agents.about,
agents.instructions,
agents.model,
agents.created_at,
agents.updated_at,
agents.metadata,
agents.default_settings
FROM session_lookup
INNER JOIN agents ON session_lookup.participant_id = agents.agent_id
WHERE
session_lookup.developer_id = $1 AND
session_id = $2 AND
session_lookup.participant_type = 'agent'
) a
) AS agents,
(
SELECT to_jsonb(s) AS session FROM (
SELECT
sessions.session_id AS id,
sessions.developer_id,
sessions.situation,
sessions.system_template,
sessions.created_at,
sessions.metadata,
sessions.render_templates,
sessions.token_budget,
sessions.context_overflow,
sessions.forward_tool_calls,
sessions.recall_options
FROM sessions
WHERE
developer_id = $1 AND
session_id = $2
LIMIT 1
) s
) AS session,
(
SELECT jsonb_agg(r) AS toolsets FROM (
SELECT
session_lookup.participant_id,
tools.tool_id as id,
tools.developer_id,
tools.agent_id,
tools.task_id,
tools.task_version,
tools.type,
tools.name,
tools.description,
tools.spec,
tools.updated_at,
tools.created_at
FROM session_lookup
INNER JOIN tools ON session_lookup.participant_id = tools.agent_id
WHERE
session_lookup.developer_id = $1 AND
session_id = $2 AND
session_lookup.participant_type = 'agent'
) r
) AS toolsets
"""


def _transform(d):
toolsets = {}
for tool in d["toolsets"]:
agent_id = tool["agent_id"]
if agent_id in toolsets:
toolsets[agent_id].append(tool)
else:
toolsets[agent_id] = [tool]

return {
**d,
"session": make_session(
agents=[a["id"] for a in d["agents"]],
Expand All @@ -40,103 +122,44 @@
),
"toolsets": [
{
**ts,
"agent_id": agent_id,
"tools": [
{
tool["type"]: tool.pop("spec"),
**tool,
}
for tool in map(fix_uuid_if_present, ts["tools"])
for tool in tools
],
}
for ts in d["toolsets"]
for agent_id, tools in toolsets.items()
],
},
}


# TODO: implement this part
# @rewrap_exceptions(
# {
# ValidationError: partialclass(HTTPException, status_code=400),
# TypeError: partialclass(HTTPException, status_code=400),
# }
# )
@wrap_in_class(
ChatContext,
one=True,
transform=_transform,
)
@cozo_query
@pg_query
@beartype
def prepare_chat_context(
async def prepare_chat_context(
*,
developer_id: UUID,
session_id: UUID,
) -> tuple[list[str], dict]:
) -> tuple[list[str], list]:
"""
Executes a complex query to retrieve memory context based on session ID.
"""

[*_, session_data_query], sd_vars = prepare_session_data.__wrapped__(
developer_id=developer_id, session_id=session_id
)

session_data_fields = ("session", "agents", "users")

session_data_query += """
:create _session_data_json {
agents: [Json],
users: [Json],
session: Json,
}
"""

toolsets_query = """
input[session_id] <- [[to_uuid($session_id)]]
tools_by_agent[agent_id, collect(tool)] :=
input[session_id],
*session_lookup{
session_id,
participant_id: agent_id,
participant_type: "agent",
},
*tools { agent_id, tool_id, name, type, spec, description, updated_at, created_at },
tool = {
"id": tool_id,
"name": name,
"type": type,
"spec": spec,
"description": description,
"updated_at": updated_at,
"created_at": created_at,
}
agent_toolsets[collect(toolset)] :=
tools_by_agent[agent_id, tools],
toolset = {
"agent_id": agent_id,
"tools": tools,
}
?[toolsets] :=
agent_toolsets[toolsets]
:create _toolsets_json {
toolsets: [Json],
}
"""

combine_query = f"""
?[{', '.join(session_data_fields)}, toolsets] :=
*_session_data_json {{ {', '.join(session_data_fields)} }},
*_toolsets_json {{ toolsets }}
:limit 1
"""

queries = [
verify_developer_id_query(developer_id),
verify_developer_owns_resource_query(
developer_id, "sessions", session_id=session_id
),
session_data_query,
toolsets_query,
combine_query,
]

return (
queries,
{
"session_id": str(session_id),
**sd_vars,
},
[query],
[developer_id, session_id],
)

0 comments on commit 63dc1bc

Please sign in to comment.