Skip to content

Commit

Permalink
chore: Refactor create execution transition queries
Browse files Browse the repository at this point in the history
  • Loading branch information
whiterabbit1983 committed Dec 10, 2024
1 parent 58b92a2 commit 1335a46
Showing 1 changed file with 73 additions and 206 deletions.
279 changes: 73 additions & 206 deletions agents-api/agents_api/models/execution/create_execution_transition.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,62 +25,8 @@
from .update_execution import update_execution


def validate_transition_targets(data: CreateTransitionRequest) -> None:
# Make sure the current/next targets are valid
match data.type:
case "finish_branch":
pass # TODO: Implement
case "finish" | "error" | "cancelled":
pass

### FIXME: HACK: Fix this and uncomment

### assert (
### data.next is None
### ), "Next target must be None for finish/finish_branch/error/cancelled"

case "init_branch" | "init":
assert (
data.next and data.current.step == data.next.step == 0
), "Next target must be same as current for init_branch/init and step 0"

case "wait":
assert data.next is None, "Next target must be None for wait"

case "resume" | "step":
assert data.next is not None, "Next target must be provided for resume/step"

if data.next.workflow == data.current.workflow:
assert (
data.next.step > data.current.step
), "Next step must be greater than current"

case _:
raise ValueError(f"Invalid transition type: {data.type}")


@rewrap_exceptions(
{
QueryException: partialclass(HTTPException, status_code=400),
ValidationError: partialclass(HTTPException, status_code=400),
TypeError: partialclass(HTTPException, status_code=400),
}
)
@wrap_in_class(
Transition,
transform=lambda d: {
**d,
"id": d["transition_id"],
"current": {"workflow": d["current"][0], "step": d["current"][1]},
"next": d["next"] and {"workflow": d["next"][0], "step": d["next"][1]},
},
one=True,
_kind="inserted",
)
@cozo_query
@increase_counter("create_execution_transition")
@beartype
def create_execution_transition(
def _create_execution_transition(
*,
developer_id: UUID,
execution_id: UUID,
Expand Down Expand Up @@ -140,7 +86,7 @@ def create_execution_transition(
]
last_transition_type[min_cost(type_created_at)] :=
*transitions:execution_id_type_idx {{
*transitions:execution_id_type_created_at_idx {{
execution_id: to_uuid("{str(execution_id)}"),
type,
created_at,
Expand Down Expand Up @@ -225,167 +171,88 @@ def create_execution_transition(
)


@rewrap_exceptions(
{
QueryException: partialclass(HTTPException, status_code=400),
ValidationError: partialclass(HTTPException, status_code=400),
TypeError: partialclass(HTTPException, status_code=400),
}
)
@wrap_in_class(
Transition,
transform=lambda d: {
**d,
"id": d["transition_id"],
"current": {"workflow": d["current"][0], "step": d["current"][1]},
"next": d["next"] and {"workflow": d["next"][0], "step": d["next"][1]},
},
one=True,
_kind="inserted",
)
@cozo_query_async
@increase_counter("create_execution_transition_async")
@beartype
async def create_execution_transition_async(
*,
developer_id: UUID,
execution_id: UUID,
data: CreateTransitionRequest,
# Only one of these needed
transition_id: UUID | None = None,
task_token: str | None = None,
# Only required for updating the execution status as well
update_execution_status: bool = False,
task_id: UUID | None = None,
) -> tuple[list[str | None], dict]:
transition_id = transition_id or uuid4()
data.metadata = data.metadata or {}
data.execution_id = execution_id

# Dump to json
if isinstance(data.output, list):
data.output = [
item.model_dump(mode="json") if hasattr(item, "model_dump") else item
for item in data.output
]

elif hasattr(data.output, "model_dump"):
data.output = data.output.model_dump(mode="json")

# TODO: This is a hack to make sure the transition is valid
# (parallel transitions are whack, we should do something better)
is_parallel = data.current.workflow.startswith("PAR:")

# Prepare the transition data
transition_data = data.model_dump(exclude_unset=True, exclude={"id"})

# Parse the current and next targets
validate_transition_targets(data)
current_target = transition_data.pop("current")
next_target = transition_data.pop("next")

transition_data["current"] = (current_target["workflow"], current_target["step"])
transition_data["next"] = next_target and (
next_target["workflow"],
next_target["step"],
)

columns, transition_values = cozo_process_mutate_data(
{
**transition_data,
"task_token": str(task_token), # Converting to str for JSON serialisation
"transition_id": str(transition_id),
"execution_id": str(execution_id),
}
)

# Make sure the transition is valid
check_last_transition_query = f"""
valid_transition[start, end] <- [
{", ".join(f'["{start}", "{end}"]' for start, ends in valid_transitions.items() for end in ends)}
]
def validate_transition_targets(data: CreateTransitionRequest) -> None:
# Make sure the current/next targets are valid
match data.type:
case "finish_branch":
pass # TODO: Implement
case "finish" | "error" | "cancelled":
pass

last_transition_type[min_cost(type_created_at)] :=
*transitions {{
execution_id: to_uuid("{str(execution_id)}"),
type,
created_at,
}},
type_created_at = [type, -created_at]
### FIXME: HACK: Fix this and uncomment

matched[collect(last_type)] :=
last_transition_type[data],
last_type_data = first(data),
last_type = if(is_null(last_type_data), "init", last_type_data),
valid_transition[last_type, $next_type]
### assert (
### data.next is None
### ), "Next target must be None for finish/finish_branch/error/cancelled"

?[valid] :=
matched[prev_transitions],
found = length(prev_transitions),
valid = if($next_type == "init", found == 0, found > 0),
assert(valid, "Invalid transition"),
case "init_branch" | "init":
assert (
data.next and data.current.step == data.next.step == 0
), "Next target must be same as current for init_branch/init and step 0"

:limit 1
"""
case "wait":
assert data.next is None, "Next target must be None for wait"

# Prepare the insert query
insert_query = f"""
?[{columns}] <- $transition_values
case "resume" | "step":
assert data.next is not None, "Next target must be provided for resume/step"

:insert transitions {{
{columns}
}}
:returning
"""
if data.next.workflow == data.current.workflow:
assert (
data.next.step > data.current.step
), "Next step must be greater than current"

validate_status_query, update_execution_query, update_execution_params = (
"",
"",
{},
)
case _:
raise ValueError(f"Invalid transition type: {data.type}")

if update_execution_status:
assert (
task_id is not None
), "task_id is required for updating the execution status"

# Prepare the execution update query
[*_, validate_status_query, update_execution_query], update_execution_params = (
update_execution.__wrapped__(
developer_id=developer_id,
task_id=task_id,
execution_id=execution_id,
data=UpdateExecutionRequest(
status=transition_to_execution_status[data.type]
),
output=data.output if data.type != "error" else None,
error=str(data.output)
if data.type == "error" and data.output
else None,
create_execution_transition = rewrap_exceptions(
{
QueryException: partialclass(HTTPException, status_code=400),
ValidationError: partialclass(HTTPException, status_code=400),
TypeError: partialclass(HTTPException, status_code=400),
}
)(
wrap_in_class(
Transition,
transform=lambda d: {
**d,
"id": d["transition_id"],
"current": {"workflow": d["current"][0], "step": d["current"][1]},
"next": d["next"] and {"workflow": d["next"][0], "step": d["next"][1]},
},
one=True,
_kind="inserted",
)(
cozo_query(
increase_counter("create_execution_transition")(
_create_execution_transition
)
)
)
)

queries = [
verify_developer_id_query(developer_id),
verify_developer_owns_resource_query(
developer_id,
"executions",
execution_id=execution_id,
parents=[("agents", "agent_id"), ("tasks", "task_id")],
),
validate_status_query if not is_parallel else None,
update_execution_query if not is_parallel else None,
check_last_transition_query if not is_parallel else None,
insert_query,
]

return (
queries,
{
"transition_values": transition_values,
"next_type": data.type,
"valid_transitions": valid_transitions,
**update_execution_params,
create_execution_transition_async = rewrap_exceptions(
{
QueryException: partialclass(HTTPException, status_code=400),
ValidationError: partialclass(HTTPException, status_code=400),
TypeError: partialclass(HTTPException, status_code=400),
}
)(
wrap_in_class(
Transition,
transform=lambda d: {
**d,
"id": d["transition_id"],
"current": {"workflow": d["current"][0], "step": d["current"][1]},
"next": d["next"] and {"workflow": d["next"][0], "step": d["next"][1]},
},
one=True,
_kind="inserted",
)(
cozo_query_async(
increase_counter("create_execution_transition")(
_create_execution_transition
)
)
)
)

0 comments on commit 1335a46

Please sign in to comment.