Skip to content

Commit

Permalink
feat: Truncate entries
Browse files Browse the repository at this point in the history
  • Loading branch information
whiterabbit1983 committed Sep 12, 2024
1 parent c07987a commit df7d6bc
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 42 deletions.
79 changes: 39 additions & 40 deletions agents-api/agents_api/activities/truncation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,57 +4,56 @@
from temporalio import activity

from agents_api.autogen.openapi_model import Entry

# from agents_api.models.entry.entries_summarization import get_toplevel_entries_query

# TODO: Reimplement truncation queries
# SCRUM-5
from agents_api.models.entry.delete_entries import delete_entries
from agents_api.models.entry.entries_summarization import get_toplevel_entries_query


def get_extra_entries(messages: list[Entry], token_count_threshold: int) -> list[UUID]:
raise NotImplementedError()
result: list[UUID] = []

if not len(messages):
return messages

_token_cnt, _offset = 0, 0
# if messages[0].role == Role.system:
# token_cnt, offset = messages[0].token_count, 1
token_cnt, offset = 0, 0
if messages[0].role == "system":
token_cnt, offset = messages[0].token_count, 1

# for m in reversed(messages[offset:]):
# token_cnt += m.token_count
# if token_cnt < token_count_threshold:
# continue
# else:
# result.append(m.id)
for m in reversed(messages[offset:]):
token_cnt += m.token_count
if token_cnt >= token_count_threshold:
result.append(m.id)

# return result
return result


# TODO: Reimplement truncation activities
# SCRUM-6
@activity.defn
@beartype
async def truncation(session_id: str, token_count_threshold: int) -> None:
async def truncation(
developer_id: str, session_id: str, token_count_threshold: int
) -> None:
session_id = UUID(session_id)

# delete_entries(
# get_extra_entries(
# [
# Entry(
# entry_id=row["entry_id"],
# session_id=session_id,
# source=row["source"],
# role=Role(row["role"]),
# name=row["name"],
# content=row["content"],
# created_at=row["created_at"],
# timestamp=row["timestamp"],
# )
# for _, row in get_toplevel_entries_query(
# session_id=session_id
# ).iterrows()
# ],
# token_count_threshold,
# ),
# )
developer_id = UUID(developer_id)

delete_entries(
developer_id=developer_id,
session_id=session_id,
entry_ids=get_extra_entries(
[
Entry(
entry_id=row["entry_id"],
session_id=session_id,
source=row["source"],
role=row["role"],
name=row["name"],
content=row["content"],
created_at=row["created_at"],
timestamp=row["timestamp"],
tokenizer=row["tokenizer"],
)
for _, row in get_toplevel_entries_query(
session_id=session_id
).iterrows()
],
token_count_threshold,
),
)
2 changes: 2 additions & 0 deletions agents-api/agents_api/models/entry/entries_summarization.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def get_toplevel_entries_query(session_id: UUID) -> tuple[str, dict]:
token_count,
created_at,
timestamp,
tokenizer,
] :=
input[session_id],
*entries{
Expand All @@ -50,6 +51,7 @@ def get_toplevel_entries_query(session_id: UUID) -> tuple[str, dict]:
token_count,
created_at,
timestamp,
tokenizer,
},
not *relations {
relation: "summary_of",
Expand Down
6 changes: 4 additions & 2 deletions agents-api/agents_api/workflows/truncation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@
@workflow.defn
class TruncationWorkflow:
@workflow.run
async def run(self, session_id: str, token_count_threshold: int) -> None:
async def run(
self, developer_id: str, session_id: str, token_count_threshold: int
) -> None:
return await workflow.execute_activity(
truncation,
args=[session_id, token_count_threshold],
args=[developer_id, session_id, token_count_threshold],
schedule_to_close_timeout=timedelta(seconds=600),
)

0 comments on commit df7d6bc

Please sign in to comment.