Skip to content

Commit

Permalink
feat(agents-api): Add integration support for tool-call step (dummy p…
Browse files Browse the repository at this point in the history
…rovider)

Signed-off-by: Diwank Singh Tomer <[email protected]>
  • Loading branch information
creatorrr committed Sep 25, 2024
1 parent 5b32645 commit 993ec2c
Show file tree
Hide file tree
Showing 4 changed files with 143 additions and 3 deletions.
55 changes: 55 additions & 0 deletions agents-api/agents_api/activities/execute_integration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from typing import Any

from beartype import beartype
from temporalio import activity

from ..autogen.openapi_model import IntegrationDef
from ..common.protocol.tasks import StepContext
from ..env import testing
from ..models.tools import get_tool_args_from_metadata


@beartype
async def execute_integration(
context: StepContext,
tool_name: str,
integration: IntegrationDef,
arguments: dict[str, Any],
) -> Any:
developer_id = context.execution_input.developer_id
agent_id = context.execution_input.agent.id
task_id = context.execution_input.task.id

merged_tool_args = get_tool_args_from_metadata(
developer_id=developer_id, agent_id=agent_id, task_id=task_id
)

arguments = merged_tool_args.get(tool_name, {}) | arguments

try:
if integration.provider == "dummy":
return arguments

else:
raise NotImplementedError(
f"Unknown integration provider: {integration.provider}"
)
except BaseException as e:
if activity.in_activity():
activity.logger.error(f"Error in execute_integration: {e}")

raise


async def mock_execute_integration(
context: StepContext,
tool_name: str,
integration: IntegrationDef,
arguments: dict[str, Any],
) -> Any:
return arguments


execute_integration = activity.defn(name="execute_integration")(
execute_integration if not testing else mock_execute_integration
)
6 changes: 4 additions & 2 deletions agents-api/agents_api/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def create_worker(client: Client) -> Any:
from ..activities import task_steps
from ..activities.demo import demo_activity
from ..activities.embed_docs import embed_docs
from ..activities.execute_integration import execute_integration
from ..activities.mem_mgmt import mem_mgmt
from ..activities.mem_rating import mem_rating
from ..activities.summarization import summarization
Expand Down Expand Up @@ -49,10 +50,11 @@ def create_worker(client: Client) -> Any:
activities=[
*task_activities,
demo_activity,
summarization,
embed_docs,
execute_integration,
mem_mgmt,
mem_rating,
embed_docs,
summarization,
truncation,
],
)
Expand Down
32 changes: 31 additions & 1 deletion agents-api/agents_api/workflows/task_execution/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# Import necessary modules and types
with workflow.unsafe.imports_passed_through():
from ...activities import task_steps
from ...activities.execute_integration import execute_integration
from ...autogen.openapi_model import (
EmbedStep,
ErrorWorkflowStep,
Expand Down Expand Up @@ -460,7 +461,9 @@ async def run(
workflow.logger.error("ParallelStep not yet implemented")
raise ApplicationError("Not implemented")

case ToolCallStep(), StepOutcome(output=tool_call):
case ToolCallStep(), StepOutcome(output=tool_call) if tool_call[
"type"
] == "function":
# Enter a wait-for-input step to ask the developer to run the tool calls
tool_call_response = await workflow.execute_activity(
task_steps.raise_complete_async,
Expand All @@ -470,6 +473,33 @@ async def run(

state = PartialTransition(output=tool_call_response, type="resume")

case ToolCallStep(), StepOutcome(output=tool_call) if tool_call[
"type"
] == "integration":
call = tool_call["integration"]
tool_name = call["name"]
arguments = call["arguments"]
integration = next(
(t for t in context.tools if t.name == tool_name), None
)

if integration is None:
raise ApplicationError(f"Integration {tool_name} not found")

tool_call_response = await workflow.execute_activity(
execute_integration,
args=[context, tool_name, integration, arguments],
schedule_to_close_timeout=timedelta(
seconds=30 if debug or testing else 600
),
)

state = PartialTransition(output=tool_call_response, type="step")

case ToolCallStep(), StepOutcome(output=_):
# FIXME: Handle system/api_call tool_calls
raise ApplicationError("Not implemented")

case _:
workflow.logger.error(
f"Unhandled step type: {type(context.current_step).__name__}"
Expand Down
53 changes: 53 additions & 0 deletions agents-api/tests/test_execution_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,59 @@ async def _(
assert result["hello"] == data.input["test"]


@test("workflow: tool call integration type step")
async def _(
client=cozo_client,
developer_id=test_developer_id,
agent=test_agent,
):
data = CreateExecutionRequest(input={"test": "input"})

task = create_task(
developer_id=developer_id,
agent_id=agent.id,
data=CreateTaskRequest(
**{
"name": "test task",
"description": "test task about",
"input_schema": {"type": "object", "additionalProperties": True},
"tools": [
{
"type": "integration",
"name": "hello",
"integration": {
"provider": "dummy",
},
}
],
"main": [
{
"tool": "hello",
"arguments": {"test": "_.test"},
},
],
}
),
client=client,
)

async with patch_testing_temporal() as (_, mock_run_task_execution_workflow):
execution, handle = await start_execution(
developer_id=developer_id,
task_id=task.id,
data=data,
client=client,
)

assert handle is not None
assert execution.task_id == task.id
assert execution.input == data.input
mock_run_task_execution_workflow.assert_called_once()

result = await handle.result()
assert result["test"] == data.input["test"]


# FIXME: This test is not working. It gets stuck
# @test("workflow: wait for input step start")
async def _(
Expand Down

0 comments on commit 993ec2c

Please sign in to comment.