From 1335a46c22e48c7f04db233d810133a8fb60fdb2 Mon Sep 17 00:00:00 2001 From: Dmitry Paramonov Date: Tue, 10 Dec 2024 11:39:24 +0300 Subject: [PATCH] chore: Refactor create execution transition queries --- .../execution/create_execution_transition.py | 279 +++++------------- 1 file changed, 73 insertions(+), 206 deletions(-) diff --git a/agents-api/agents_api/models/execution/create_execution_transition.py b/agents-api/agents_api/models/execution/create_execution_transition.py index 86a13312a..cd2e2a8e1 100644 --- a/agents-api/agents_api/models/execution/create_execution_transition.py +++ b/agents-api/agents_api/models/execution/create_execution_transition.py @@ -25,62 +25,8 @@ from .update_execution import update_execution -def validate_transition_targets(data: CreateTransitionRequest) -> None: - # Make sure the current/next targets are valid - match data.type: - case "finish_branch": - pass # TODO: Implement - case "finish" | "error" | "cancelled": - pass - - ### FIXME: HACK: Fix this and uncomment - - ### assert ( - ### data.next is None - ### ), "Next target must be None for finish/finish_branch/error/cancelled" - - case "init_branch" | "init": - assert ( - data.next and data.current.step == data.next.step == 0 - ), "Next target must be same as current for init_branch/init and step 0" - - case "wait": - assert data.next is None, "Next target must be None for wait" - - case "resume" | "step": - assert data.next is not None, "Next target must be provided for resume/step" - - if data.next.workflow == data.current.workflow: - assert ( - data.next.step > data.current.step - ), "Next step must be greater than current" - - case _: - raise ValueError(f"Invalid transition type: {data.type}") - - -@rewrap_exceptions( - { - QueryException: partialclass(HTTPException, status_code=400), - ValidationError: partialclass(HTTPException, status_code=400), - TypeError: partialclass(HTTPException, status_code=400), - } -) -@wrap_in_class( - Transition, - transform=lambda d: { - **d, - "id": d["transition_id"], - "current": {"workflow": d["current"][0], "step": d["current"][1]}, - "next": d["next"] and {"workflow": d["next"][0], "step": d["next"][1]}, - }, - one=True, - _kind="inserted", -) -@cozo_query -@increase_counter("create_execution_transition") @beartype -def create_execution_transition( +def _create_execution_transition( *, developer_id: UUID, execution_id: UUID, @@ -140,7 +86,7 @@ def create_execution_transition( ] last_transition_type[min_cost(type_created_at)] := - *transitions:execution_id_type_idx {{ + *transitions:execution_id_type_created_at_idx {{ execution_id: to_uuid("{str(execution_id)}"), type, created_at, @@ -225,167 +171,88 @@ def create_execution_transition( ) -@rewrap_exceptions( - { - QueryException: partialclass(HTTPException, status_code=400), - ValidationError: partialclass(HTTPException, status_code=400), - TypeError: partialclass(HTTPException, status_code=400), - } -) -@wrap_in_class( - Transition, - transform=lambda d: { - **d, - "id": d["transition_id"], - "current": {"workflow": d["current"][0], "step": d["current"][1]}, - "next": d["next"] and {"workflow": d["next"][0], "step": d["next"][1]}, - }, - one=True, - _kind="inserted", -) -@cozo_query_async -@increase_counter("create_execution_transition_async") -@beartype -async def create_execution_transition_async( - *, - developer_id: UUID, - execution_id: UUID, - data: CreateTransitionRequest, - # Only one of these needed - transition_id: UUID | None = None, - task_token: str | None = None, - # Only required for updating the execution status as well - update_execution_status: bool = False, - task_id: UUID | None = None, -) -> tuple[list[str | None], dict]: - transition_id = transition_id or uuid4() - data.metadata = data.metadata or {} - data.execution_id = execution_id - - # Dump to json - if isinstance(data.output, list): - data.output = [ - item.model_dump(mode="json") if hasattr(item, "model_dump") else item - for item in data.output - ] - - elif hasattr(data.output, "model_dump"): - data.output = data.output.model_dump(mode="json") - - # TODO: This is a hack to make sure the transition is valid - # (parallel transitions are whack, we should do something better) - is_parallel = data.current.workflow.startswith("PAR:") - - # Prepare the transition data - transition_data = data.model_dump(exclude_unset=True, exclude={"id"}) - - # Parse the current and next targets - validate_transition_targets(data) - current_target = transition_data.pop("current") - next_target = transition_data.pop("next") - - transition_data["current"] = (current_target["workflow"], current_target["step"]) - transition_data["next"] = next_target and ( - next_target["workflow"], - next_target["step"], - ) - - columns, transition_values = cozo_process_mutate_data( - { - **transition_data, - "task_token": str(task_token), # Converting to str for JSON serialisation - "transition_id": str(transition_id), - "execution_id": str(execution_id), - } - ) - - # Make sure the transition is valid - check_last_transition_query = f""" - valid_transition[start, end] <- [ - {", ".join(f'["{start}", "{end}"]' for start, ends in valid_transitions.items() for end in ends)} - ] +def validate_transition_targets(data: CreateTransitionRequest) -> None: + # Make sure the current/next targets are valid + match data.type: + case "finish_branch": + pass # TODO: Implement + case "finish" | "error" | "cancelled": + pass - last_transition_type[min_cost(type_created_at)] := - *transitions {{ - execution_id: to_uuid("{str(execution_id)}"), - type, - created_at, - }}, - type_created_at = [type, -created_at] + ### FIXME: HACK: Fix this and uncomment - matched[collect(last_type)] := - last_transition_type[data], - last_type_data = first(data), - last_type = if(is_null(last_type_data), "init", last_type_data), - valid_transition[last_type, $next_type] + ### assert ( + ### data.next is None + ### ), "Next target must be None for finish/finish_branch/error/cancelled" - ?[valid] := - matched[prev_transitions], - found = length(prev_transitions), - valid = if($next_type == "init", found == 0, found > 0), - assert(valid, "Invalid transition"), + case "init_branch" | "init": + assert ( + data.next and data.current.step == data.next.step == 0 + ), "Next target must be same as current for init_branch/init and step 0" - :limit 1 - """ + case "wait": + assert data.next is None, "Next target must be None for wait" - # Prepare the insert query - insert_query = f""" - ?[{columns}] <- $transition_values + case "resume" | "step": + assert data.next is not None, "Next target must be provided for resume/step" - :insert transitions {{ - {columns} - }} - - :returning - """ + if data.next.workflow == data.current.workflow: + assert ( + data.next.step > data.current.step + ), "Next step must be greater than current" - validate_status_query, update_execution_query, update_execution_params = ( - "", - "", - {}, - ) + case _: + raise ValueError(f"Invalid transition type: {data.type}") - if update_execution_status: - assert ( - task_id is not None - ), "task_id is required for updating the execution status" - # Prepare the execution update query - [*_, validate_status_query, update_execution_query], update_execution_params = ( - update_execution.__wrapped__( - developer_id=developer_id, - task_id=task_id, - execution_id=execution_id, - data=UpdateExecutionRequest( - status=transition_to_execution_status[data.type] - ), - output=data.output if data.type != "error" else None, - error=str(data.output) - if data.type == "error" and data.output - else None, +create_execution_transition = rewrap_exceptions( + { + QueryException: partialclass(HTTPException, status_code=400), + ValidationError: partialclass(HTTPException, status_code=400), + TypeError: partialclass(HTTPException, status_code=400), + } +)( + wrap_in_class( + Transition, + transform=lambda d: { + **d, + "id": d["transition_id"], + "current": {"workflow": d["current"][0], "step": d["current"][1]}, + "next": d["next"] and {"workflow": d["next"][0], "step": d["next"][1]}, + }, + one=True, + _kind="inserted", + )( + cozo_query( + increase_counter("create_execution_transition")( + _create_execution_transition ) ) + ) +) - queries = [ - verify_developer_id_query(developer_id), - verify_developer_owns_resource_query( - developer_id, - "executions", - execution_id=execution_id, - parents=[("agents", "agent_id"), ("tasks", "task_id")], - ), - validate_status_query if not is_parallel else None, - update_execution_query if not is_parallel else None, - check_last_transition_query if not is_parallel else None, - insert_query, - ] - - return ( - queries, - { - "transition_values": transition_values, - "next_type": data.type, - "valid_transitions": valid_transitions, - **update_execution_params, +create_execution_transition_async = rewrap_exceptions( + { + QueryException: partialclass(HTTPException, status_code=400), + ValidationError: partialclass(HTTPException, status_code=400), + TypeError: partialclass(HTTPException, status_code=400), + } +)( + wrap_in_class( + Transition, + transform=lambda d: { + **d, + "id": d["transition_id"], + "current": {"workflow": d["current"][0], "step": d["current"][1]}, + "next": d["next"] and {"workflow": d["next"][0], "step": d["next"][1]}, }, + one=True, + _kind="inserted", + )( + cozo_query_async( + increase_counter("create_execution_transition")( + _create_execution_transition + ) + ) ) +)