diff --git a/agents-api/agents_api/activities/truncation.py b/agents-api/agents_api/activities/truncation.py index 4742ee6d4..f5a90ee2b 100644 --- a/agents-api/agents_api/activities/truncation.py +++ b/agents-api/agents_api/activities/truncation.py @@ -4,57 +4,56 @@ from temporalio import activity from agents_api.autogen.openapi_model import Entry - -# from agents_api.models.entry.entries_summarization import get_toplevel_entries_query - -# TODO: Reimplement truncation queries -# SCRUM-5 +from agents_api.models.entry.delete_entries import delete_entries +from agents_api.models.entry.entries_summarization import get_toplevel_entries_query def get_extra_entries(messages: list[Entry], token_count_threshold: int) -> list[UUID]: - raise NotImplementedError() + result: list[UUID] = [] if not len(messages): return messages - _token_cnt, _offset = 0, 0 - # if messages[0].role == Role.system: - # token_cnt, offset = messages[0].token_count, 1 + token_cnt, offset = 0, 0 + if messages[0].role == "system": + token_cnt, offset = messages[0].token_count, 1 - # for m in reversed(messages[offset:]): - # token_cnt += m.token_count - # if token_cnt < token_count_threshold: - # continue - # else: - # result.append(m.id) + for m in reversed(messages[offset:]): + token_cnt += m.token_count + if token_cnt >= token_count_threshold: + result.append(m.id) - # return result + return result -# TODO: Reimplement truncation activities -# SCRUM-6 @activity.defn @beartype -async def truncation(session_id: str, token_count_threshold: int) -> None: +async def truncation( + developer_id: str, session_id: str, token_count_threshold: int +) -> None: session_id = UUID(session_id) - - # delete_entries( - # get_extra_entries( - # [ - # Entry( - # entry_id=row["entry_id"], - # session_id=session_id, - # source=row["source"], - # role=Role(row["role"]), - # name=row["name"], - # content=row["content"], - # created_at=row["created_at"], - # timestamp=row["timestamp"], - # ) - # for _, row in get_toplevel_entries_query( - # session_id=session_id - # ).iterrows() - # ], - # token_count_threshold, - # ), - # ) + developer_id = UUID(developer_id) + + delete_entries( + developer_id=developer_id, + session_id=session_id, + entry_ids=get_extra_entries( + [ + Entry( + entry_id=row["entry_id"], + session_id=session_id, + source=row["source"], + role=row["role"], + name=row["name"], + content=row["content"], + created_at=row["created_at"], + timestamp=row["timestamp"], + tokenizer=row["tokenizer"], + ) + for _, row in get_toplevel_entries_query( + session_id=session_id + ).iterrows() + ], + token_count_threshold, + ), + ) diff --git a/agents-api/agents_api/models/entry/entries_summarization.py b/agents-api/agents_api/models/entry/entries_summarization.py index 899b07730..8fa70fde2 100644 --- a/agents-api/agents_api/models/entry/entries_summarization.py +++ b/agents-api/agents_api/models/entry/entries_summarization.py @@ -38,6 +38,7 @@ def get_toplevel_entries_query(session_id: UUID) -> tuple[str, dict]: token_count, created_at, timestamp, + tokenizer, ] := input[session_id], *entries{ @@ -50,6 +51,7 @@ def get_toplevel_entries_query(session_id: UUID) -> tuple[str, dict]: token_count, created_at, timestamp, + tokenizer, }, not *relations { relation: "summary_of", diff --git a/agents-api/agents_api/workflows/truncation.py b/agents-api/agents_api/workflows/truncation.py index d3646ccbe..02649764b 100644 --- a/agents-api/agents_api/workflows/truncation.py +++ b/agents-api/agents_api/workflows/truncation.py @@ -12,9 +12,11 @@ @workflow.defn class TruncationWorkflow: @workflow.run - async def run(self, session_id: str, token_count_threshold: int) -> None: + async def run( + self, developer_id: str, session_id: str, token_count_threshold: int + ) -> None: return await workflow.execute_activity( truncation, - args=[session_id, token_count_threshold], + args=[developer_id, session_id, token_count_threshold], schedule_to_close_timeout=timedelta(seconds=600), )