Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(agents-api): Adaptive context (via recursive summarization) #306

Merged
merged 45 commits into from
Jun 11, 2024
Merged
Show file tree
Hide file tree
Changes from 34 commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
997b696
research: Recursive summarization experiments
creatorrr May 1, 2024
ccf7bae
fix: Minor fix to rec-sum notebook
creatorrr May 1, 2024
cd89123
wip: Rec-sum notebook
creatorrr May 2, 2024
60eda68
entity and trim step prompt + notebook. summarise setup.
alt-glitch May 3, 2024
1062ed4
summarization step added.
alt-glitch May 3, 2024
a6ce7b9
wip
creatorrr May 7, 2024
086366d
chore: Move rec_sum subpackage
whiterabbit1983 May 21, 2024
39985ea
feat: Summarize messages recursively
whiterabbit1983 May 22, 2024
94199d3
feat: Use custom model for summarization
whiterabbit1983 May 24, 2024
dd301d9
chore: Remove commented out code
whiterabbit1983 May 24, 2024
97fbd70
fix: Serialize pandas series objects
whiterabbit1983 May 24, 2024
373f003
fix: Choose correct way to generate based on model name
whiterabbit1983 May 24, 2024
395646c
fix: Add old entities message as a child for the new one instead of d…
whiterabbit1983 May 27, 2024
dcbeafa
chore: Add tenacity dependency
whiterabbit1983 May 28, 2024
3011c5d
fix: Strip closing ct:trimmed tag
whiterabbit1983 May 28, 2024
83d63f6
fix: Strip closing ct:summarized-messages tag
whiterabbit1983 May 28, 2024
c1e4404
fix: Add a list of entries instead of only one
whiterabbit1983 May 28, 2024
f777c4e
fix: Convert strings to UUID
whiterabbit1983 May 29, 2024
1346bb6
fix: Stringify message content
whiterabbit1983 May 29, 2024
0e207bf
fix: Convert new entry content to a list of JSON
whiterabbit1983 May 29, 2024
55bc9d6
fix: Strip ct:entities tag
whiterabbit1983 May 29, 2024
d7cf690
fix: Do not add new entry explicitly
whiterabbit1983 May 29, 2024
e9de8db
fix: Update assertion
whiterabbit1983 May 30, 2024
9998986
feat: Truncate context window based on session settings (#381)
whiterabbit1983 Jun 1, 2024
e4da1d3
feat: Calculate tokens for image content
whiterabbit1983 Jun 3, 2024
7a6be6f
feat: Update SDKs to support adaptive context options
whiterabbit1983 Jun 3, 2024
e8be05b
fix: Truncate entries
whiterabbit1983 Jun 5, 2024
3181fca
fix: Make truncation a background task
whiterabbit1983 Jun 5, 2024
2e20a4c
fix: Add truncation workflow to registry
whiterabbit1983 Jun 5, 2024
c824516
fix: Fix deleting query
whiterabbit1983 Jun 5, 2024
4e7798a
fix: Remove truncated entries
whiterabbit1983 Jun 5, 2024
e696951
fix: Convert role to string only if needed
whiterabbit1983 Jun 5, 2024
f38bf3e
fix: Replace delete by rm
whiterabbit1983 Jun 5, 2024
d313315
fix: Fix entries deleting logic
whiterabbit1983 Jun 5, 2024
9b58614
fix: Set name to None if absent
whiterabbit1983 Jun 5, 2024
08f6295
fix: Make deleting query accept UUID as a string
whiterabbit1983 Jun 6, 2024
8f8b604
fix: Convert UUIDs to strings
whiterabbit1983 Jun 6, 2024
9f74d34
fix: Fix query parameter name
whiterabbit1983 Jun 6, 2024
97d265d
fix: Convert input to array of arrays
whiterabbit1983 Jun 6, 2024
2c4ceb6
fix: Make entries indices zro-based
whiterabbit1983 Jun 6, 2024
ddbf7d0
fix: Customize summarization model via environment variable
whiterabbit1983 Jun 6, 2024
9a4935a
chore: Re-arrange operations on dialog entries
whiterabbit1983 Jun 8, 2024
7460d54
Merge branch 'dev' into f/rec-sum-experiments
Jun 11, 2024
43990ab
deps: poetry lock on agents-api and python sdk
Jun 11, 2024
c59587f
Merge branch 'dev' into f/rec-sum-experiments
creatorrr Jun 11, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions agents-api/agents_api/activities/logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import logging


logger = logging.getLogger(__name__)
creatorrr marked this conversation as resolved.
Show resolved Hide resolved
h = logging.StreamHandler()
fmt = logging.Formatter("[%(asctime)s/%(levelname)s] - %(message)s")
h.setFormatter(fmt)
logger.addHandler(h)
55 changes: 41 additions & 14 deletions agents-api/agents_api/activities/summarization.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#!/usr/bin/env python3

import asyncio
from uuid import UUID
from typing import Callable
from textwrap import dedent
Expand All @@ -11,7 +12,10 @@
)
from agents_api.common.protocol.entries import Entry
from ..model_registry import JULEP_MODELS
from ..env import summarization_model_name, model_inference_url, model_api_key
from ..env import model_inference_url, model_api_key, summarization_model_name
from agents_api.rec_sum.entities import get_entities
from agents_api.rec_sum.summarize import summarize_messages
from agents_api.rec_sum.trim import trim_messages


example_previous_memory = """
Expand Down Expand Up @@ -160,28 +164,51 @@ async def run_prompt(
@activity.defn
async def summarization(session_id: str) -> None:
session_id = UUID(session_id)
creatorrr marked this conversation as resolved.
Show resolved Hide resolved
entries = [
Entry(**row)
for _, row in get_toplevel_entries_query(session_id=session_id).iterrows()
]
entries = []
entities_entry_ids = []
for _, row in get_toplevel_entries_query(session_id=session_id).iterrows():
if row["role"] == "system" and row.get("name") == "entities":
entities_entry_ids.append(UUID(row["entry_id"], version=4))
else:
entries.append(row)

assert len(entries) > 0, "no need to summarize on empty entries list"

response = await run_prompt(
dialog=entries, previous_memories=[], model=f"openai/{summarization_model_name}"
trimmed_messages, entities = await asyncio.gather(
trim_messages(entries, model=summarization_model_name),
get_entities(entries, model=summarization_model_name),
)

new_entry = Entry(
summarized = await summarize_messages(trimmed_messages)
ts_delta = (entries[1]["timestamp"] - entries[0]["timestamp"]) / 2
creatorrr marked this conversation as resolved.
Show resolved Hide resolved
new_entities_entry = Entry(
session_id=session_id,
source="summarizer",
role="system",
name="information",
content=response,
timestamp=entries[-1].timestamp + 0.01,
name="entities",
content=entities["content"],
timestamp=entries[0]["timestamp"] + ts_delta,
)
whiterabbit1983 marked this conversation as resolved.
Show resolved Hide resolved

creatorrr marked this conversation as resolved.
Show resolved Hide resolved
entries_summarization_query(
session_id=session_id,
new_entry=new_entry,
old_entry_ids=[e.id for e in entries],
new_entry=new_entities_entry,
old_entry_ids=entities_entry_ids,
)

for msg in summarized:
new_entry = Entry(
session_id=session_id,
source="summarizer",
role="system",
name="information",
content=msg["content"],
timestamp=entries[-1]["timestamp"] + 0.01,
)

entries_summarization_query(
session_id=session_id,
new_entry=new_entry,
old_entry_ids=[
UUID(entries[idx]["entry_id"], version=4) for idx in msg["summarizes"]
creatorrr marked this conversation as resolved.
Show resolved Hide resolved
],
)
51 changes: 51 additions & 0 deletions agents-api/agents_api/activities/truncation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from uuid import UUID
from temporalio import activity
from agents_api.models.entry.entries_summarization import get_toplevel_entries_query
from agents_api.models.entry.delete_entries import delete_entries
from agents_api.autogen.openapi_model import Role
from agents_api.common.protocol.entries import Entry


def get_extra_entries(messages: list[Entry], token_count_threshold: int) -> list[Entry]:
if not len(messages):
return messages
creatorrr marked this conversation as resolved.
Show resolved Hide resolved

result: list[Entry] = []
token_cnt, offset = 0, 0
if messages[0].role == Role.system:
token_cnt, offset = messages[0].token_count, 1

for m in reversed(messages[offset:]):
token_cnt += m.token_count
if token_cnt < token_count_threshold:
creatorrr marked this conversation as resolved.
Show resolved Hide resolved
continue
else:
result.append(m)

return result


@activity.defn
async def truncation(session_id: str, token_count_threshold: int) -> None:
session_id = UUID(session_id)

delete_entries(
get_extra_entries(
[
Entry(
entry_id=row["entry_id"],
session_id=session_id,
source=row["source"],
role=Role(row["role"]),
name=row["name"],
content=row["content"],
created_at=row["created_at"],
timestamp=row["timestamp"],
)
for _, row in get_toplevel_entries_query(
session_id=session_id
).iterrows()
],
token_count_threshold,
),
)
34 changes: 33 additions & 1 deletion agents-api/agents_api/autogen/openapi_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# generated by datamodel-codegen:
# filename: openapi.yaml
# timestamp: 2024-05-28T03:07:50+00:00
# timestamp: 2024-06-03T12:10:34+00:00

from __future__ import annotations

Expand Down Expand Up @@ -124,6 +124,14 @@ class Session(BaseModel):
"""
Render system and assistant message content as jinja templates
"""
token_budget: int | None = None
"""
Threshold value for the adaptive context functionality
"""
context_overflow: str | None = None
"""
Action to start on context window overflow
"""


class CreateSessionRequest(BaseModel):
Expand Down Expand Up @@ -151,6 +159,14 @@ class CreateSessionRequest(BaseModel):
"""
Render system and assistant message content as jinja templates
"""
token_budget: int | None = None
"""
Threshold value for the adaptive context functionality
"""
context_overflow: str | None = None
"""
Action to start on context window overflow
"""


class UpdateSessionRequest(BaseModel):
Expand All @@ -166,6 +182,14 @@ class UpdateSessionRequest(BaseModel):
"""
Optional metadata
"""
token_budget: int | None = None
"""
Threshold value for the adaptive context functionality
"""
context_overflow: str | None = None
"""
Action to start on context window overflow
"""


class UpdateUserRequest(BaseModel):
Expand Down Expand Up @@ -753,6 +777,14 @@ class PatchSessionRequest(BaseModel):
"""
Optional metadata
"""
token_budget: int | None = None
"""
Threshold value for the adaptive context functionality
"""
context_overflow: str | None = None
"""
Action to start on context window overflow
"""


class PartialFunctionDef(BaseModel):
Expand Down
13 changes: 13 additions & 0 deletions agents-api/agents_api/clients/temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,16 @@ async def run_embed_docs_task(
task_queue="memory-task-queue",
id=str(job_id),
)


async def run_truncation_task(
token_count_threshold: int, session_id: UUID, job_id: UUID
):
client = await get_client()

await client.execute_workflow(
"TruncationWorkflow",
args=[str(session_id), token_count_threshold],
task_queue="memory-task-queue",
id=str(job_id),
)
16 changes: 12 additions & 4 deletions agents-api/agents_api/common/protocol/entries.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,18 @@
Role,
ChatMLImageContentPart,
ChatMLTextContentPart,
Detail,
)
from agents_api.common.utils.datetime import utcnow

EntrySource = Literal["api_request", "api_response", "internal", "summarizer"]
Tokenizer = Literal["character_count"]


LOW_IMAGE_TOKEN_COUNT = 85
HIGH_IMAGE_TOKEN_COUNT = 85 + 4 * 170


class Entry(BaseModel):
"""Represents an entry in the system, encapsulating all necessary details such as ID, session ID, source, role, and content among others."""

Expand Down Expand Up @@ -44,12 +49,15 @@ def token_count(self) -> int:
elif isinstance(self.content, dict):
content_length = len(json.dumps(self.content))
elif isinstance(self.content, list):
text = ""
for part in self.content:
# TODO: how to calc token count for images?
if isinstance(part, ChatMLTextContentPart):
text += part.text
content_length = len(text)
content_length += len(part.text)
elif isinstance(part, ChatMLImageContentPart):
content_length += (
LOW_IMAGE_TOKEN_COUNT
if part.image_url.detail == Detail.low
else HIGH_IMAGE_TOKEN_COUNT
)

# Divide the content length by 3.5 to estimate token count based on character count.
return int(content_length // 3.5)
Expand Down
2 changes: 2 additions & 0 deletions agents-api/agents_api/common/protocol/sessions.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,5 @@ class SessionData(BaseModel):
metadata: Dict = {}
user_metadata: Optional[Dict] = None
agent_metadata: Dict = {}
token_budget: int | None = None
context_overflow: str | None = None
15 changes: 15 additions & 0 deletions agents-api/agents_api/common/utils/messages.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
from typing import cast
from agents_api.autogen.openapi_model import (
ChatMLTextContentPart,
ChatMLImageContentPart,
Expand All @@ -23,3 +24,17 @@ def content_to_json(
result = [{"type": "text", "text": json.dumps(content, indent=4)}]

return result


def stringify_content(
msg: str | list[ChatMLTextContentPart] | list[ChatMLImageContentPart] | dict,
) -> str:
content = ""
if isinstance(msg, list):
content = " ".join([part.text for part in msg if part.type == "text"])
elif isinstance(msg, str):
content = msg
elif isinstance(msg, dict) and msg["type"] == "text":
content = cast(str, msg["text"])

return content
8 changes: 0 additions & 8 deletions agents-api/agents_api/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,6 @@
model_api_key: str = env.str("MODEL_API_KEY", default=None)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The removal of summarization_tokens_threshold from the environment configuration will cause runtime errors in other parts of the codebase where it is still being used. Consider either reintroducing this environment variable or updating the dependent code to handle its absence.

model_inference_url: str = env.str("MODEL_INFERENCE_URL", default=None)
openai_api_key: str = env.str("OPENAI_API_KEY", default="")
summarization_ratio_threshold: float = env.float(
"MAX_TOKENS_RATIO_TO_SUMMARIZE", default=0.5
)
summarization_tokens_threshold: int = env.int(
"SUMMARIZATION_TOKENS_THRESHOLD", default=2048
)
summarization_model_name: str = env.str(
"SUMMARIZATION_MODEL_NAME", default="gpt-4-turbo"
)
Expand Down Expand Up @@ -78,8 +72,6 @@
debug=debug,
cozo_host=cozo_host,
cozo_auth=cozo_auth,
summarization_ratio_threshold=summarization_ratio_threshold,
summarization_tokens_threshold=summarization_tokens_threshold,
worker_url=worker_url,
sentry_dsn=sentry_dsn,
temporal_endpoint=temporal_endpoint,
Expand Down
54 changes: 54 additions & 0 deletions agents-api/agents_api/models/entry/delete_entries.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@


from ..utils import cozo_query
from ...common.protocol.entries import Entry
from ...common.utils.messages import content_to_json


@cozo_query
Expand Down Expand Up @@ -61,3 +63,55 @@ def delete_entries_query(session_id: UUID) -> tuple[str, dict]:
}"""

return (query, {"session_id": str(session_id)})


@cozo_query
def delete_entries(entries: list[Entry]) -> tuple[str, dict]:
entry_keys = [
[
f'to_uuid("{e.id}")',
f'to_uuid("{e.session_id}")',
e.source,
e.role.value if hasattr(e.role, "value") else e.role,
e.name,
content_to_json(e.content),
e.token_count,
e.tokenizer,
e.created_at,
e.timestamp,
]
for e in entries
]

query = """
{
?[
entry_id,
session_id,
source,
role,
name,
content,
token_count,
tokenizer,
created_at,
timestamp,
] <- $entry_keys

:delete entries {
entry_id,
session_id,
source,
role,
name,
content,
token_count,
tokenizer,
created_at,
timestamp,
}

:returning
}"""

return (query, {"entry_keys": entry_keys})
Loading
Loading