Skip to content

Commit

Permalink
feat: Summarize messages recursively
Browse files Browse the repository at this point in the history
  • Loading branch information
whiterabbit1983 committed May 22, 2024
1 parent e77381e commit a41a85b
Show file tree
Hide file tree
Showing 8 changed files with 155 additions and 162 deletions.
85 changes: 67 additions & 18 deletions agents-api/agents_api/activities/summarization.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
#!/usr/bin/env python3

import asyncio
from uuid import UUID
from typing import Callable
from textwrap import dedent
from temporalio import activity
from litellm import acompletion
from agents_api.models.entry.add_entries import add_entries_query
from agents_api.models.entry.entries_summarization import (
get_toplevel_entries_query,
entries_summarization_query,
)
from agents_api.common.protocol.entries import Entry
from ..env import summarization_model_name
from agents_api.rec_sum.entities import get_entities
from agents_api.rec_sum.summarize import summarize_messages
from agents_api.rec_sum.trim import trim_messages


example_previous_memory = """
Expand Down Expand Up @@ -148,31 +152,76 @@ async def run_prompt(
return parser(content.strip() if content is not None else "")


# @activity.defn
# async def summarization(session_id: str) -> None:
# session_id = UUID(session_id)
# entries = [
# Entry(**row)
# for _, row in get_toplevel_entries_query(session_id=session_id).iterrows()
# ]

# assert len(entries) > 0, "no need to summarize on empty entries list"

# response = await run_prompt(
# dialog=entries, previous_memories=[], model=summarization_model_name
# )

# new_entry = Entry(
# session_id=session_id,
# source="summarizer",
# role="system",
# name="information",
# content=response,
# timestamp=entries[-1].timestamp + 0.01,
# )

# entries_summarization_query(
# session_id=session_id,
# new_entry=new_entry,
# old_entry_ids=[e.id for e in entries],
# )


@activity.defn
async def summarization(session_id: str) -> None:
session_id = UUID(session_id)
entries = [
Entry(**row)
for _, row in get_toplevel_entries_query(session_id=session_id).iterrows()
row for _, row in get_toplevel_entries_query(session_id=session_id).iterrows()
]

assert len(entries) > 0, "no need to summarize on empty entries list"

response = await run_prompt(
dialog=entries, previous_memories=[], model=summarization_model_name
trimmed_messages, entities = await asyncio.gather(
trim_messages(entries),
get_entities(entries),
)

new_entry = Entry(
session_id=session_id,
source="summarizer",
role="system",
name="information",
content=response,
timestamp=entries[-1].timestamp + 0.01,
summarized = await summarize_messages(trimmed_messages)

ts_delta = (entries[1]["timestamp"] - entries[0]["timestamp"]) / 2

add_entries_query(
Entry(
session_id=session_id,
source="summarizer",
role="system",
name="entities",
content=entities["content"],
timestamp=entries[0]["timestamp"] + ts_delta,
)
)

entries_summarization_query(
session_id=session_id,
new_entry=new_entry,
old_entry_ids=[e.id for e in entries],
)
for msg in summarized:
new_entry = Entry(
session_id=session_id,
source="summarizer",
role="system",
name="information",
content=msg["content"],
timestamp=entries[-1]["timestamp"] + 0.01,
)

entries_summarization_query(
session_id=session_id,
new_entry=new_entry,
old_entry_ids=[entries[idx]["entry_id"] for idx in msg["summarizes"]],
)
3 changes: 0 additions & 3 deletions agents-api/agents_api/rec_sum/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +0,0 @@
from .entities import get_entities
from .summarize import summarize_messages
from .trim import trim_messages
17 changes: 5 additions & 12 deletions agents-api/agents_api/rec_sum/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,28 +5,21 @@
module_directory = Path(__file__).parent



with open(f"{module_directory}/entities_example_chat.json", 'r') as _f:
with open(f"{module_directory}/entities_example_chat.json", "r") as _f:
entities_example_chat = json.load(_f)



with open(f"{module_directory}/trim_example_chat.json", 'r') as _f:
with open(f"{module_directory}/trim_example_chat.json", "r") as _f:
trim_example_chat = json.load(_f)



with open(f"{module_directory}/trim_example_result.json", 'r') as _f:
with open(f"{module_directory}/trim_example_result.json", "r") as _f:
trim_example_result = json.load(_f)



with open(f"{module_directory}/summarize_example_chat.json", 'r') as _f:
with open(f"{module_directory}/summarize_example_chat.json", "r") as _f:
summarize_example_chat = json.load(_f)



with open(f"{module_directory}/summarize_example_result.json", 'r') as _f:
with open(f"{module_directory}/summarize_example_result.json", "r") as _f:
summarize_example_result = json.load(_f)


40 changes: 10 additions & 30 deletions agents-api/agents_api/rec_sum/entities.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json

from tenacity import retry, stop_after_attempt, wait_fixed
from tenacity import retry, stop_after_attempt

from .data import entities_example_chat
from .generate import generate
Expand Down Expand Up @@ -41,38 +41,18 @@
- See the example to get a better idea of the task."""


make_entities_prompt = lambda session, user="a user", assistant="gpt-4-turbo", **_: [f"""\
You are given a session history of a chat between {user or "a user"} and {assistant or "gpt-4-turbo"}. The session is formatted in the ChatML JSON format (from OpenAI).
{entities_instructions}
<ct:example-session>
{json.dumps(entities_example_chat, indent=2)}
</ct:example-session>
<ct:example-plan>
{entities_example_plan}
</ct:example-plan>
<ct:example-entities>
{entities_example_result}
</ct:example-entities>""",

f"""\
Begin! Write the entities as a Markdown formatted list. First write your plan inside <ct:plan></ct:plan> and then the extracted entities between <ct:entities></ct:entities>.
<ct:session>
{json.dumps(session, indent=2)}
</ct:session>"""]

def make_entities_prompt(session, user="a user", assistant="gpt-4-turbo", **_):
return [
f"You are given a session history of a chat between {user or 'a user'} and {assistant or 'gpt-4-turbo'}. The session is formatted in the ChatML JSON format (from OpenAI).\n\n{entities_instructions}\n\n<ct:example-session>\n{json.dumps(entities_example_chat, indent=2)}\n</ct:example-session>\n\n<ct:example-plan>\n{entities_example_plan}\n</ct:example-plan>\n\n<ct:example-entities>\n{entities_example_result}\n</ct:example-entities>",
f"Begin! Write the entities as a Markdown formatted list. First write your plan inside <ct:plan></ct:plan> and then the extracted entities between <ct:entities></ct:entities>.\n\n<ct:session>\n{json.dumps(session, indent=2)}\n\n</ct:session>",
]


@retry(stop=stop_after_attempt(2))
async def get_entities(
chat_session,
model="gpt-4-turbo",
stop=["</ct:entities"],
model="gpt-4-turbo",
stop=["</ct:entities"],
temperature=0.7,
**kwargs,
):
Expand All @@ -84,7 +64,7 @@ async def get_entities(
and chat_session[0].get("name") != "entities"
):
chat_session = chat_session[1:]

names = get_names_from_session(chat_session)
system_prompt, user_message = make_entities_prompt(chat_session, **names)
messages = [chatml.system(system_prompt), chatml.user(user_message)]
Expand All @@ -100,5 +80,5 @@ async def get_entities(
result["content"] = result["content"].split("<ct:entities>")[-1].strip()
result["role"] = "system"
result["name"] = "entities"

return chatml.make(**result)
9 changes: 4 additions & 5 deletions agents-api/agents_api/rec_sum/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,15 @@

@retry(wait=wait_fixed(2), stop=stop_after_attempt(5))
async def generate(
messages: list[dict],
client: AsyncClient=client,
model: str="gpt-4-turbo",
messages: list[dict],
client: AsyncClient = client,
model: str = "gpt-4-turbo",
**kwargs
) -> dict:
result = await client.chat.completions.create(
model=model, messages=messages, **kwargs
)

result = result.choices[0].message.__dict__

return result

48 changes: 14 additions & 34 deletions agents-api/agents_api/rec_sum/summarize.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json

from tenacity import retry, stop_after_attempt, wait_fixed
from tenacity import retry, stop_after_attempt

from .data import summarize_example_chat, summarize_example_result
from .generate import generate
Expand All @@ -24,7 +24,6 @@
- We can safely summarize message 34's essay into just the salient points only."""



summarize_instructions = """\
Your goal is to compactify the history by coalescing redundant information in messages into their summary in order to reduce its size and save costs.
Expand All @@ -36,53 +35,32 @@
- VERY IMPORTANT: Add the indices of messages that are being summarized so that those messages can then be removed from the session otherwise, there'll be no way to identify which messages to remove. See example for more details."""



make_summarize_prompt = lambda session, user="a user", assistant="gpt-4-turbo", **_: [f"""\
You are given a session history of a chat between {user or "a user"} and {assistant or "gpt-4-turbo"}. The session is formatted in the ChatML JSON format (from OpenAI).
{summarize_instructions}
<ct:example-session>
{json.dumps(add_indices(summarize_example_chat), indent=2)}
</ct:example-session>
<ct:example-plan>
{summarize_example_plan}
</ct:example-plan>
<ct:example-summarized-messages>
{json.dumps(summarize_example_result, indent=2)}
</ct:example-summarized-messages>""",

f"""\
Begin! Write the summarized messages as a json list just like the example above. First write your plan inside <ct:plan></ct:plan> and then your answer between <ct:summarized-messages></ct:summarized-messages>. Don't forget to add the indices of the messages being summarized alongside each summary.
<ct:session>
{json.dumps(add_indices(session), indent=2)}
</ct:session>"""]

def make_summarize_prompt(session, user="a user", assistant="gpt-4-turbo", **_):
return [
f"You are given a session history of a chat between {user or 'a user'} and {assistant or 'gpt-4-turbo'}. The session is formatted in the ChatML JSON format (from OpenAI).\n\n{summarize_instructions}\n\n<ct:example-session>\n{json.dumps(add_indices(summarize_example_chat), indent=2)}\n</ct:example-session>\n\n<ct:example-plan>\n{summarize_example_plan}\n</ct:example-plan>\n\n<ct:example-summarized-messages>\n{json.dumps(summarize_example_result, indent=2)}\n</ct:example-summarized-messages>",
f"Begin! Write the summarized messages as a json list just like the example above. First write your plan inside <ct:plan></ct:plan> and then your answer between <ct:summarized-messages></ct:summarized-messages>. Don't forget to add the indices of the messages being summarized alongside each summary.\n\n<ct:session>\n{json.dumps(add_indices(session), indent=2)}\n\n</ct:session>",
]


@retry(stop=stop_after_attempt(2))
async def summarize_messages(
chat_session,
model="gpt-4-turbo",
stop=["</ct:summarized"],
model="gpt-4-turbo",
stop=["</ct:summarized"],
temperature=0.8,
**kwargs,
):
assert len(chat_session) > 2, "Session is too short"

offset = 0

# Remove the system prompt if present
if (
chat_session[0]["role"] == "system"
and chat_session[0].get("name") != "entities"
):
chat_session = chat_session[1:]

# The indices are not matched up correctly
offset = 1

Expand All @@ -98,7 +76,9 @@ async def summarize_messages(
)

assert "<ct:summarized-messages>" in result["content"]
summarized_messages = json.loads(result["content"].split("<ct:summarized-messages>")[-1].strip())
summarized_messages = json.loads(
result["content"].split("<ct:summarized-messages>")[-1].strip()
)

assert all((msg.get("summarizes") is not None for msg in summarized_messages))

Expand All @@ -107,5 +87,5 @@ async def summarize_messages(
{**msg, "summarizes": [i + offset for i in msg["summarizes"]]}
for msg in summarized_messages
]

return summarized_messages
Loading

0 comments on commit a41a85b

Please sign in to comment.