Skip to content

Commit

Permalink
feat(agents-api): Add some workflow tests
Browse files Browse the repository at this point in the history
Signed-off-by: Diwank Tomer <[email protected]>
  • Loading branch information
Diwank Tomer committed Aug 14, 2024
1 parent be18ec6 commit 3a38e70
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 28 deletions.
18 changes: 3 additions & 15 deletions agents-api/agents_api/activities/task_steps/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
InputChatMLMessage,
PromptStep,
ToolCallStep,
UpdateExecutionRequest,
YieldStep,
)
from ...clients import (
Expand All @@ -25,9 +24,6 @@
from ...models.execution.create_execution_transition import (
create_execution_transition as create_execution_transition_query,
)
from ...models.execution.update_execution import (
update_execution as update_execution_query,
)


@activity.defn
Expand Down Expand Up @@ -134,8 +130,7 @@ async def transition_step(
"cancelled",
] = "awaiting_input",
):
print("Running transition step")
# raise NotImplementedError()
activity.heartbeat("Running transition step")

# Get transition info
transition_data = transition_info.model_dump(by_alias=False)
Expand All @@ -150,16 +145,9 @@ async def transition_step(
developer_id=context.developer_id,
execution_id=context.execution.id,
transition_id=uuid4(),
**transition_data,
)

update_execution_query(
developer_id=context.developer_id,
update_execution_status=True,
task_id=context.task.id,
execution_id=context.execution.id,
data=UpdateExecutionRequest(
status=execution_status,
),
**transition_data,
)

# Raise if it's a waiting step
Expand Down
2 changes: 1 addition & 1 deletion agents-api/agents_api/common/protocol/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
class ExecutionInput(BaseModel):
developer_id: UUID
execution: Execution
task: TaskSpec
task: TaskSpecDef
agent: Agent
tools: list[Tool]
arguments: dict[str, Any]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@
from pycozo.client import QueryException
from pydantic import ValidationError

from ...autogen.openapi_model import CreateTransitionRequest, Transition
from ...autogen.openapi_model import (
CreateTransitionRequest,
Transition,
UpdateExecutionRequest,
)
from ...common.utils.cozo import cozo_process_mutate_data
from ..utils import (
cozo_query,
Expand All @@ -15,6 +19,7 @@
verify_developer_owns_resource_query,
wrap_in_class,
)
from .update_execution import update_execution

valid_transitions = {
# Start state
Expand All @@ -29,6 +34,16 @@
"step": ["wait", "error", "step", "finish", "cancelled"],
}

transition_to_execution_status = {
"init": "queued",
"wait": "awaiting_input",
"resume": "running",
"step": "running",
"finish": "succeeded",
"error": "failed",
"cancelled": "cancelled",
}


@rewrap_exceptions(
{
Expand All @@ -46,17 +61,22 @@ def create_execution_transition(
*,
developer_id: UUID,
execution_id: UUID,
transition_id: UUID | None = None,
data: CreateTransitionRequest,
# Only one of these needed
transition_id: UUID | None = None,
task_token: str | None = None,
# Only required for updating the execution status as well
update_execution_status: bool = False,
task_id: UUID | None = None,
) -> tuple[list[str], dict]:
transition_id = transition_id or uuid4()

data.metadata = data.metadata or {}
data.execution_id = execution_id

# Prepare the transition data
transition_data = data.model_dump(exclude_unset=True)
columns, values = cozo_process_mutate_data(
columns, transition_values = cozo_process_mutate_data(
{
**transition_data,
"task_token": task_token,
Expand Down Expand Up @@ -87,8 +107,9 @@ def create_execution_transition(
:assert some
"""

# Prepare the insert query
insert_query = f"""
?[{columns}] <- $values
?[{columns}] <- $transition_values
:insert transitions {{
{columns}
Expand All @@ -97,6 +118,29 @@ def create_execution_transition(
:returning
"""

validate_status_query, update_execution_query, update_execution_params = (
"",
"",
{},
)

if update_execution_status:
assert (
task_id is not None
), "task_id is required for updating the execution status"

# Prepare the execution update query
[*_, validate_status_query, update_execution_query], update_execution_params = (
update_execution.__wrapped__(
developer_id=developer_id,
task_id=task_id,
execution_id=execution_id,
data=UpdateExecutionRequest(
status=transition_to_execution_status[data.type]
),
)
)

queries = [
verify_developer_id_query(developer_id),
verify_developer_owns_resource_query(
Expand All @@ -105,15 +149,18 @@ def create_execution_transition(
execution_id=execution_id,
parents=[("agents", "agent_id"), ("tasks", "task_id")],
),
validate_status_query,
update_execution_query,
check_last_transition_query,
insert_query,
]

return (
queries,
{
"values": values,
"transition_values": transition_values,
"next_type": data.type,
"valid_transitions": valid_transitions,
**update_execution_params,
},
)
28 changes: 28 additions & 0 deletions agents-api/tests/test_execution_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,3 +83,31 @@ def _(client=cozo_client, developer_id=test_developer_id, execution=test_executi
assert result is not None
assert result.type == "step"
assert result.output == {"result": "test"}


@test("model: create execution transition with execution update")
def _(
client=cozo_client,
developer_id=test_developer_id,
task=test_task,
execution=test_execution,
):
result = create_execution_transition(
developer_id=developer_id,
execution_id=execution.id,
data=CreateTransitionRequest(
**{
"type": "step",
"output": {"result": "test"},
"current": ["main", 0],
"next": None,
}
),
task_id=task.id,
update_execution_status=True,
client=client,
)

assert result is not None
assert result.type == "step"
assert result.output == {"result": "test"}
14 changes: 7 additions & 7 deletions agents-api/tests/test_task_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,45 +3,45 @@

from ward import test

from tests.fixtures import client, make_request, test_execution, test_task
from tests.fixtures import client, make_request, test_agent


@test("route: unauthorized should fail")
def _(client=client):
def _(client=client, agent=test_agent):
data = dict(
name="test user",
main={
"kind_": "evaluate",
"evaluate": {
"additionalProp1": "value1",
}
},
},
)

response = client.request(
method="POST",
url="/tasks",
url=f"/agents/{str(agent.id)}/tasks",
data=data,
)

assert response.status_code == 403


@test("route: create task")
def _(make_request=make_request):
def _(make_request=make_request, agent=test_agent):
data = dict(
name="test user",
main={
"kind_": "evaluate",
"evaluate": {
"additionalProp1": "value1",
}
},
},
)

response = make_request(
method="POST",
url="/tasks",
url=f"/agents/{str(agent.id)}/tasks",
json=data,
)

Expand Down

0 comments on commit 3a38e70

Please sign in to comment.