Skip to content

Commit

Permalink
feat(agents-api): Preliminary implementation of session.chat
Browse files Browse the repository at this point in the history
Signed-off-by: Diwank Tomer <[email protected]>
  • Loading branch information
Diwank Tomer committed Aug 11, 2024
1 parent aefe3ab commit 892226d
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 19 deletions.
4 changes: 4 additions & 0 deletions agents-api/agents_api/models/entry/create_entries.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from ...common.utils.messages import content_to_json
from ..utils import (
cozo_query,
mark_session_updated_query,
partialclass,
rewrap_exceptions,
verify_developer_id_query,
Expand Down Expand Up @@ -41,6 +42,7 @@ def create_entries(
developer_id: UUID,
session_id: UUID,
data: list[CreateEntryRequest],
mark_session_as_updated: bool = True,
) -> tuple[list[str], dict]:
developer_id = str(developer_id)
session_id = str(session_id)
Expand Down Expand Up @@ -76,6 +78,8 @@ def create_entries(
verify_developer_owns_resource_query(
developer_id, "sessions", session_id=session_id
),
mark_session_as_updated
and mark_session_updated_query(developer_id, session_id),
create_query,
]

Expand Down
5 changes: 4 additions & 1 deletion agents-api/agents_api/models/entry/delete_entries.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from ...common.utils.datetime import utcnow
from ..utils import (
cozo_query,
mark_session_updated_query,
partialclass,
rewrap_exceptions,
verify_developer_id_query,
Expand Down Expand Up @@ -37,7 +38,7 @@
@cozo_query
@beartype
def delete_entries_for_session(
*, developer_id: UUID, session_id: UUID
*, developer_id: UUID, session_id: UUID, mark_session_as_updated: bool = True
) -> tuple[list[str], dict]:
"""
Constructs and returns a datalog query for deleting entries associated with a given session ID from the 'cozodb' database.
Expand Down Expand Up @@ -79,6 +80,8 @@ def delete_entries_for_session(
verify_developer_owns_resource_query(
developer_id, "sessions", session_id=session_id
),
mark_session_as_updated
and mark_session_updated_query(developer_id, session_id),
delete_query,
]

Expand Down
22 changes: 22 additions & 0 deletions agents-api/agents_api/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,28 @@ class NewCls(cls):
return NewCls


def mark_session_updated_query(developer_id: UUID | str, session_id: UUID | str) -> str:
return f"""
input[developer_id, session_id] <- [[
to_uuid("{str(developer_id)}"),
to_uuid("{str(session_id)}"),
]]
?[developer_id, session_id, updated_at] :=
input[developer_id, session_id],
*sessions {{
session_id,
}},
updated_at = [floor(now()), true]
:update sessions {{
developer_id,
session_id,
updated_at,
}}
"""


def verify_developer_id_query(developer_id: UUID | str) -> str:
return f"""
matched[count(developer_id)] :=
Expand Down
86 changes: 68 additions & 18 deletions agents-api/agents_api/routers/sessions/chat.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,24 @@
from typing import Annotated
from uuid import UUID
from uuid import UUID, uuid4

from fastapi import Depends
from pydantic import UUID4
from fastapi import BackgroundTasks, Depends
from starlette.status import HTTP_201_CREATED

from ...autogen.openapi_model import (
ChatInput,
ChatResponse,
CreateEntryRequest,
DocReference,
History,
)
from ...clients.embed import embed
from ...clients.litellm import acompletion
from ...common.protocol.developers import Developer
from ...common.protocol.sessions import ChatContext
from ...common.utils.template import render_template
from ...dependencies.developer_id import get_developer_id
from ...dependencies.developer_id import get_developer_data
from ...models.docs.search_docs_hybrid import search_docs_hybrid
from ...models.entry.create_entries import create_entries
from ...models.entry.get_history import get_history
from ...models.session.prepare_chat_context import prepare_chat_context
from .router import router
Expand All @@ -23,48 +30,91 @@
tags=["sessions", "chat"],
)
async def chat(
x_developer_id: Annotated[UUID4, Depends(get_developer_id)],
developer: Annotated[Developer, Depends(get_developer_data)],
session_id: UUID,
data: ChatInput,
background_tasks: BackgroundTasks,
) -> ChatResponse:
# First get the chat context
chat_context = prepare_chat_context(
developer_id=x_developer_id,
agent_id=data.agent_id,
chat_context: ChatContext = prepare_chat_context(
developer_id=developer.id,
session_id=session_id,
)
assert isinstance(chat_context, ChatContext)

# Merge the settings and prepare environment
request_settings = data.settings
chat_context.merge_settings(request_settings)

chat_context.merge_settings(data)
settings: dict = chat_context.settings.model_dump()
env: dict = chat_context.get_chat_environment()

# Get the session history
history: History = get_history(
developer_id=x_developer_id,
developer_id=developer.id,
session_id=session_id,
allowed_sources=["api_request", "api_response", "tool_response", "summarizer"],
)

# Keep leaf nodes only
relations = history.relations
past_entries = [
past_messages = [
entry.model_dump()
for entry in history.entries
if entry.id not in {r.head for r in relations}
]

past_messages = render_template(past_entries, variables=env)
new_raw_messages = [msg.model_dump() for msg in data.messages]

messages = past_messages + [msg.model_dump() for msg in data.messages]
# Search matching docs
[query_embedding, *_] = await embed(
inputs=[
f"{msg.get('name') or msg['role']}: {msg['content']}"
for msg in new_raw_messages
],
join_inputs=True,
)
query_text = new_raw_messages[-1]["content"]

# TODO: Implement the chat logic here
print(messages)
doc_references: list[DocReference] = search_docs_hybrid(
developer_id=developer.id,
owner_type="agent",
owner_id=chat_context.get_active_agent().id,
query=query_text,
query_embedding=query_embedding,
)

# Render the messages
env["docs"] = doc_references
new_messages = render_template(new_raw_messages, variables=env)
messages = past_messages + new_messages

# Get the response from the model
model_response = await acompletion(
messages=messages,
**settings,
user=str(developer.id),
tags=developer.tags,
)

# Save the input and the response to the session history
new_entries = [CreateEntryRequest(**msg) for msg in new_messages]
background_tasks.add_task(
create_entries,
developer_id=developer.id,
session_id=session_id,
data=new_entries,
mark_session_as_updated=True,
)

# Return the response
raise NotImplementedError()
response_json = model_response.model_dump()
response_json.pop("id", None)

chat_response: ChatResponse = ChatResponse(
**response_json,
id=uuid4(),
created_at=model_response.created,
jobs=[],
docs=doc_references,
)

return chat_response

0 comments on commit 892226d

Please sign in to comment.