Skip to content

Commit

Permalink
feat(agents-api): Add system tool type
Browse files Browse the repository at this point in the history
  • Loading branch information
HamadaSalhab committed Oct 4, 2024
1 parent 9022023 commit 3822af9
Show file tree
Hide file tree
Showing 6 changed files with 290 additions and 22 deletions.
163 changes: 163 additions & 0 deletions agents-api/agents_api/activities/execute_system.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
from typing import Any
from uuid import UUID

from beartype import beartype
from temporalio import activity

from ..autogen.Tools import SystemDef
from ..common.protocol.tasks import StepContext
from ..env import testing
from ..models.agent.create_agent import create_agent as create_agent_query
from ..models.agent.delete_agent import delete_agent as delete_agent_query
from ..models.agent.get_agent import get_agent as get_agent_query
from ..models.agent.list_agents import list_agents as list_agents_query
from ..models.agent.update_agent import update_agent as update_agent_query
from ..models.docs.create_doc import create_doc as create_doc_query
from ..models.docs.delete_doc import delete_doc as delete_doc_query
from ..models.docs.get_doc import get_doc as get_doc_query
from ..models.docs.list_docs import list_docs as list_docs_query
from ..models.session.create_session import create_session as create_session_query
from ..models.session.delete_session import delete_session as delete_session_query
from ..models.session.get_session import get_session as get_session_query
from ..models.session.list_sessions import list_sessions as list_sessions_query
from ..models.session.update_session import update_session as update_session_query
from ..models.task.create_task import create_task as create_task_query
from ..models.task.delete_task import delete_task as delete_task_query
from ..models.task.get_task import get_task as get_task_query
from ..models.task.list_tasks import list_tasks as list_tasks_query
from ..models.task.update_task import update_task as update_task_query
from ..models.user.create_user import create_user as create_user_query
from ..models.user.delete_user import delete_user as delete_user_query
from ..models.user.get_user import get_user as get_user_query
from ..models.user.list_users import list_users as list_users_query
from ..models.user.update_user import update_user as update_user_query


@beartype
async def execute_system(
context: StepContext,
system: SystemDef,
) -> Any:
arguments = system.arguments
arguments["developer_id"] = context.execution_input.developer_id

# Convert all UUIDs to UUID objects
if "agent_id" in arguments:
arguments["agent_id"] = UUID(arguments["agent_id"])
if "user_id" in arguments:
arguments["user_id"] = UUID(arguments["user_id"])
if "task_id" in arguments:
arguments["task_id"] = UUID(arguments["task_id"])
if "session_id" in arguments:
arguments["session_id"] = UUID(arguments["session_id"])
if "doc_id" in arguments:
arguments["doc_id"] = UUID(arguments["doc_id"])

# FIXME: This is a total mess. Should be refactored.
try:
# AGENTS
if system.resource == "agent":
# DOCS SUBRESOURCE
if system.subresource == "doc":
# Define the arguments for the agent doc queries
agent_doc_args = {
**{
"owner_type": "agent",
"owner_id": arguments.pop("agent_id"),
},
**arguments,
}
if system.operation == "list":
return list_docs_query(**agent_doc_args)
elif system.operation == "create":
return create_doc_query(**agent_doc_args)
elif system.operation == "delete":
return delete_doc_query(**agent_doc_args)

# NO SUBRESOURCE
elif system.subresource == None:
if system.operation == "list":
return list_agents_query(**arguments)
elif system.operation == "get":
return get_agent_query(**arguments)
elif system.operation == "create":
return create_agent_query(**arguments)
elif system.operation == "update":
return update_agent_query(**arguments)
elif system.operation == "delete":
return delete_agent_query(**arguments)

# USERS
elif system.resource == "user":
# DOCS SUBRESOURCE
if system.subresource == "doc":
# Define the arguments for the user doc queries
user_doc_args = {
**{
"owner_type": "user",
"owner_id": arguments.pop("user_id"),
},
**arguments,
}
if system.operation == "list":
return list_docs_query(**user_doc_args)
elif system.operation == "create":
return create_doc_query(**user_doc_args)
elif system.operation == "delete":
return delete_doc_query(**user_doc_args)

# NO SUBRESOURCE
elif system.subresource == None:
if system.operation == "list":
return list_users_query(**arguments)
elif system.operation == "get":
return get_user_query(**arguments)
elif system.operation == "create":
return create_user_query(**arguments)
elif system.operation == "update":
return update_user_query(**arguments)
elif system.operation == "delete":
return delete_user_query(**arguments)

# SESSIONS
elif system.resource == "session":
if system.operation == "list":
return list_sessions_query(**arguments)
elif system.operation == "get":
return get_session_query(**arguments)
elif system.operation == "create":
return create_session_query(**arguments)
elif system.operation == "update":
return update_session_query(**arguments)
elif system.operation == "delete":
return update_session_query(**arguments)
elif system.operation == "delete":
return delete_session_query(**arguments)
# TASKS
elif system.resource == "task":
if system.operation == "list":
return list_tasks_query(**arguments)
elif system.operation == "get":
return get_task_query(**arguments)
elif system.operation == "create":
return create_task_query(**arguments)
elif system.operation == "update":
return update_task_query(**arguments)
elif system.operation == "delete":
return delete_task_query(**arguments)

raise NotImplementedError(f"System call not implemented for {
system.resource}.{system.operation}")

except BaseException as e:
if activity.in_activity():
activity.logger.error(f"Error in execute_system_call: {e}")
raise


# Mock and activity definition
mock_execute_system = execute_system

execute_system = activity.defn(name="execute_system")(
execute_system if not testing else mock_execute_system
)
27 changes: 18 additions & 9 deletions agents-api/agents_api/activities/task_steps/prompt_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from temporalio import activity
from temporalio.exceptions import ApplicationError

from ...autogen.Tools import Tool
from ...clients import (
litellm, # We dont directly import `acompletion` so we can mock it
)
Expand All @@ -10,6 +11,22 @@
from ...models.tools.list_tools import list_tools


# FIXME: This shouldn't be here.
def format_agent_tool(tool: Tool) -> dict:
if tool.function:
return {
"type": "function",
"function": {
"name": tool.name,
"description": tool.description,
"parameters": tool.function.parameters,
},
}
# TODO: Add integration | system | api_call tool types
else:
return {}


@activity.defn
@beartype
async def prompt_step(context: StepContext) -> StepOutcome:
Expand Down Expand Up @@ -46,15 +63,7 @@ async def prompt_step(context: StepContext) -> StepOutcome:

# Format agent_tools for litellm
formatted_agent_tools = [
{
"type": tool.type,
"function": {
"name": tool.function.name,
"description": tool.function.description,
"parameters": tool.function.parameters,
},
}
for tool in agent_tools
format_agent_tool(tool) for tool in agent_tools if format_agent_tool(tool)
]

if context.current_step.settings:
Expand Down
33 changes: 23 additions & 10 deletions agents-api/agents_api/activities/task_steps/tool_call_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,14 @@
from temporalio.exceptions import ApplicationError

from ...activities.task_steps.base_evaluate import base_evaluate
from ...autogen.openapi_model import Tool, ToolCallStep
from ...autogen.openapi_model import TaskToolDef, Tool, ToolCallStep
from ...common.protocol.tasks import (
StepContext,
StepOutcome,
)


# FIXME: This shouldn't be here.
def generate_call_id():
# Generate 18 random bytes (which will result in 24 base64 characters)
random_bytes = secrets.token_bytes(18)
Expand All @@ -22,6 +23,26 @@ def generate_call_id():
return f"call_{base64_string}"


# FIXME: This shouldn't be here, and shouldn't be done this way. Should be refactored.
def construct_tool_call(tool: TaskToolDef, arguments: dict, call_id: str) -> dict:
return {
tool.type: {
"arguments": arguments,
"name": tool.name,
}
if tool.type != "system"
else {
"resource": tool.spec["resource"],
"operation": tool.spec["operation"],
"resource_id": tool.spec["resource_id"],
"subresource": tool.spec["subresource"],
"arguments": arguments,
},
"id": call_id,
"type": tool.type,
}


@activity.defn
@beartype
async def tool_call_step(context: StepContext) -> StepOutcome:
Expand All @@ -40,14 +61,6 @@ async def tool_call_step(context: StepContext) -> StepOutcome:
)

call_id = generate_call_id()

tool_call = {
tool.type: {
"arguments": arguments,
"name": tool_name,
},
"id": call_id,
"type": tool.type,
}
tool_call = construct_tool_call(tool, arguments, call_id)

return StepOutcome(output=tool_call)
2 changes: 2 additions & 0 deletions agents-api/agents_api/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def create_worker(client: Client) -> Any:
from ..activities.embed_docs import embed_docs
from ..activities.excecute_api_call import execute_api_call
from ..activities.execute_integration import execute_integration
from ..activities.execute_system import execute_system
from ..activities.mem_mgmt import mem_mgmt
from ..activities.mem_rating import mem_rating
from ..activities.summarization import summarization
Expand Down Expand Up @@ -53,6 +54,7 @@ def create_worker(client: Client) -> Any:
demo_activity,
embed_docs,
execute_integration,
execute_system,
execute_api_call,
mem_mgmt,
mem_rating,
Expand Down
27 changes: 24 additions & 3 deletions agents-api/agents_api/workflows/task_execution/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from ...activities import task_steps
from ...activities.excecute_api_call import execute_api_call
from ...activities.execute_integration import execute_integration
from ...activities.execute_system import execute_system
from ...autogen.openapi_model import (
ApiCallDef,
EmbedStep,
Expand All @@ -39,6 +40,7 @@
WorkflowStep,
YieldStep,
)
from ...autogen.Tools import SystemDef
from ...common.protocol.tasks import (
ExecutionInput,
PartialTransition,
Expand Down Expand Up @@ -545,9 +547,28 @@ async def run(

state = PartialTransition(output=tool_call_response)

case ToolCallStep(), StepOutcome(output=_):
# FIXME: Handle system/api_call tool_calls
raise ApplicationError("Not implemented")
case ToolCallStep(), StepOutcome(output=tool_call) if tool_call[
"type"
] == "system":
call = tool_call.get("system")

system_call = SystemDef(**call)
tool_call_response = await workflow.execute_activity(
execute_system,
args=[context, system_call],
schedule_to_close_timeout=timedelta(
seconds=30 if debug or testing else 600
),
)

# FIXME: This is a hack to make the output of the system call match
# the expected output format (convert uuid/datetime to strings)
def model_dump(obj):
if isinstance(obj, list):
return [model_dump(item) for item in obj]
return obj.model_dump(mode="json")

state = PartialTransition(output=model_dump(tool_call_response))

case _:
workflow.logger.error(
Expand Down
Loading

0 comments on commit 3822af9

Please sign in to comment.