Skip to content

Commit

Permalink
fix(agents-api): Fix session.create & Add session.update system t…
Browse files Browse the repository at this point in the history
…ools
  • Loading branch information
HamadaSalhab committed Jan 3, 2025
1 parent f8514a2 commit de88a33
Showing 1 changed file with 39 additions and 4 deletions.
43 changes: 39 additions & 4 deletions agents-api/agents_api/activities/execute_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@
HybridDocSearchRequest,
SystemDef,
TextOnlyDocSearchRequest,
UpdateSessionRequest,
VectorDocSearchRequest,
)
from ..common.protocol.tasks import StepContext
from ..common.protocol.tasks import ExecutionInput, StepContext
from ..common.storage_handler import auto_blob_store, load_from_blob_store_if_remote
from ..env import testing
from ..models.developer import get_developer
Expand All @@ -40,6 +41,10 @@ async def execute_system(
if set(arguments.keys()) == {"bucket", "key"}:
arguments = await load_from_blob_store_if_remote(arguments)

if not isinstance(context.execution_input, ExecutionInput):
raise TypeError(
"Expected ExecutionInput type for context.execution_input")

arguments["developer_id"] = context.execution_input.developer_id

# Unbox all the arguments
Expand Down Expand Up @@ -91,7 +96,8 @@ async def execute_system(

# Handle chat operations
if system.operation == "chat" and system.resource == "session":
developer = get_developer(developer_id=arguments.get("developer_id"))
developer = get_developer(
developer_id=arguments.get("developer_id"))
session_id = arguments.get("session_id")
x_custom_api_key = arguments.get("x_custom_api_key", None)
chat_input = ChatInput(**arguments)
Expand All @@ -106,10 +112,11 @@ async def execute_system(
await bg_runner()
return res

# Handle create operations
if system.operation == "create" and system.resource == "session":
developer_id = arguments.pop("developer_id")
session_id = arguments.pop("session_id", None)
data = CreateSessionRequest(**arguments)
create_session_request = CreateSessionRequest(**arguments)

# In case sessions.create becomes asynchronous in the future
if asyncio.iscoroutinefunction(handler):
Expand All @@ -118,7 +125,35 @@ async def execute_system(
# Run the synchronous function in another process
loop = asyncio.get_running_loop()
return await loop.run_in_executor(
process_pool_executor, partial(handler, developer_id, session_id, data)
process_pool_executor,
partial(
handler,
developer_id=developer_id,
session_id=session_id,
data=create_session_request,
),
)

# Handle update operations
if system.operation == "update" and system.resource == "session":
developer_id = arguments.pop("developer_id")
session_id = arguments.pop("session_id")
update_session_request = UpdateSessionRequest(**arguments)

# In case sessions.update becomes asynchronous in the future
if asyncio.iscoroutinefunction(handler):
return await handler()

# Run the synchronous function in another process
loop = asyncio.get_running_loop()
return await loop.run_in_executor(
process_pool_executor,
partial(
handler,
developer_id=developer_id,
session_id=session_id,
data=update_session_request,
),
)

# Handle regular operations
Expand Down

0 comments on commit de88a33

Please sign in to comment.