From dbb1b323b6af8ea8ef9e5306835101d5749319d0 Mon Sep 17 00:00:00 2001 From: Diwank Tomer Date: Wed, 14 Aug 2024 12:29:24 -0400 Subject: [PATCH] feat(agents-api): Add fixtures for testing workflows Signed-off-by: Diwank Tomer --- .../agents_api/activities/co_density.py | 115 -------- .../agents_api/activities/embed_docs.py | 18 +- agents-api/agents_api/activities/mem_mgmt.py | 12 +- .../agents_api/activities/summarization.py | 250 ++++-------------- .../activities/task_steps/__init__.py | 5 +- .../agents_api/activities/truncation.py | 63 ++--- agents-api/agents_api/clients/temporal.py | 36 ++- agents-api/agents_api/clients/worker/types.py | 87 +----- .../agents_api/common/protocol/entries.py | 52 ++-- .../agents_api/models/docs/embed_snippets.py | 2 +- agents-api/agents_api/rec_sum/entities.py | 2 +- agents-api/agents_api/rec_sum/generate.py | 6 +- agents-api/agents_api/rec_sum/summarize.py | 2 +- agents-api/agents_api/rec_sum/trim.py | 2 +- agents-api/agents_api/worker/__main__.py | 85 +----- agents-api/agents_api/worker/worker.py | 68 +++++ agents-api/agents_api/workflows/co_density.py | 20 -- agents-api/agents_api/workflows/mem_mgmt.py | 7 +- agents-api/tests/fixtures.py | 27 ++ agents-api/tests/test_activities.py | 52 +++- 20 files changed, 333 insertions(+), 578 deletions(-) delete mode 100644 agents-api/agents_api/activities/co_density.py create mode 100644 agents-api/agents_api/worker/worker.py delete mode 100644 agents-api/agents_api/workflows/co_density.py diff --git a/agents-api/agents_api/activities/co_density.py b/agents-api/agents_api/activities/co_density.py deleted file mode 100644 index 408cc398a..000000000 --- a/agents-api/agents_api/activities/co_density.py +++ /dev/null @@ -1,115 +0,0 @@ -from textwrap import dedent -from typing import Callable - -from temporalio import activity - -from agents_api.clients import litellm - -from .types import MemoryDensityTaskArgs - - -def make_prompt(args: MemoryDensityTaskArgs): - # Unpack - memory = args.memory - - # Template - template = dedent( - """\ - [[Memory from a Dialog]] - {memory} - - [[Instruction]] - You will generate increasingly concise, entity-dense summaries of the above Memory. - - Repeat the following 2 steps 5 times. - - Step 1: Identify 1-3 informative Entities (";" delimited) from the Memory which are missing from the previously generated summary. - Step 2: Write a new, denser summary of identical length which covers every entity and detail from the previous summary plus the Missing Entities. - - A Missing Entity is: - - Relevant: to the main story. - - Specific: descriptive yet concise (5 words or fewer). - - Novel: not in the previous summary. - - Faithful: present in the Memory. - - Anywhere: located anywhere in the Memory. - - Guidelines: - - The first summary should be long (4-5 sentences, ~80 words) yet highly non-specific, containing little information beyond the entities marked as missing. Use overly verbose language and fillers (e.g., "this article discusses") to reach ~80 words. - - Make every word count: rewrite the previous summary to improve flow and make space for additional entities. - - Make space with fusion, compression, and removal of uninformative phrases like "the memory discusses." - - The summaries should become highly dense and concise yet self-contained, e.g., easily understood without the Memory. - - Missing entities can appear anywhere in the new summary. - - Never drop entities from the previous summary. If space cannot be made, add fewer new entities. - - Remember, use the exact same number of words for each summary. - - Answer in JSON. The JSON should be a list (length 5) of dictionaries whose keys are "Missing_Entities", "Denser_Summary" and "Density_Score" (between 1-10, higher is better). - - [[Result]] - ```json - """ - ).strip() - - prompt = template.format(memory=memory) - - return prompt - - -async def run_prompt( - memory: str, - model: str = "gpt-4o", - max_tokens: int = 400, - temperature: float = 0.2, - parser: Callable[[str], str] = lambda x: x, -) -> str: - prompt = make_prompt(MemoryDensityTaskArgs(memory=memory)) - - response = await litellm.acompletion( - model=model, - messages=[ - { - "content": prompt, - "role": "user", - } - ], - max_tokens=max_tokens, - temperature=temperature, - stop=["<", "<|"], - stream=False, - ) - - content = response.choices[0].message.content - - return parser(content.strip() if content is not None else "") - - -@activity.defn -async def co_density(memory: str) -> None: - # session_id = UUID(session_id) - # entries = [ - # Entry(**row) - # for _, row in client.run( - # get_toplevel_entries_query(session_id=session_id) - # ).iterrows() - # ] - - # assert len(entries) > 0, "no need to summarize on empty entries list" - - await run_prompt(memory=memory) - - # new_entry = Entry( - # session_id=session_id, - # source="summarizer", - # role="system", - # name="information", - # content=response, - # timestamp=entries[-1].timestamp + 0.01, - # ) - - # client.run( - # entries_summarization_query( - # session_id=session_id, - # new_entry=new_entry, - # old_entry_ids=[e.id for e in entries], - # ) - # ) diff --git a/agents-api/agents_api/activities/embed_docs.py b/agents-api/agents_api/activities/embed_docs.py index b8e01e65e..da7bb313f 100644 --- a/agents-api/agents_api/activities/embed_docs.py +++ b/agents-api/agents_api/activities/embed_docs.py @@ -1,27 +1,39 @@ -from pydantic import UUID4 +from uuid import UUID + from temporalio import activity from agents_api.clients import embed as embedder +from agents_api.clients.cozo import get_cozo_client from agents_api.models.docs.embed_snippets import embed_snippets as embed_snippets_query snippet_embed_instruction = "Encode this passage for retrieval: " @activity.defn -async def embed_docs(doc_id: UUID4, title: str, content: list[str]) -> None: +async def embed_docs( + developer_id: UUID, + doc_id: UUID, + title: str, + content: list[str], + include_title: bool = True, + cozo_client=None, +) -> None: indices, snippets = list(zip(*enumerate(content))) + embeddings = await embedder.embed( [ { "instruction": snippet_embed_instruction, - "text": title + "\n\n" + snippet, + "text": (title + "\n\n" + snippet) if include_title else snippet, } for snippet in snippets ] ) embed_snippets_query( + developer_id=developer_id, doc_id=doc_id, snippet_indices=indices, embeddings=embeddings, + client=cozo_client or get_cozo_client(), ) diff --git a/agents-api/agents_api/activities/mem_mgmt.py b/agents-api/agents_api/activities/mem_mgmt.py index f368a0b0b..ea4bb84d2 100644 --- a/agents-api/agents_api/activities/mem_mgmt.py +++ b/agents-api/agents_api/activities/mem_mgmt.py @@ -4,9 +4,9 @@ from temporalio import activity -from agents_api.clients import litellm - -from .types import ChatML, MemoryManagementTaskArgs +from ..autogen.openapi_model import InputChatMLMessage +from ..clients import litellm +from .types import MemoryManagementTaskArgs example_previous_memory = """ Speaker 1: Composes and listens to music. Likes to buy basketball shoes but doesn't wear them often. @@ -118,7 +118,7 @@ def make_prompt( async def run_prompt( - dialog: list[ChatML], + dialog: list[InputChatMLMessage], session_id: UUID, previous_memories: list[str] = [], model: str = "gpt-4o", @@ -156,7 +156,9 @@ async def run_prompt( @activity.defn async def mem_mgmt( - dialog: list[ChatML], session_id: UUID, previous_memories: list[str] = [] + dialog: list[InputChatMLMessage], + session_id: UUID, + previous_memories: list[str] = [], ) -> None: # session_id = UUID(session_id) # entries = [ diff --git a/agents-api/agents_api/activities/summarization.py b/agents-api/agents_api/activities/summarization.py index dc365380d..581dcdb00 100644 --- a/agents-api/agents_api/activities/summarization.py +++ b/agents-api/agents_api/activities/summarization.py @@ -8,8 +8,7 @@ import pandas as pd from temporalio import activity -from agents_api.common.protocol.entries import Entry - +# from agents_api.common.protocol.entries import Entry # from agents_api.models.entry.entries_summarization import ( # entries_summarization_query, # get_toplevel_entries_query, @@ -18,7 +17,6 @@ from agents_api.rec_sum.summarize import summarize_messages from agents_api.rec_sum.trim import trim_messages -from ..clients import litellm from ..env import summarization_model_name @@ -31,196 +29,60 @@ def get_toplevel_entries_query(*args, **kwargs): return pd.DataFrame() -# - -example_previous_memory = """ -Speaker 1: Composes and listens to music. Likes to buy basketball shoes but doesn't wear them often. -""".strip() - -example_dialog_context = """ -Speaker 1: Did you find a place to donate your shoes? -Speaker 2: I did! I was driving to the grocery store the other day, when I noticed a bin labeled "Donation for Shoes and Clothing." It was easier than I thought! How about you? Why do you have so many pairs of sandals? -Speaker 1: I don't understand myself! When I look them online I just have the urge to buy them, even when I know I don't need them. This addiction is getting worse and worse. -Speaker 2: I completely agree that buying shoes can become an addiction! Are there any ways you can make money from home while waiting for a job offer from a call center? -Speaker 1: Well I already got the job so I just need to learn using the software. When I was still searching for jobs, we actually do a yard sale to sell many of my random items that are never used and clearly aren't needed either. -Speaker 2: Congratulations on getting the job! I know it'll help you out so much. And of course, maybe I should turn to yard sales as well, for they can be a great way to make some extra cash! -Speaker 1: Do you have another job or do you compose music for a living? How does your shopping addiction go? -Speaker 2: As a matter of fact, I do have another job in addition to composing music. I'm actually a music teacher at a private school, and on the side, I compose music for friends and family. As far as my shopping addiction goes, it's getting better. I promised myself that I wouldn't buy myself any more shoes this year! -Speaker 1: Ah, I remember the time I promised myself the same thing on not buying random things anymore, never work so far. Good luck with yours! -Speaker 2: Thanks! I need the good luck wishes. I've been avoiding malls and shopping outlets. Maybe you can try the same! -Speaker 1: I can avoid them physically, but with my job enable me sitting in front of my computer for a long period of time, I already turn the shopping addiction into online-shopping addiction. lol. Wish me luck! -Speaker 2: Sure thing! You know, and speaking of spending time before a computer, I need to look up information about Precious Moments figurines. I'd still like to know what they are! -""".strip() - -example_updated_memory = """ -Speaker 1: -- Enjoys composing and listening to music. -- Recently got a job that requires the use of specialized software. -- Displays a shopping addiction, particularly for shoes, that has transitioned to online-shopping due to job nature. -- Previously attempted to mitigate shopping addiction without success. -- Had organized a yard sale to sell unused items when job searching. - -Speaker 2: -- Also enjoys buying shoes and admits to it being addictive. -- Works as a music teacher at a private school in addition to composing music. -- Takes active measures to control his shopping addiction, including avoiding malls. -- Is interested in Precious Moments figurines. -""".strip() - - -def make_prompt( - dialog: list[Entry], - previous_memories: list[str], - max_turns: int = 10, - num_sentences: int = 10, -): - # Template - template = dedent( - """\ - **Instructions** - You are an advanced AI language model with the ability to store and update a memory to keep track of key personality information for people. You will receive a memory and a dialogue between two people. - - Your goal is to update the memory by incorporating the new personality information for both participants while ensuring that the memory does not exceed {num_sentences} sentences. - - To successfully update the memory, follow these steps: - - 1. Carefully analyze the existing memory and extract the key personality information of the participants from it. - 2. Consider the dialogue provided to identify any new or changed personality traits of either participant that need to be incorporated into the memory. - 3. Combine the old and new personality information to create an updated representation of the participants' traits. - 4. Structure the updated memory in a clear and concise manner, ensuring that it does not exceed {num_sentences} sentences. - 5. Pay attention to the relevance and importance of the personality information, focusing on capturing the most significant aspects while maintaining the overall coherence of the memory. - - Remember, the memory should serve as a reference point to maintain continuity in the dialogue and help accurately set context in future conversations based on the personality traits of the participants. - - **Test Example** - [[Previous Memory]] - {example_previous_memory} - - [[Dialogue Context]] - {example_dialog_context} - - [[Updated Memory]] - {example_updated_memory} - - **Actual Run** - [[Previous Memory]] - {previous_memory} - - [[Dialogue Context]] - {dialog_context} - - [[Updated Memory]] - """ - ).strip() - - # Filter dialog (keep only user and assistant sections) - dialog = [entry for entry in dialog if entry.role != "system"] - - # Truncate to max_turns - dialog = dialog[-max_turns:] - - # Prepare dialog context - dialog_context = "\n".join( - [ - f'{e.name or ("User" if e.role == "user" else "Assistant")}: {e.content}' - for e in dialog - ] - ) - - prompt = template.format( - dialog_context=dialog_context, - previous_memory="\n".join(previous_memories), - num_sentences=num_sentences, - example_dialog_context=example_dialog_context, - example_previous_memory=example_previous_memory, - example_updated_memory=example_updated_memory, - ) - - return prompt - - -async def run_prompt( - dialog: list[Entry], - previous_memories: list[str], - model: str = "gpt-4o", - max_tokens: int = 400, - temperature: float = 0.1, - parser: Callable[[str], str] = lambda x: x, - **kwargs, -) -> str: - prompt = make_prompt(dialog, previous_memories, **kwargs) - response = await litellm.acompletion( - model=model, - messages=[ - { - "content": prompt, - "role": "user", - } - ], - max_tokens=max_tokens, - temperature=temperature, - stop=["<", "<|"], - stream=False, - ) - - content = response.choices[0].message.content - - return parser(content.strip() if content is not None else "") - - @activity.defn async def summarization(session_id: str) -> None: - session_id = UUID(session_id) - entries = [] - entities_entry_ids = [] - for _, row in get_toplevel_entries_query(session_id=session_id).iterrows(): - if row["role"] == "system" and row.get("name") == "entities": - entities_entry_ids.append(UUID(row["entry_id"], version=4)) - else: - entries.append(row) - - assert len(entries) > 0, "no need to summarize on empty entries list" - - summarized, entities = await asyncio.gather( - summarize_messages(entries, model=summarization_model_name), - get_entities(entries, model=summarization_model_name), - ) - trimmed_messages = await trim_messages(summarized, model=summarization_model_name) - ts_delta = (entries[1]["timestamp"] - entries[0]["timestamp"]) / 2 - new_entities_entry = Entry( - session_id=session_id, - source="summarizer", - role="system", - name="entities", - content=entities["content"], - timestamp=entries[0]["timestamp"] + ts_delta, - ) - - entries_summarization_query( - session_id=session_id, - new_entry=new_entities_entry, - old_entry_ids=entities_entry_ids, - ) - - trimmed_map = { - m["index"]: m["content"] for m in trimmed_messages if m.get("index") is not None - } - - for idx, msg in enumerate(summarized): - new_entry = Entry( - session_id=session_id, - source="summarizer", - role="system", - name="information", - content=trimmed_map.get(idx, msg["content"]), - timestamp=entries[-1]["timestamp"] + 0.01, - ) - - entries_summarization_query( - session_id=session_id, - new_entry=new_entry, - old_entry_ids=[ - UUID(entries[idx - 1]["entry_id"], version=4) - for idx in msg["summarizes"] - ], - ) + raise NotImplementedError() + # session_id = UUID(session_id) + # entries = [] + # entities_entry_ids = [] + # for _, row in get_toplevel_entries_query(session_id=session_id).iterrows(): + # if row["role"] == "system" and row.get("name") == "entities": + # entities_entry_ids.append(UUID(row["entry_id"], version=4)) + # else: + # entries.append(row) + + # assert len(entries) > 0, "no need to summarize on empty entries list" + + # summarized, entities = await asyncio.gather( + # summarize_messages(entries, model=summarization_model_name), + # get_entities(entries, model=summarization_model_name), + # ) + # trimmed_messages = await trim_messages(summarized, model=summarization_model_name) + # ts_delta = (entries[1]["timestamp"] - entries[0]["timestamp"]) / 2 + # new_entities_entry = Entry( + # session_id=session_id, + # source="summarizer", + # role="system", + # name="entities", + # content=entities["content"], + # timestamp=entries[0]["timestamp"] + ts_delta, + # ) + + # entries_summarization_query( + # session_id=session_id, + # new_entry=new_entities_entry, + # old_entry_ids=entities_entry_ids, + # ) + + # trimmed_map = { + # m["index"]: m["content"] for m in trimmed_messages if m.get("index") is not None + # } + + # for idx, msg in enumerate(summarized): + # new_entry = Entry( + # session_id=session_id, + # source="summarizer", + # role="system", + # name="information", + # content=trimmed_map.get(idx, msg["content"]), + # timestamp=entries[-1]["timestamp"] + 0.01, + # ) + + # entries_summarization_query( + # session_id=session_id, + # new_entry=new_entry, + # old_entry_ids=[ + # UUID(entries[idx - 1]["entry_id"], version=4) + # for idx in msg["summarizes"] + # ], + # ) diff --git a/agents-api/agents_api/activities/task_steps/__init__.py b/agents-api/agents_api/activities/task_steps/__init__.py index 494226a5b..13f1adcfe 100644 --- a/agents-api/agents_api/activities/task_steps/__init__.py +++ b/agents-api/agents_api/activities/task_steps/__init__.py @@ -17,7 +17,6 @@ from ...clients import ( litellm, # We dont directly import `acompletion` so we can mock it ) -from ...clients.worker.types import ChatML from ...common.protocol.tasks import ( StepContext, TransitionInfo, @@ -53,7 +52,9 @@ async def prompt_step(context: StepContext) -> dict: ) messages = [ - ChatML(role="user", content=m) if isinstance(m, str) else ChatML(**m) + InputChatMLMessage(role="user", content=m) + if isinstance(m, str) + else InputChatMLMessage(**m) for m in messages ] diff --git a/agents-api/agents_api/activities/truncation.py b/agents-api/agents_api/activities/truncation.py index 190190a79..353e4b570 100644 --- a/agents-api/agents_api/activities/truncation.py +++ b/agents-api/agents_api/activities/truncation.py @@ -2,10 +2,11 @@ from temporalio import activity -from agents_api.autogen.openapi_model import Role +# from agents_api.autogen.openapi_model import Role from agents_api.common.protocol.entries import Entry from agents_api.models.entry.delete_entries import delete_entries -from agents_api.models.entry.entries_summarization import get_toplevel_entries_query + +# from agents_api.models.entry.entries_summarization import get_toplevel_entries_query def get_extra_entries(messages: list[Entry], token_count_threshold: int) -> list[UUID]: @@ -14,40 +15,40 @@ def get_extra_entries(messages: list[Entry], token_count_threshold: int) -> list result: list[UUID] = [] token_cnt, offset = 0, 0 - if messages[0].role == Role.system: - token_cnt, offset = messages[0].token_count, 1 + # if messages[0].role == Role.system: + # token_cnt, offset = messages[0].token_count, 1 - for m in reversed(messages[offset:]): - token_cnt += m.token_count - if token_cnt < token_count_threshold: - continue - else: - result.append(m.id) + # for m in reversed(messages[offset:]): + # token_cnt += m.token_count + # if token_cnt < token_count_threshold: + # continue + # else: + # result.append(m.id) - return result + # return result @activity.defn async def truncation(session_id: str, token_count_threshold: int) -> None: session_id = UUID(session_id) - delete_entries( - get_extra_entries( - [ - Entry( - entry_id=row["entry_id"], - session_id=session_id, - source=row["source"], - role=Role(row["role"]), - name=row["name"], - content=row["content"], - created_at=row["created_at"], - timestamp=row["timestamp"], - ) - for _, row in get_toplevel_entries_query( - session_id=session_id - ).iterrows() - ], - token_count_threshold, - ), - ) + # delete_entries( + # get_extra_entries( + # [ + # Entry( + # entry_id=row["entry_id"], + # session_id=session_id, + # source=row["source"], + # role=Role(row["role"]), + # name=row["name"], + # content=row["content"], + # created_at=row["created_at"], + # timestamp=row["timestamp"], + # ) + # for _, row in get_toplevel_entries_query( + # session_id=session_id + # ).iterrows() + # ], + # token_count_threshold, + # ), + # ) diff --git a/agents-api/agents_api/clients/temporal.py b/agents-api/agents_api/clients/temporal.py index 7e45b50d7..72a5056c8 100644 --- a/agents-api/agents_api/clients/temporal.py +++ b/agents-api/agents_api/clients/temporal.py @@ -13,7 +13,11 @@ from ..worker.codec import pydantic_data_converter -async def get_client(): +async def get_client( + worker_url: str = temporal_worker_url, + namespace: str = temporal_namespace, + data_converter=pydantic_data_converter, +): tls_config = False if temporal_private_key and temporal_client_cert: @@ -23,15 +27,17 @@ async def get_client(): ) return await Client.connect( - temporal_worker_url, - namespace=temporal_namespace, + worker_url, + namespace=namespace, tls=tls_config, - data_converter=pydantic_data_converter, + data_converter=data_converter, ) -async def run_summarization_task(session_id: UUID, job_id: UUID): - client = await get_client() +async def run_summarization_task( + session_id: UUID, job_id: UUID, client: Client | None = None +): + client = client or (await get_client()) await client.execute_workflow( "SummarizationWorkflow", @@ -42,9 +48,13 @@ async def run_summarization_task(session_id: UUID, job_id: UUID): async def run_embed_docs_task( - doc_id: UUID, title: str, content: list[str], job_id: UUID + doc_id: UUID, + title: str, + content: list[str], + job_id: UUID, + client: Client | None = None, ): - client = await get_client() + client = client or (await get_client()) await client.execute_workflow( "EmbedDocsWorkflow", @@ -55,9 +65,12 @@ async def run_embed_docs_task( async def run_truncation_task( - token_count_threshold: int, session_id: UUID, job_id: UUID + token_count_threshold: int, + session_id: UUID, + job_id: UUID, + client: Client | None = None, ): - client = await get_client() + client = client or (await get_client()) await client.execute_workflow( "TruncationWorkflow", @@ -72,8 +85,9 @@ async def run_task_execution_workflow( job_id: UUID, start: tuple[str, int] = ("main", 0), previous_inputs: list[dict] = [], + client: Client | None = None, ): - client = await get_client() + client = client or (await get_client()) return await client.start_workflow( "TaskExecutionWorkflow", diff --git a/agents-api/agents_api/clients/worker/types.py b/agents-api/agents_api/clients/worker/types.py index 02b6add6c..3bf063083 100644 --- a/agents-api/agents_api/clients/worker/types.py +++ b/agents-api/agents_api/clients/worker/types.py @@ -1,108 +1,41 @@ -from typing import Callable, Literal, Optional, Protocol +from typing import Literal from uuid import UUID from pydantic import BaseModel from agents_api.autogen.openapi_model import ( - ChatMLImageContentPart, - ChatMLTextContentPart, + InputChatMLMessage, ) -class PromptModule(Protocol): - stop: list[str] - temperature: float - parser: Callable[[str], str] - make_prompt: Callable[..., str] - - -class ChatML(BaseModel): - role: Literal["system", "user", "assistant", "function_call"] - content: str | dict | list[ChatMLTextContentPart] | list[ChatMLImageContentPart] - - name: Optional[str] = None - entry_id: Optional[UUID] = None - - processed: bool = False - parent_id: Optional[UUID] = None - session_id: Optional[UUID] = None - timestamp: Optional[float] = None - token_count: Optional[int] = None - - -class BaseTask(BaseModel): ... - - -class BaseTaskArgs(BaseModel): ... - - -class MemoryManagementTaskArgs(BaseTaskArgs): +class MemoryManagementTaskArgs(BaseModel): session_id: UUID model: str - dialog: list[ChatML] + dialog: list[InputChatMLMessage] previous_memories: list[str] = [] -class MemoryManagementTask(BaseTask): +class MemoryManagementTask(BaseModel): name: Literal["memory_management.v1"] args: MemoryManagementTaskArgs -class MemoryDensityTaskArgs(BaseTaskArgs): +class MemoryDensityTaskArgs(BaseModel): memory: str -class MemoryDensityTask(BaseTask): +class MemoryDensityTask(BaseModel): name: Literal["memory_density.v1"] args: MemoryDensityTaskArgs -class MemoryRatingTaskArgs(BaseTaskArgs): +class MemoryRatingTaskArgs(BaseModel): memory: str -class MemoryRatingTask(BaseTask): +class MemoryRatingTask(BaseModel): name: Literal["memory_rating.v1"] args: MemoryRatingTaskArgs -class DialogInsightsTaskArgs(BaseTaskArgs): - dialog: list[ChatML] - person1: str - person2: str - - -class DialogInsightsTask(BaseTask): - name: Literal["dialog_insights.v1"] - args: DialogInsightsTaskArgs - - -class RelationshipSummaryTaskArgs(BaseTaskArgs): - statements: list[str] - person1: str - person2: str - - -class RelationshipSummaryTask(BaseTask): - name: Literal["relationship_summary.v1"] - args: RelationshipSummaryTaskArgs - - -class SalientQuestionsTaskArgs(BaseTaskArgs): - statements: list[str] - num: int = 3 - - -class SalientQuestionsTask(BaseTask): - name: Literal["salient_questions.v1"] - args: SalientQuestionsTaskArgs - - -CombinedTask = ( - MemoryManagementTask - | MemoryDensityTask - | MemoryRatingTask - | DialogInsightsTask - | RelationshipSummaryTask - | SalientQuestionsTask -) +CombinedTask = MemoryManagementTask | MemoryDensityTask | MemoryRatingTask diff --git a/agents-api/agents_api/common/protocol/entries.py b/agents-api/agents_api/common/protocol/entries.py index 6ef7f70f2..18d63f583 100644 --- a/agents-api/agents_api/common/protocol/entries.py +++ b/agents-api/agents_api/common/protocol/entries.py @@ -27,28 +27,30 @@ class Entry(BaseEntry): token_count: int tokenizer: str = Field(default="character_count") - @computed_field - @property - def token_count(self) -> int: - """Calculates the token count based on the content's character count. The tokenizer 'character_count' divides the length of the content by 3.5 to estimate the token count. Raises NotImplementedError for unknown tokenizers.""" - if self.tokenizer == "character_count": - content_length = 0 - if isinstance(self.content, str): - content_length = len(self.content) - elif isinstance(self.content, dict): - content_length = len(json.dumps(self.content)) - elif isinstance(self.content, list): - for part in self.content: - if isinstance(part, ChatMLTextContentPart): - content_length += len(part.text) - elif isinstance(part, ChatMLImageContentPart): - content_length += ( - LOW_IMAGE_TOKEN_COUNT - if part.image_url.detail == "low" - else HIGH_IMAGE_TOKEN_COUNT - ) - - # Divide the content length by 3.5 to estimate token count based on character count. - return int(content_length // 3.5) - - raise NotImplementedError(f"Unknown tokenizer: {self.tokenizer}") + # TODO: Replace this with a proper implementation. + + # @computed_field + # @property + # def token_count(self) -> int: + # """Calculates the token count based on the content's character count. The tokenizer 'character_count' divides the length of the content by 3.5 to estimate the token count. Raises NotImplementedError for unknown tokenizers.""" + # if self.tokenizer == "character_count": + # content_length = 0 + # if isinstance(self.content, str): + # content_length = len(self.content) + # elif isinstance(self.content, dict): + # content_length = len(json.dumps(self.content)) + # elif isinstance(self.content, list): + # for part in self.content: + # if isinstance(part, ChatMLTextContentPart): + # content_length += len(part.text) + # elif isinstance(part, ChatMLImageContentPart): + # content_length += ( + # LOW_IMAGE_TOKEN_COUNT + # if part.image_url.detail == "low" + # else HIGH_IMAGE_TOKEN_COUNT + # ) + + # # Divide the content length by 3.5 to estimate token count based on character count. + # return int(content_length // 3.5) + + # raise NotImplementedError(f"Unknown tokenizer: {self.tokenizer}") diff --git a/agents-api/agents_api/models/docs/embed_snippets.py b/agents-api/agents_api/models/docs/embed_snippets.py index 5dd4b4457..89750192a 100644 --- a/agents-api/agents_api/models/docs/embed_snippets.py +++ b/agents-api/agents_api/models/docs/embed_snippets.py @@ -37,7 +37,7 @@ def embed_snippets( *, developer_id: UUID, doc_id: UUID, - snippet_indices: list[int] | tuple[int], + snippet_indices: list[int] | tuple[int, ...], embeddings: list[list[float]], embedding_size: int = 1024, ) -> tuple[list[str], dict]: diff --git a/agents-api/agents_api/rec_sum/entities.py b/agents-api/agents_api/rec_sum/entities.py index 11346447c..9992063ff 100644 --- a/agents-api/agents_api/rec_sum/entities.py +++ b/agents-api/agents_api/rec_sum/entities.py @@ -57,7 +57,7 @@ def make_entities_prompt( @retry(stop=stop_after_attempt(2)) async def get_entities( chat_session, - model="gpt-4-turbo", + model="gpt-4o", stop=[" dict: - result = await acompletion( + result = await litellm.acompletion( model=model, messages=messages, **kwargs, diff --git a/agents-api/agents_api/rec_sum/summarize.py b/agents-api/agents_api/rec_sum/summarize.py index 97f39905b..f98f35094 100644 --- a/agents-api/agents_api/rec_sum/summarize.py +++ b/agents-api/agents_api/rec_sum/summarize.py @@ -44,7 +44,7 @@ def make_summarize_prompt(session, user="a user", assistant="gpt-4-turbo", **_): @retry(stop=stop_after_attempt(2)) async def summarize_messages( chat_session, - model="gpt-4-turbo", + model="gpt-4o", stop=[" None: - return await workflow.execute_activity( - co_density, - memory, - schedule_to_close_timeout=timedelta(seconds=600), - ) diff --git a/agents-api/agents_api/workflows/mem_mgmt.py b/agents-api/agents_api/workflows/mem_mgmt.py index 2db9f95da..31c973741 100644 --- a/agents-api/agents_api/workflows/mem_mgmt.py +++ b/agents-api/agents_api/workflows/mem_mgmt.py @@ -7,14 +7,17 @@ with workflow.unsafe.imports_passed_through(): from ..activities.mem_mgmt import mem_mgmt - from ..activities.types import ChatML + from ..autogen.openapi_model import InputChatMLMessage @workflow.defn class MemMgmtWorkflow: @workflow.run async def run( - self, dialog: list[ChatML], session_id: str, previous_memories: list[str] + self, + dialog: list[InputChatMLMessage], + session_id: str, + previous_memories: list[str], ) -> None: return await workflow.execute_activity( mem_mgmt, diff --git a/agents-api/tests/fixtures.py b/agents-api/tests/fixtures.py index 1b3b1000a..d8acec61c 100644 --- a/agents-api/tests/fixtures.py +++ b/agents-api/tests/fixtures.py @@ -6,6 +6,7 @@ from litellm.types.utils import Choices, ModelResponse from pycozo import Client as CozoClient from temporalio.client import WorkflowHandle +from temporalio.testing import ActivityEnvironment, WorkflowEnvironment from ward import fixture from agents_api.autogen.openapi_model import ( @@ -33,6 +34,7 @@ from agents_api.models.user.create_user import create_user from agents_api.models.user.delete_user import delete_user from agents_api.web import app +from agents_api.worker.worker import create_worker EMBEDDING_SIZE: int = 1024 @@ -49,6 +51,31 @@ def cozo_client(migrations_dir: str = "./migrations"): return client +@fixture(scope="test") +def activity_environment(): + return ActivityEnvironment() + + +@fixture(scope="global") +async def workflow_environment(): + wf_env = await WorkflowEnvironment.start_local() + return wf_env + + +@fixture(scope="global") +async def temporal_client(wf_env=workflow_environment): + return wf_env.client + + +@fixture(scope="global") +async def temporal_worker(temporal_client=temporal_client): + worker = await create_worker(client=temporal_client) + + async with worker as running_worker: + yield running_worker + await running_worker.shutdown() + + @fixture(scope="global") def test_developer_id(cozo_client=cozo_client): developer_id = uuid4() diff --git a/agents-api/tests/test_activities.py b/agents-api/tests/test_activities.py index 2723911cc..a5e35760b 100644 --- a/agents-api/tests/test_activities.py +++ b/agents-api/tests/test_activities.py @@ -1,13 +1,59 @@ -# import time -# import uuid +from ward import test -# from ward import test +from agents_api.activities.embed_docs import embed_docs + +from .fixtures import ( + cozo_client, + patch_embed_acompletion, + temporal_client, + temporal_worker, + test_developer_id, + test_doc, + workflow_environment, +) # from agents_api.activities.truncation import get_extra_entries # from agents_api.autogen.openapi_model import Role # from agents_api.common.protocol.entries import Entry +@test("activity: embed_docs") +async def _( + cozo_client=cozo_client, + developer_id=test_developer_id, + doc=test_doc, + mocks=patch_embed_acompletion, +): + (embed, _) = mocks + + title = "title" + content = ["content 1"] + include_title = True + + await embed_docs( + developer_id=developer_id, + doc_id=doc.id, + title=title, + content=content, + include_title=include_title, + cozo_client=cozo_client, + ) + + embed.assert_called_once() + + +@test("activity: check that workflow environment and worker are started correctly") +async def _( + workflow_environment=workflow_environment, + worker=temporal_worker, + client=temporal_client, +): + async with workflow_environment as wf_env: + assert wf_env is not None + assert worker is not None + assert worker.is_running + + # @test("get extra entries, do not strip system message") # def _(): # session_ids = [uuid.uuid4()] * 3