From 7d73a6aa7fdf83767940d4af987a8d2830781ef7 Mon Sep 17 00:00:00 2001 From: Dmitry Paramonov Date: Thu, 19 Dec 2024 15:26:58 +0300 Subject: [PATCH] feat: Add prepare chat context query --- .../queries/chat/gather_messages.py | 12 +- .../queries/chat/prepare_chat_context.py | 161 +++++++++++++----- 2 files changed, 124 insertions(+), 49 deletions(-) diff --git a/agents-api/agents_api/queries/chat/gather_messages.py b/agents-api/agents_api/queries/chat/gather_messages.py index 28dc6607f..34a7c564f 100644 --- a/agents-api/agents_api/queries/chat/gather_messages.py +++ b/agents-api/agents_api/queries/chat/gather_messages.py @@ -3,18 +3,17 @@ from beartype import beartype from fastapi import HTTPException -from pycozo.client import QueryException from pydantic import ValidationError from ...autogen.openapi_model import ChatInput, DocReference, History from ...clients import litellm from ...common.protocol.developers import Developer from ...common.protocol.sessions import ChatContext -from ..docs.search_docs_by_embedding import search_docs_by_embedding -from ..docs.search_docs_by_text import search_docs_by_text -from ..docs.search_docs_hybrid import search_docs_hybrid -from ..entry.get_history import get_history -from ..session.get_session import get_session +# from ..docs.search_docs_by_embedding import search_docs_by_embedding +# from ..docs.search_docs_by_text import search_docs_by_text +# from ..docs.search_docs_hybrid import search_docs_hybrid +# from ..entry.get_history import get_history +from ..sessions.get_session import get_session from ..utils import ( partialclass, rewrap_exceptions, @@ -25,7 +24,6 @@ @rewrap_exceptions( { - QueryException: partialclass(HTTPException, status_code=400), ValidationError: partialclass(HTTPException, status_code=400), TypeError: partialclass(HTTPException, status_code=400), } diff --git a/agents-api/agents_api/queries/chat/prepare_chat_context.py b/agents-api/agents_api/queries/chat/prepare_chat_context.py index 645abd2fe..23926ea4c 100644 --- a/agents-api/agents_api/queries/chat/prepare_chat_context.py +++ b/agents-api/agents_api/queries/chat/prepare_chat_context.py @@ -4,7 +4,6 @@ from beartype import beartype from ...common.protocol.sessions import ChatContext, make_session -from ..session.prepare_session_data import prepare_session_data from ..utils import ( pg_query, wrap_in_class, @@ -13,37 +12,108 @@ ModelT = TypeVar("ModelT", bound=Any) T = TypeVar("T") -# tool_id, name, type, spec, description, updated_at, created_at -get_agent_query = """ -SELECT - session_lookup.agent_id, - tools.tool_id, - tools.name, - tools.type, - tools.spec, - tools.description, - tools.updated_at, - tools.created_at -FROM session_lookup -JOIN ON tools -WHERE - developer_id = $1 AND - session_id = $2 and participant_type == 'agent AND - session_lookup.agent_id = tools.agent_id + +query = """ +SELECT * FROM +( + SELECT jsonb_agg(u) AS users FROM ( + SELECT + session_lookup.participant_id, + users.user_id AS id, + users.developer_id, + users.name, + users.about, + users.created_at, + users.updated_at, + users.metadata + FROM session_lookup + INNER JOIN users ON session_lookup.participant_id = users.user_id + WHERE + session_lookup.developer_id = $1 AND + session_id = $2 AND + session_lookup.participant_type = 'user' + ) u +) AS users, +( + SELECT jsonb_agg(a) AS agents FROM ( + SELECT + session_lookup.participant_id, + agents.agent_id AS id, + agents.developer_id, + agents.canonical_name, + agents.name, + agents.about, + agents.instructions, + agents.model, + agents.created_at, + agents.updated_at, + agents.metadata, + agents.default_settings + FROM session_lookup + INNER JOIN agents ON session_lookup.participant_id = agents.agent_id + WHERE + session_lookup.developer_id = $1 AND + session_id = $2 AND + session_lookup.participant_type = 'agent' + ) a +) AS agents, +( + SELECT to_jsonb(s) AS session FROM ( + SELECT + sessions.session_id AS id, + sessions.developer_id, + sessions.situation, + sessions.system_template, + sessions.created_at, + sessions.metadata, + sessions.render_templates, + sessions.token_budget, + sessions.context_overflow, + sessions.forward_tool_calls, + sessions.recall_options + FROM sessions + WHERE + developer_id = $1 AND + session_id = $2 + LIMIT 1 + ) s +) AS session, +( + SELECT jsonb_agg(r) AS toolsets FROM ( + SELECT + session_lookup.participant_id, + tools.tool_id as id, + tools.developer_id, + tools.agent_id, + tools.task_id, + tools.task_version, + tools.type, + tools.name, + tools.description, + tools.spec, + tools.updated_at, + tools.created_at + FROM session_lookup + INNER JOIN tools ON session_lookup.participant_id = tools.agent_id + WHERE + session_lookup.developer_id = $1 AND + session_id = $2 AND + session_lookup.participant_type = 'agent' + ) r +) AS toolsets """ -# TODO: implement this part -# @rewrap_exceptions( -# { -# ValidationError: partialclass(HTTPException, status_code=400), -# TypeError: partialclass(HTTPException, status_code=400), -# } -# ) -@wrap_in_class( - ChatContext, - one=True, - transform=lambda d: { +def _transform(d): + toolsets = {} + for tool in d["toolsets"]: + agent_id = tool["agent_id"] + if agent_id in toolsets: + toolsets[agent_id].append(tool) + else: + toolsets[agent_id] = [tool] + + return { **d, "session": make_session( agents=[a["id"] for a in d["agents"]], @@ -52,37 +122,44 @@ ), "toolsets": [ { - **ts, + "agent_id": agent_id, "tools": [ { tool["type"]: tool.pop("spec"), **tool, } - for tool in map(fix_uuid_if_present, ts["tools"]) + for tool in tools ], } - for ts in d["toolsets"] + for agent_id, tools in toolsets.items() ], - }, + } + + +# TODO: implement this part +# @rewrap_exceptions( +# { +# ValidationError: partialclass(HTTPException, status_code=400), +# TypeError: partialclass(HTTPException, status_code=400), +# } +# ) +@wrap_in_class( + ChatContext, + one=True, + transform=_transform, ) @pg_query @beartype -def prepare_chat_context( +async def prepare_chat_context( *, developer_id: UUID, session_id: UUID, -) -> tuple[list[str], dict]: +) -> tuple[list[str], list]: """ Executes a complex query to retrieve memory context based on session ID. """ - [*_, session_data_query], sd_vars = prepare_session_data.__wrapped__( - developer_id=developer_id, session_id=session_id - ) - - queries = [session_data_query, get_agent_query] - return ( - queries, + [query], [developer_id, session_id], )