From 3ab775839ab1f7d678b293a2b5efb16aaeac8607 Mon Sep 17 00:00:00 2001 From: Diwank Tomer Date: Wed, 14 Aug 2024 18:54:42 -0400 Subject: [PATCH] refactor(agents-api): Move gather_messages to its own model Signed-off-by: Diwank Tomer --- agents-api/agents_api/models/chat/__init__.py | 22 +++++ .../agents_api/models/chat/gather_messages.py | 82 +++++++++++++++++ .../{session => chat}/get_cached_response.py | 0 .../{session => chat}/prepare_chat_context.py | 2 +- .../{session => chat}/set_cached_response.py | 0 .../agents_api/models/entry/create_entries.py | 6 +- .../agents_api/models/entry/get_history.py | 2 + .../agents_api/models/entry/list_entries.py | 2 + .../agents_api/models/session/__init__.py | 3 - agents-api/agents_api/models/utils.py | 4 + .../agents_api/routers/sessions/chat.py | 88 ++++--------------- agents-api/poetry.lock | 12 +-- 12 files changed, 137 insertions(+), 86 deletions(-) create mode 100644 agents-api/agents_api/models/chat/__init__.py create mode 100644 agents-api/agents_api/models/chat/gather_messages.py rename agents-api/agents_api/models/{session => chat}/get_cached_response.py (100%) rename agents-api/agents_api/models/{session => chat}/prepare_chat_context.py (98%) rename agents-api/agents_api/models/{session => chat}/set_cached_response.py (100%) diff --git a/agents-api/agents_api/models/chat/__init__.py b/agents-api/agents_api/models/chat/__init__.py new file mode 100644 index 000000000..428b72572 --- /dev/null +++ b/agents-api/agents_api/models/chat/__init__.py @@ -0,0 +1,22 @@ +""" +Module: agents_api/models/docs + +This module is responsible for managing document-related operations within the application, particularly for agents and possibly other entities. It serves as a core component of the document management system, enabling features such as document creation, listing, deletion, and embedding of snippets for enhanced search and retrieval capabilities. + +Main functionalities include: +- Creating new documents and associating them with agents or users. +- Listing documents based on various criteria, including ownership and metadata filters. +- Deleting documents by their unique identifiers. +- Embedding document snippets for retrieval purposes. + +The module interacts with other parts of the application, such as the agents and users modules, to provide a comprehensive document management system. Its role is crucial in enabling document search, retrieval, and management features within the context of agents and users. + +This documentation aims to provide clear, concise, and sufficient context for new developers or contributors to understand the module's role without needing to dive deep into the code immediately. +""" + +# ruff: noqa: F401, F403, F405 + +from .gather_messages import gather_messages +from .get_cached_response import get_cached_response +from .prepare_chat_context import prepare_chat_context +from .set_cached_response import set_cached_response diff --git a/agents-api/agents_api/models/chat/gather_messages.py b/agents-api/agents_api/models/chat/gather_messages.py new file mode 100644 index 000000000..2a3c0eca1 --- /dev/null +++ b/agents-api/agents_api/models/chat/gather_messages.py @@ -0,0 +1,82 @@ +from uuid import UUID + +from beartype import beartype +from fastapi import HTTPException +from pycozo.client import QueryException +from pydantic import ValidationError + +from agents_api.autogen.Chat import ChatInput + +from ...autogen.openapi_model import DocReference, History +from ...clients import embed +from ...common.protocol.developers import Developer +from ...common.protocol.sessions import ChatContext +from ..docs.search_docs_hybrid import search_docs_hybrid +from ..entry.get_history import get_history +from ..utils import ( + partialclass, + rewrap_exceptions, +) + + +@rewrap_exceptions( + { + QueryException: partialclass(HTTPException, status_code=400), + ValidationError: partialclass(HTTPException, status_code=400), + TypeError: partialclass(HTTPException, status_code=400), + } +) +@beartype +async def gather_messages( + *, + developer: Developer, + session_id: UUID, + chat_context: ChatContext, + chat_input: ChatInput, +): + new_raw_messages = [msg.model_dump() for msg in chat_input.messages] + recall = chat_input.recall + + assert len(new_raw_messages) > 0 + + # Get the session history + history: History = get_history( + developer_id=developer.id, + session_id=session_id, + allowed_sources=["api_request", "api_response", "tool_response", "summarizer"], + ) + + # Keep leaf nodes only + relations = history.relations + past_messages = [ + entry.model_dump() + for entry in history.entries + if entry.id not in {r.head for r in relations} + ] + + if not recall: + return past_messages, [] + + # Search matching docs + [query_embedding, *_] = await embed.embed( + inputs=[ + f"{msg.get('name') or msg['role']}: {msg['content']}" + for msg in new_raw_messages + ], + join_inputs=True, + ) + query_text = new_raw_messages[-1]["content"] + + # List all the applicable owners to search docs from + active_agent_id = chat_context.get_active_agent().id + user_ids = [user.id for user in chat_context.users] + owners = [("user", user_id) for user_id in user_ids] + [("agent", active_agent_id)] + + doc_references: list[DocReference] = search_docs_hybrid( + developer_id=developer.id, + owners=owners, + query=query_text, + query_embedding=query_embedding, + ) + + return past_messages, doc_references diff --git a/agents-api/agents_api/models/session/get_cached_response.py b/agents-api/agents_api/models/chat/get_cached_response.py similarity index 100% rename from agents-api/agents_api/models/session/get_cached_response.py rename to agents-api/agents_api/models/chat/get_cached_response.py diff --git a/agents-api/agents_api/models/session/prepare_chat_context.py b/agents-api/agents_api/models/chat/prepare_chat_context.py similarity index 98% rename from agents-api/agents_api/models/session/prepare_chat_context.py rename to agents-api/agents_api/models/chat/prepare_chat_context.py index 83e6c6f8b..0e076bc20 100644 --- a/agents-api/agents_api/models/session/prepare_chat_context.py +++ b/agents-api/agents_api/models/chat/prepare_chat_context.py @@ -7,6 +7,7 @@ from ...autogen.openapi_model import make_session from ...common.protocol.sessions import ChatContext +from ..session.prepare_session_data import prepare_session_data from ..utils import ( cozo_query, fix_uuid_if_present, @@ -16,7 +17,6 @@ verify_developer_owns_resource_query, wrap_in_class, ) -from .prepare_session_data import prepare_session_data @rewrap_exceptions( diff --git a/agents-api/agents_api/models/session/set_cached_response.py b/agents-api/agents_api/models/chat/set_cached_response.py similarity index 100% rename from agents-api/agents_api/models/session/set_cached_response.py rename to agents-api/agents_api/models/chat/set_cached_response.py diff --git a/agents-api/agents_api/models/entry/create_entries.py b/agents-api/agents_api/models/entry/create_entries.py index 01193d395..31c8b4d01 100644 --- a/agents-api/agents_api/models/entry/create_entries.py +++ b/agents-api/agents_api/models/entry/create_entries.py @@ -1,4 +1,3 @@ -import json from uuid import UUID, uuid4 from beartype import beartype @@ -34,6 +33,7 @@ "id": UUID(d.pop("entry_id")), **d, }, + _kind="inserted", ) @cozo_query @beartype @@ -55,10 +55,6 @@ def create_entries( item["entry_id"] = item.pop("id", None) or str(uuid4()) item["created_at"] = (item.get("created_at") or utcnow()).timestamp() - if not item.get("token_count"): - item["token_count"] = len(json.dumps(item)) // 3.5 - item["tokenizer"] = "character_count" - cols, rows = cozo_process_mutate_data(data_dicts) # Construct a datalog query to insert the processed entries into the 'cozodb' database. diff --git a/agents-api/agents_api/models/entry/get_history.py b/agents-api/agents_api/models/entry/get_history.py index 49eb7b929..fed658ea5 100644 --- a/agents-api/agents_api/models/entry/get_history.py +++ b/agents-api/agents_api/models/entry/get_history.py @@ -62,6 +62,7 @@ def get_history( content, source, token_count, + tokenizer, created_at, timestamp, }, @@ -75,6 +76,7 @@ def get_history( "content": content, "source": source, "token_count": token_count, + "tokenizer": tokenizer, "created_at": created_at, "timestamp": timestamp } diff --git a/agents-api/agents_api/models/entry/list_entries.py b/agents-api/agents_api/models/entry/list_entries.py index 0c47d9a74..da2341c4c 100644 --- a/agents-api/agents_api/models/entry/list_entries.py +++ b/agents-api/agents_api/models/entry/list_entries.py @@ -65,6 +65,7 @@ def list_entries( content, source, token_count, + tokenizer, created_at, timestamp, ] := *entries {{ @@ -75,6 +76,7 @@ def list_entries( content, source, token_count, + tokenizer, created_at, timestamp, }}, diff --git a/agents-api/agents_api/models/session/__init__.py b/agents-api/agents_api/models/session/__init__.py index b4092611f..bc5f7fbb4 100644 --- a/agents-api/agents_api/models/session/__init__.py +++ b/agents-api/agents_api/models/session/__init__.py @@ -14,11 +14,8 @@ from .create_or_update_session import create_or_update_session from .create_session import create_session from .delete_session import delete_session -from .get_cached_response import get_cached_response from .get_session import get_session from .list_sessions import list_sessions from .patch_session import patch_session -from .prepare_chat_context import prepare_chat_context from .prepare_session_data import prepare_session_data -from .set_cached_response import set_cached_response from .update_session import update_session diff --git a/agents-api/agents_api/models/utils.py b/agents-api/agents_api/models/utils.py index 411349ad3..ee2ba3fdd 100644 --- a/agents-api/agents_api/models/utils.py +++ b/agents-api/agents_api/models/utils.py @@ -232,6 +232,7 @@ def wrap_in_class( cls: Type[ModelT] | Callable[..., ModelT], one: bool = False, transform: Callable[[dict], dict] | None = None, + _kind: str | None = None, ): def decorator(func: Callable[P, pd.DataFrame]): @wraps(func) @@ -239,6 +240,9 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> ModelT | list[ModelT]: df = func(*args, **kwargs) # Convert df to list of dicts + if _kind: + df = df[df["_kind"] == _kind] + data = df.to_dict(orient="records") nonlocal transform diff --git a/agents-api/agents_api/routers/sessions/chat.py b/agents-api/agents_api/routers/sessions/chat.py index e6103c15e..8d0355de2 100644 --- a/agents-api/agents_api/routers/sessions/chat.py +++ b/agents-api/agents_api/routers/sessions/chat.py @@ -9,76 +9,20 @@ ChatResponse, ChunkChatResponse, CreateEntryRequest, - DocReference, - History, MessageChatResponse, ) -from ...clients import embed, litellm +from ...clients import litellm from ...common.protocol.developers import Developer from ...common.protocol.sessions import ChatContext from ...common.utils.datetime import utcnow from ...common.utils.template import render_template from ...dependencies.developer_id import get_developer_data -from ...models.docs.search_docs_hybrid import search_docs_hybrid +from ...models.chat.gather_messages import gather_messages +from ...models.chat.prepare_chat_context import prepare_chat_context from ...models.entry.create_entries import create_entries -from ...models.entry.get_history import get_history -from ...models.session.prepare_chat_context import prepare_chat_context from .router import router -async def get_messages( - *, - developer: Developer, - session_id: UUID, - new_raw_messages: list[dict], - chat_context: ChatContext, - recall: bool, -): - assert len(new_raw_messages) > 0 - - # Get the session history - history: History = get_history( - developer_id=developer.id, - session_id=session_id, - allowed_sources=["api_request", "api_response", "tool_response", "summarizer"], - ) - - # Keep leaf nodes only - relations = history.relations - past_messages = [ - entry.model_dump() - for entry in history.entries - if entry.id not in {r.head for r in relations} - ] - - if not recall: - return past_messages, [] - - # Search matching docs - [query_embedding, *_] = await embed.embed( - inputs=[ - f"{msg.get('name') or msg['role']}: {msg['content']}" - for msg in new_raw_messages - ], - join_inputs=True, - ) - query_text = new_raw_messages[-1]["content"] - - # List all the applicable owners to search docs from - active_agent_id = chat_context.get_active_agent().id - user_ids = [user.id for user in chat_context.users] - owners = [("user", user_id) for user_id in user_ids] + [("agent", active_agent_id)] - - doc_references: list[DocReference] = search_docs_hybrid( - developer_id=developer.id, - owners=owners, - query=query_text, - query_embedding=query_embedding, - ) - - return past_messages, doc_references - - @router.post( "/sessions/{session_id}/chat", status_code=HTTP_201_CREATED, @@ -87,7 +31,7 @@ async def get_messages( async def chat( developer: Annotated[Developer, Depends(get_developer_data)], session_id: UUID, - input: ChatInput, + chat_input: ChatInput, background_tasks: BackgroundTasks, ) -> ChatResponse: # First get the chat context @@ -97,18 +41,17 @@ async def chat( ) # Merge the settings and prepare environment - chat_context.merge_settings(input) + chat_context.merge_settings(chat_input) settings: dict = chat_context.settings.model_dump() env: dict = chat_context.get_chat_environment() - new_raw_messages = [msg.model_dump() for msg in input.messages] + new_raw_messages = [msg.model_dump() for msg in chat_input.messages] # Render the messages - past_messages, doc_references = await get_messages( + past_messages, doc_references = await gather_messages( developer=developer, session_id=session_id, - new_raw_messages=new_raw_messages, chat_context=chat_context, - recall=input.recall, + chat_input=chat_input, ) env["docs"] = doc_references @@ -118,7 +61,7 @@ async def chat( # Get the tools tools = settings.get("tools") or chat_context.get_active_tools() - # Truncate the messages if necessary + # TODO: Truncate the messages if necessary if chat_context.session.context_overflow == "truncate": # messages = messages[-settings["max_tokens"] :] raise NotImplementedError("Truncation is not yet implemented") @@ -133,11 +76,12 @@ async def chat( ) # Save the input and the response to the session history - if input.save: - # TODO: Count the number of tokens before saving it to the session - + if chat_input.save: new_entries = [ - CreateEntryRequest(**msg, source="api_request") for msg in new_messages + CreateEntryRequest.from_model_input( + model=settings["model"], **msg, source="api_request" + ) + for msg in new_messages ] background_tasks.add_task( @@ -156,7 +100,9 @@ async def chat( raise NotImplementedError("Adaptive context is not yet implemented") # Return the response - chat_response_class = ChunkChatResponse if input.stream else MessageChatResponse + chat_response_class = ( + ChunkChatResponse if chat_input.stream else MessageChatResponse + ) chat_response: ChatResponse = chat_response_class( id=uuid4(), created_at=utcnow(), diff --git a/agents-api/poetry.lock b/agents-api/poetry.lock index b8c0a42b7..6a29c9f4a 100644 --- a/agents-api/poetry.lock +++ b/agents-api/poetry.lock @@ -2149,13 +2149,13 @@ tenacity = ">=8.1.0,<8.4.0 || >8.4.0,<9.0.0" [[package]] name = "langchain-core" -version = "0.2.30" +version = "0.2.31" description = "Building applications with LLMs through composability" optional = false python-versions = "<4.0,>=3.8.1" files = [ - {file = "langchain_core-0.2.30-py3-none-any.whl", hash = "sha256:ea7eccb9566dd51b2b74bd292c4239d843a77cdba8ffae2b5edf7000d70d4194"}, - {file = "langchain_core-0.2.30.tar.gz", hash = "sha256:552ec586698140062cd299a83bad7e308f925b496d306b62529579c6fb122f7a"}, + {file = "langchain_core-0.2.31-py3-none-any.whl", hash = "sha256:b4daf5ddc23c0c3d8c5fd1a6c118f95fb5d0f96067b43f2c5935e1cd572e4374"}, + {file = "langchain_core-0.2.31.tar.gz", hash = "sha256:afb2089d4c10842d2477dc5cfa9ae9feb415c1421c6ef9aa608fea879ee41769"}, ] [package.dependencies] @@ -2255,13 +2255,13 @@ dev = ["Sphinx (>=5.1.1)", "black (==23.12.1)", "build (>=0.10.0)", "coverage (> [[package]] name = "litellm" -version = "1.43.9" +version = "1.43.12" description = "Library to easily interface with LLM API providers" optional = false python-versions = "!=2.7.*,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,!=3.7.*,>=3.8" files = [ - {file = "litellm-1.43.9-py3-none-any.whl", hash = "sha256:54253281139e61f130b7e1a613a11f7a5ee896c2ee8536b0ca9a5ffbfce4c5f0"}, - {file = "litellm-1.43.9.tar.gz", hash = "sha256:c397a14c9b851f007f09c99e5a28606f7f122fdb4ae954931220f60e9edc6918"}, + {file = "litellm-1.43.12-py3-none-any.whl", hash = "sha256:f2c5f498a079df6eb8448ac41704367a389ea679a22e195c79b7963ede5cc462"}, + {file = "litellm-1.43.12.tar.gz", hash = "sha256:719eca58904942465dfd827e9d8f317112996ef481db71f9562f5263a553c74a"}, ] [package.dependencies]