Skip to content

Commit

Permalink
feat: Choose the activity based on tool type
Browse files Browse the repository at this point in the history
  • Loading branch information
whiterabbit1983 committed Oct 17, 2024
1 parent 000fc19 commit 44e4c13
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 122 deletions.
2 changes: 1 addition & 1 deletion agents-api/agents_api/activities/execute_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ async def execute_system(
context: StepContext,
system: SystemDef,
) -> Any:
arguments = system.arguments
arguments = system.arguments or {}
arguments["developer_id"] = context.execution_input.developer_id

# Unbox all the arguments
Expand Down
36 changes: 4 additions & 32 deletions agents-api/agents_api/activities/task_steps/prompt_step.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from beartype import beartype
from litellm.types.utils import Function
from temporalio import activity
from temporalio.exceptions import ApplicationError

Expand Down Expand Up @@ -111,13 +110,6 @@ async def prompt_step(context: StepContext) -> StepOutcome:
direction="desc",
)

### [Function(...), ApiCall(...), Integration(...)]
### -> [{"type": "function", "function": {...}}, {"type": "api_call", "api_call": {...}}, {"type": "integration", "integration": {...}}]
### -> [{"type": "function", "function": {...}}]
### -> openai

# Format agent_tools for litellm
# COMMENT(oct-16): Format the tools for openai api here (api_call | integration | system) -> function
formatted_agent_tools = [
func_call for tool in agent_tools if (func_call := make_function_call(tool))
]
Expand Down Expand Up @@ -157,30 +149,10 @@ async def prompt_step(context: StepContext) -> StepOutcome:
response = choice.message.content

if choice.finish_reason == "tool_calls":
choice.message.tool_calls = [
call if isinstance(tc.function, dict) else tc.function.name
for tc in choice.message.tool_calls
if (
call := (
tools_mapping.get(
tc.function["name"]
if isinstance(tc.function, dict)
else tc.function.name
)
)
)
]

### response.choices[0].finish_reason == "tool_calls"
### -> response.choices[0].message.tool_calls
### -> [{"id": "call_abc", "name": "my_function", "arguments": "..."}, ...]
### (cross-reference with original agent_tools list to get the original tool)
###
### -> FunctionCall(...) | ApiCall(...) | IntegrationCall(...) | SystemCall(...)
### -> set this on response.choices[0].tool_calls

# COMMENT(oct-16): Reference the original tool from tools passed to the activity
# if openai chooses to use a tool (finish_reason == "tool_calls")
tc = choice.message.tool_calls[0]
choice.message.tool_calls = tools_mapping.get(
tc.function["name"] if isinstance(tc.function, dict) else tc.function.name
)

return StepOutcome(
output=response.model_dump() if hasattr(response, "model_dump") else response,
Expand Down
2 changes: 2 additions & 0 deletions agents-api/agents_api/common/exceptions/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"""

import asyncio
from typing import cast

import beartype
import beartype.roar
Expand Down Expand Up @@ -121,6 +122,7 @@ def is_non_retryable_error(error: BaseException) -> bool:

# Check for specific HTTP errors (status code == 429)
if isinstance(error, httpx.HTTPStatusError):
error = cast(httpx.HTTPStatusError, error)
if error.response.status_code in (408, 429, 503, 504):
return False

Expand Down
174 changes: 85 additions & 89 deletions agents-api/agents_api/workflows/task_execution/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,71 @@
# Probably can be implemented much more efficiently


async def run_api_call(tool_call: dict, context: StepContext):
call = tool_call["api_call"]
tool_name = call["name"]
arguments = call["arguments"]
apicall_spec = next((t for t in context.tools if t.name == tool_name), None)

if apicall_spec is None:
raise ApplicationError(f"Integration {tool_name} not found")

api_call = ApiCallDef(
method=apicall_spec.spec["method"],
url=apicall_spec.spec["url"],
headers=apicall_spec.spec["headers"],
follow_redirects=apicall_spec.spec["follow_redirects"],
)

if "json_" in arguments:
arguments["json"] = arguments["json_"]
del arguments["json_"]

return await workflow.execute_activity(
execute_api_call,
args=[
api_call,
arguments,
],
schedule_to_close_timeout=timedelta(seconds=30 if debug or testing else 600),
)


async def run_integration_call(tool_call: dict, context: StepContext):
call = tool_call["integration"]
tool_name = call["name"]
arguments = call["arguments"]
integration_spec = next((t for t in context.tools if t.name == tool_name), None)

if integration_spec is None:
raise ApplicationError(f"Integration {tool_name} not found")

integration = IntegrationDef(
provider=integration_spec.spec["provider"],
setup=integration_spec.spec["setup"],
method=integration_spec.spec["method"],
arguments=arguments,
)

return await workflow.execute_activity(
execute_integration,
args=[context, tool_name, integration, arguments],
schedule_to_close_timeout=timedelta(seconds=30 if debug or testing else 600),
retry_policy=DEFAULT_RETRY_POLICY,
)


async def run_system_call(tool_call: dict, context: StepContext):
call = tool_call.get("system")

system_call = SystemDef(**call)
return await workflow.execute_activity(
execute_system,
args=[context, system_call],
schedule_to_close_timeout=timedelta(seconds=30 if debug or testing else 600),
)


# Main workflow definition
@workflow.defn
class TaskExecutionWorkflow:
Expand Down Expand Up @@ -382,26 +447,23 @@ async def run(
message = response["choices"][0]["message"]
tool_calls_input = message["tool_calls"]

### COMMENT(oct-16): do a match-case on tool_calls_input.type
### -> FunctionCall(...), ApiCall(...), IntegrationCall(...), SystemCall(...)
### -> if api_call:
### => execute_api_call(api_call)
### -> if integration_call:
### => execute_integration(integration_call)
### -> if system_call:
### => execute_system(system_call)
### -> else:
### => wait for input

# Enter a wait-for-input step to ask the developer to run the tool calls
tool_calls_results = await workflow.execute_activity(
task_steps.raise_complete_async,
args=[context, tool_calls_input],
schedule_to_close_timeout=timedelta(days=31),
retry_policy=DEFAULT_RETRY_POLICY,
)

### COMMENT(oct-16): Continue as usual. Feed the tool call results back to the model
if tool_calls_input.get("api_call"):
tool_calls_results = await run_api_call(tool_calls_input, context)
elif tool_calls_input.get("integration"):
tool_calls_results = await run_integration_call(
tool_calls_input, context
)
elif tool_calls_input.get("system"):
tool_calls_results = await run_system_call(
tool_calls_input, context
)
else:
tool_calls_results = await workflow.execute_activity(
task_steps.raise_complete_async,
args=[context, tool_calls_input],
schedule_to_close_timeout=timedelta(days=31),
retry_policy=DEFAULT_RETRY_POLICY,
)

# Feed the tool call results back to the model
context.current_step.prompt.append(message)
Expand Down Expand Up @@ -453,85 +515,19 @@ async def run(
case ToolCallStep(), StepOutcome(output=tool_call) if tool_call[
"type"
] == "integration":
call = tool_call["integration"]
tool_name = call["name"]
arguments = call["arguments"]
integration_spec = next(
(t for t in context.tools if t.name == tool_name), None
)

if integration_spec is None:
raise ApplicationError(f"Integration {tool_name} not found")

integration = IntegrationDef(
provider=integration_spec.spec["provider"],
setup=integration_spec.spec["setup"],
method=integration_spec.spec["method"],
arguments=arguments,
)

tool_call_response = await workflow.execute_activity(
execute_integration,
args=[context, tool_name, integration, arguments],
schedule_to_close_timeout=timedelta(
seconds=30 if debug or testing else 600
),
retry_policy=DEFAULT_RETRY_POLICY,
)

tool_call_response = await run_integration_call(tool_call, context)
state = PartialTransition(output=tool_call_response)

case ToolCallStep(), StepOutcome(output=tool_call) if tool_call[
"type"
] == "api_call":
call = tool_call["api_call"]
tool_name = call["name"]
arguments = call["arguments"]
apicall_spec = next(
(t for t in context.tools if t.name == tool_name), None
)

if apicall_spec is None:
raise ApplicationError(f"Integration {tool_name} not found")

api_call = ApiCallDef(
method=apicall_spec.spec["method"],
url=apicall_spec.spec["url"],
headers=apicall_spec.spec["headers"],
follow_redirects=apicall_spec.spec["follow_redirects"],
)

if "json_" in arguments:
arguments["json"] = arguments["json_"]
del arguments["json_"]

# Execute the API call using the `execute_api_call` function
tool_call_response = await workflow.execute_activity(
execute_api_call,
args=[
api_call,
arguments,
],
schedule_to_close_timeout=timedelta(
seconds=30 if debug or testing else 600
),
)

tool_call_response = await run_api_call(tool_call, context)
state = PartialTransition(output=tool_call_response)

case ToolCallStep(), StepOutcome(output=tool_call) if tool_call[
"type"
] == "system":
call = tool_call.get("system")

system_call = SystemDef(**call)
tool_call_response = await workflow.execute_activity(
execute_system,
args=[context, system_call],
schedule_to_close_timeout=timedelta(
seconds=30 if debug or testing else 600
),
)
tool_call_response = await run_system_call(tool_call, context)

# FIXME: This is a hack to make the output of the system call match
# the expected output format (convert uuid/datetime to strings)
Expand Down

0 comments on commit 44e4c13

Please sign in to comment.