From 892226d79db1151eeda07dfbeb1390d155f93f50 Mon Sep 17 00:00:00 2001 From: Diwank Tomer Date: Sun, 11 Aug 2024 19:23:42 -0400 Subject: [PATCH] feat(agents-api): Preliminary implementation of session.chat Signed-off-by: Diwank Tomer --- .../agents_api/models/entry/create_entries.py | 4 + .../agents_api/models/entry/delete_entries.py | 5 +- agents-api/agents_api/models/utils.py | 22 +++++ .../agents_api/routers/sessions/chat.py | 86 +++++++++++++++---- 4 files changed, 98 insertions(+), 19 deletions(-) diff --git a/agents-api/agents_api/models/entry/create_entries.py b/agents-api/agents_api/models/entry/create_entries.py index 72551dacc..68d644266 100644 --- a/agents-api/agents_api/models/entry/create_entries.py +++ b/agents-api/agents_api/models/entry/create_entries.py @@ -12,6 +12,7 @@ from ...common.utils.messages import content_to_json from ..utils import ( cozo_query, + mark_session_updated_query, partialclass, rewrap_exceptions, verify_developer_id_query, @@ -41,6 +42,7 @@ def create_entries( developer_id: UUID, session_id: UUID, data: list[CreateEntryRequest], + mark_session_as_updated: bool = True, ) -> tuple[list[str], dict]: developer_id = str(developer_id) session_id = str(session_id) @@ -76,6 +78,8 @@ def create_entries( verify_developer_owns_resource_query( developer_id, "sessions", session_id=session_id ), + mark_session_as_updated + and mark_session_updated_query(developer_id, session_id), create_query, ] diff --git a/agents-api/agents_api/models/entry/delete_entries.py b/agents-api/agents_api/models/entry/delete_entries.py index a156275b0..f64bfbf73 100644 --- a/agents-api/agents_api/models/entry/delete_entries.py +++ b/agents-api/agents_api/models/entry/delete_entries.py @@ -9,6 +9,7 @@ from ...common.utils.datetime import utcnow from ..utils import ( cozo_query, + mark_session_updated_query, partialclass, rewrap_exceptions, verify_developer_id_query, @@ -37,7 +38,7 @@ @cozo_query @beartype def delete_entries_for_session( - *, developer_id: UUID, session_id: UUID + *, developer_id: UUID, session_id: UUID, mark_session_as_updated: bool = True ) -> tuple[list[str], dict]: """ Constructs and returns a datalog query for deleting entries associated with a given session ID from the 'cozodb' database. @@ -79,6 +80,8 @@ def delete_entries_for_session( verify_developer_owns_resource_query( developer_id, "sessions", session_id=session_id ), + mark_session_as_updated + and mark_session_updated_query(developer_id, session_id), delete_query, ] diff --git a/agents-api/agents_api/models/utils.py b/agents-api/agents_api/models/utils.py index 73c3f3e0b..2939b2208 100644 --- a/agents-api/agents_api/models/utils.py +++ b/agents-api/agents_api/models/utils.py @@ -64,6 +64,28 @@ class NewCls(cls): return NewCls +def mark_session_updated_query(developer_id: UUID | str, session_id: UUID | str) -> str: + return f""" + input[developer_id, session_id] <- [[ + to_uuid("{str(developer_id)}"), + to_uuid("{str(session_id)}"), + ]] + + ?[developer_id, session_id, updated_at] := + input[developer_id, session_id], + *sessions {{ + session_id, + }}, + updated_at = [floor(now()), true] + + :update sessions {{ + developer_id, + session_id, + updated_at, + }} + """ + + def verify_developer_id_query(developer_id: UUID | str) -> str: return f""" matched[count(developer_id)] := diff --git a/agents-api/agents_api/routers/sessions/chat.py b/agents-api/agents_api/routers/sessions/chat.py index 368fe2ea2..225e25163 100644 --- a/agents-api/agents_api/routers/sessions/chat.py +++ b/agents-api/agents_api/routers/sessions/chat.py @@ -1,17 +1,24 @@ from typing import Annotated -from uuid import UUID +from uuid import UUID, uuid4 -from fastapi import Depends -from pydantic import UUID4 +from fastapi import BackgroundTasks, Depends from starlette.status import HTTP_201_CREATED from ...autogen.openapi_model import ( ChatInput, ChatResponse, + CreateEntryRequest, + DocReference, History, ) +from ...clients.embed import embed +from ...clients.litellm import acompletion +from ...common.protocol.developers import Developer +from ...common.protocol.sessions import ChatContext from ...common.utils.template import render_template -from ...dependencies.developer_id import get_developer_id +from ...dependencies.developer_id import get_developer_data +from ...models.docs.search_docs_hybrid import search_docs_hybrid +from ...models.entry.create_entries import create_entries from ...models.entry.get_history import get_history from ...models.session.prepare_chat_context import prepare_chat_context from .router import router @@ -23,48 +30,91 @@ tags=["sessions", "chat"], ) async def chat( - x_developer_id: Annotated[UUID4, Depends(get_developer_id)], + developer: Annotated[Developer, Depends(get_developer_data)], session_id: UUID, data: ChatInput, + background_tasks: BackgroundTasks, ) -> ChatResponse: # First get the chat context - chat_context = prepare_chat_context( - developer_id=x_developer_id, - agent_id=data.agent_id, + chat_context: ChatContext = prepare_chat_context( + developer_id=developer.id, session_id=session_id, ) + assert isinstance(chat_context, ChatContext) # Merge the settings and prepare environment - request_settings = data.settings - chat_context.merge_settings(request_settings) - + chat_context.merge_settings(data) + settings: dict = chat_context.settings.model_dump() env: dict = chat_context.get_chat_environment() # Get the session history history: History = get_history( - developer_id=x_developer_id, + developer_id=developer.id, session_id=session_id, allowed_sources=["api_request", "api_response", "tool_response", "summarizer"], ) # Keep leaf nodes only relations = history.relations - past_entries = [ + past_messages = [ entry.model_dump() for entry in history.entries if entry.id not in {r.head for r in relations} ] - past_messages = render_template(past_entries, variables=env) + new_raw_messages = [msg.model_dump() for msg in data.messages] - messages = past_messages + [msg.model_dump() for msg in data.messages] + # Search matching docs + [query_embedding, *_] = await embed( + inputs=[ + f"{msg.get('name') or msg['role']}: {msg['content']}" + for msg in new_raw_messages + ], + join_inputs=True, + ) + query_text = new_raw_messages[-1]["content"] - # TODO: Implement the chat logic here - print(messages) + doc_references: list[DocReference] = search_docs_hybrid( + developer_id=developer.id, + owner_type="agent", + owner_id=chat_context.get_active_agent().id, + query=query_text, + query_embedding=query_embedding, + ) + + # Render the messages + env["docs"] = doc_references + new_messages = render_template(new_raw_messages, variables=env) + messages = past_messages + new_messages # Get the response from the model + model_response = await acompletion( + messages=messages, + **settings, + user=str(developer.id), + tags=developer.tags, + ) # Save the input and the response to the session history + new_entries = [CreateEntryRequest(**msg) for msg in new_messages] + background_tasks.add_task( + create_entries, + developer_id=developer.id, + session_id=session_id, + data=new_entries, + mark_session_as_updated=True, + ) # Return the response - raise NotImplementedError() + response_json = model_response.model_dump() + response_json.pop("id", None) + + chat_response: ChatResponse = ChatResponse( + **response_json, + id=uuid4(), + created_at=model_response.created, + jobs=[], + docs=doc_references, + ) + + return chat_response