From de88a33c64274c14aac62186aa0faafc6c54aeb9 Mon Sep 17 00:00:00 2001 From: HamadaSalhab Date: Fri, 3 Jan 2025 14:19:43 +0300 Subject: [PATCH] fix(agents-api): Fix `session.create` & Add `session.update` system tools --- .../agents_api/activities/execute_system.py | 43 +++++++++++++++++-- 1 file changed, 39 insertions(+), 4 deletions(-) diff --git a/agents-api/agents_api/activities/execute_system.py b/agents-api/agents_api/activities/execute_system.py index b91c52e76..b1aabaea4 100644 --- a/agents-api/agents_api/activities/execute_system.py +++ b/agents-api/agents_api/activities/execute_system.py @@ -16,9 +16,10 @@ HybridDocSearchRequest, SystemDef, TextOnlyDocSearchRequest, + UpdateSessionRequest, VectorDocSearchRequest, ) -from ..common.protocol.tasks import StepContext +from ..common.protocol.tasks import ExecutionInput, StepContext from ..common.storage_handler import auto_blob_store, load_from_blob_store_if_remote from ..env import testing from ..models.developer import get_developer @@ -40,6 +41,10 @@ async def execute_system( if set(arguments.keys()) == {"bucket", "key"}: arguments = await load_from_blob_store_if_remote(arguments) + if not isinstance(context.execution_input, ExecutionInput): + raise TypeError( + "Expected ExecutionInput type for context.execution_input") + arguments["developer_id"] = context.execution_input.developer_id # Unbox all the arguments @@ -91,7 +96,8 @@ async def execute_system( # Handle chat operations if system.operation == "chat" and system.resource == "session": - developer = get_developer(developer_id=arguments.get("developer_id")) + developer = get_developer( + developer_id=arguments.get("developer_id")) session_id = arguments.get("session_id") x_custom_api_key = arguments.get("x_custom_api_key", None) chat_input = ChatInput(**arguments) @@ -106,10 +112,11 @@ async def execute_system( await bg_runner() return res + # Handle create operations if system.operation == "create" and system.resource == "session": developer_id = arguments.pop("developer_id") session_id = arguments.pop("session_id", None) - data = CreateSessionRequest(**arguments) + create_session_request = CreateSessionRequest(**arguments) # In case sessions.create becomes asynchronous in the future if asyncio.iscoroutinefunction(handler): @@ -118,7 +125,35 @@ async def execute_system( # Run the synchronous function in another process loop = asyncio.get_running_loop() return await loop.run_in_executor( - process_pool_executor, partial(handler, developer_id, session_id, data) + process_pool_executor, + partial( + handler, + developer_id=developer_id, + session_id=session_id, + data=create_session_request, + ), + ) + + # Handle update operations + if system.operation == "update" and system.resource == "session": + developer_id = arguments.pop("developer_id") + session_id = arguments.pop("session_id") + update_session_request = UpdateSessionRequest(**arguments) + + # In case sessions.update becomes asynchronous in the future + if asyncio.iscoroutinefunction(handler): + return await handler() + + # Run the synchronous function in another process + loop = asyncio.get_running_loop() + return await loop.run_in_executor( + process_pool_executor, + partial( + handler, + developer_id=developer_id, + session_id=session_id, + data=update_session_request, + ), ) # Handle regular operations