diff --git a/agents-api/agents_api/activities/summarization.py b/agents-api/agents_api/activities/summarization.py index aa7fa4740..b3ed45ecc 100644 --- a/agents-api/agents_api/activities/summarization.py +++ b/agents-api/agents_api/activities/summarization.py @@ -1,86 +1,90 @@ #!/usr/bin/env python3 -import pandas as pd +import asyncio +from uuid import UUID, uuid4 + from beartype import beartype from temporalio import activity -# from agents_api.models.entry.entries_summarization import ( -# entries_summarization_query, -# get_toplevel_entries_query, -# ) - - -# TODO: Implement entry summarization queries -# SCRUM-3 -def entries_summarization_query(*args, **kwargs) -> pd.DataFrame: - return pd.DataFrame() - - -def get_toplevel_entries_query(*args, **kwargs) -> pd.DataFrame: - return pd.DataFrame() - - -# TODO: Implement entry summarization activities -# SCRUM-4 +from agents_api.autogen.openapi_model import Entry +from agents_api.common.utils.datetime import utcnow +from agents_api.env import summarization_model_name +from agents_api.models.entry.entries_summarization import ( + entries_summarization_query, + get_toplevel_entries_query, +) +from agents_api.rec_sum.entities import get_entities +from agents_api.rec_sum.summarize import summarize_messages +from agents_api.rec_sum.trim import trim_messages @activity.defn @beartype async def summarization(session_id: str) -> None: - raise NotImplementedError() - - # session_id = UUID(session_id) - # entries = [] - # entities_entry_ids = [] - # for _, row in get_toplevel_entries_query(session_id=session_id).iterrows(): - # if row["role"] == "system" and row.get("name") == "entities": - # entities_entry_ids.append(UUID(row["entry_id"], version=4)) - # else: - # entries.append(row) - - # assert len(entries) > 0, "no need to summarize on empty entries list" - - # summarized, entities = await asyncio.gather( - # summarize_messages(entries, model=summarization_model_name), - # get_entities(entries, model=summarization_model_name), - # ) - # trimmed_messages = await trim_messages(summarized, model=summarization_model_name) - # ts_delta = (entries[1]["timestamp"] - entries[0]["timestamp"]) / 2 - # new_entities_entry = Entry( - # session_id=session_id, - # source="summarizer", - # role="system", - # name="entities", - # content=entities["content"], - # timestamp=entries[0]["timestamp"] + ts_delta, - # ) - - # entries_summarization_query( - # session_id=session_id, - # new_entry=new_entities_entry, - # old_entry_ids=entities_entry_ids, - # ) - - # trimmed_map = { - # m["index"]: m["content"] for m in trimmed_messages if m.get("index") is not None - # } - - # for idx, msg in enumerate(summarized): - # new_entry = Entry( - # session_id=session_id, - # source="summarizer", - # role="system", - # name="information", - # content=trimmed_map.get(idx, msg["content"]), - # timestamp=entries[-1]["timestamp"] + 0.01, - # ) - - # entries_summarization_query( - # session_id=session_id, - # new_entry=new_entry, - # old_entry_ids=[ - # UUID(entries[idx - 1]["entry_id"], version=4) - # for idx in msg["summarizes"] - # ], - # ) + session_id = UUID(session_id) + entries = [] + entities_entry_ids = [] + for _, row in get_toplevel_entries_query(session_id=session_id).iterrows(): + if row["role"] == "system" and row.get("name") == "entities": + entities_entry_ids.append(UUID(row["entry_id"], version=4)) + else: + entries.append(row) + + assert len(entries) > 0, "no need to summarize on empty entries list" + + summarized, entities = await asyncio.gather( + summarize_messages(entries, model=summarization_model_name), + get_entities(entries, model=summarization_model_name), + ) + trimmed_messages = await trim_messages(summarized, model=summarization_model_name) + ts_delta = (entries[1]["timestamp"] - entries[0]["timestamp"]) / 2 + # TODO: set tokenizer, double check token_count calculation + new_entities_entry = Entry( + id=uuid4(), + session_id=session_id, + source="summarizer", + role="system", + name="entities", + content=entities["content"], + timestamp=entries[0]["timestamp"] + ts_delta, + token_count=sum([len(c) // 3.5 for c in entities["content"]]), + created_at=utcnow(), + tokenizer="", + ) + + entries_summarization_query( + session_id=session_id, + new_entry=new_entities_entry, + old_entry_ids=entities_entry_ids, + ) + + trimmed_map = { + m["index"]: m["content"] for m in trimmed_messages if m.get("index") is not None + } + + for idx, msg in enumerate(summarized): + content = trimmed_map.get(idx, msg["content"]) + # TODO: set tokenizer + # TODO: calc token count + new_entry = Entry( + id=uuid4(), + session_id=session_id, + source="summarizer", + role="system", + name="information", + content=content, + timestamp=entries[-1]["timestamp"] + 0.01, + token_count=sum([len(c) // 3.5 for c in content]), + created_at=utcnow(), + tokenizer="", + ) + + entries_summarization_query( + session_id=session_id, + new_entry=new_entry, + old_entry_ids=[ + UUID(entries[idx - 1]["entry_id"], version=4) + for idx in msg["summarizes"] + ], + ) diff --git a/agents-api/agents_api/activities/truncation.py b/agents-api/agents_api/activities/truncation.py index 4742ee6d4..745674fab 100644 --- a/agents-api/agents_api/activities/truncation.py +++ b/agents-api/agents_api/activities/truncation.py @@ -4,57 +4,57 @@ from temporalio import activity from agents_api.autogen.openapi_model import Entry - -# from agents_api.models.entry.entries_summarization import get_toplevel_entries_query - -# TODO: Reimplement truncation queries -# SCRUM-5 +from agents_api.models.entry.delete_entries import delete_entries +from agents_api.models.entry.entries_summarization import get_toplevel_entries_query def get_extra_entries(messages: list[Entry], token_count_threshold: int) -> list[UUID]: - raise NotImplementedError() + result: list[UUID] = [] if not len(messages): return messages - _token_cnt, _offset = 0, 0 - # if messages[0].role == Role.system: - # token_cnt, offset = messages[0].token_count, 1 + token_cnt, offset = 0, 0 + if messages[0].role == "system": + token_cnt, offset = messages[0].token_count, 1 - # for m in reversed(messages[offset:]): - # token_cnt += m.token_count - # if token_cnt < token_count_threshold: - # continue - # else: - # result.append(m.id) + for m in reversed(messages[offset:]): + token_cnt += m.token_count + if token_cnt >= token_count_threshold: + result.append(m.id) - # return result + return result -# TODO: Reimplement truncation activities -# SCRUM-6 @activity.defn @beartype -async def truncation(session_id: str, token_count_threshold: int) -> None: +async def truncation( + developer_id: str, session_id: str, token_count_threshold: int +) -> None: session_id = UUID(session_id) - - # delete_entries( - # get_extra_entries( - # [ - # Entry( - # entry_id=row["entry_id"], - # session_id=session_id, - # source=row["source"], - # role=Role(row["role"]), - # name=row["name"], - # content=row["content"], - # created_at=row["created_at"], - # timestamp=row["timestamp"], - # ) - # for _, row in get_toplevel_entries_query( - # session_id=session_id - # ).iterrows() - # ], - # token_count_threshold, - # ), - # ) + developer_id = UUID(developer_id) + + delete_entries( + developer_id=developer_id, + session_id=session_id, + entry_ids=get_extra_entries( + [ + Entry( + id=row["entry_id"], + session_id=session_id, + source=row["source"], + role=row["role"], + name=row["name"], + content=row["content"], + created_at=row["created_at"], + timestamp=row["timestamp"], + tokenizer=row["tokenizer"], + token_count=row["token_count"], + ) + for _, row in get_toplevel_entries_query( + session_id=session_id + ).iterrows() + ], + token_count_threshold, + ), + ) diff --git a/agents-api/agents_api/clients/temporal.py b/agents-api/agents_api/clients/temporal.py index 5f14b84f6..75d6f796b 100644 --- a/agents-api/agents_api/clients/temporal.py +++ b/agents-api/agents_api/clients/temporal.py @@ -70,3 +70,31 @@ async def get_workflow_handle( ) return handle + + +async def run_truncation_task( + token_count_threshold: int, developer_id: UUID, session_id: UUID, job_id: UUID +): + from ..workflows.truncation import TruncationWorkflow + + client = await get_client() + + await client.start_workflow( + TruncationWorkflow.run, + args=[str(developer_id), str(session_id), token_count_threshold], + task_queue="memory-task-queue", + id=str(job_id), + ) + + +async def run_summarization_task(session_id: UUID, job_id: UUID): + from ..workflows.summarization import SummarizationWorkflow + + client = await get_client() + + await client.start_workflow( + SummarizationWorkflow.run, + args=[str(session_id)], + task_queue="memory-task-queue", + id=str(job_id), + ) diff --git a/agents-api/agents_api/common/protocol/sessions.py b/agents-api/agents_api/common/protocol/sessions.py index 121afe702..49030a4f3 100644 --- a/agents-api/agents_api/common/protocol/sessions.py +++ b/agents-api/agents_api/common/protocol/sessions.py @@ -97,9 +97,14 @@ def get_active_tools(self) -> list[Tool]: return [] active_agent = self.get_active_agent() - active_toolset = next( - toolset for toolset in self.toolsets if toolset.agent_id == active_agent.id - ) + try: + active_toolset = next( + toolset + for toolset in self.toolsets + if toolset.agent_id == active_agent.id + ) + except StopIteration: + return [] return active_toolset.tools diff --git a/agents-api/agents_api/models/entry/entries_summarization.py b/agents-api/agents_api/models/entry/entries_summarization.py new file mode 100644 index 000000000..0846d4420 --- /dev/null +++ b/agents-api/agents_api/models/entry/entries_summarization.py @@ -0,0 +1,133 @@ +"""This module contains functions for querying and summarizing entry data in the 'cozodb' database.""" + +from uuid import UUID + +from beartype import beartype + +from ...autogen.openapi_model import Entry +from ...common.utils.messages import content_to_json +from ..utils import cozo_query + + +@cozo_query +@beartype +def get_toplevel_entries_query(session_id: UUID) -> tuple[str, dict]: + """ + Retrieves top-level entries from the database for a given session. + + Parameters: + - session_id (UUID): The unique identifier for the session. + + Returns: + - pd.DataFrame: A DataFrame containing the queried top-level entries. + """ + + query = """ + # Construct a datalog query to retrieve entries not summarized by any other entry. + input[session_id] <- [[to_uuid($session_id)]] + # Define an input table with the session ID to filter entries related to the specific session. + + # Query to retrieve top-level entries that are not summarized by any other entry, ensuring uniqueness. + ?[ + entry_id, + session_id, + source, + role, + name, + content, + token_count, + created_at, + timestamp, + tokenizer, + ] := + input[session_id], + *entries{ + entry_id, + session_id, + source, + role, + name, + content, + token_count, + created_at, + timestamp, + tokenizer, + }, + not *relations { + relation: "summary_of", + tail: entry_id, + } + + :sort timestamp + """ + + return (query, {"session_id": str(session_id)}) + + +@cozo_query +@beartype +def entries_summarization_query( + session_id: UUID, new_entry: Entry, old_entry_ids: list[UUID] +) -> tuple[str, dict]: + """ + Inserts a new entry and its summarization relations into the database. + + Parameters: + - session_id (UUID): The session identifier. + - new_entry (Entry): The new entry to be inserted. + - old_entry_ids (list[UUID]): List of entry IDs that the new entry summarizes. + + Returns: + - pd.DataFrame: A DataFrame containing the result of the insertion operation. + """ + + # Prepare relations data for insertion, marking the new entry as a summary of the old entries. + relations = [ + [str(new_entry.id), "summary_of", str(old_id)] for old_id in old_entry_ids + ] + # Create a list of relations indicating which entries the new entry summarizes. + + # Convert the new entry's source information into JSON format for storage. + + entries = [ + [ + str(new_entry.id), + str(session_id), + new_entry.source, + new_entry.role, + new_entry.name or "", + content_to_json(new_entry.content), + new_entry.token_count, + new_entry.tokenizer, + new_entry.timestamp, + ] + ] + + query = """ + { + ?[entry_id, session_id, source, role, name, content, token_count, tokenizer, timestamp] <- $entries + + :insert entries { + entry_id, + session_id, + source, + role, + name, => + content, + token_count, + tokenizer, + timestamp, + } + } + { + ?[head, relation, tail] <- $relations + + :insert relations { + head, + relation, + tail, + } + } + """ + + return (query, {"relations": relations, "entries": entries}) diff --git a/agents-api/agents_api/routers/sessions/chat.py b/agents-api/agents_api/routers/sessions/chat.py index b7bf96d3a..ebb7bfb8b 100644 --- a/agents-api/agents_api/routers/sessions/chat.py +++ b/agents-api/agents_api/routers/sessions/chat.py @@ -1,7 +1,9 @@ +from contextlib import suppress from typing import Annotated from uuid import UUID, uuid4 from fastapi import BackgroundTasks, Depends +from litellm.types.utils import ModelResponse from starlette.status import HTTP_201_CREATED from ...autogen.openapi_model import ( @@ -12,11 +14,13 @@ MessageChatResponse, ) from ...clients import litellm +from ...clients.temporal import run_summarization_task, run_truncation_task from ...common.protocol.developers import Developer from ...common.protocol.sessions import ChatContext from ...common.utils.datetime import utcnow from ...common.utils.template import render_template from ...dependencies.developer_id import get_developer_data +from ...exceptions import PromptTooBigError from ...models.chat.gather_messages import gather_messages from ...models.chat.prepare_chat_context import prepare_chat_context from ...models.entry.create_entries import create_entries @@ -48,6 +52,9 @@ async def chat( settings: dict = chat_context.settings.model_dump() settings["model"] = f"openai/{settings['model']}" # litellm proxy idiosyncracy + # Render the messages + new_raw_messages = [msg.model_dump() for msg in chat_input.messages] + # Get the past messages and doc references past_messages, doc_references = await gather_messages( developer=developer, @@ -80,27 +87,18 @@ async def chat( else: new_messages = new_raw_messages + for m in past_messages: + with suppress(KeyError): + del m["created_at"] + with suppress(KeyError): + del m["id"] + # Combine the past messages with the new messages messages = past_messages + new_messages # Get the tools tools = settings.get("tools") or chat_context.get_active_tools() - # FIXME: Truncate chat messages in the chat context - # SCRUM-7 - if chat_context.session.context_overflow == "truncate": - # messages = messages[-settings["max_tokens"] :] - raise NotImplementedError("Truncation is not yet implemented") - - # FIXME: Hotfix for datetime not serializable. Needs investigation - messages = [ - msg.model_dump() if hasattr(msg, "model_dump") else msg for msg in messages - ] - - messages = [ - dict(role=m["role"], content=m["content"], user=m.get("user")) for m in messages - ] - # Get the response from the model model_response = await litellm.acompletion( messages=messages, @@ -136,13 +134,28 @@ async def chat( ) # Adaptive context handling - jobs = [] - if chat_context.session.context_overflow == "adaptive": - # FIXME: Start the adaptive context workflow - # SCRUM-8 - - # jobs = [await start_adaptive_context_workflow] - raise NotImplementedError("Adaptive context is not yet implemented") + # TODO: set this value for a streaming response + total_tokens = 0 + if isinstance(model_response, ModelResponse): + total_tokens = model_response.usage.total_tokens + + job_id = uuid4() + + if ( + chat_context.session.token_budget is not None + and total_tokens >= chat_context.session.token_budget + ): + if chat_context.session.context_overflow == "adaptive": + await run_summarization_task(session_id=session_id, job_id=job_id) + elif chat_context.session.context_overflow == "truncate": + await run_truncation_task( + token_count_threshold=chat_context.session.token_budget, + developer_id=developer.id, + session_id=session_id, + job_id=job_id, + ) + else: + raise PromptTooBigError(total_tokens, chat_context.session.token_budget) # Return the response # FIXME: Implement streaming for chat @@ -152,7 +165,7 @@ async def chat( chat_response: ChatResponse = chat_response_class( id=uuid4(), created_at=utcnow(), - jobs=jobs, + jobs=[job_id], docs=doc_references, usage=model_response.usage.model_dump(), choices=[choice.model_dump() for choice in model_response.choices], diff --git a/agents-api/agents_api/routers/sessions/create_or_update_session.py b/agents-api/agents_api/routers/sessions/create_or_update_session.py index 37e928a16..9ba3c970d 100644 --- a/agents-api/agents_api/routers/sessions/create_or_update_session.py +++ b/agents-api/agents_api/routers/sessions/create_or_update_session.py @@ -22,10 +22,8 @@ async def create_or_update_session( session_id: UUID, data: CreateOrUpdateSessionRequest, ) -> ResourceUpdatedResponse: - session_updated = create_session_query( + return create_session_query( developer_id=x_developer_id, session_id=session_id, data=data, ) - - return session_updated diff --git a/agents-api/agents_api/routers/tasks/update_execution.py b/agents-api/agents_api/routers/tasks/update_execution.py index 7b60d8bdf..08d41f6f2 100644 --- a/agents-api/agents_api/routers/tasks/update_execution.py +++ b/agents-api/agents_api/routers/tasks/update_execution.py @@ -35,7 +35,6 @@ async def update_execution( await wf_handle.cancel() case ResumeExecutionRequest(): - token_data = get_paused_execution_token( developer_id=x_developer_id, execution_id=execution_id ) diff --git a/agents-api/agents_api/workflows/truncation.py b/agents-api/agents_api/workflows/truncation.py index d3646ccbe..02649764b 100644 --- a/agents-api/agents_api/workflows/truncation.py +++ b/agents-api/agents_api/workflows/truncation.py @@ -12,9 +12,11 @@ @workflow.defn class TruncationWorkflow: @workflow.run - async def run(self, session_id: str, token_count_threshold: int) -> None: + async def run( + self, developer_id: str, session_id: str, token_count_threshold: int + ) -> None: return await workflow.execute_activity( truncation, - args=[session_id, token_count_threshold], + args=[developer_id, session_id, token_count_threshold], schedule_to_close_timeout=timedelta(seconds=600), )