From 0b74840ae0c4ba0fb13de8c5cbb51e98f340008c Mon Sep 17 00:00:00 2001 From: Diwank Tomer Date: Tue, 13 Aug 2024 15:01:19 -0400 Subject: [PATCH 1/4] feat: Fix type checks Signed-off-by: Diwank Tomer --- agents-api/agents_api/activities/__init__.py | 1 - .../agents_api/activities/co_density.py | 7 ++++--- agents-api/agents_api/activities/demo.py | 10 ---------- .../agents_api/activities/dialog_insights.py | 7 ++++--- .../agents_api/activities/embed_docs.py | 4 ++-- agents-api/agents_api/activities/mem_mgmt.py | 7 ++++--- .../agents_api/activities/mem_rating.py | 6 +++--- .../activities/relationship_summary.py | 6 +++--- .../activities/salient_questions.py | 6 +++--- .../agents_api/activities/summarization.py | 6 +++--- .../activities/task_steps/__init__.py | 1 - agents-api/agents_api/model_registry.py | 4 ++-- .../models/agent/create_or_update_agent.py | 2 +- .../migrations/migrate_1704699172_init.py | 2 +- .../migrate_1704699595_developers.py | 4 ++-- .../migrations/migrate_1709631202_metadata.py | 4 ++-- ...igrate_1712405369_simplify_instructions.py | 4 ++-- agents-api/pytype.toml | 19 +++++++------------ agents-api/tests/test_entry_queries.py | 2 +- agents-api/tests/test_execution_queries.py | 2 +- agents-api/tests/test_session_queries.py | 2 +- 21 files changed, 46 insertions(+), 60 deletions(-) delete mode 100644 agents-api/agents_api/activities/demo.py diff --git a/agents-api/agents_api/activities/__init__.py b/agents-api/agents_api/activities/__init__.py index a804127fc..49722a7d5 100644 --- a/agents-api/agents_api/activities/__init__.py +++ b/agents-api/agents_api/activities/__init__.py @@ -2,7 +2,6 @@ The `activities` module within the agents-api package is designed to facilitate various activities related to agent interactions. This includes handling memory management, generating insights from dialogues, summarizing relationships, and more. Each file within the module offers specific functionality: - `co_density.py`: Conducts cognitive density analysis to generate concise, entity-dense summaries. -- `demo.py`: Provides a simple demonstration of defining an activity with Temporal. - `dialog_insights.py`: Extracts insights from dialogues, identifying details that participants might find interesting. - `mem_mgmt.py`: Manages memory by updating and incorporating new personality information from dialogues. - `mem_rating.py`: Rates memories based on their poignancy and importance. diff --git a/agents-api/agents_api/activities/co_density.py b/agents-api/agents_api/activities/co_density.py index 8d276b401..408cc398a 100644 --- a/agents-api/agents_api/activities/co_density.py +++ b/agents-api/agents_api/activities/co_density.py @@ -3,7 +3,8 @@ from temporalio import activity -from ..clients.model import julep_client +from agents_api.clients import litellm + from .types import MemoryDensityTaskArgs @@ -56,14 +57,14 @@ def make_prompt(args: MemoryDensityTaskArgs): async def run_prompt( memory: str, - model: str = "julep-ai/samantha-1-turbo", + model: str = "gpt-4o", max_tokens: int = 400, temperature: float = 0.2, parser: Callable[[str], str] = lambda x: x, ) -> str: prompt = make_prompt(MemoryDensityTaskArgs(memory=memory)) - response = await julep_client.chat.completions.create( + response = await litellm.acompletion( model=model, messages=[ { diff --git a/agents-api/agents_api/activities/demo.py b/agents-api/agents_api/activities/demo.py deleted file mode 100644 index a0edcde3c..000000000 --- a/agents-api/agents_api/activities/demo.py +++ /dev/null @@ -1,10 +0,0 @@ -#!/usr/bin/env python3 - -from temporalio import activity - - -@activity.defn -async def say_hello(name: str) -> str: - message = f"Hello, {name}!" - print(message) - return message diff --git a/agents-api/agents_api/activities/dialog_insights.py b/agents-api/agents_api/activities/dialog_insights.py index d6b10ae01..1d5adec39 100644 --- a/agents-api/agents_api/activities/dialog_insights.py +++ b/agents-api/agents_api/activities/dialog_insights.py @@ -3,7 +3,8 @@ from temporalio import activity -from ..clients.model import julep_client +from agents_api.clients import litellm + from .types import ChatML, DialogInsightsTaskArgs @@ -57,7 +58,7 @@ async def run_prompt( dialog: list[ChatML], person1: str, person2: str, - model: str = "julep-ai/samantha-1-turbo", + model: str = "gpt-4o", max_tokens: int = 400, temperature: float = 0.4, parser: Callable[[str], str] = lambda x: x, @@ -66,7 +67,7 @@ async def run_prompt( DialogInsightsTaskArgs(dialog=dialog, person1=person1, person2=person2) ) - response = await julep_client.chat.completions.create( + response = await litellm.acompletion( model=model, messages=[ { diff --git a/agents-api/agents_api/activities/embed_docs.py b/agents-api/agents_api/activities/embed_docs.py index b486c3af1..b8e01e65e 100644 --- a/agents-api/agents_api/activities/embed_docs.py +++ b/agents-api/agents_api/activities/embed_docs.py @@ -1,7 +1,7 @@ from pydantic import UUID4 from temporalio import activity -from agents_api.clients.embed import embed +from agents_api.clients import embed as embedder from agents_api.models.docs.embed_snippets import embed_snippets as embed_snippets_query snippet_embed_instruction = "Encode this passage for retrieval: " @@ -10,7 +10,7 @@ @activity.defn async def embed_docs(doc_id: UUID4, title: str, content: list[str]) -> None: indices, snippets = list(zip(*enumerate(content))) - embeddings = await embed( + embeddings = await embedder.embed( [ { "instruction": snippet_embed_instruction, diff --git a/agents-api/agents_api/activities/mem_mgmt.py b/agents-api/agents_api/activities/mem_mgmt.py index 4f661ca46..f368a0b0b 100644 --- a/agents-api/agents_api/activities/mem_mgmt.py +++ b/agents-api/agents_api/activities/mem_mgmt.py @@ -4,7 +4,8 @@ from temporalio import activity -from ..clients.model import julep_client +from agents_api.clients import litellm + from .types import ChatML, MemoryManagementTaskArgs example_previous_memory = """ @@ -120,7 +121,7 @@ async def run_prompt( dialog: list[ChatML], session_id: UUID, previous_memories: list[str] = [], - model: str = "julep-ai/samantha-1-turbo", + model: str = "gpt-4o", max_tokens: int = 400, temperature: float = 0.4, parser: Callable[[str], str] = lambda x: x, @@ -134,7 +135,7 @@ async def run_prompt( ) ) - response = await julep_client.chat.completions.create( + response = await litellm.acompletion( model=model, messages=[ { diff --git a/agents-api/agents_api/activities/mem_rating.py b/agents-api/agents_api/activities/mem_rating.py index bc35ac82d..222148f4c 100644 --- a/agents-api/agents_api/activities/mem_rating.py +++ b/agents-api/agents_api/activities/mem_rating.py @@ -3,7 +3,7 @@ from temporalio import activity -from ..clients.model import julep_client +from ..clients import litellm from .types import MemoryRatingTaskArgs @@ -40,14 +40,14 @@ def make_prompt(args: MemoryRatingTaskArgs): async def run_prompt( memory: str, - model: str = "julep-ai/samantha-1-turbo", + model: str = "gpt-4o", max_tokens: int = 400, temperature: float = 0.1, parser: Callable[[str], str] = lambda x: x, ) -> str: prompt = make_prompt(MemoryRatingTaskArgs(memory=memory)) - response = await julep_client.chat.completions.create( + response = await litellm.acompletion( model=model, messages=[ { diff --git a/agents-api/agents_api/activities/relationship_summary.py b/agents-api/agents_api/activities/relationship_summary.py index 5346040d3..997eaf40a 100644 --- a/agents-api/agents_api/activities/relationship_summary.py +++ b/agents-api/agents_api/activities/relationship_summary.py @@ -3,7 +3,7 @@ from temporalio import activity -from ..clients.model import julep_client +from ..clients import litellm from .types import RelationshipSummaryTaskArgs @@ -38,7 +38,7 @@ async def run_prompt( statements: list[str], person1: str, person2: str, - model: str = "julep-ai/samantha-1-turbo", + model: str = "gpt-4o", max_tokens: int = 400, temperature: float = 0.6, parser: Callable[[str], str] = lambda x: x, @@ -49,7 +49,7 @@ async def run_prompt( ) ) - response = await julep_client.chat.completions.create( + response = await litellm.acompletion( model=model, messages=[ { diff --git a/agents-api/agents_api/activities/salient_questions.py b/agents-api/agents_api/activities/salient_questions.py index 6a34409d6..0194e8c72 100644 --- a/agents-api/agents_api/activities/salient_questions.py +++ b/agents-api/agents_api/activities/salient_questions.py @@ -3,7 +3,7 @@ from temporalio import activity -from ..clients.model import julep_client +from ..clients import litellm from .types import SalientQuestionsTaskArgs @@ -33,14 +33,14 @@ def make_prompt(args: SalientQuestionsTaskArgs): async def run_prompt( statements: list[str], num: int = 3, - model: str = "julep-ai/samantha-1-turbo", + model: str = "gpt-4o", max_tokens: int = 400, temperature: float = 0.6, parser: Callable[[str], str] = lambda x: x, ) -> str: prompt = make_prompt(SalientQuestionsTaskArgs(statements=statements, num=num)) - response = await julep_client.chat.completions.create( + response = await litellm.acompletion( model=model, messages=[ { diff --git a/agents-api/agents_api/activities/summarization.py b/agents-api/agents_api/activities/summarization.py index 4d2b37f8c..dc365380d 100644 --- a/agents-api/agents_api/activities/summarization.py +++ b/agents-api/agents_api/activities/summarization.py @@ -18,7 +18,7 @@ from agents_api.rec_sum.summarize import summarize_messages from agents_api.rec_sum.trim import trim_messages -from ..clients.litellm import acompletion +from ..clients import litellm from ..env import summarization_model_name @@ -142,14 +142,14 @@ def make_prompt( async def run_prompt( dialog: list[Entry], previous_memories: list[str], - model: str = "julep-ai/samantha-1-turbo", + model: str = "gpt-4o", max_tokens: int = 400, temperature: float = 0.1, parser: Callable[[str], str] = lambda x: x, **kwargs, ) -> str: prompt = make_prompt(dialog, previous_memories, **kwargs) - response = await acompletion( + response = await litellm.acompletion( model=model, messages=[ { diff --git a/agents-api/agents_api/activities/task_steps/__init__.py b/agents-api/agents_api/activities/task_steps/__init__.py index 58a8ddcfa..a9818d515 100644 --- a/agents-api/agents_api/activities/task_steps/__init__.py +++ b/agents-api/agents_api/activities/task_steps/__init__.py @@ -6,7 +6,6 @@ from temporalio import activity from ...autogen.openapi_model import ( - CreateTransitionRequest, EvaluateStep, IfElseWorkflowStep, InputChatMLMessage, diff --git a/agents-api/agents_api/model_registry.py b/agents-api/agents_api/model_registry.py index 8b7bce4f0..99ae66ea3 100644 --- a/agents-api/agents_api/model_registry.py +++ b/agents-api/agents_api/model_registry.py @@ -87,8 +87,8 @@ OPENAI_MODELS = {**GPT4_MODELS, **TURBO_MODELS, **GPT3_5_MODELS, **GPT3_MODELS} LOCAL_MODELS = { - "julep-ai/samantha-1-turbo": 32768, - "julep-ai/samantha-1-turbo-awq": 32768, + "gpt-4o": 32768, + "gpt-4o-awq": 32768, "TinyLlama/TinyLlama_v1.1": 2048, "casperhansen/llama-3-8b-instruct-awq": 8192, "julep-ai/Hermes-2-Theta-Llama-3-8B": 8192, diff --git a/agents-api/agents_api/models/agent/create_or_update_agent.py b/agents-api/agents_api/models/agent/create_or_update_agent.py index 8cba24d2b..fb80f6dd7 100644 --- a/agents-api/agents_api/models/agent/create_or_update_agent.py +++ b/agents-api/agents_api/models/agent/create_or_update_agent.py @@ -48,7 +48,7 @@ def create_or_update_agent( - name (str): The name of the agent. - about (str): A description of the agent. - instructions (list[str], optional): A list of instructions for using the agent. Defaults to an empty list. - - model (str, optional): The model identifier for the agent. Defaults to "julep-ai/samantha-1-turbo". + - model (str, optional): The model identifier for the agent. Defaults to "gpt-4o". - metadata (dict, optional): A dictionary of metadata for the agent. Defaults to an empty dict. - default_settings (dict, optional): A dictionary of default settings for the agent. Defaults to an empty dict. - client (CozoClient, optional): The CozoDB client instance to use for the query. Defaults to a preconfigured client instance. diff --git a/agents-api/migrations/migrate_1704699172_init.py b/agents-api/migrations/migrate_1704699172_init.py index 1f131c8a1..3a427ad48 100644 --- a/agents-api/migrations/migrate_1704699172_init.py +++ b/agents-api/migrations/migrate_1704699172_init.py @@ -19,7 +19,7 @@ def up(client): => name: String, about: String, - model: String default 'julep-ai/samantha-1-turbo', + model: String default 'gpt-4o', created_at: Float default now(), updated_at: Float default now(), } diff --git a/agents-api/migrations/migrate_1704699595_developers.py b/agents-api/migrations/migrate_1704699595_developers.py index e2c183520..d22edb393 100644 --- a/agents-api/migrations/migrate_1704699595_developers.py +++ b/agents-api/migrations/migrate_1704699595_developers.py @@ -29,7 +29,7 @@ def up(client): => name: String, about: String, - model: String default 'julep-ai/samantha-1-turbo', + model: String default 'gpt-4o', created_at: Float default now(), updated_at: Float default now(), } @@ -99,7 +99,7 @@ def down(client): => name: String, about: String, - model: String default 'julep-ai/samantha-1-turbo', + model: String default 'gpt-4o', created_at: Float default now(), updated_at: Float default now(), } diff --git a/agents-api/migrations/migrate_1709631202_metadata.py b/agents-api/migrations/migrate_1709631202_metadata.py index b5c220cb3..36c1c8ec4 100644 --- a/agents-api/migrations/migrate_1709631202_metadata.py +++ b/agents-api/migrations/migrate_1709631202_metadata.py @@ -22,7 +22,7 @@ => name: String, about: String, - model: String default 'julep-ai/samantha-1-turbo', + model: String default 'gpt-4o', created_at: Float default now(), updated_at: Float default now(), metadata: Json default {}, @@ -45,7 +45,7 @@ => name: String, about: String, - model: String default 'julep-ai/samantha-1-turbo', + model: String default 'gpt-4o', created_at: Float default now(), updated_at: Float default now(), } diff --git a/agents-api/migrations/migrate_1712405369_simplify_instructions.py b/agents-api/migrations/migrate_1712405369_simplify_instructions.py index ee3a87da1..b3f8a289a 100644 --- a/agents-api/migrations/migrate_1712405369_simplify_instructions.py +++ b/agents-api/migrations/migrate_1712405369_simplify_instructions.py @@ -24,7 +24,7 @@ name: String, about: String, instructions: [String] default [], - model: String default 'julep-ai/samantha-1-turbo', + model: String default 'gpt-4o', created_at: Float default now(), updated_at: Float default now(), metadata: Json default {}, @@ -47,7 +47,7 @@ => name: String, about: String, - model: String default 'julep-ai/samantha-1-turbo', + model: String default 'gpt-4o', created_at: Float default now(), updated_at: Float default now(), metadata: Json default {}, diff --git a/agents-api/pytype.toml b/agents-api/pytype.toml index 1b95217a6..edd07e7d4 100644 --- a/agents-api/pytype.toml +++ b/agents-api/pytype.toml @@ -2,15 +2,10 @@ [tool.pytype] -# Space-separated list of files or directories to exclude. -exclude = [ - '**/*_test.py', - '**/test_*.py', -] - # Space-separated list of files or directories to process. inputs = [ 'agents_api', + 'tests', ] # Keep going past errors to analyze as many files as possible. @@ -30,7 +25,7 @@ platform = 'linux' pythonpath = '.' # Python version (major.minor) of the target code. -python_version = '3.10' +python_version = '3.11' # Bind 'self' in methods with non-transparent decorators. This flag is temporary # and will be removed once this behavior is enabled by default. @@ -38,7 +33,7 @@ bind_decorated_methods = true # Don't allow None to match bool. This flag is temporary and will be removed # once this behavior is enabled by default. -none_is_not_bool = false +none_is_not_bool = true # Enable parameter count checks for overriding methods with renamed arguments. # This flag is temporary and will be removed once this behavior is enabled by @@ -64,20 +59,20 @@ require_override_decorator = false precise_return = true # Experimental: Solve unknown types to label with structural types. -protocols = false +protocols = true # Experimental: Only load submodules that are explicitly imported. strict_import = true # Experimental: Enable exhaustive checking of function parameter types. -strict_parameter_checks = false +strict_parameter_checks = true # Experimental: Emit errors for comparisons between incompatible primitive # types. -strict_primitive_comparisons = false +strict_primitive_comparisons = true # Experimental: Check that variables are defined in all possible code paths. -strict_undefined_checks = false +strict_undefined_checks = true # Experimental: FOR TESTING ONLY. Use pytype/rewrite/. use_rewrite = false diff --git a/agents-api/tests/test_entry_queries.py b/agents-api/tests/test_entry_queries.py index be8e6362e..6161ad94c 100644 --- a/agents-api/tests/test_entry_queries.py +++ b/agents-api/tests/test_entry_queries.py @@ -14,7 +14,7 @@ from agents_api.models.entry.list_entries import list_entries from tests.fixtures import cozo_client, test_developer_id, test_session -MODEL = "julep-ai/samantha-1-turbo" +MODEL = "gpt-4o" @test("model: create entry") diff --git a/agents-api/tests/test_execution_queries.py b/agents-api/tests/test_execution_queries.py index b2ef1a04e..c4dec7c5f 100644 --- a/agents-api/tests/test_execution_queries.py +++ b/agents-api/tests/test_execution_queries.py @@ -15,7 +15,7 @@ from agents_api.models.execution.list_executions import list_executions from tests.fixtures import cozo_client, test_developer_id, test_execution, test_task -MODEL = "julep-ai/samantha-1-turbo" +MODEL = "gpt-4o" @test("model: create execution") diff --git a/agents-api/tests/test_session_queries.py b/agents-api/tests/test_session_queries.py index 763e62bed..94b5bbbe4 100644 --- a/agents-api/tests/test_session_queries.py +++ b/agents-api/tests/test_session_queries.py @@ -21,7 +21,7 @@ test_user, ) -MODEL = "julep-ai/samantha-1-turbo" +MODEL = "gpt-4o" @test("model: create session") From 5007d7be0627ea3e85447707311e8aead385db12 Mon Sep 17 00:00:00 2001 From: Diwank Tomer Date: Tue, 13 Aug 2024 16:35:44 -0400 Subject: [PATCH 2/4] fix(agents-api): Fix chat endpoint behavior Signed-off-by: Diwank Tomer --- agents-api/agents_api/autogen/Chat.py | 20 ++++-- .../models/docs/search_docs_by_embedding.py | 66 ++++++++++++------- .../models/docs/search_docs_by_text.py | 63 ++++++++++-------- .../models/docs/search_docs_hybrid.py | 9 +-- .../agents_api/routers/docs/search_docs.py | 6 +- .../agents_api/routers/sessions/chat.py | 31 ++++++--- agents-api/poetry.lock | 20 +++--- agents-api/tests/fixtures.py | 2 +- agents-api/tests/test_chat_routes.py | 41 ++++++++++++ agents-api/tests/test_docs_queries.py | 28 +++++++- sdks/python/julep/api/client.py | 28 ++++---- sdks/python/julep/api/reference.md | 6 +- .../julep/api/types/chat_competion_usage.py | 6 +- sdks/python/poetry.lock | 14 ++-- sdks/ts/src/api/models/Chat_ChatInput.ts | 8 +-- sdks/ts/src/api/models/Chat_CompetionUsage.ts | 6 +- sdks/ts/src/api/schemas/$Chat_ChatInput.ts | 9 ++- .../src/api/schemas/$Chat_CompetionUsage.ts | 3 - typespec/chat/models.tsp | 15 ++--- 19 files changed, 238 insertions(+), 143 deletions(-) diff --git a/agents-api/agents_api/autogen/Chat.py b/agents-api/agents_api/autogen/Chat.py index 94832c4cb..4d157f7c5 100644 --- a/agents-api/agents_api/autogen/Chat.py +++ b/agents-api/agents_api/autogen/Chat.py @@ -114,15 +114,21 @@ class CompetionUsage(BaseModel): model_config = ConfigDict( populate_by_name=True, ) - completion_tokens: Annotated[int, Field(json_schema_extra={"readOnly": True})] + completion_tokens: Annotated[ + int | None, Field(None, json_schema_extra={"readOnly": True}) + ] """ Number of tokens in the generated completion """ - prompt_tokens: Annotated[int, Field(json_schema_extra={"readOnly": True})] + prompt_tokens: Annotated[ + int | None, Field(None, json_schema_extra={"readOnly": True}) + ] """ Number of tokens in the prompt """ - total_tokens: Annotated[int, Field(json_schema_extra={"readOnly": True})] + total_tokens: Annotated[ + int | None, Field(None, json_schema_extra={"readOnly": True}) + ] """ Total number of tokens used in the request (prompt + completion) """ @@ -213,13 +219,13 @@ class ChatInput(ChatInputData): model_config = ConfigDict( populate_by_name=True, ) - recall: Annotated[bool, Field(False, json_schema_extra={"readOnly": True})] + remember: Annotated[bool, Field(False, json_schema_extra={"readOnly": True})] """ - Whether previous memories should be recalled or not (will be enabled in a future release) + DISABLED: Whether this interaction should form new memories or not (will be enabled in a future release) """ - remember: Annotated[bool, Field(False, json_schema_extra={"readOnly": True})] + recall: bool = True """ - Whether this interaction should form new memories or not (will be enabled in a future release) + Whether previous memories and docs should be recalled or not """ save: bool = True """ diff --git a/agents-api/agents_api/models/docs/search_docs_by_embedding.py b/agents-api/agents_api/models/docs/search_docs_by_embedding.py index 0acbf8f6a..3f7114a23 100644 --- a/agents-api/agents_api/models/docs/search_docs_by_embedding.py +++ b/agents-api/agents_api/models/docs/search_docs_by_embedding.py @@ -41,8 +41,7 @@ def search_docs_by_embedding( *, developer_id: UUID, - owner_type: Literal["user", "agent"], - owner_id: UUID, + owners: list[tuple[Literal["user", "agent"], UUID]], query_embedding: list[float], k: int = 3, confidence: float = 0.7, @@ -65,25 +64,29 @@ def search_docs_by_embedding( assert len(query_embedding) == embedding_size assert sum(query_embedding) - owner_id = str(owner_id) + owners: list[list[str]] = [ + [owner_type, str(owner_id)] for owner_type, owner_id in owners + ] # Calculate the search radius based on confidence level radius: float = 1.0 - confidence # Construct the datalog query for searching document snippets interim_query = f""" + owners[owner_type, owner_id] <- $owners input[ + owner_type, owner_id, query_embedding, - ] <- [[ - to_uuid($owner_id), - vec($query_embedding), - ]] + ] := + owners[owner_type, owner_id_str], + owner_id = to_uuid(owner_id_str), + query_embedding = vec($query_embedding) candidate[doc_id] := - input[owner_id, _], + input[owner_type, owner_id, _], *docs {{ - owner_type: $owner_type, + owner_type, owner_id, doc_id }} @@ -125,7 +128,7 @@ def search_docs_by_embedding( index, distance, ] := - input[owner_id, query], + input[_, __, query], candidate[doc_id], ~snippets:embedding_space {{ doc_id, @@ -151,6 +154,8 @@ def search_docs_by_embedding( snippet_data = [index, content] ?[ + owner_type, + owner_id, doc_id, snippet_data, distance, @@ -175,6 +180,8 @@ def search_docs_by_embedding( :limit {k} :create _interim {{ + owner_type, + owner_id, doc_id, snippet_data, distance, @@ -186,18 +193,23 @@ def search_docs_by_embedding( collect_query = """ m[ doc_id, + owner_type, + owner_id, collect(snippet), distance, title, - ] := *_interim { - doc_id, - snippet_data, - distance, - title, - }, snippet = { - "index": snippet_data->0, - "content": snippet_data->1, - } + ] := + *_interim { + owner_type, + owner_id, + doc_id, + snippet_data, + distance, + title, + }, snippet = { + "index": snippet_data->0, + "content": snippet_data->1, + } ?[ id, @@ -208,17 +220,22 @@ def search_docs_by_embedding( title, ] := m[ id, + owner_type, + owner_id, snippets, distance, title, - ], owner_type = $owner_type, owner_id = $owner_id + ] """ queries = [ verify_developer_id_query(developer_id), - verify_developer_owns_resource_query( - developer_id, f"{owner_type}s", **{f"{owner_type}_id": owner_id} - ), + *[ + verify_developer_owns_resource_query( + developer_id, f"{owner_type}s", **{f"{owner_type}_id": owner_id} + ) + for owner_type, owner_id in owners + ], interim_query, collect_query, ] @@ -226,8 +243,7 @@ def search_docs_by_embedding( return ( queries, { - "owner_type": owner_type, - "owner_id": owner_id, + "owners": owners, "query_embedding": query_embedding, }, ) diff --git a/agents-api/agents_api/models/docs/search_docs_by_text.py b/agents-api/agents_api/models/docs/search_docs_by_text.py index 8befeb07d..a5e379f24 100644 --- a/agents-api/agents_api/models/docs/search_docs_by_text.py +++ b/agents-api/agents_api/models/docs/search_docs_by_text.py @@ -41,8 +41,7 @@ def search_docs_by_text( *, developer_id: UUID, - owner_type: Literal["user", "agent"], - owner_id: UUID, + owners: list[tuple[Literal["user", "agent"], UUID]], query: str, k: int = 3, ) -> tuple[list[str], dict]: @@ -50,28 +49,29 @@ def search_docs_by_text( Searches for document snippets in CozoDB by embedding query. Parameters: - - owner_type (Literal["user", "agent"]): The type of the owner of the documents. - - owner_id (UUID): The unique identifier of the owner. + - owners (list[tuple[Literal["user", "agent"], UUID]]): The type of the owner of the documents. - query (str): The query string. - k (int, optional): The number of nearest neighbors to retrieve. Defaults to 3. """ - owner_id = str(owner_id) + owners: list[list[str]] = [ + [owner_type, str(owner_id)] for owner_type, owner_id in owners + ] # Construct the datalog query for searching document snippets search_query = f""" + owners[owner_type, owner_id] <- $owners input[ + owner_type, owner_id, - query, - ] <- [[ - to_uuid($owner_id), - $query, - ]] + ] := + owners[owner_type, owner_id_str], + owner_id = to_uuid(owner_id_str) candidate[doc_id] := - input[owner_id, _], + input[owner_type, owner_id], *docs {{ - owner_type: $owner_type, + owner_type, owner_id, doc_id }} @@ -81,17 +81,16 @@ def search_docs_by_text( snippet_data, distance, ] := - input[owner_id, query], candidate[doc_id], ~snippets:lsh {{ doc_id, index, content | - query: query, + query: $query, k: {k}, }}, - distance = 10000000, # Very large distance to depict no distance + distance = 10000000, # Very large distance to depict no valid distance snippet_data = [index, content] search_result[ @@ -99,14 +98,13 @@ def search_docs_by_text( snippet_data, distance, ] := - input[owner_id, query], candidate[doc_id], ~snippets:fts {{ doc_id, index, content | - query: query, + query: $query, k: {k}, score_kind: 'tf_idf', bind_score: score, @@ -119,10 +117,12 @@ def search_docs_by_text( collect(snippet), distance, title, + owner_type, + owner_id, ] := candidate[doc_id], *docs {{ - owner_type: $owner_type, + owner_type, owner_id, doc_id, title, @@ -145,12 +145,16 @@ def search_docs_by_text( snippets, distance, title, - ] := m[ - id, - snippets, - distance, - title, - ], owner_type = $owner_type, owner_id = $owner_id + ] := + input[owner_type, owner_id], + m[ + id, + snippets, + distance, + title, + owner_type, + owner_id, + ] # Sort the results by distance to find the closest matches :sort distance @@ -159,13 +163,16 @@ def search_docs_by_text( queries = [ verify_developer_id_query(developer_id), - verify_developer_owns_resource_query( - developer_id, f"{owner_type}s", **{f"{owner_type}_id": owner_id} - ), + *[ + verify_developer_owns_resource_query( + developer_id, f"{owner_type}s", **{f"{owner_type}_id": owner_id} + ) + for owner_type, owner_id in owners + ], search_query, ] return ( queries, - {"owner_type": owner_type, "owner_id": owner_id, "query": query}, + {"owners": owners, "query": query}, ) diff --git a/agents-api/agents_api/models/docs/search_docs_hybrid.py b/agents-api/agents_api/models/docs/search_docs_hybrid.py index 0a9cd2815..03fb44037 100644 --- a/agents-api/agents_api/models/docs/search_docs_hybrid.py +++ b/agents-api/agents_api/models/docs/search_docs_hybrid.py @@ -95,8 +95,7 @@ def dbsf_fuse( def search_docs_hybrid( *, developer_id: UUID, - owner_type: Literal["user", "agent"], - owner_id: UUID, + owners: list[tuple[Literal["user", "agent"], UUID]], query: str, query_embedding: list[float], k: int = 3, @@ -107,8 +106,7 @@ def search_docs_hybrid( # TODO: We should probably parallelize these queries text_results = search_docs_by_text( developer_id=developer_id, - owner_type=owner_type, - owner_id=owner_id, + owners=owners, query=query, k=2 * k, **text_search_options, @@ -116,8 +114,7 @@ def search_docs_hybrid( embedding_results = search_docs_by_embedding( developer_id=developer_id, - owner_type=owner_type, - owner_id=owner_id, + owners=owners, query_embedding=query_embedding, k=2 * k, **embed_search_options, diff --git a/agents-api/agents_api/routers/docs/search_docs.py b/agents-api/agents_api/routers/docs/search_docs.py index 0e5430a7a..ad19a3178 100644 --- a/agents-api/agents_api/routers/docs/search_docs.py +++ b/agents-api/agents_api/routers/docs/search_docs.py @@ -70,8 +70,7 @@ async def search_user_docs( start = time.time() docs = search_fn( developer_id=x_developer_id, - owner_type="user", - owner_id=user_id, + owners=[("user", user_id)], **params, ) @@ -98,8 +97,7 @@ async def search_agent_docs( start = time.time() docs = search_fn( developer_id=x_developer_id, - owner_type="agent", - owner_id=agent_id, + owners=[("agent", agent_id)], **params, ) diff --git a/agents-api/agents_api/routers/sessions/chat.py b/agents-api/agents_api/routers/sessions/chat.py index f0023cb93..afe7e3e2d 100644 --- a/agents-api/agents_api/routers/sessions/chat.py +++ b/agents-api/agents_api/routers/sessions/chat.py @@ -32,6 +32,7 @@ async def get_messages( session_id: UUID, new_raw_messages: list[dict], chat_context: ChatContext, + recall: bool, ): assert len(new_raw_messages) > 0 @@ -50,6 +51,9 @@ async def get_messages( if entry.id not in {r.head for r in relations} ] + if not recall: + return past_messages, [] + # Search matching docs [query_embedding, *_] = await embed.embed( inputs=[ @@ -60,10 +64,14 @@ async def get_messages( ) query_text = new_raw_messages[-1]["content"] + # List all the applicable owners to search docs from + active_agent_id = chat_context.get_active_agent().id + user_ids = [user.id for user in chat_context.users] + owners = [("user", user_id) for user_id in user_ids] + [("agent", active_agent_id)] + doc_references: list[DocReference] = search_docs_hybrid( developer_id=developer.id, - owner_type="agent", - owner_id=chat_context.get_active_agent().id, + owners=owners, query=query_text, query_embedding=query_embedding, ) @@ -79,7 +87,7 @@ async def get_messages( async def chat( developer: Annotated[Developer, Depends(get_developer_data)], session_id: UUID, - data: ChatInput, + input: ChatInput, background_tasks: BackgroundTasks, ) -> ChatResponse: # First get the chat context @@ -89,10 +97,10 @@ async def chat( ) # Merge the settings and prepare environment - chat_context.merge_settings(data) + chat_context.merge_settings(input) settings: dict = chat_context.settings.model_dump() env: dict = chat_context.get_chat_environment() - new_raw_messages = [msg.model_dump() for msg in data.messages] + new_raw_messages = [msg.model_dump() for msg in input.messages] # Render the messages past_messages, doc_references = await get_messages( @@ -100,22 +108,27 @@ async def chat( session_id=session_id, new_raw_messages=new_raw_messages, chat_context=chat_context, + recall=input.recall, ) env["docs"] = doc_references new_messages = await render_template(new_raw_messages, variables=env) messages = past_messages + new_messages + # Get the tools + tools = settings.get("tools") or chat_context.get_active_tools() + # Get the response from the model model_response = await litellm.acompletion( messages=messages, + tools=tools, + user=str(developer.id), # For tracking usage + tags=developer.tags, # For filtering models in litellm **settings, - user=str(developer.id), - tags=developer.tags, ) # Save the input and the response to the session history - if data.save: + if input.save: new_entries = [ CreateEntryRequest(**msg, source="api_request") for msg in new_messages ] @@ -128,7 +141,7 @@ async def chat( ) # Return the response - chat_response_class = ChunkChatResponse if data.stream else MessageChatResponse + chat_response_class = ChunkChatResponse if input.stream else MessageChatResponse chat_response: ChatResponse = chat_response_class( id=uuid4(), created_at=utcnow(), diff --git a/agents-api/poetry.lock b/agents-api/poetry.lock index ba83ef4a2..b8c0a42b7 100644 --- a/agents-api/poetry.lock +++ b/agents-api/poetry.lock @@ -2255,13 +2255,13 @@ dev = ["Sphinx (>=5.1.1)", "black (==23.12.1)", "build (>=0.10.0)", "coverage (> [[package]] name = "litellm" -version = "1.43.7" +version = "1.43.9" description = "Library to easily interface with LLM API providers" optional = false python-versions = "!=2.7.*,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,!=3.7.*,>=3.8" files = [ - {file = "litellm-1.43.7-py3-none-any.whl", hash = "sha256:88d9d8dcb4579839106941f1ce59143ab926af986a2206cce4bcda1ae153a78c"}, - {file = "litellm-1.43.7.tar.gz", hash = "sha256:b6ef8db0c7555d590957c37b228584efc5e9154b925ab0fffb112be26f1ab5ab"}, + {file = "litellm-1.43.9-py3-none-any.whl", hash = "sha256:54253281139e61f130b7e1a613a11f7a5ee896c2ee8536b0ca9a5ffbfce4c5f0"}, + {file = "litellm-1.43.9.tar.gz", hash = "sha256:c397a14c9b851f007f09c99e5a28606f7f122fdb4ae954931220f60e9edc6918"}, ] [package.dependencies] @@ -4219,18 +4219,18 @@ tornado = ["tornado (>=5)"] [[package]] name = "setuptools" -version = "72.1.0" +version = "72.2.0" description = "Easily download, build, install, upgrade, and uninstall Python packages" optional = false python-versions = ">=3.8" files = [ - {file = "setuptools-72.1.0-py3-none-any.whl", hash = "sha256:5a03e1860cf56bb6ef48ce186b0e557fdba433237481a9a625176c2831be15d1"}, - {file = "setuptools-72.1.0.tar.gz", hash = "sha256:8d243eff56d095e5817f796ede6ae32941278f542e0f941867cc05ae52b162ec"}, + {file = "setuptools-72.2.0-py3-none-any.whl", hash = "sha256:f11dd94b7bae3a156a95ec151f24e4637fb4fa19c878e4d191bfb8b2d82728c4"}, + {file = "setuptools-72.2.0.tar.gz", hash = "sha256:80aacbf633704e9c8bfa1d99fa5dd4dc59573efcf9e4042c13d3bcef91ac2ef9"}, ] [package.extras] core = ["importlib-metadata (>=6)", "importlib-resources (>=5.10.2)", "jaraco.text (>=3.7)", "more-itertools (>=8.8)", "ordered-set (>=3.1.1)", "packaging (>=24)", "platformdirs (>=2.6.2)", "tomli (>=2.0.1)", "wheel (>=0.43.0)"] -doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "pyproject-hooks (!=1.1)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (>=1,<2)", "sphinx-reredirects", "sphinxcontrib-towncrier"] +doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "pyproject-hooks (!=1.1)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (>=1,<2)", "sphinx-reredirects", "sphinxcontrib-towncrier", "towncrier (<24.7)"] test = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "importlib-metadata", "ini2toml[lite] (>=0.14)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "jaraco.test", "mypy (==1.11.*)", "packaging (>=23.2)", "pip (>=19.1)", "pyproject-hooks (!=1.1)", "pytest (>=6,!=8.1.*)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-home (>=0.5)", "pytest-mypy", "pytest-perf", "pytest-ruff (<0.4)", "pytest-ruff (>=0.2.1)", "pytest-ruff (>=0.3.2)", "pytest-subprocess", "pytest-timeout", "pytest-xdist (>=3)", "tomli", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel"] [[package]] @@ -4290,13 +4290,13 @@ files = [ [[package]] name = "soupsieve" -version = "2.5" +version = "2.6" description = "A modern CSS selector implementation for Beautiful Soup." optional = false python-versions = ">=3.8" files = [ - {file = "soupsieve-2.5-py3-none-any.whl", hash = "sha256:eaa337ff55a1579b6549dc679565eac1e3d000563bcb1c8ab0d0fefbc0c2cdc7"}, - {file = "soupsieve-2.5.tar.gz", hash = "sha256:5663d5a7b3bfaeee0bc4372e7fc48f9cff4940b3eec54a6451cc5299f1097690"}, + {file = "soupsieve-2.6-py3-none-any.whl", hash = "sha256:e72c4ff06e4fb6e4b5a9f0f55fe6e81514581fca1515028625d0f299c602ccc9"}, + {file = "soupsieve-2.6.tar.gz", hash = "sha256:e2e68417777af359ec65daac1057404a3c8a5455bb8abc36f1a9866ab1a51abb"}, ] [[package]] diff --git a/agents-api/tests/fixtures.py b/agents-api/tests/fixtures.py index fafb351f0..1b3b1000a 100644 --- a/agents-api/tests/fixtures.py +++ b/agents-api/tests/fixtures.py @@ -78,7 +78,7 @@ def test_developer(cozo_client=cozo_client, developer_id=test_developer_id): ) -@fixture(scope="global") +@fixture(scope="test") def patch_embed_acompletion(): mock_model_response = ModelResponse( id="fake_id", diff --git a/agents-api/tests/test_chat_routes.py b/agents-api/tests/test_chat_routes.py index 55d94b2a0..ccf91c89e 100644 --- a/agents-api/tests/test_chat_routes.py +++ b/agents-api/tests/test_chat_routes.py @@ -28,6 +28,46 @@ async def _( assert (await embed.embed())[0][0] == 1.0 +@test("chat: check that non-recall get_messages works") +async def _( + developer=test_developer, + client=cozo_client, + developer_id=test_developer_id, + agent=test_agent, + session=test_session, + tool=test_tool, + user=test_user, + mocks=patch_embed_acompletion, +): + (embed, _) = mocks + + chat_context = prepare_chat_context( + developer_id=developer_id, + session_id=session.id, + client=client, + ) + + session_id = session.id + + new_raw_messages = [{"role": "user", "content": "hello"}] + + past_messages, doc_references = await get_messages( + developer=developer, + session_id=session_id, + new_raw_messages=new_raw_messages, + chat_context=chat_context, + recall=False, + ) + + assert isinstance(past_messages, list) + assert len(past_messages) >= 0 + assert isinstance(doc_references, list) + assert len(doc_references) == 0 + + # Check that embed was not called + embed.assert_not_called() + + @test("chat: check that get_messages works") async def _( developer=test_developer, @@ -56,6 +96,7 @@ async def _( session_id=session_id, new_raw_messages=new_raw_messages, chat_context=chat_context, + recall=True, ) assert isinstance(past_messages, list) diff --git a/agents-api/tests/test_docs_queries.py b/agents-api/tests/test_docs_queries.py index 4743ea45d..fcf7f9bd6 100644 --- a/agents-api/tests/test_docs_queries.py +++ b/agents-api/tests/test_docs_queries.py @@ -9,6 +9,7 @@ from agents_api.models.docs.get_doc import get_doc from agents_api.models.docs.list_docs import list_docs from agents_api.models.docs.search_docs_by_embedding import search_docs_by_embedding +from agents_api.models.docs.search_docs_by_text import search_docs_by_text from tests.fixtures import ( EMBEDDING_SIZE, cozo_client, @@ -82,7 +83,29 @@ def _( assert len(result) >= 1 -@test("model: search docs") +@test("model: search docs by text") +def _(client=cozo_client, agent=test_agent, developer_id=test_developer_id): + create_doc( + developer_id=developer_id, + owner_type="agent", + owner_id=agent.id, + data=CreateDocRequest( + title="Hello", content=["The world is a funny little thing"] + ), + client=client, + ) + + result = search_docs_by_text( + developer_id=developer_id, + owners=[("agent", agent.id)], + query="funny", + client=client, + ) + + assert len(result) >= 1 + + +@test("model: search docs by embedding") def _(client=cozo_client, agent=test_agent, developer_id=test_developer_id): doc = create_doc( developer_id=developer_id, @@ -106,8 +129,7 @@ def _(client=cozo_client, agent=test_agent, developer_id=test_developer_id): result = search_docs_by_embedding( developer_id=developer_id, - owner_type="agent", - owner_id=agent.id, + owners=[("agent", agent.id)], query_embedding=query_embedding, client=client, ) diff --git a/sdks/python/julep/api/client.py b/sdks/python/julep/api/client.py index 42c9f45a5..2030b74ce 100644 --- a/sdks/python/julep/api/client.py +++ b/sdks/python/julep/api/client.py @@ -2663,8 +2663,8 @@ def chat_route_generate( self, id: CommonUuid, *, - recall: bool, remember: bool, + recall: bool, save: bool, stream: bool, messages: typing.Sequence[EntriesInputChatMlMessage], @@ -2694,11 +2694,11 @@ def chat_route_generate( id : CommonUuid The session ID - recall : bool - Whether previous memories should be recalled or not (will be enabled in a future release) - remember : bool - Whether this interaction should form new memories or not (will be enabled in a future release) + DISABLED: Whether this interaction should form new memories or not (will be enabled in a future release) + + recall : bool + Whether previous memories and docs should be recalled or not save : bool Whether this interaction should be stored in the session history or not @@ -2782,8 +2782,8 @@ def chat_route_generate( content="content", ) ], - recall=True, remember=True, + recall=True, save=True, stream=True, ) @@ -2792,8 +2792,8 @@ def chat_route_generate( f"sessions/{jsonable_encoder(id)}/chat", method="POST", json={ - "recall": recall, "remember": remember, + "recall": recall, "save": save, "model": model, "stream": stream, @@ -6542,8 +6542,8 @@ async def chat_route_generate( self, id: CommonUuid, *, - recall: bool, remember: bool, + recall: bool, save: bool, stream: bool, messages: typing.Sequence[EntriesInputChatMlMessage], @@ -6573,11 +6573,11 @@ async def chat_route_generate( id : CommonUuid The session ID - recall : bool - Whether previous memories should be recalled or not (will be enabled in a future release) - remember : bool - Whether this interaction should form new memories or not (will be enabled in a future release) + DISABLED: Whether this interaction should form new memories or not (will be enabled in a future release) + + recall : bool + Whether previous memories and docs should be recalled or not save : bool Whether this interaction should be stored in the session history or not @@ -6666,8 +6666,8 @@ async def main() -> None: content="content", ) ], - recall=True, remember=True, + recall=True, save=True, stream=True, ) @@ -6679,8 +6679,8 @@ async def main() -> None: f"sessions/{jsonable_encoder(id)}/chat", method="POST", json={ - "recall": recall, "remember": remember, + "recall": recall, "save": save, "model": model, "stream": stream, diff --git a/sdks/python/julep/api/reference.md b/sdks/python/julep/api/reference.md index 5bf9bf81f..85fbaf94c 100644 --- a/sdks/python/julep/api/reference.md +++ b/sdks/python/julep/api/reference.md @@ -3790,8 +3790,8 @@ client.chat_route_generate( content="content", ) ], - recall=True, remember=True, + recall=True, save=True, stream=True, ) @@ -3818,7 +3818,7 @@ client.chat_route_generate(
-**recall:** `bool` — Whether previous memories should be recalled or not (will be enabled in a future release) +**remember:** `bool` — DISABLED: Whether this interaction should form new memories or not (will be enabled in a future release)
@@ -3826,7 +3826,7 @@ client.chat_route_generate(
-**remember:** `bool` — Whether this interaction should form new memories or not (will be enabled in a future release) +**recall:** `bool` — Whether previous memories and docs should be recalled or not
diff --git a/sdks/python/julep/api/types/chat_competion_usage.py b/sdks/python/julep/api/types/chat_competion_usage.py index f6f798330..23d6b7b3c 100644 --- a/sdks/python/julep/api/types/chat_competion_usage.py +++ b/sdks/python/julep/api/types/chat_competion_usage.py @@ -12,17 +12,17 @@ class ChatCompetionUsage(pydantic_v1.BaseModel): Usage statistics for the completion request """ - completion_tokens: int = pydantic_v1.Field() + completion_tokens: typing.Optional[int] = pydantic_v1.Field(default=None) """ Number of tokens in the generated completion """ - prompt_tokens: int = pydantic_v1.Field() + prompt_tokens: typing.Optional[int] = pydantic_v1.Field(default=None) """ Number of tokens in the prompt """ - total_tokens: int = pydantic_v1.Field() + total_tokens: typing.Optional[int] = pydantic_v1.Field(default=None) """ Total number of tokens used in the request (prompt + completion) """ diff --git a/sdks/python/poetry.lock b/sdks/python/poetry.lock index 38e2c1b19..060ffa16f 100644 --- a/sdks/python/poetry.lock +++ b/sdks/python/poetry.lock @@ -2865,18 +2865,18 @@ win32 = ["pywin32"] [[package]] name = "setuptools" -version = "72.1.0" +version = "72.2.0" description = "Easily download, build, install, upgrade, and uninstall Python packages" optional = false python-versions = ">=3.8" files = [ - {file = "setuptools-72.1.0-py3-none-any.whl", hash = "sha256:5a03e1860cf56bb6ef48ce186b0e557fdba433237481a9a625176c2831be15d1"}, - {file = "setuptools-72.1.0.tar.gz", hash = "sha256:8d243eff56d095e5817f796ede6ae32941278f542e0f941867cc05ae52b162ec"}, + {file = "setuptools-72.2.0-py3-none-any.whl", hash = "sha256:f11dd94b7bae3a156a95ec151f24e4637fb4fa19c878e4d191bfb8b2d82728c4"}, + {file = "setuptools-72.2.0.tar.gz", hash = "sha256:80aacbf633704e9c8bfa1d99fa5dd4dc59573efcf9e4042c13d3bcef91ac2ef9"}, ] [package.extras] core = ["importlib-metadata (>=6)", "importlib-resources (>=5.10.2)", "jaraco.text (>=3.7)", "more-itertools (>=8.8)", "ordered-set (>=3.1.1)", "packaging (>=24)", "platformdirs (>=2.6.2)", "tomli (>=2.0.1)", "wheel (>=0.43.0)"] -doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "pyproject-hooks (!=1.1)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (>=1,<2)", "sphinx-reredirects", "sphinxcontrib-towncrier"] +doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "pyproject-hooks (!=1.1)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (>=1,<2)", "sphinx-reredirects", "sphinxcontrib-towncrier", "towncrier (<24.7)"] test = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "importlib-metadata", "ini2toml[lite] (>=0.14)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "jaraco.test", "mypy (==1.11.*)", "packaging (>=23.2)", "pip (>=19.1)", "pyproject-hooks (!=1.1)", "pytest (>=6,!=8.1.*)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-home (>=0.5)", "pytest-mypy", "pytest-perf", "pytest-ruff (<0.4)", "pytest-ruff (>=0.2.1)", "pytest-ruff (>=0.3.2)", "pytest-subprocess", "pytest-timeout", "pytest-xdist (>=3)", "tomli", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel"] [[package]] @@ -2914,13 +2914,13 @@ files = [ [[package]] name = "soupsieve" -version = "2.5" +version = "2.6" description = "A modern CSS selector implementation for Beautiful Soup." optional = false python-versions = ">=3.8" files = [ - {file = "soupsieve-2.5-py3-none-any.whl", hash = "sha256:eaa337ff55a1579b6549dc679565eac1e3d000563bcb1c8ab0d0fefbc0c2cdc7"}, - {file = "soupsieve-2.5.tar.gz", hash = "sha256:5663d5a7b3bfaeee0bc4372e7fc48f9cff4940b3eec54a6451cc5299f1097690"}, + {file = "soupsieve-2.6-py3-none-any.whl", hash = "sha256:e72c4ff06e4fb6e4b5a9f0f55fe6e81514581fca1515028625d0f299c602ccc9"}, + {file = "soupsieve-2.6.tar.gz", hash = "sha256:e2e68417777af359ec65daac1057404a3c8a5455bb8abc36f1a9866ab1a51abb"}, ] [[package]] diff --git a/sdks/ts/src/api/models/Chat_ChatInput.ts b/sdks/ts/src/api/models/Chat_ChatInput.ts index 5de17283c..e03fa2914 100644 --- a/sdks/ts/src/api/models/Chat_ChatInput.ts +++ b/sdks/ts/src/api/models/Chat_ChatInput.ts @@ -9,13 +9,13 @@ import type { Common_logit_bias } from "./Common_logit_bias"; import type { Common_uuid } from "./Common_uuid"; export type Chat_ChatInput = Chat_ChatInputData & { /** - * Whether previous memories should be recalled or not (will be enabled in a future release) + * DISABLED: Whether this interaction should form new memories or not (will be enabled in a future release) */ - readonly recall: boolean; + readonly remember: boolean; /** - * Whether this interaction should form new memories or not (will be enabled in a future release) + * Whether previous memories and docs should be recalled or not */ - readonly remember: boolean; + recall: boolean; /** * Whether this interaction should be stored in the session history or not */ diff --git a/sdks/ts/src/api/models/Chat_CompetionUsage.ts b/sdks/ts/src/api/models/Chat_CompetionUsage.ts index 089b474f1..8422d5f36 100644 --- a/sdks/ts/src/api/models/Chat_CompetionUsage.ts +++ b/sdks/ts/src/api/models/Chat_CompetionUsage.ts @@ -9,13 +9,13 @@ export type Chat_CompetionUsage = { /** * Number of tokens in the generated completion */ - readonly completion_tokens: number; + readonly completion_tokens?: number; /** * Number of tokens in the prompt */ - readonly prompt_tokens: number; + readonly prompt_tokens?: number; /** * Total number of tokens used in the request (prompt + completion) */ - readonly total_tokens: number; + readonly total_tokens?: number; }; diff --git a/sdks/ts/src/api/schemas/$Chat_ChatInput.ts b/sdks/ts/src/api/schemas/$Chat_ChatInput.ts index 635d83534..538b90e76 100644 --- a/sdks/ts/src/api/schemas/$Chat_ChatInput.ts +++ b/sdks/ts/src/api/schemas/$Chat_ChatInput.ts @@ -10,16 +10,15 @@ export const $Chat_ChatInput = { }, { properties: { - recall: { + remember: { type: "boolean", - description: `Whether previous memories should be recalled or not (will be enabled in a future release)`, + description: `DISABLED: Whether this interaction should form new memories or not (will be enabled in a future release)`, isReadOnly: true, isRequired: true, }, - remember: { + recall: { type: "boolean", - description: `Whether this interaction should form new memories or not (will be enabled in a future release)`, - isReadOnly: true, + description: `Whether previous memories and docs should be recalled or not`, isRequired: true, }, save: { diff --git a/sdks/ts/src/api/schemas/$Chat_CompetionUsage.ts b/sdks/ts/src/api/schemas/$Chat_CompetionUsage.ts index 556114810..d8f34cb14 100644 --- a/sdks/ts/src/api/schemas/$Chat_CompetionUsage.ts +++ b/sdks/ts/src/api/schemas/$Chat_CompetionUsage.ts @@ -9,21 +9,18 @@ export const $Chat_CompetionUsage = { type: "number", description: `Number of tokens in the generated completion`, isReadOnly: true, - isRequired: true, format: "uint32", }, prompt_tokens: { type: "number", description: `Number of tokens in the prompt`, isReadOnly: true, - isRequired: true, format: "uint32", }, total_tokens: { type: "number", description: `Total number of tokens used in the request (prompt + completion)`, isReadOnly: true, - isRequired: true, format: "uint32", }, }, diff --git a/typespec/chat/models.tsp b/typespec/chat/models.tsp index da7c170ca..f52dae04c 100644 --- a/typespec/chat/models.tsp +++ b/typespec/chat/models.tsp @@ -34,14 +34,13 @@ enum FinishReason { /** Determines how the session accesses history and memories */ model MemoryAccessOptions { - /** Whether previous memories should be recalled or not (will be enabled in a future release) */ - @visibility("read") // DISABLED - recall: boolean = false; - - /** Whether this interaction should form new memories or not (will be enabled in a future release) */ + /** DISABLED: Whether this interaction should form new memories or not (will be enabled in a future release) */ @visibility("read") // DISABLED remember: boolean = false; + /** Whether previous memories and docs should be recalled or not */ + recall: boolean = true; + /** Whether this interaction should be stored in the session history or not */ save: boolean = true; } @@ -134,15 +133,15 @@ model ChatSettings extends DefaultChatSettings { model CompetionUsage { /** Number of tokens in the generated completion */ @visibility("read") - completion_tokens: uint32; + completion_tokens?: uint32; /** Number of tokens in the prompt */ @visibility("read") - prompt_tokens: uint32; + prompt_tokens?: uint32; /** Total number of tokens used in the request (prompt + completion) */ @visibility("read") - total_tokens: uint32; + total_tokens?: uint32; } model ChatInputData { From 31e07286b413f7ab7eef55c2e4d752a1490fea5a Mon Sep 17 00:00:00 2001 From: Diwank Tomer Date: Tue, 13 Aug 2024 18:06:12 -0400 Subject: [PATCH 3/4] refactor(agents-api): Remove unnecessary stuff Signed-off-by: Diwank Tomer --- .../agents_api/activities/dialog_insights.py | 118 ------------------ .../activities/relationship_summary.py | 102 --------------- .../activities/salient_questions.py | 91 -------------- .../activities/task_steps/__init__.py | 6 +- agents-api/agents_api/activities/types.py | 98 ++------------- .../agents_api/routers/agents/create_agent.py | 1 - .../agents_api/routers/sessions/chat.py | 17 ++- .../agents_api/routers/users/delete_user.py | 1 - agents-api/agents_api/worker/__main__.py | 12 -- .../agents_api/workflows/dialog_insights.py | 21 ---- .../workflows/relationship_summary.py | 20 --- .../agents_api/workflows/salient_questions.py | 20 --- 12 files changed, 27 insertions(+), 480 deletions(-) delete mode 100644 agents-api/agents_api/activities/dialog_insights.py delete mode 100644 agents-api/agents_api/activities/relationship_summary.py delete mode 100644 agents-api/agents_api/activities/salient_questions.py delete mode 100644 agents-api/agents_api/workflows/dialog_insights.py delete mode 100644 agents-api/agents_api/workflows/relationship_summary.py delete mode 100644 agents-api/agents_api/workflows/salient_questions.py diff --git a/agents-api/agents_api/activities/dialog_insights.py b/agents-api/agents_api/activities/dialog_insights.py deleted file mode 100644 index 1d5adec39..000000000 --- a/agents-api/agents_api/activities/dialog_insights.py +++ /dev/null @@ -1,118 +0,0 @@ -from textwrap import dedent -from typing import Callable - -from temporalio import activity - -from agents_api.clients import litellm - -from .types import ChatML, DialogInsightsTaskArgs - - -def make_prompt( - args: DialogInsightsTaskArgs, - max_turns: int = 20, -): - # Unpack - dialog = args.dialog - person1 = args.person1 - person2 = args.person2 - - # Template - template = dedent( - """\ - [[Conversation]] - {dialog_context} - - --- - - Write down if there are any details from the conversation above that {person1} might have found interesting from {person2}'s perspective, in a full sentence. Write down point by point only the most important points. Answer must be in third person. - - Answer: " - """ - ).strip() - - # Filter dialog (keep only user and assistant sections) - dialog = [entry for entry in dialog if entry.role != "system"] - - # Truncate to max_turns - dialog = dialog[-max_turns:] - - # Prepare dialog context - dialog_context = "\n".join( - [ - f'{e.name or ("User" if e.role == "user" else "Assistant")}: {e.content}' - for e in dialog - ] - ) - - prompt = template.format( - dialog_context=dialog_context, - person1=person1, - person2=person2, - ) - - return prompt - - -async def run_prompt( - dialog: list[ChatML], - person1: str, - person2: str, - model: str = "gpt-4o", - max_tokens: int = 400, - temperature: float = 0.4, - parser: Callable[[str], str] = lambda x: x, -) -> str: - prompt = make_prompt( - DialogInsightsTaskArgs(dialog=dialog, person1=person1, person2=person2) - ) - - response = await litellm.acompletion( - model=model, - messages=[ - { - "content": prompt, - "role": "user", - } - ], - max_tokens=max_tokens, - temperature=temperature, - stop=["<", "<|"], - stream=False, - ) - - content = response.choices[0].message.content - - return parser(content.strip() if content is not None else "") - - -@activity.defn -async def dialog_insights(dialog: list[ChatML], person1: str, person2: str) -> None: - # session_id = UUID(session_id) - # entries = [ - # Entry(**row) - # for _, row in client.run( - # get_toplevel_entries_query(session_id=session_id) - # ).iterrows() - # ] - - # assert len(entries) > 0, "no need to summarize on empty entries list" - - await run_prompt(dialog, person1, person2) - - # new_entry = Entry( - # session_id=session_id, - # source="summarizer", - # role="system", - # name="information", - # content=response, - # timestamp=entries[-1].timestamp + 0.01, - # ) - - # client.run( - # entries_summarization_query( - # session_id=session_id, - # new_entry=new_entry, - # old_entry_ids=[e.id for e in entries], - # ) - # ) diff --git a/agents-api/agents_api/activities/relationship_summary.py b/agents-api/agents_api/activities/relationship_summary.py deleted file mode 100644 index 997eaf40a..000000000 --- a/agents-api/agents_api/activities/relationship_summary.py +++ /dev/null @@ -1,102 +0,0 @@ -from textwrap import dedent -from typing import Callable - -from temporalio import activity - -from ..clients import litellm -from .types import RelationshipSummaryTaskArgs - - -def make_prompt(args: RelationshipSummaryTaskArgs): - # Unpack - statements = args.statements - person1 = args.person1 - person2 = args.person2 - - # Template - template = dedent( - """\ - Statements: - - {statements_joined} - - Based on the statements above, summarize {person1} and {person2}'s relationship in a 2-3 sentences. What do they feel or know about each other? - - Answer: " - """ - ).strip() - - prompt = template.format( - statements_joined="\n- ".join(statements), - person1=person1, - person2=person2, - ) - - return prompt - - -async def run_prompt( - statements: list[str], - person1: str, - person2: str, - model: str = "gpt-4o", - max_tokens: int = 400, - temperature: float = 0.6, - parser: Callable[[str], str] = lambda x: x, -) -> str: - prompt = make_prompt( - RelationshipSummaryTaskArgs( - statements=statements, person1=person1, person2=person2 - ) - ) - - response = await litellm.acompletion( - model=model, - messages=[ - { - "content": prompt, - "role": "user", - } - ], - max_tokens=max_tokens, - temperature=temperature, - stop=["<", "<|"], - stream=False, - ) - - content = response.choices[0].message.content - - return parser(content.strip() if content is not None else "") - - -@activity.defn -async def relationship_summary( - statements: list[str], person1: str, person2: str -) -> None: - # session_id = UUID(session_id) - # entries = [ - # Entry(**row) - # for _, row in client.run( - # get_toplevel_entries_query(session_id=session_id) - # ).iterrows() - # ] - - # assert len(entries) > 0, "no need to summarize on empty entries list" - - await run_prompt(statements=statements, person1=person1, person2=person2) - - # new_entry = Entry( - # session_id=session_id, - # source="summarizer", - # role="system", - # name="information", - # content=response, - # timestamp=entries[-1].timestamp + 0.01, - # ) - - # client.run( - # entries_summarization_query( - # session_id=session_id, - # new_entry=new_entry, - # old_entry_ids=[e.id for e in entries], - # ) - # ) diff --git a/agents-api/agents_api/activities/salient_questions.py b/agents-api/agents_api/activities/salient_questions.py deleted file mode 100644 index 0194e8c72..000000000 --- a/agents-api/agents_api/activities/salient_questions.py +++ /dev/null @@ -1,91 +0,0 @@ -from textwrap import dedent -from typing import Callable - -from temporalio import activity - -from ..clients import litellm -from .types import SalientQuestionsTaskArgs - - -def make_prompt(args: SalientQuestionsTaskArgs): - # Unpack - statements = args.statements - num = args.num - - # Template - template = dedent( - """\ - Statements: - - {statements_joined} - - Given only the information above, what are the {num} most salient high-level questions we can answer about the subjects grounded in the statements? - - """ - ).strip() - - prompt = template.format( - statements_joined="\n- ".join(statements), - num=num, - ) - - return prompt - - -async def run_prompt( - statements: list[str], - num: int = 3, - model: str = "gpt-4o", - max_tokens: int = 400, - temperature: float = 0.6, - parser: Callable[[str], str] = lambda x: x, -) -> str: - prompt = make_prompt(SalientQuestionsTaskArgs(statements=statements, num=num)) - - response = await litellm.acompletion( - model=model, - messages=[ - { - "content": prompt, - "role": "user", - } - ], - max_tokens=max_tokens, - temperature=temperature, - stop=["<", "<|"], - stream=False, - ) - - content = response.choices[0].message.content - - return parser(content.strip() if content is not None else "") - - -@activity.defn -async def salient_questions(statements: list[str], num: int = 3) -> None: - # session_id = UUID(session_id) - # entries = [ - # Entry(**row) - # for _, row in client.run( - # get_toplevel_entries_query(session_id=session_id) - # ).iterrows() - # ] - - # assert len(entries) > 0, "no need to summarize on empty entries list" - - await run_prompt(statements=statements, num=num) - - # new_entry = Entry( - # session_id=session_id, - # source="summarizer", - # role="system", - # name="information", - # content=response, - # timestamp=entries[-1].timestamp + 0.01, - # ) - - # client.run( - # entries_summarization_query( - # session_id=session_id, - # new_entry=new_entry, - # old_entry_ids=[e.id for e in entries], - # ) - # ) diff --git a/agents-api/agents_api/activities/task_steps/__init__.py b/agents-api/agents_api/activities/task_steps/__init__.py index a9818d515..494226a5b 100644 --- a/agents-api/agents_api/activities/task_steps/__init__.py +++ b/agents-api/agents_api/activities/task_steps/__init__.py @@ -14,7 +14,9 @@ UpdateExecutionRequest, YieldStep, ) -from ...clients.litellm import acompletion +from ...clients import ( + litellm, # We dont directly import `acompletion` so we can mock it +) from ...clients.worker.types import ChatML from ...common.protocol.tasks import ( StepContext, @@ -57,7 +59,7 @@ async def prompt_step(context: StepContext) -> dict: settings: dict = context.definition.settings.model_dump() # Get settings and run llm - response = await acompletion( + response = await litellm.acompletion( messages=messages, **settings, ) diff --git a/agents-api/agents_api/activities/types.py b/agents-api/agents_api/activities/types.py index 37fd8015d..f550b5c75 100644 --- a/agents-api/agents_api/activities/types.py +++ b/agents-api/agents_api/activities/types.py @@ -1,111 +1,27 @@ -from typing import Any, Callable, Literal, Optional, Protocol, TypedDict +from typing import Literal from uuid import UUID from pydantic import BaseModel +from ..autogen.openapi_model import InputChatMLMessage -class PromptModule(Protocol): - stop: list[str] - temperature: float - parser: Callable[[str], str] - make_prompt: Callable[..., str] - -class ChatML(BaseModel): - role: Literal["system", "user", "assistant"] - content: str - - name: Optional[str] = None - entry_id: Optional[UUID] = None - - processed: bool = False - parent_id: Optional[UUID] = None - session_id: Optional[UUID] = None - timestamp: Optional[float] = None - token_count: Optional[int] = None - - -class BaseTask(BaseModel): ... - - -class BaseTaskArgs(BaseModel): ... - - -class AddPrinciplesTaskArgs(BaseTaskArgs): - scores: dict[str, Any] - full: bool = False - name: Optional[str] = None - user_id: Optional[UUID] = None - character_id: Optional[UUID] = None - - -class AddPrinciplesTask(BaseTask): - name: Literal["add_principles.v1"] - args: AddPrinciplesTaskArgs - - -class MemoryManagementTaskArgs(BaseTaskArgs): +class MemoryManagementTaskArgs(BaseModel): session_id: UUID model: str - dialog: list[ChatML] + dialog: list[InputChatMLMessage] previous_memories: list[str] = [] -class MemoryManagementTask(BaseTask): +class MemoryManagementTask(BaseModel): name: Literal["memory_management.v1"] args: MemoryManagementTaskArgs -class MemoryDensityTaskArgs(BaseTaskArgs): - memory: str - - -class MemoryDensityTask(BaseTask): - name: Literal["memory_density.v1"] - args: MemoryDensityTaskArgs - - -class MemoryRatingTaskArgs(BaseTaskArgs): +class MemoryRatingTaskArgs(BaseModel): memory: str -class MemoryRatingTask(BaseTask): +class MemoryRatingTask(BaseModel): name: Literal["memory_rating.v1"] args: MemoryRatingTaskArgs - - -class DialogInsightsTaskArgs(BaseTaskArgs): - dialog: list[ChatML] - person1: str - person2: str - - -class DialogInsightsTask(BaseTask): - name: Literal["dialog_insights.v1"] - args: DialogInsightsTaskArgs - - -class RelationshipSummaryTaskArgs(BaseTaskArgs): - statements: list[str] - person1: str - person2: str - - -class RelationshipSummaryTask(BaseTask): - name: Literal["relationship_summary.v1"] - args: RelationshipSummaryTaskArgs - - -class SalientQuestionsTaskArgs(BaseTaskArgs): - statements: list[str] - num: int = 3 - - -class SalientQuestionsTask(BaseTask): - name: Literal["salient_questions.v1"] - args: SalientQuestionsTaskArgs - - -class CombinedTask(TypedDict): - name: str - args: dict[Any, Any] diff --git a/agents-api/agents_api/routers/agents/create_agent.py b/agents-api/agents_api/routers/agents/create_agent.py index 56e2eadf7..d1cac0d6b 100644 --- a/agents-api/agents_api/routers/agents/create_agent.py +++ b/agents-api/agents_api/routers/agents/create_agent.py @@ -19,7 +19,6 @@ async def create_agent( x_developer_id: Annotated[UUID4, Depends(get_developer_id)], data: CreateAgentRequest, ) -> ResourceCreatedResponse: - print("create_agent", x_developer_id, data) agent = models.agent.create_agent( developer_id=x_developer_id, data=data, diff --git a/agents-api/agents_api/routers/sessions/chat.py b/agents-api/agents_api/routers/sessions/chat.py index afe7e3e2d..e6103c15e 100644 --- a/agents-api/agents_api/routers/sessions/chat.py +++ b/agents-api/agents_api/routers/sessions/chat.py @@ -118,6 +118,11 @@ async def chat( # Get the tools tools = settings.get("tools") or chat_context.get_active_tools() + # Truncate the messages if necessary + if chat_context.session.context_overflow == "truncate": + # messages = messages[-settings["max_tokens"] :] + raise NotImplementedError("Truncation is not yet implemented") + # Get the response from the model model_response = await litellm.acompletion( messages=messages, @@ -129,9 +134,12 @@ async def chat( # Save the input and the response to the session history if input.save: + # TODO: Count the number of tokens before saving it to the session + new_entries = [ CreateEntryRequest(**msg, source="api_request") for msg in new_messages ] + background_tasks.add_task( create_entries, developer_id=developer.id, @@ -140,12 +148,19 @@ async def chat( mark_session_as_updated=True, ) + # Adaptive context handling + jobs = [] + if chat_context.session.context_overflow == "adaptive": + # TODO: Start the adaptive context workflow + # jobs = [await start_adaptive_context_workflow] + raise NotImplementedError("Adaptive context is not yet implemented") + # Return the response chat_response_class = ChunkChatResponse if input.stream else MessageChatResponse chat_response: ChatResponse = chat_response_class( id=uuid4(), created_at=utcnow(), - jobs=[], + jobs=jobs, docs=doc_references, usage=model_response.usage.model_dump(), choices=[choice.model_dump() for choice in model_response.choices], diff --git a/agents-api/agents_api/routers/users/delete_user.py b/agents-api/agents_api/routers/users/delete_user.py index 3a63e42e9..fd1d02a94 100644 --- a/agents-api/agents_api/routers/users/delete_user.py +++ b/agents-api/agents_api/routers/users/delete_user.py @@ -14,5 +14,4 @@ async def delete_user( user_id: UUID4, x_developer_id: Annotated[UUID4, Depends(get_developer_id)] ) -> ResourceDeletedResponse: - print(user_id) return delete_user_query(developer_id=x_developer_id, user_id=user_id) diff --git a/agents-api/agents_api/worker/__main__.py b/agents-api/agents_api/worker/__main__.py index 544a28b4d..b84ed7992 100644 --- a/agents-api/agents_api/worker/__main__.py +++ b/agents-api/agents_api/worker/__main__.py @@ -11,12 +11,9 @@ from temporalio.worker import Worker from ..activities.co_density import co_density -from ..activities.dialog_insights import dialog_insights from ..activities.embed_docs import embed_docs from ..activities.mem_mgmt import mem_mgmt from ..activities.mem_rating import mem_rating -from ..activities.relationship_summary import relationship_summary -from ..activities.salient_questions import salient_questions from ..activities.summarization import summarization from ..activities.task_steps import ( evaluate_step, @@ -35,12 +32,9 @@ temporal_task_queue, ) from ..workflows.co_density import CoDensityWorkflow -from ..workflows.dialog_insights import DialogInsightsWorkflow from ..workflows.embed_docs import EmbedDocsWorkflow from ..workflows.mem_mgmt import MemMgmtWorkflow from ..workflows.mem_rating import MemRatingWorkflow -from ..workflows.relationship_summary import RelationshipSummaryWorkflow -from ..workflows.salient_questions import SalientQuestionsWorkflow from ..workflows.summarization import SummarizationWorkflow from ..workflows.task_execution import TaskExecutionWorkflow from ..workflows.truncation import TruncationWorkflow @@ -88,11 +82,8 @@ async def main(): workflows=[ SummarizationWorkflow, CoDensityWorkflow, - DialogInsightsWorkflow, MemMgmtWorkflow, MemRatingWorkflow, - RelationshipSummaryWorkflow, - SalientQuestionsWorkflow, EmbedDocsWorkflow, TaskExecutionWorkflow, TruncationWorkflow, @@ -101,11 +92,8 @@ async def main(): *task_activities, summarization, co_density, - dialog_insights, mem_mgmt, mem_rating, - relationship_summary, - salient_questions, embed_docs, truncation, ], diff --git a/agents-api/agents_api/workflows/dialog_insights.py b/agents-api/agents_api/workflows/dialog_insights.py deleted file mode 100644 index d7e40395e..000000000 --- a/agents-api/agents_api/workflows/dialog_insights.py +++ /dev/null @@ -1,21 +0,0 @@ -#!/usr/bin/env python3 - - -from datetime import timedelta - -from temporalio import workflow - -with workflow.unsafe.imports_passed_through(): - from ..activities.dialog_insights import dialog_insights - from ..activities.types import ChatML - - -@workflow.defn -class DialogInsightsWorkflow: - @workflow.run - async def run(self, dialog: list[ChatML], person1: str, person2: str) -> None: - return await workflow.execute_activity( - dialog_insights, - [dialog, person1, person2], - schedule_to_close_timeout=timedelta(seconds=600), - ) diff --git a/agents-api/agents_api/workflows/relationship_summary.py b/agents-api/agents_api/workflows/relationship_summary.py deleted file mode 100644 index 0f2e5fb07..000000000 --- a/agents-api/agents_api/workflows/relationship_summary.py +++ /dev/null @@ -1,20 +0,0 @@ -#!/usr/bin/env python3 - - -from datetime import timedelta - -from temporalio import workflow - -with workflow.unsafe.imports_passed_through(): - from ..activities.relationship_summary import relationship_summary - - -@workflow.defn -class RelationshipSummaryWorkflow: - @workflow.run - async def run(self, statements: list[str], person1: str, person2: str) -> None: - return await workflow.execute_activity( - relationship_summary, - [statements, person1, person2], - schedule_to_close_timeout=timedelta(seconds=600), - ) diff --git a/agents-api/agents_api/workflows/salient_questions.py b/agents-api/agents_api/workflows/salient_questions.py deleted file mode 100644 index 59f30dc37..000000000 --- a/agents-api/agents_api/workflows/salient_questions.py +++ /dev/null @@ -1,20 +0,0 @@ -#!/usr/bin/env python3 - - -from datetime import timedelta - -from temporalio import workflow - -with workflow.unsafe.imports_passed_through(): - from ..activities.salient_questions import salient_questions - - -@workflow.defn -class SalientQuestionsWorkflow: - @workflow.run - async def run(self, statements: list[str], num: int = 3) -> None: - return await workflow.execute_activity( - salient_questions, - [statements, num], - schedule_to_close_timeout=timedelta(seconds=600), - ) From dbb1b323b6af8ea8ef9e5306835101d5749319d0 Mon Sep 17 00:00:00 2001 From: Diwank Tomer Date: Wed, 14 Aug 2024 12:29:24 -0400 Subject: [PATCH 4/4] feat(agents-api): Add fixtures for testing workflows Signed-off-by: Diwank Tomer --- .../agents_api/activities/co_density.py | 115 -------- .../agents_api/activities/embed_docs.py | 18 +- agents-api/agents_api/activities/mem_mgmt.py | 12 +- .../agents_api/activities/summarization.py | 250 ++++-------------- .../activities/task_steps/__init__.py | 5 +- .../agents_api/activities/truncation.py | 63 ++--- agents-api/agents_api/clients/temporal.py | 36 ++- agents-api/agents_api/clients/worker/types.py | 87 +----- .../agents_api/common/protocol/entries.py | 52 ++-- .../agents_api/models/docs/embed_snippets.py | 2 +- agents-api/agents_api/rec_sum/entities.py | 2 +- agents-api/agents_api/rec_sum/generate.py | 6 +- agents-api/agents_api/rec_sum/summarize.py | 2 +- agents-api/agents_api/rec_sum/trim.py | 2 +- agents-api/agents_api/worker/__main__.py | 85 +----- agents-api/agents_api/worker/worker.py | 68 +++++ agents-api/agents_api/workflows/co_density.py | 20 -- agents-api/agents_api/workflows/mem_mgmt.py | 7 +- agents-api/tests/fixtures.py | 27 ++ agents-api/tests/test_activities.py | 52 +++- 20 files changed, 333 insertions(+), 578 deletions(-) delete mode 100644 agents-api/agents_api/activities/co_density.py create mode 100644 agents-api/agents_api/worker/worker.py delete mode 100644 agents-api/agents_api/workflows/co_density.py diff --git a/agents-api/agents_api/activities/co_density.py b/agents-api/agents_api/activities/co_density.py deleted file mode 100644 index 408cc398a..000000000 --- a/agents-api/agents_api/activities/co_density.py +++ /dev/null @@ -1,115 +0,0 @@ -from textwrap import dedent -from typing import Callable - -from temporalio import activity - -from agents_api.clients import litellm - -from .types import MemoryDensityTaskArgs - - -def make_prompt(args: MemoryDensityTaskArgs): - # Unpack - memory = args.memory - - # Template - template = dedent( - """\ - [[Memory from a Dialog]] - {memory} - - [[Instruction]] - You will generate increasingly concise, entity-dense summaries of the above Memory. - - Repeat the following 2 steps 5 times. - - Step 1: Identify 1-3 informative Entities (";" delimited) from the Memory which are missing from the previously generated summary. - Step 2: Write a new, denser summary of identical length which covers every entity and detail from the previous summary plus the Missing Entities. - - A Missing Entity is: - - Relevant: to the main story. - - Specific: descriptive yet concise (5 words or fewer). - - Novel: not in the previous summary. - - Faithful: present in the Memory. - - Anywhere: located anywhere in the Memory. - - Guidelines: - - The first summary should be long (4-5 sentences, ~80 words) yet highly non-specific, containing little information beyond the entities marked as missing. Use overly verbose language and fillers (e.g., "this article discusses") to reach ~80 words. - - Make every word count: rewrite the previous summary to improve flow and make space for additional entities. - - Make space with fusion, compression, and removal of uninformative phrases like "the memory discusses." - - The summaries should become highly dense and concise yet self-contained, e.g., easily understood without the Memory. - - Missing entities can appear anywhere in the new summary. - - Never drop entities from the previous summary. If space cannot be made, add fewer new entities. - - Remember, use the exact same number of words for each summary. - - Answer in JSON. The JSON should be a list (length 5) of dictionaries whose keys are "Missing_Entities", "Denser_Summary" and "Density_Score" (between 1-10, higher is better). - - [[Result]] - ```json - """ - ).strip() - - prompt = template.format(memory=memory) - - return prompt - - -async def run_prompt( - memory: str, - model: str = "gpt-4o", - max_tokens: int = 400, - temperature: float = 0.2, - parser: Callable[[str], str] = lambda x: x, -) -> str: - prompt = make_prompt(MemoryDensityTaskArgs(memory=memory)) - - response = await litellm.acompletion( - model=model, - messages=[ - { - "content": prompt, - "role": "user", - } - ], - max_tokens=max_tokens, - temperature=temperature, - stop=["<", "<|"], - stream=False, - ) - - content = response.choices[0].message.content - - return parser(content.strip() if content is not None else "") - - -@activity.defn -async def co_density(memory: str) -> None: - # session_id = UUID(session_id) - # entries = [ - # Entry(**row) - # for _, row in client.run( - # get_toplevel_entries_query(session_id=session_id) - # ).iterrows() - # ] - - # assert len(entries) > 0, "no need to summarize on empty entries list" - - await run_prompt(memory=memory) - - # new_entry = Entry( - # session_id=session_id, - # source="summarizer", - # role="system", - # name="information", - # content=response, - # timestamp=entries[-1].timestamp + 0.01, - # ) - - # client.run( - # entries_summarization_query( - # session_id=session_id, - # new_entry=new_entry, - # old_entry_ids=[e.id for e in entries], - # ) - # ) diff --git a/agents-api/agents_api/activities/embed_docs.py b/agents-api/agents_api/activities/embed_docs.py index b8e01e65e..da7bb313f 100644 --- a/agents-api/agents_api/activities/embed_docs.py +++ b/agents-api/agents_api/activities/embed_docs.py @@ -1,27 +1,39 @@ -from pydantic import UUID4 +from uuid import UUID + from temporalio import activity from agents_api.clients import embed as embedder +from agents_api.clients.cozo import get_cozo_client from agents_api.models.docs.embed_snippets import embed_snippets as embed_snippets_query snippet_embed_instruction = "Encode this passage for retrieval: " @activity.defn -async def embed_docs(doc_id: UUID4, title: str, content: list[str]) -> None: +async def embed_docs( + developer_id: UUID, + doc_id: UUID, + title: str, + content: list[str], + include_title: bool = True, + cozo_client=None, +) -> None: indices, snippets = list(zip(*enumerate(content))) + embeddings = await embedder.embed( [ { "instruction": snippet_embed_instruction, - "text": title + "\n\n" + snippet, + "text": (title + "\n\n" + snippet) if include_title else snippet, } for snippet in snippets ] ) embed_snippets_query( + developer_id=developer_id, doc_id=doc_id, snippet_indices=indices, embeddings=embeddings, + client=cozo_client or get_cozo_client(), ) diff --git a/agents-api/agents_api/activities/mem_mgmt.py b/agents-api/agents_api/activities/mem_mgmt.py index f368a0b0b..ea4bb84d2 100644 --- a/agents-api/agents_api/activities/mem_mgmt.py +++ b/agents-api/agents_api/activities/mem_mgmt.py @@ -4,9 +4,9 @@ from temporalio import activity -from agents_api.clients import litellm - -from .types import ChatML, MemoryManagementTaskArgs +from ..autogen.openapi_model import InputChatMLMessage +from ..clients import litellm +from .types import MemoryManagementTaskArgs example_previous_memory = """ Speaker 1: Composes and listens to music. Likes to buy basketball shoes but doesn't wear them often. @@ -118,7 +118,7 @@ def make_prompt( async def run_prompt( - dialog: list[ChatML], + dialog: list[InputChatMLMessage], session_id: UUID, previous_memories: list[str] = [], model: str = "gpt-4o", @@ -156,7 +156,9 @@ async def run_prompt( @activity.defn async def mem_mgmt( - dialog: list[ChatML], session_id: UUID, previous_memories: list[str] = [] + dialog: list[InputChatMLMessage], + session_id: UUID, + previous_memories: list[str] = [], ) -> None: # session_id = UUID(session_id) # entries = [ diff --git a/agents-api/agents_api/activities/summarization.py b/agents-api/agents_api/activities/summarization.py index dc365380d..581dcdb00 100644 --- a/agents-api/agents_api/activities/summarization.py +++ b/agents-api/agents_api/activities/summarization.py @@ -8,8 +8,7 @@ import pandas as pd from temporalio import activity -from agents_api.common.protocol.entries import Entry - +# from agents_api.common.protocol.entries import Entry # from agents_api.models.entry.entries_summarization import ( # entries_summarization_query, # get_toplevel_entries_query, @@ -18,7 +17,6 @@ from agents_api.rec_sum.summarize import summarize_messages from agents_api.rec_sum.trim import trim_messages -from ..clients import litellm from ..env import summarization_model_name @@ -31,196 +29,60 @@ def get_toplevel_entries_query(*args, **kwargs): return pd.DataFrame() -# - -example_previous_memory = """ -Speaker 1: Composes and listens to music. Likes to buy basketball shoes but doesn't wear them often. -""".strip() - -example_dialog_context = """ -Speaker 1: Did you find a place to donate your shoes? -Speaker 2: I did! I was driving to the grocery store the other day, when I noticed a bin labeled "Donation for Shoes and Clothing." It was easier than I thought! How about you? Why do you have so many pairs of sandals? -Speaker 1: I don't understand myself! When I look them online I just have the urge to buy them, even when I know I don't need them. This addiction is getting worse and worse. -Speaker 2: I completely agree that buying shoes can become an addiction! Are there any ways you can make money from home while waiting for a job offer from a call center? -Speaker 1: Well I already got the job so I just need to learn using the software. When I was still searching for jobs, we actually do a yard sale to sell many of my random items that are never used and clearly aren't needed either. -Speaker 2: Congratulations on getting the job! I know it'll help you out so much. And of course, maybe I should turn to yard sales as well, for they can be a great way to make some extra cash! -Speaker 1: Do you have another job or do you compose music for a living? How does your shopping addiction go? -Speaker 2: As a matter of fact, I do have another job in addition to composing music. I'm actually a music teacher at a private school, and on the side, I compose music for friends and family. As far as my shopping addiction goes, it's getting better. I promised myself that I wouldn't buy myself any more shoes this year! -Speaker 1: Ah, I remember the time I promised myself the same thing on not buying random things anymore, never work so far. Good luck with yours! -Speaker 2: Thanks! I need the good luck wishes. I've been avoiding malls and shopping outlets. Maybe you can try the same! -Speaker 1: I can avoid them physically, but with my job enable me sitting in front of my computer for a long period of time, I already turn the shopping addiction into online-shopping addiction. lol. Wish me luck! -Speaker 2: Sure thing! You know, and speaking of spending time before a computer, I need to look up information about Precious Moments figurines. I'd still like to know what they are! -""".strip() - -example_updated_memory = """ -Speaker 1: -- Enjoys composing and listening to music. -- Recently got a job that requires the use of specialized software. -- Displays a shopping addiction, particularly for shoes, that has transitioned to online-shopping due to job nature. -- Previously attempted to mitigate shopping addiction without success. -- Had organized a yard sale to sell unused items when job searching. - -Speaker 2: -- Also enjoys buying shoes and admits to it being addictive. -- Works as a music teacher at a private school in addition to composing music. -- Takes active measures to control his shopping addiction, including avoiding malls. -- Is interested in Precious Moments figurines. -""".strip() - - -def make_prompt( - dialog: list[Entry], - previous_memories: list[str], - max_turns: int = 10, - num_sentences: int = 10, -): - # Template - template = dedent( - """\ - **Instructions** - You are an advanced AI language model with the ability to store and update a memory to keep track of key personality information for people. You will receive a memory and a dialogue between two people. - - Your goal is to update the memory by incorporating the new personality information for both participants while ensuring that the memory does not exceed {num_sentences} sentences. - - To successfully update the memory, follow these steps: - - 1. Carefully analyze the existing memory and extract the key personality information of the participants from it. - 2. Consider the dialogue provided to identify any new or changed personality traits of either participant that need to be incorporated into the memory. - 3. Combine the old and new personality information to create an updated representation of the participants' traits. - 4. Structure the updated memory in a clear and concise manner, ensuring that it does not exceed {num_sentences} sentences. - 5. Pay attention to the relevance and importance of the personality information, focusing on capturing the most significant aspects while maintaining the overall coherence of the memory. - - Remember, the memory should serve as a reference point to maintain continuity in the dialogue and help accurately set context in future conversations based on the personality traits of the participants. - - **Test Example** - [[Previous Memory]] - {example_previous_memory} - - [[Dialogue Context]] - {example_dialog_context} - - [[Updated Memory]] - {example_updated_memory} - - **Actual Run** - [[Previous Memory]] - {previous_memory} - - [[Dialogue Context]] - {dialog_context} - - [[Updated Memory]] - """ - ).strip() - - # Filter dialog (keep only user and assistant sections) - dialog = [entry for entry in dialog if entry.role != "system"] - - # Truncate to max_turns - dialog = dialog[-max_turns:] - - # Prepare dialog context - dialog_context = "\n".join( - [ - f'{e.name or ("User" if e.role == "user" else "Assistant")}: {e.content}' - for e in dialog - ] - ) - - prompt = template.format( - dialog_context=dialog_context, - previous_memory="\n".join(previous_memories), - num_sentences=num_sentences, - example_dialog_context=example_dialog_context, - example_previous_memory=example_previous_memory, - example_updated_memory=example_updated_memory, - ) - - return prompt - - -async def run_prompt( - dialog: list[Entry], - previous_memories: list[str], - model: str = "gpt-4o", - max_tokens: int = 400, - temperature: float = 0.1, - parser: Callable[[str], str] = lambda x: x, - **kwargs, -) -> str: - prompt = make_prompt(dialog, previous_memories, **kwargs) - response = await litellm.acompletion( - model=model, - messages=[ - { - "content": prompt, - "role": "user", - } - ], - max_tokens=max_tokens, - temperature=temperature, - stop=["<", "<|"], - stream=False, - ) - - content = response.choices[0].message.content - - return parser(content.strip() if content is not None else "") - - @activity.defn async def summarization(session_id: str) -> None: - session_id = UUID(session_id) - entries = [] - entities_entry_ids = [] - for _, row in get_toplevel_entries_query(session_id=session_id).iterrows(): - if row["role"] == "system" and row.get("name") == "entities": - entities_entry_ids.append(UUID(row["entry_id"], version=4)) - else: - entries.append(row) - - assert len(entries) > 0, "no need to summarize on empty entries list" - - summarized, entities = await asyncio.gather( - summarize_messages(entries, model=summarization_model_name), - get_entities(entries, model=summarization_model_name), - ) - trimmed_messages = await trim_messages(summarized, model=summarization_model_name) - ts_delta = (entries[1]["timestamp"] - entries[0]["timestamp"]) / 2 - new_entities_entry = Entry( - session_id=session_id, - source="summarizer", - role="system", - name="entities", - content=entities["content"], - timestamp=entries[0]["timestamp"] + ts_delta, - ) - - entries_summarization_query( - session_id=session_id, - new_entry=new_entities_entry, - old_entry_ids=entities_entry_ids, - ) - - trimmed_map = { - m["index"]: m["content"] for m in trimmed_messages if m.get("index") is not None - } - - for idx, msg in enumerate(summarized): - new_entry = Entry( - session_id=session_id, - source="summarizer", - role="system", - name="information", - content=trimmed_map.get(idx, msg["content"]), - timestamp=entries[-1]["timestamp"] + 0.01, - ) - - entries_summarization_query( - session_id=session_id, - new_entry=new_entry, - old_entry_ids=[ - UUID(entries[idx - 1]["entry_id"], version=4) - for idx in msg["summarizes"] - ], - ) + raise NotImplementedError() + # session_id = UUID(session_id) + # entries = [] + # entities_entry_ids = [] + # for _, row in get_toplevel_entries_query(session_id=session_id).iterrows(): + # if row["role"] == "system" and row.get("name") == "entities": + # entities_entry_ids.append(UUID(row["entry_id"], version=4)) + # else: + # entries.append(row) + + # assert len(entries) > 0, "no need to summarize on empty entries list" + + # summarized, entities = await asyncio.gather( + # summarize_messages(entries, model=summarization_model_name), + # get_entities(entries, model=summarization_model_name), + # ) + # trimmed_messages = await trim_messages(summarized, model=summarization_model_name) + # ts_delta = (entries[1]["timestamp"] - entries[0]["timestamp"]) / 2 + # new_entities_entry = Entry( + # session_id=session_id, + # source="summarizer", + # role="system", + # name="entities", + # content=entities["content"], + # timestamp=entries[0]["timestamp"] + ts_delta, + # ) + + # entries_summarization_query( + # session_id=session_id, + # new_entry=new_entities_entry, + # old_entry_ids=entities_entry_ids, + # ) + + # trimmed_map = { + # m["index"]: m["content"] for m in trimmed_messages if m.get("index") is not None + # } + + # for idx, msg in enumerate(summarized): + # new_entry = Entry( + # session_id=session_id, + # source="summarizer", + # role="system", + # name="information", + # content=trimmed_map.get(idx, msg["content"]), + # timestamp=entries[-1]["timestamp"] + 0.01, + # ) + + # entries_summarization_query( + # session_id=session_id, + # new_entry=new_entry, + # old_entry_ids=[ + # UUID(entries[idx - 1]["entry_id"], version=4) + # for idx in msg["summarizes"] + # ], + # ) diff --git a/agents-api/agents_api/activities/task_steps/__init__.py b/agents-api/agents_api/activities/task_steps/__init__.py index 494226a5b..13f1adcfe 100644 --- a/agents-api/agents_api/activities/task_steps/__init__.py +++ b/agents-api/agents_api/activities/task_steps/__init__.py @@ -17,7 +17,6 @@ from ...clients import ( litellm, # We dont directly import `acompletion` so we can mock it ) -from ...clients.worker.types import ChatML from ...common.protocol.tasks import ( StepContext, TransitionInfo, @@ -53,7 +52,9 @@ async def prompt_step(context: StepContext) -> dict: ) messages = [ - ChatML(role="user", content=m) if isinstance(m, str) else ChatML(**m) + InputChatMLMessage(role="user", content=m) + if isinstance(m, str) + else InputChatMLMessage(**m) for m in messages ] diff --git a/agents-api/agents_api/activities/truncation.py b/agents-api/agents_api/activities/truncation.py index 190190a79..353e4b570 100644 --- a/agents-api/agents_api/activities/truncation.py +++ b/agents-api/agents_api/activities/truncation.py @@ -2,10 +2,11 @@ from temporalio import activity -from agents_api.autogen.openapi_model import Role +# from agents_api.autogen.openapi_model import Role from agents_api.common.protocol.entries import Entry from agents_api.models.entry.delete_entries import delete_entries -from agents_api.models.entry.entries_summarization import get_toplevel_entries_query + +# from agents_api.models.entry.entries_summarization import get_toplevel_entries_query def get_extra_entries(messages: list[Entry], token_count_threshold: int) -> list[UUID]: @@ -14,40 +15,40 @@ def get_extra_entries(messages: list[Entry], token_count_threshold: int) -> list result: list[UUID] = [] token_cnt, offset = 0, 0 - if messages[0].role == Role.system: - token_cnt, offset = messages[0].token_count, 1 + # if messages[0].role == Role.system: + # token_cnt, offset = messages[0].token_count, 1 - for m in reversed(messages[offset:]): - token_cnt += m.token_count - if token_cnt < token_count_threshold: - continue - else: - result.append(m.id) + # for m in reversed(messages[offset:]): + # token_cnt += m.token_count + # if token_cnt < token_count_threshold: + # continue + # else: + # result.append(m.id) - return result + # return result @activity.defn async def truncation(session_id: str, token_count_threshold: int) -> None: session_id = UUID(session_id) - delete_entries( - get_extra_entries( - [ - Entry( - entry_id=row["entry_id"], - session_id=session_id, - source=row["source"], - role=Role(row["role"]), - name=row["name"], - content=row["content"], - created_at=row["created_at"], - timestamp=row["timestamp"], - ) - for _, row in get_toplevel_entries_query( - session_id=session_id - ).iterrows() - ], - token_count_threshold, - ), - ) + # delete_entries( + # get_extra_entries( + # [ + # Entry( + # entry_id=row["entry_id"], + # session_id=session_id, + # source=row["source"], + # role=Role(row["role"]), + # name=row["name"], + # content=row["content"], + # created_at=row["created_at"], + # timestamp=row["timestamp"], + # ) + # for _, row in get_toplevel_entries_query( + # session_id=session_id + # ).iterrows() + # ], + # token_count_threshold, + # ), + # ) diff --git a/agents-api/agents_api/clients/temporal.py b/agents-api/agents_api/clients/temporal.py index 7e45b50d7..72a5056c8 100644 --- a/agents-api/agents_api/clients/temporal.py +++ b/agents-api/agents_api/clients/temporal.py @@ -13,7 +13,11 @@ from ..worker.codec import pydantic_data_converter -async def get_client(): +async def get_client( + worker_url: str = temporal_worker_url, + namespace: str = temporal_namespace, + data_converter=pydantic_data_converter, +): tls_config = False if temporal_private_key and temporal_client_cert: @@ -23,15 +27,17 @@ async def get_client(): ) return await Client.connect( - temporal_worker_url, - namespace=temporal_namespace, + worker_url, + namespace=namespace, tls=tls_config, - data_converter=pydantic_data_converter, + data_converter=data_converter, ) -async def run_summarization_task(session_id: UUID, job_id: UUID): - client = await get_client() +async def run_summarization_task( + session_id: UUID, job_id: UUID, client: Client | None = None +): + client = client or (await get_client()) await client.execute_workflow( "SummarizationWorkflow", @@ -42,9 +48,13 @@ async def run_summarization_task(session_id: UUID, job_id: UUID): async def run_embed_docs_task( - doc_id: UUID, title: str, content: list[str], job_id: UUID + doc_id: UUID, + title: str, + content: list[str], + job_id: UUID, + client: Client | None = None, ): - client = await get_client() + client = client or (await get_client()) await client.execute_workflow( "EmbedDocsWorkflow", @@ -55,9 +65,12 @@ async def run_embed_docs_task( async def run_truncation_task( - token_count_threshold: int, session_id: UUID, job_id: UUID + token_count_threshold: int, + session_id: UUID, + job_id: UUID, + client: Client | None = None, ): - client = await get_client() + client = client or (await get_client()) await client.execute_workflow( "TruncationWorkflow", @@ -72,8 +85,9 @@ async def run_task_execution_workflow( job_id: UUID, start: tuple[str, int] = ("main", 0), previous_inputs: list[dict] = [], + client: Client | None = None, ): - client = await get_client() + client = client or (await get_client()) return await client.start_workflow( "TaskExecutionWorkflow", diff --git a/agents-api/agents_api/clients/worker/types.py b/agents-api/agents_api/clients/worker/types.py index 02b6add6c..3bf063083 100644 --- a/agents-api/agents_api/clients/worker/types.py +++ b/agents-api/agents_api/clients/worker/types.py @@ -1,108 +1,41 @@ -from typing import Callable, Literal, Optional, Protocol +from typing import Literal from uuid import UUID from pydantic import BaseModel from agents_api.autogen.openapi_model import ( - ChatMLImageContentPart, - ChatMLTextContentPart, + InputChatMLMessage, ) -class PromptModule(Protocol): - stop: list[str] - temperature: float - parser: Callable[[str], str] - make_prompt: Callable[..., str] - - -class ChatML(BaseModel): - role: Literal["system", "user", "assistant", "function_call"] - content: str | dict | list[ChatMLTextContentPart] | list[ChatMLImageContentPart] - - name: Optional[str] = None - entry_id: Optional[UUID] = None - - processed: bool = False - parent_id: Optional[UUID] = None - session_id: Optional[UUID] = None - timestamp: Optional[float] = None - token_count: Optional[int] = None - - -class BaseTask(BaseModel): ... - - -class BaseTaskArgs(BaseModel): ... - - -class MemoryManagementTaskArgs(BaseTaskArgs): +class MemoryManagementTaskArgs(BaseModel): session_id: UUID model: str - dialog: list[ChatML] + dialog: list[InputChatMLMessage] previous_memories: list[str] = [] -class MemoryManagementTask(BaseTask): +class MemoryManagementTask(BaseModel): name: Literal["memory_management.v1"] args: MemoryManagementTaskArgs -class MemoryDensityTaskArgs(BaseTaskArgs): +class MemoryDensityTaskArgs(BaseModel): memory: str -class MemoryDensityTask(BaseTask): +class MemoryDensityTask(BaseModel): name: Literal["memory_density.v1"] args: MemoryDensityTaskArgs -class MemoryRatingTaskArgs(BaseTaskArgs): +class MemoryRatingTaskArgs(BaseModel): memory: str -class MemoryRatingTask(BaseTask): +class MemoryRatingTask(BaseModel): name: Literal["memory_rating.v1"] args: MemoryRatingTaskArgs -class DialogInsightsTaskArgs(BaseTaskArgs): - dialog: list[ChatML] - person1: str - person2: str - - -class DialogInsightsTask(BaseTask): - name: Literal["dialog_insights.v1"] - args: DialogInsightsTaskArgs - - -class RelationshipSummaryTaskArgs(BaseTaskArgs): - statements: list[str] - person1: str - person2: str - - -class RelationshipSummaryTask(BaseTask): - name: Literal["relationship_summary.v1"] - args: RelationshipSummaryTaskArgs - - -class SalientQuestionsTaskArgs(BaseTaskArgs): - statements: list[str] - num: int = 3 - - -class SalientQuestionsTask(BaseTask): - name: Literal["salient_questions.v1"] - args: SalientQuestionsTaskArgs - - -CombinedTask = ( - MemoryManagementTask - | MemoryDensityTask - | MemoryRatingTask - | DialogInsightsTask - | RelationshipSummaryTask - | SalientQuestionsTask -) +CombinedTask = MemoryManagementTask | MemoryDensityTask | MemoryRatingTask diff --git a/agents-api/agents_api/common/protocol/entries.py b/agents-api/agents_api/common/protocol/entries.py index 6ef7f70f2..18d63f583 100644 --- a/agents-api/agents_api/common/protocol/entries.py +++ b/agents-api/agents_api/common/protocol/entries.py @@ -27,28 +27,30 @@ class Entry(BaseEntry): token_count: int tokenizer: str = Field(default="character_count") - @computed_field - @property - def token_count(self) -> int: - """Calculates the token count based on the content's character count. The tokenizer 'character_count' divides the length of the content by 3.5 to estimate the token count. Raises NotImplementedError for unknown tokenizers.""" - if self.tokenizer == "character_count": - content_length = 0 - if isinstance(self.content, str): - content_length = len(self.content) - elif isinstance(self.content, dict): - content_length = len(json.dumps(self.content)) - elif isinstance(self.content, list): - for part in self.content: - if isinstance(part, ChatMLTextContentPart): - content_length += len(part.text) - elif isinstance(part, ChatMLImageContentPart): - content_length += ( - LOW_IMAGE_TOKEN_COUNT - if part.image_url.detail == "low" - else HIGH_IMAGE_TOKEN_COUNT - ) - - # Divide the content length by 3.5 to estimate token count based on character count. - return int(content_length // 3.5) - - raise NotImplementedError(f"Unknown tokenizer: {self.tokenizer}") + # TODO: Replace this with a proper implementation. + + # @computed_field + # @property + # def token_count(self) -> int: + # """Calculates the token count based on the content's character count. The tokenizer 'character_count' divides the length of the content by 3.5 to estimate the token count. Raises NotImplementedError for unknown tokenizers.""" + # if self.tokenizer == "character_count": + # content_length = 0 + # if isinstance(self.content, str): + # content_length = len(self.content) + # elif isinstance(self.content, dict): + # content_length = len(json.dumps(self.content)) + # elif isinstance(self.content, list): + # for part in self.content: + # if isinstance(part, ChatMLTextContentPart): + # content_length += len(part.text) + # elif isinstance(part, ChatMLImageContentPart): + # content_length += ( + # LOW_IMAGE_TOKEN_COUNT + # if part.image_url.detail == "low" + # else HIGH_IMAGE_TOKEN_COUNT + # ) + + # # Divide the content length by 3.5 to estimate token count based on character count. + # return int(content_length // 3.5) + + # raise NotImplementedError(f"Unknown tokenizer: {self.tokenizer}") diff --git a/agents-api/agents_api/models/docs/embed_snippets.py b/agents-api/agents_api/models/docs/embed_snippets.py index 5dd4b4457..89750192a 100644 --- a/agents-api/agents_api/models/docs/embed_snippets.py +++ b/agents-api/agents_api/models/docs/embed_snippets.py @@ -37,7 +37,7 @@ def embed_snippets( *, developer_id: UUID, doc_id: UUID, - snippet_indices: list[int] | tuple[int], + snippet_indices: list[int] | tuple[int, ...], embeddings: list[list[float]], embedding_size: int = 1024, ) -> tuple[list[str], dict]: diff --git a/agents-api/agents_api/rec_sum/entities.py b/agents-api/agents_api/rec_sum/entities.py index 11346447c..9992063ff 100644 --- a/agents-api/agents_api/rec_sum/entities.py +++ b/agents-api/agents_api/rec_sum/entities.py @@ -57,7 +57,7 @@ def make_entities_prompt( @retry(stop=stop_after_attempt(2)) async def get_entities( chat_session, - model="gpt-4-turbo", + model="gpt-4o", stop=[" dict: - result = await acompletion( + result = await litellm.acompletion( model=model, messages=messages, **kwargs, diff --git a/agents-api/agents_api/rec_sum/summarize.py b/agents-api/agents_api/rec_sum/summarize.py index 97f39905b..f98f35094 100644 --- a/agents-api/agents_api/rec_sum/summarize.py +++ b/agents-api/agents_api/rec_sum/summarize.py @@ -44,7 +44,7 @@ def make_summarize_prompt(session, user="a user", assistant="gpt-4-turbo", **_): @retry(stop=stop_after_attempt(2)) async def summarize_messages( chat_session, - model="gpt-4-turbo", + model="gpt-4o", stop=[" None: - return await workflow.execute_activity( - co_density, - memory, - schedule_to_close_timeout=timedelta(seconds=600), - ) diff --git a/agents-api/agents_api/workflows/mem_mgmt.py b/agents-api/agents_api/workflows/mem_mgmt.py index 2db9f95da..31c973741 100644 --- a/agents-api/agents_api/workflows/mem_mgmt.py +++ b/agents-api/agents_api/workflows/mem_mgmt.py @@ -7,14 +7,17 @@ with workflow.unsafe.imports_passed_through(): from ..activities.mem_mgmt import mem_mgmt - from ..activities.types import ChatML + from ..autogen.openapi_model import InputChatMLMessage @workflow.defn class MemMgmtWorkflow: @workflow.run async def run( - self, dialog: list[ChatML], session_id: str, previous_memories: list[str] + self, + dialog: list[InputChatMLMessage], + session_id: str, + previous_memories: list[str], ) -> None: return await workflow.execute_activity( mem_mgmt, diff --git a/agents-api/tests/fixtures.py b/agents-api/tests/fixtures.py index 1b3b1000a..d8acec61c 100644 --- a/agents-api/tests/fixtures.py +++ b/agents-api/tests/fixtures.py @@ -6,6 +6,7 @@ from litellm.types.utils import Choices, ModelResponse from pycozo import Client as CozoClient from temporalio.client import WorkflowHandle +from temporalio.testing import ActivityEnvironment, WorkflowEnvironment from ward import fixture from agents_api.autogen.openapi_model import ( @@ -33,6 +34,7 @@ from agents_api.models.user.create_user import create_user from agents_api.models.user.delete_user import delete_user from agents_api.web import app +from agents_api.worker.worker import create_worker EMBEDDING_SIZE: int = 1024 @@ -49,6 +51,31 @@ def cozo_client(migrations_dir: str = "./migrations"): return client +@fixture(scope="test") +def activity_environment(): + return ActivityEnvironment() + + +@fixture(scope="global") +async def workflow_environment(): + wf_env = await WorkflowEnvironment.start_local() + return wf_env + + +@fixture(scope="global") +async def temporal_client(wf_env=workflow_environment): + return wf_env.client + + +@fixture(scope="global") +async def temporal_worker(temporal_client=temporal_client): + worker = await create_worker(client=temporal_client) + + async with worker as running_worker: + yield running_worker + await running_worker.shutdown() + + @fixture(scope="global") def test_developer_id(cozo_client=cozo_client): developer_id = uuid4() diff --git a/agents-api/tests/test_activities.py b/agents-api/tests/test_activities.py index 2723911cc..a5e35760b 100644 --- a/agents-api/tests/test_activities.py +++ b/agents-api/tests/test_activities.py @@ -1,13 +1,59 @@ -# import time -# import uuid +from ward import test -# from ward import test +from agents_api.activities.embed_docs import embed_docs + +from .fixtures import ( + cozo_client, + patch_embed_acompletion, + temporal_client, + temporal_worker, + test_developer_id, + test_doc, + workflow_environment, +) # from agents_api.activities.truncation import get_extra_entries # from agents_api.autogen.openapi_model import Role # from agents_api.common.protocol.entries import Entry +@test("activity: embed_docs") +async def _( + cozo_client=cozo_client, + developer_id=test_developer_id, + doc=test_doc, + mocks=patch_embed_acompletion, +): + (embed, _) = mocks + + title = "title" + content = ["content 1"] + include_title = True + + await embed_docs( + developer_id=developer_id, + doc_id=doc.id, + title=title, + content=content, + include_title=include_title, + cozo_client=cozo_client, + ) + + embed.assert_called_once() + + +@test("activity: check that workflow environment and worker are started correctly") +async def _( + workflow_environment=workflow_environment, + worker=temporal_worker, + client=temporal_client, +): + async with workflow_environment as wf_env: + assert wf_env is not None + assert worker is not None + assert worker.is_running + + # @test("get extra entries, do not strip system message") # def _(): # session_ids = [uuid.uuid4()] * 3