Skip to content

Commit

Permalink
refactor(agents-api): Move gather_messages to its own model
Browse files Browse the repository at this point in the history
Signed-off-by: Diwank Tomer <[email protected]>
  • Loading branch information
Diwank Tomer committed Aug 14, 2024
1 parent b1fc2a4 commit 3ab7758
Show file tree
Hide file tree
Showing 12 changed files with 137 additions and 86 deletions.
22 changes: 22 additions & 0 deletions agents-api/agents_api/models/chat/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
"""
Module: agents_api/models/docs
This module is responsible for managing document-related operations within the application, particularly for agents and possibly other entities. It serves as a core component of the document management system, enabling features such as document creation, listing, deletion, and embedding of snippets for enhanced search and retrieval capabilities.
Main functionalities include:
- Creating new documents and associating them with agents or users.
- Listing documents based on various criteria, including ownership and metadata filters.
- Deleting documents by their unique identifiers.
- Embedding document snippets for retrieval purposes.
The module interacts with other parts of the application, such as the agents and users modules, to provide a comprehensive document management system. Its role is crucial in enabling document search, retrieval, and management features within the context of agents and users.
This documentation aims to provide clear, concise, and sufficient context for new developers or contributors to understand the module's role without needing to dive deep into the code immediately.
"""

# ruff: noqa: F401, F403, F405

from .gather_messages import gather_messages
from .get_cached_response import get_cached_response
from .prepare_chat_context import prepare_chat_context
from .set_cached_response import set_cached_response
82 changes: 82 additions & 0 deletions agents-api/agents_api/models/chat/gather_messages.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
from uuid import UUID

from beartype import beartype
from fastapi import HTTPException
from pycozo.client import QueryException
from pydantic import ValidationError

from agents_api.autogen.Chat import ChatInput

from ...autogen.openapi_model import DocReference, History
from ...clients import embed
from ...common.protocol.developers import Developer
from ...common.protocol.sessions import ChatContext
from ..docs.search_docs_hybrid import search_docs_hybrid
from ..entry.get_history import get_history
from ..utils import (
partialclass,
rewrap_exceptions,
)


@rewrap_exceptions(
{
QueryException: partialclass(HTTPException, status_code=400),
ValidationError: partialclass(HTTPException, status_code=400),
TypeError: partialclass(HTTPException, status_code=400),
}
)
@beartype
async def gather_messages(
*,
developer: Developer,
session_id: UUID,
chat_context: ChatContext,
chat_input: ChatInput,
):
new_raw_messages = [msg.model_dump() for msg in chat_input.messages]
recall = chat_input.recall

assert len(new_raw_messages) > 0

# Get the session history
history: History = get_history(
developer_id=developer.id,
session_id=session_id,
allowed_sources=["api_request", "api_response", "tool_response", "summarizer"],
)

# Keep leaf nodes only
relations = history.relations
past_messages = [
entry.model_dump()
for entry in history.entries
if entry.id not in {r.head for r in relations}
]

if not recall:
return past_messages, []

# Search matching docs
[query_embedding, *_] = await embed.embed(
inputs=[
f"{msg.get('name') or msg['role']}: {msg['content']}"
for msg in new_raw_messages
],
join_inputs=True,
)
query_text = new_raw_messages[-1]["content"]

# List all the applicable owners to search docs from
active_agent_id = chat_context.get_active_agent().id
user_ids = [user.id for user in chat_context.users]
owners = [("user", user_id) for user_id in user_ids] + [("agent", active_agent_id)]

doc_references: list[DocReference] = search_docs_hybrid(
developer_id=developer.id,
owners=owners,
query=query_text,
query_embedding=query_embedding,
)

return past_messages, doc_references
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from ...autogen.openapi_model import make_session
from ...common.protocol.sessions import ChatContext
from ..session.prepare_session_data import prepare_session_data
from ..utils import (
cozo_query,
fix_uuid_if_present,
Expand All @@ -16,7 +17,6 @@
verify_developer_owns_resource_query,
wrap_in_class,
)
from .prepare_session_data import prepare_session_data


@rewrap_exceptions(
Expand Down
6 changes: 1 addition & 5 deletions agents-api/agents_api/models/entry/create_entries.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import json
from uuid import UUID, uuid4

from beartype import beartype
Expand Down Expand Up @@ -34,6 +33,7 @@
"id": UUID(d.pop("entry_id")),
**d,
},
_kind="inserted",
)
@cozo_query
@beartype
Expand All @@ -55,10 +55,6 @@ def create_entries(
item["entry_id"] = item.pop("id", None) or str(uuid4())
item["created_at"] = (item.get("created_at") or utcnow()).timestamp()

if not item.get("token_count"):
item["token_count"] = len(json.dumps(item)) // 3.5
item["tokenizer"] = "character_count"

cols, rows = cozo_process_mutate_data(data_dicts)

# Construct a datalog query to insert the processed entries into the 'cozodb' database.
Expand Down
2 changes: 2 additions & 0 deletions agents-api/agents_api/models/entry/get_history.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def get_history(
content,
source,
token_count,
tokenizer,
created_at,
timestamp,
},
Expand All @@ -75,6 +76,7 @@ def get_history(
"content": content,
"source": source,
"token_count": token_count,
"tokenizer": tokenizer,
"created_at": created_at,
"timestamp": timestamp
}
Expand Down
2 changes: 2 additions & 0 deletions agents-api/agents_api/models/entry/list_entries.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def list_entries(
content,
source,
token_count,
tokenizer,
created_at,
timestamp,
] := *entries {{
Expand All @@ -75,6 +76,7 @@ def list_entries(
content,
source,
token_count,
tokenizer,
created_at,
timestamp,
}},
Expand Down
3 changes: 0 additions & 3 deletions agents-api/agents_api/models/session/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,8 @@
from .create_or_update_session import create_or_update_session
from .create_session import create_session
from .delete_session import delete_session
from .get_cached_response import get_cached_response
from .get_session import get_session
from .list_sessions import list_sessions
from .patch_session import patch_session
from .prepare_chat_context import prepare_chat_context
from .prepare_session_data import prepare_session_data
from .set_cached_response import set_cached_response
from .update_session import update_session
4 changes: 4 additions & 0 deletions agents-api/agents_api/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,13 +232,17 @@ def wrap_in_class(
cls: Type[ModelT] | Callable[..., ModelT],
one: bool = False,
transform: Callable[[dict], dict] | None = None,
_kind: str | None = None,
):
def decorator(func: Callable[P, pd.DataFrame]):
@wraps(func)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> ModelT | list[ModelT]:
df = func(*args, **kwargs)

# Convert df to list of dicts
if _kind:
df = df[df["_kind"] == _kind]

data = df.to_dict(orient="records")

nonlocal transform
Expand Down
88 changes: 17 additions & 71 deletions agents-api/agents_api/routers/sessions/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,76 +9,20 @@
ChatResponse,
ChunkChatResponse,
CreateEntryRequest,
DocReference,
History,
MessageChatResponse,
)
from ...clients import embed, litellm
from ...clients import litellm
from ...common.protocol.developers import Developer
from ...common.protocol.sessions import ChatContext
from ...common.utils.datetime import utcnow
from ...common.utils.template import render_template
from ...dependencies.developer_id import get_developer_data
from ...models.docs.search_docs_hybrid import search_docs_hybrid
from ...models.chat.gather_messages import gather_messages
from ...models.chat.prepare_chat_context import prepare_chat_context
from ...models.entry.create_entries import create_entries
from ...models.entry.get_history import get_history
from ...models.session.prepare_chat_context import prepare_chat_context
from .router import router


async def get_messages(
*,
developer: Developer,
session_id: UUID,
new_raw_messages: list[dict],
chat_context: ChatContext,
recall: bool,
):
assert len(new_raw_messages) > 0

# Get the session history
history: History = get_history(
developer_id=developer.id,
session_id=session_id,
allowed_sources=["api_request", "api_response", "tool_response", "summarizer"],
)

# Keep leaf nodes only
relations = history.relations
past_messages = [
entry.model_dump()
for entry in history.entries
if entry.id not in {r.head for r in relations}
]

if not recall:
return past_messages, []

# Search matching docs
[query_embedding, *_] = await embed.embed(
inputs=[
f"{msg.get('name') or msg['role']}: {msg['content']}"
for msg in new_raw_messages
],
join_inputs=True,
)
query_text = new_raw_messages[-1]["content"]

# List all the applicable owners to search docs from
active_agent_id = chat_context.get_active_agent().id
user_ids = [user.id for user in chat_context.users]
owners = [("user", user_id) for user_id in user_ids] + [("agent", active_agent_id)]

doc_references: list[DocReference] = search_docs_hybrid(
developer_id=developer.id,
owners=owners,
query=query_text,
query_embedding=query_embedding,
)

return past_messages, doc_references


@router.post(
"/sessions/{session_id}/chat",
status_code=HTTP_201_CREATED,
Expand All @@ -87,7 +31,7 @@ async def get_messages(
async def chat(
developer: Annotated[Developer, Depends(get_developer_data)],
session_id: UUID,
input: ChatInput,
chat_input: ChatInput,
background_tasks: BackgroundTasks,
) -> ChatResponse:
# First get the chat context
Expand All @@ -97,18 +41,17 @@ async def chat(
)

# Merge the settings and prepare environment
chat_context.merge_settings(input)
chat_context.merge_settings(chat_input)
settings: dict = chat_context.settings.model_dump()
env: dict = chat_context.get_chat_environment()
new_raw_messages = [msg.model_dump() for msg in input.messages]
new_raw_messages = [msg.model_dump() for msg in chat_input.messages]

# Render the messages
past_messages, doc_references = await get_messages(
past_messages, doc_references = await gather_messages(
developer=developer,
session_id=session_id,
new_raw_messages=new_raw_messages,
chat_context=chat_context,
recall=input.recall,
chat_input=chat_input,
)

env["docs"] = doc_references
Expand All @@ -118,7 +61,7 @@ async def chat(
# Get the tools
tools = settings.get("tools") or chat_context.get_active_tools()

# Truncate the messages if necessary
# TODO: Truncate the messages if necessary
if chat_context.session.context_overflow == "truncate":
# messages = messages[-settings["max_tokens"] :]
raise NotImplementedError("Truncation is not yet implemented")
Expand All @@ -133,11 +76,12 @@ async def chat(
)

# Save the input and the response to the session history
if input.save:
# TODO: Count the number of tokens before saving it to the session

if chat_input.save:
new_entries = [
CreateEntryRequest(**msg, source="api_request") for msg in new_messages
CreateEntryRequest.from_model_input(
model=settings["model"], **msg, source="api_request"
)
for msg in new_messages
]

background_tasks.add_task(
Expand All @@ -156,7 +100,9 @@ async def chat(
raise NotImplementedError("Adaptive context is not yet implemented")

# Return the response
chat_response_class = ChunkChatResponse if input.stream else MessageChatResponse
chat_response_class = (
ChunkChatResponse if chat_input.stream else MessageChatResponse
)
chat_response: ChatResponse = chat_response_class(
id=uuid4(),
created_at=utcnow(),
Expand Down
12 changes: 6 additions & 6 deletions agents-api/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 3ab7758

Please sign in to comment.