Skip to content

Commit

Permalink
feat(agents-api): Adaptive context (via recursive summarization) (#306)
Browse files Browse the repository at this point in the history
* research: Recursive summarization experiments

Signed-off-by: Diwank Singh Tomer <[email protected]>

* fix: Minor fix to rec-sum notebook

Signed-off-by: Diwank Singh Tomer <[email protected]>

* wip: Rec-sum notebook

Signed-off-by: Diwank Singh Tomer <[email protected]>

* entity and trim step prompt + notebook. summarise setup.

* summarization step added.

* wip

Signed-off-by: Diwank Singh Tomer <[email protected]>

* chore: Move rec_sum subpackage

* feat: Summarize messages recursively

* feat: Use custom model for summarization

* chore: Remove commented out code

* fix: Serialize pandas series objects

* fix: Choose correct way to generate based on model name

* fix: Add old entities message as a child for the new one instead of deleting it

* chore: Add tenacity dependency

* fix: Strip closing ct:trimmed tag

* fix: Strip closing ct:summarized-messages tag

* fix: Add a list of entries instead of only one

* fix: Convert strings to UUID

* fix: Stringify message content

* fix: Convert new entry content to a list of JSON

* fix: Strip ct:entities tag

* fix: Do not add new entry explicitly

* fix: Update assertion

* feat: Truncate context window based on session settings (#381)

* feat: Calculate tokens for image content

* feat: Update SDKs to support adaptive context options

* fix: Truncate entries

* fix: Make truncation a background task

* fix: Add truncation workflow to registry

* fix: Fix deleting query

* fix: Remove truncated entries

* fix: Convert role to string only if needed

* fix: Replace delete by rm

* fix: Fix entries deleting logic

* fix: Set name to None if absent

* fix: Make deleting query accept UUID as a string

* fix: Convert UUIDs to strings

* fix: Fix query parameter name

* fix: Convert input to array of arrays

* fix: Make entries indices zro-based

* fix: Customize summarization model via environment variable

* chore: Re-arrange operations on dialog entries

* deps: poetry lock on agents-api and python sdk

Signed-off-by: Diwank Tomer <[email protected]>

---------

Signed-off-by: Diwank Singh Tomer <[email protected]>
Signed-off-by: Diwank Tomer <[email protected]>
Co-authored-by: Siddharth Balyan <[email protected]>
Co-authored-by: Dmitry Paramonov <[email protected]>
Co-authored-by: Diwank Tomer <[email protected]>
  • Loading branch information
4 people authored Jun 11, 2024
1 parent c6e6079 commit 5b06179
Show file tree
Hide file tree
Showing 57 changed files with 6,706 additions and 228 deletions.
8 changes: 8 additions & 0 deletions agents-api/agents_api/activities/logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import logging


logger = logging.getLogger(__name__)
h = logging.StreamHandler()
fmt = logging.Formatter("[%(asctime)s/%(levelname)s] - %(message)s")
h.setFormatter(fmt)
logger.addHandler(h)
60 changes: 46 additions & 14 deletions agents-api/agents_api/activities/summarization.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#!/usr/bin/env python3

import asyncio
from uuid import UUID
from typing import Callable
from textwrap import dedent
Expand All @@ -11,7 +12,10 @@
)
from agents_api.common.protocol.entries import Entry
from ..model_registry import JULEP_MODELS
from ..env import summarization_model_name, model_inference_url, model_api_key
from ..env import model_inference_url, model_api_key, summarization_model_name
from agents_api.rec_sum.entities import get_entities
from agents_api.rec_sum.summarize import summarize_messages
from agents_api.rec_sum.trim import trim_messages


example_previous_memory = """
Expand Down Expand Up @@ -160,28 +164,56 @@ async def run_prompt(
@activity.defn
async def summarization(session_id: str) -> None:
session_id = UUID(session_id)
entries = [
Entry(**row)
for _, row in get_toplevel_entries_query(session_id=session_id).iterrows()
]
entries = []
entities_entry_ids = []
for _, row in get_toplevel_entries_query(session_id=session_id).iterrows():
if row["role"] == "system" and row.get("name") == "entities":
entities_entry_ids.append(UUID(row["entry_id"], version=4))
else:
entries.append(row)

assert len(entries) > 0, "no need to summarize on empty entries list"

response = await run_prompt(
dialog=entries, previous_memories=[], model=f"openai/{summarization_model_name}"
summarized, entities = await asyncio.gather(
summarize_messages(entries, model=summarization_model_name),
get_entities(entries, model=summarization_model_name),
)

new_entry = Entry(
trimmed_messages = await trim_messages(summarized, model=summarization_model_name)
ts_delta = (entries[1]["timestamp"] - entries[0]["timestamp"]) / 2
new_entities_entry = Entry(
session_id=session_id,
source="summarizer",
role="system",
name="information",
content=response,
timestamp=entries[-1].timestamp + 0.01,
name="entities",
content=entities["content"],
timestamp=entries[0]["timestamp"] + ts_delta,
)

entries_summarization_query(
session_id=session_id,
new_entry=new_entry,
old_entry_ids=[e.id for e in entries],
new_entry=new_entities_entry,
old_entry_ids=entities_entry_ids,
)

trimmed_map = {
m["index"]: m["content"] for m in trimmed_messages if m.get("index") is not None
}

for idx, msg in enumerate(summarized):
new_entry = Entry(
session_id=session_id,
source="summarizer",
role="system",
name="information",
content=trimmed_map.get(idx, msg["content"]),
timestamp=entries[-1]["timestamp"] + 0.01,
)

entries_summarization_query(
session_id=session_id,
new_entry=new_entry,
old_entry_ids=[
UUID(entries[idx - 1]["entry_id"], version=4)
for idx in msg["summarizes"]
],
)
51 changes: 51 additions & 0 deletions agents-api/agents_api/activities/truncation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from uuid import UUID
from temporalio import activity
from agents_api.models.entry.entries_summarization import get_toplevel_entries_query
from agents_api.models.entry.delete_entries import delete_entries
from agents_api.autogen.openapi_model import Role
from agents_api.common.protocol.entries import Entry


def get_extra_entries(messages: list[Entry], token_count_threshold: int) -> list[UUID]:
if not len(messages):
return messages

result: list[UUID] = []
token_cnt, offset = 0, 0
if messages[0].role == Role.system:
token_cnt, offset = messages[0].token_count, 1

for m in reversed(messages[offset:]):
token_cnt += m.token_count
if token_cnt < token_count_threshold:
continue
else:
result.append(m.id)

return result


@activity.defn
async def truncation(session_id: str, token_count_threshold: int) -> None:
session_id = UUID(session_id)

delete_entries(
get_extra_entries(
[
Entry(
entry_id=row["entry_id"],
session_id=session_id,
source=row["source"],
role=Role(row["role"]),
name=row["name"],
content=row["content"],
created_at=row["created_at"],
timestamp=row["timestamp"],
)
for _, row in get_toplevel_entries_query(
session_id=session_id
).iterrows()
],
token_count_threshold,
),
)
32 changes: 32 additions & 0 deletions agents-api/agents_api/autogen/openapi_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,14 @@ class Session(BaseModel):
"""
Render system and assistant message content as jinja templates
"""
token_budget: int | None = None
"""
Threshold value for the adaptive context functionality
"""
context_overflow: str | None = None
"""
Action to start on context window overflow
"""


class CreateSessionRequest(BaseModel):
Expand Down Expand Up @@ -151,6 +159,14 @@ class CreateSessionRequest(BaseModel):
"""
Render system and assistant message content as jinja templates
"""
token_budget: int | None = None
"""
Threshold value for the adaptive context functionality
"""
context_overflow: str | None = None
"""
Action to start on context window overflow
"""


class UpdateSessionRequest(BaseModel):
Expand All @@ -166,6 +182,14 @@ class UpdateSessionRequest(BaseModel):
"""
Optional metadata
"""
token_budget: int | None = None
"""
Threshold value for the adaptive context functionality
"""
context_overflow: str | None = None
"""
Action to start on context window overflow
"""


class UpdateUserRequest(BaseModel):
Expand Down Expand Up @@ -753,6 +777,14 @@ class PatchSessionRequest(BaseModel):
"""
Optional metadata
"""
token_budget: int | None = None
"""
Threshold value for the adaptive context functionality
"""
context_overflow: str | None = None
"""
Action to start on context window overflow
"""


class PartialFunctionDef(BaseModel):
Expand Down
13 changes: 13 additions & 0 deletions agents-api/agents_api/clients/temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,16 @@ async def run_embed_docs_task(
task_queue="memory-task-queue",
id=str(job_id),
)


async def run_truncation_task(
token_count_threshold: int, session_id: UUID, job_id: UUID
):
client = await get_client()

await client.execute_workflow(
"TruncationWorkflow",
args=[str(session_id), token_count_threshold],
task_queue="memory-task-queue",
id=str(job_id),
)
16 changes: 12 additions & 4 deletions agents-api/agents_api/common/protocol/entries.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,18 @@
Role,
ChatMLImageContentPart,
ChatMLTextContentPart,
Detail,
)
from agents_api.common.utils.datetime import utcnow

EntrySource = Literal["api_request", "api_response", "internal", "summarizer"]
Tokenizer = Literal["character_count"]


LOW_IMAGE_TOKEN_COUNT = 85
HIGH_IMAGE_TOKEN_COUNT = 85 + 4 * 170


class Entry(BaseModel):
"""Represents an entry in the system, encapsulating all necessary details such as ID, session ID, source, role, and content among others."""

Expand Down Expand Up @@ -44,12 +49,15 @@ def token_count(self) -> int:
elif isinstance(self.content, dict):
content_length = len(json.dumps(self.content))
elif isinstance(self.content, list):
text = ""
for part in self.content:
# TODO: how to calc token count for images?
if isinstance(part, ChatMLTextContentPart):
text += part.text
content_length = len(text)
content_length += len(part.text)
elif isinstance(part, ChatMLImageContentPart):
content_length += (
LOW_IMAGE_TOKEN_COUNT
if part.image_url.detail == Detail.low
else HIGH_IMAGE_TOKEN_COUNT
)

# Divide the content length by 3.5 to estimate token count based on character count.
return int(content_length // 3.5)
Expand Down
2 changes: 2 additions & 0 deletions agents-api/agents_api/common/protocol/sessions.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,5 @@ class SessionData(BaseModel):
metadata: Dict = {}
user_metadata: Optional[Dict] = None
agent_metadata: Dict = {}
token_budget: int | None = None
context_overflow: str | None = None
15 changes: 15 additions & 0 deletions agents-api/agents_api/common/utils/messages.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
from typing import cast
from agents_api.autogen.openapi_model import (
ChatMLTextContentPart,
ChatMLImageContentPart,
Expand All @@ -23,3 +24,17 @@ def content_to_json(
result = [{"type": "text", "text": json.dumps(content, indent=4)}]

return result


def stringify_content(
msg: str | list[ChatMLTextContentPart] | list[ChatMLImageContentPart] | dict,
) -> str:
content = ""
if isinstance(msg, list):
content = " ".join([part.text for part in msg if part.type == "text"])
elif isinstance(msg, str):
content = msg
elif isinstance(msg, dict) and msg["type"] == "text":
content = cast(str, msg["text"])

return content
8 changes: 0 additions & 8 deletions agents-api/agents_api/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,6 @@
model_api_key: str = env.str("MODEL_API_KEY", default=None)
model_inference_url: str = env.str("MODEL_INFERENCE_URL", default=None)
openai_api_key: str = env.str("OPENAI_API_KEY", default="")
summarization_ratio_threshold: float = env.float(
"MAX_TOKENS_RATIO_TO_SUMMARIZE", default=0.5
)
summarization_tokens_threshold: int = env.int(
"SUMMARIZATION_TOKENS_THRESHOLD", default=2048
)
summarization_model_name: str = env.str(
"SUMMARIZATION_MODEL_NAME", default="gpt-4-turbo"
)
Expand Down Expand Up @@ -78,8 +72,6 @@
debug=debug,
cozo_host=cozo_host,
cozo_auth=cozo_auth,
summarization_ratio_threshold=summarization_ratio_threshold,
summarization_tokens_threshold=summarization_tokens_threshold,
worker_url=worker_url,
sentry_dsn=sentry_dsn,
temporal_endpoint=temporal_endpoint,
Expand Down
52 changes: 52 additions & 0 deletions agents-api/agents_api/models/entry/delete_entries.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,55 @@ def delete_entries_query(session_id: UUID) -> tuple[str, dict]:
}"""

return (query, {"session_id": str(session_id)})


@cozo_query
def delete_entries(entry_ids: list[UUID]) -> tuple[str, dict]:
query = """
{
input[entry_id_str] <- $entry_ids
?[
entry_id,
session_id,
source,
role,
name,
content,
token_count,
tokenizer,
created_at,
timestamp,
] :=
input[entry_id_str],
entry_id = to_uuid(entry_id_str),
*entries {
entry_id,
session_id,
source,
role,
name,
content,
token_count,
tokenizer,
created_at,
timestamp,
}
:delete entries {
entry_id,
session_id,
source,
role,
name,
content,
token_count,
tokenizer,
created_at,
timestamp,
}
:returning
}"""

return (query, {"entry_ids": [[str(id)] for id in entry_ids]})
Loading

0 comments on commit 5b06179

Please sign in to comment.