Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add dialog entries summarization #496

Draft
wants to merge 15 commits into
base: dev
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 79 additions & 75 deletions agents-api/agents_api/activities/summarization.py
Original file line number Diff line number Diff line change
@@ -1,86 +1,90 @@
#!/usr/bin/env python3


import pandas as pd
import asyncio
from uuid import UUID, uuid4

from beartype import beartype
from temporalio import activity

# from agents_api.models.entry.entries_summarization import (
# entries_summarization_query,
# get_toplevel_entries_query,
# )


# TODO: Implement entry summarization queries
# SCRUM-3
def entries_summarization_query(*args, **kwargs) -> pd.DataFrame:
return pd.DataFrame()


def get_toplevel_entries_query(*args, **kwargs) -> pd.DataFrame:
return pd.DataFrame()


# TODO: Implement entry summarization activities
# SCRUM-4
from agents_api.autogen.openapi_model import Entry
from agents_api.common.utils.datetime import utcnow
from agents_api.env import summarization_model_name
from agents_api.models.entry.entries_summarization import (
entries_summarization_query,
get_toplevel_entries_query,
)
from agents_api.rec_sum.entities import get_entities
from agents_api.rec_sum.summarize import summarize_messages
from agents_api.rec_sum.trim import trim_messages


@activity.defn
@beartype
async def summarization(session_id: str) -> None:
raise NotImplementedError()

# session_id = UUID(session_id)
# entries = []
# entities_entry_ids = []
# for _, row in get_toplevel_entries_query(session_id=session_id).iterrows():
# if row["role"] == "system" and row.get("name") == "entities":
# entities_entry_ids.append(UUID(row["entry_id"], version=4))
# else:
# entries.append(row)

# assert len(entries) > 0, "no need to summarize on empty entries list"

# summarized, entities = await asyncio.gather(
# summarize_messages(entries, model=summarization_model_name),
# get_entities(entries, model=summarization_model_name),
# )
# trimmed_messages = await trim_messages(summarized, model=summarization_model_name)
# ts_delta = (entries[1]["timestamp"] - entries[0]["timestamp"]) / 2
# new_entities_entry = Entry(
# session_id=session_id,
# source="summarizer",
# role="system",
# name="entities",
# content=entities["content"],
# timestamp=entries[0]["timestamp"] + ts_delta,
# )

# entries_summarization_query(
# session_id=session_id,
# new_entry=new_entities_entry,
# old_entry_ids=entities_entry_ids,
# )

# trimmed_map = {
# m["index"]: m["content"] for m in trimmed_messages if m.get("index") is not None
# }

# for idx, msg in enumerate(summarized):
# new_entry = Entry(
# session_id=session_id,
# source="summarizer",
# role="system",
# name="information",
# content=trimmed_map.get(idx, msg["content"]),
# timestamp=entries[-1]["timestamp"] + 0.01,
# )

# entries_summarization_query(
# session_id=session_id,
# new_entry=new_entry,
# old_entry_ids=[
# UUID(entries[idx - 1]["entry_id"], version=4)
# for idx in msg["summarizes"]
# ],
# )
session_id = UUID(session_id)
entries = []
entities_entry_ids = []
for _, row in get_toplevel_entries_query(session_id=session_id).iterrows():
if row["role"] == "system" and row.get("name") == "entities":
entities_entry_ids.append(UUID(row["entry_id"], version=4))
else:
entries.append(row)

assert len(entries) > 0, "no need to summarize on empty entries list"

summarized, entities = await asyncio.gather(
summarize_messages(entries, model=summarization_model_name),
get_entities(entries, model=summarization_model_name),
)
trimmed_messages = await trim_messages(summarized, model=summarization_model_name)
ts_delta = (entries[1]["timestamp"] - entries[0]["timestamp"]) / 2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ensure entries has at least two elements before calculating ts_delta to avoid potential IndexError.

# TODO: set tokenizer, double check token_count calculation
new_entities_entry = Entry(
id=uuid4(),
session_id=session_id,
source="summarizer",
role="system",
name="entities",
content=entities["content"],
timestamp=entries[0]["timestamp"] + ts_delta,
token_count=sum([len(c) // 3.5 for c in entities["content"]]),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The token_count calculation using len(c) // 3.5 is an approximation and may not be accurate. Consider using a tokenizer to get the exact token count.

created_at=utcnow(),
tokenizer="",
)

entries_summarization_query(
session_id=session_id,
new_entry=new_entities_entry,
old_entry_ids=entities_entry_ids,
)

trimmed_map = {
m["index"]: m["content"] for m in trimmed_messages if m.get("index") is not None
}

for idx, msg in enumerate(summarized):
content = trimmed_map.get(idx, msg["content"])
# TODO: set tokenizer
# TODO: calc token count
new_entry = Entry(
id=uuid4(),
session_id=session_id,
source="summarizer",
role="system",
name="information",
content=content,
timestamp=entries[-1]["timestamp"] + 0.01,
token_count=sum([len(c) // 3.5 for c in content]),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The token count calculation using len(c) // 3.5 is a rough estimate and may not be accurate. Consider using a tokenizer to calculate the token count more precisely. Also, ensure the tokenizer field is set appropriately.

created_at=utcnow(),
tokenizer="",
)

entries_summarization_query(
session_id=session_id,
new_entry=new_entry,
old_entry_ids=[
UUID(entries[idx - 1]["entry_id"], version=4)
for idx in msg["summarizes"]
],
)
80 changes: 40 additions & 40 deletions agents-api/agents_api/activities/truncation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,57 +4,57 @@
from temporalio import activity

from agents_api.autogen.openapi_model import Entry

# from agents_api.models.entry.entries_summarization import get_toplevel_entries_query

# TODO: Reimplement truncation queries
# SCRUM-5
from agents_api.models.entry.delete_entries import delete_entries
from agents_api.models.entry.entries_summarization import get_toplevel_entries_query


def get_extra_entries(messages: list[Entry], token_count_threshold: int) -> list[UUID]:
raise NotImplementedError()
result: list[UUID] = []

if not len(messages):
return messages
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The function get_extra_entries should return an empty list of UUIDs instead of messages when messages is empty.


_token_cnt, _offset = 0, 0
# if messages[0].role == Role.system:
# token_cnt, offset = messages[0].token_count, 1
token_cnt, offset = 0, 0
if messages[0].role == "system":
token_cnt, offset = messages[0].token_count, 1

# for m in reversed(messages[offset:]):
# token_cnt += m.token_count
# if token_cnt < token_count_threshold:
# continue
# else:
# result.append(m.id)
for m in reversed(messages[offset:]):
token_cnt += m.token_count
if token_cnt >= token_count_threshold:
result.append(m.id)

# return result
return result


# TODO: Reimplement truncation activities
# SCRUM-6
@activity.defn
@beartype
async def truncation(session_id: str, token_count_threshold: int) -> None:
async def truncation(
developer_id: str, session_id: str, token_count_threshold: int
) -> None:
session_id = UUID(session_id)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider adding error handling for UUID conversion to handle invalid UUID strings gracefully.


# delete_entries(
# get_extra_entries(
# [
# Entry(
# entry_id=row["entry_id"],
# session_id=session_id,
# source=row["source"],
# role=Role(row["role"]),
# name=row["name"],
# content=row["content"],
# created_at=row["created_at"],
# timestamp=row["timestamp"],
# )
# for _, row in get_toplevel_entries_query(
# session_id=session_id
# ).iterrows()
# ],
# token_count_threshold,
# ),
# )
developer_id = UUID(developer_id)

delete_entries(
developer_id=developer_id,
session_id=session_id,
entry_ids=get_extra_entries(
[
Entry(
id=row["entry_id"],
session_id=session_id,
source=row["source"],
role=row["role"],
name=row["name"],
content=row["content"],
created_at=row["created_at"],
timestamp=row["timestamp"],
tokenizer=row["tokenizer"],
token_count=row["token_count"],
)
for _, row in get_toplevel_entries_query(
session_id=session_id
).iterrows()
],
token_count_threshold,
),
)
28 changes: 28 additions & 0 deletions agents-api/agents_api/clients/temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,31 @@ async def get_workflow_handle(
)

return handle


async def run_truncation_task(
token_count_threshold: int, developer_id: UUID, session_id: UUID, job_id: UUID
):
from ..workflows.truncation import TruncationWorkflow
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider moving the import statement for TruncationWorkflow to the top of the file to avoid repeated imports every time the function is called, which can affect performance. This applies to the import in run_summarization_task as well.


client = await get_client()

await client.start_workflow(
TruncationWorkflow.run,
args=[str(developer_id), str(session_id), token_count_threshold],
task_queue="memory-task-queue",
id=str(job_id),
)


async def run_summarization_task(session_id: UUID, job_id: UUID):
from ..workflows.summarization import SummarizationWorkflow

client = await get_client()

await client.start_workflow(
SummarizationWorkflow.run,
args=[str(session_id)],
task_queue="memory-task-queue",
id=str(job_id),
)
11 changes: 8 additions & 3 deletions agents-api/agents_api/common/protocol/sessions.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,14 @@ def get_active_tools(self) -> list[Tool]:
return []

active_agent = self.get_active_agent()
active_toolset = next(
toolset for toolset in self.toolsets if toolset.agent_id == active_agent.id
)
try:
active_toolset = next(
toolset
for toolset in self.toolsets
if toolset.agent_id == active_agent.id
)
except StopIteration:
return []

return active_toolset.tools

Expand Down
Loading