diff --git a/agents-api/agents_api/models/docs/search_docs_by_embedding.py b/agents-api/agents_api/models/docs/search_docs_by_embedding.py index 992e12f9d..49ec069df 100644 --- a/agents-api/agents_api/models/docs/search_docs_by_embedding.py +++ b/agents-api/agents_api/models/docs/search_docs_by_embedding.py @@ -91,7 +91,7 @@ def search_docs_by_embedding( snippet_counter[count(item)] := owners[owner_type, owner_id_str], owner_id = to_uuid(owner_id_str), - *docs {{ + *docs:owner_id_metadata_doc_id_idx {{ owner_type, owner_id, doc_id: item, diff --git a/agents-api/agents_api/models/docs/search_docs_by_text.py b/agents-api/agents_api/models/docs/search_docs_by_text.py index ac1a9f54f..ce5319673 100644 --- a/agents-api/agents_api/models/docs/search_docs_by_text.py +++ b/agents-api/agents_api/models/docs/search_docs_by_text.py @@ -90,7 +90,7 @@ def search_docs_by_text( candidate[doc_id] := input[owner_type, owner_id], - *docs {{ + *docs:owner_id_metadata_doc_id_idx {{ owner_type, owner_id, doc_id, diff --git a/agents-api/agents_api/models/execution/count_executions.py b/agents-api/agents_api/models/execution/count_executions.py index 7f10e5bfa..d130f0359 100644 --- a/agents-api/agents_api/models/execution/count_executions.py +++ b/agents-api/agents_api/models/execution/count_executions.py @@ -39,7 +39,7 @@ def count_executions( counter[count(id)] := input[task_id], - *executions { + *executions:task_id_execution_id_idx { task_id, execution_id: id, } diff --git a/agents-api/agents_api/models/execution/create_execution_transition.py b/agents-api/agents_api/models/execution/create_execution_transition.py index d885379fb..59a63ed09 100644 --- a/agents-api/agents_api/models/execution/create_execution_transition.py +++ b/agents-api/agents_api/models/execution/create_execution_transition.py @@ -25,62 +25,8 @@ from .update_execution import update_execution -def validate_transition_targets(data: CreateTransitionRequest) -> None: - # Make sure the current/next targets are valid - match data.type: - case "finish_branch": - pass # TODO: Implement - case "finish" | "error" | "cancelled": - pass - - ### FIXME: HACK: Fix this and uncomment - - ### assert ( - ### data.next is None - ### ), "Next target must be None for finish/finish_branch/error/cancelled" - - case "init_branch" | "init": - assert ( - data.next and data.current.step == data.next.step == 0 - ), "Next target must be same as current for init_branch/init and step 0" - - case "wait": - assert data.next is None, "Next target must be None for wait" - - case "resume" | "step": - assert data.next is not None, "Next target must be provided for resume/step" - - if data.next.workflow == data.current.workflow: - assert ( - data.next.step > data.current.step - ), "Next step must be greater than current" - - case _: - raise ValueError(f"Invalid transition type: {data.type}") - - -@rewrap_exceptions( - { - QueryException: partialclass(HTTPException, status_code=400), - ValidationError: partialclass(HTTPException, status_code=400), - TypeError: partialclass(HTTPException, status_code=400), - } -) -@wrap_in_class( - Transition, - transform=lambda d: { - **d, - "id": d["transition_id"], - "current": {"workflow": d["current"][0], "step": d["current"][1]}, - "next": d["next"] and {"workflow": d["next"][0], "step": d["next"][1]}, - }, - one=True, - _kind="inserted", -) -@cozo_query -@increase_counter("create_execution_transition") @beartype -def create_execution_transition( +def _create_execution_transition( *, developer_id: UUID, execution_id: UUID, @@ -140,7 +86,7 @@ def create_execution_transition( ] last_transition_type[min_cost(type_created_at)] := - *transitions {{ + *transitions:execution_id_type_created_at_idx {{ execution_id: to_uuid("{str(execution_id)}"), type, created_at, @@ -225,167 +171,88 @@ def create_execution_transition( ) -@rewrap_exceptions( - { - QueryException: partialclass(HTTPException, status_code=400), - ValidationError: partialclass(HTTPException, status_code=400), - TypeError: partialclass(HTTPException, status_code=400), - } -) -@wrap_in_class( - Transition, - transform=lambda d: { - **d, - "id": d["transition_id"], - "current": {"workflow": d["current"][0], "step": d["current"][1]}, - "next": d["next"] and {"workflow": d["next"][0], "step": d["next"][1]}, - }, - one=True, - _kind="inserted", -) -@cozo_query_async -@increase_counter("create_execution_transition_async") -@beartype -async def create_execution_transition_async( - *, - developer_id: UUID, - execution_id: UUID, - data: CreateTransitionRequest, - # Only one of these needed - transition_id: UUID | None = None, - task_token: str | None = None, - # Only required for updating the execution status as well - update_execution_status: bool = False, - task_id: UUID | None = None, -) -> tuple[list[str | None], dict]: - transition_id = transition_id or uuid4() - data.metadata = data.metadata or {} - data.execution_id = execution_id - - # Dump to json - if isinstance(data.output, list): - data.output = [ - item.model_dump(mode="json") if hasattr(item, "model_dump") else item - for item in data.output - ] - - elif hasattr(data.output, "model_dump"): - data.output = data.output.model_dump(mode="json") - - # TODO: This is a hack to make sure the transition is valid - # (parallel transitions are whack, we should do something better) - is_parallel = data.current.workflow.startswith("PAR:") - - # Prepare the transition data - transition_data = data.model_dump(exclude_unset=True, exclude={"id"}) - - # Parse the current and next targets - validate_transition_targets(data) - current_target = transition_data.pop("current") - next_target = transition_data.pop("next") - - transition_data["current"] = (current_target["workflow"], current_target["step"]) - transition_data["next"] = next_target and ( - next_target["workflow"], - next_target["step"], - ) - - columns, transition_values = cozo_process_mutate_data( - { - **transition_data, - "task_token": str(task_token), # Converting to str for JSON serialisation - "transition_id": str(transition_id), - "execution_id": str(execution_id), - } - ) - - # Make sure the transition is valid - check_last_transition_query = f""" - valid_transition[start, end] <- [ - {", ".join(f'["{start}", "{end}"]' for start, ends in valid_transitions.items() for end in ends)} - ] +def validate_transition_targets(data: CreateTransitionRequest) -> None: + # Make sure the current/next targets are valid + match data.type: + case "finish_branch": + pass # TODO: Implement + case "finish" | "error" | "cancelled": + pass - last_transition_type[min_cost(type_created_at)] := - *transitions {{ - execution_id: to_uuid("{str(execution_id)}"), - type, - created_at, - }}, - type_created_at = [type, -created_at] + ### FIXME: HACK: Fix this and uncomment - matched[collect(last_type)] := - last_transition_type[data], - last_type_data = first(data), - last_type = if(is_null(last_type_data), "init", last_type_data), - valid_transition[last_type, $next_type] + ### assert ( + ### data.next is None + ### ), "Next target must be None for finish/finish_branch/error/cancelled" - ?[valid] := - matched[prev_transitions], - found = length(prev_transitions), - valid = if($next_type == "init", found == 0, found > 0), - assert(valid, "Invalid transition"), + case "init_branch" | "init": + assert ( + data.next and data.current.step == data.next.step == 0 + ), "Next target must be same as current for init_branch/init and step 0" - :limit 1 - """ + case "wait": + assert data.next is None, "Next target must be None for wait" - # Prepare the insert query - insert_query = f""" - ?[{columns}] <- $transition_values + case "resume" | "step": + assert data.next is not None, "Next target must be provided for resume/step" - :insert transitions {{ - {columns} - }} - - :returning - """ + if data.next.workflow == data.current.workflow: + assert ( + data.next.step > data.current.step + ), "Next step must be greater than current" - validate_status_query, update_execution_query, update_execution_params = ( - "", - "", - {}, - ) + case _: + raise ValueError(f"Invalid transition type: {data.type}") - if update_execution_status: - assert ( - task_id is not None - ), "task_id is required for updating the execution status" - # Prepare the execution update query - [*_, validate_status_query, update_execution_query], update_execution_params = ( - update_execution.__wrapped__( - developer_id=developer_id, - task_id=task_id, - execution_id=execution_id, - data=UpdateExecutionRequest( - status=transition_to_execution_status[data.type] - ), - output=data.output if data.type != "error" else None, - error=str(data.output) - if data.type == "error" and data.output - else None, +create_execution_transition = rewrap_exceptions( + { + QueryException: partialclass(HTTPException, status_code=400), + ValidationError: partialclass(HTTPException, status_code=400), + TypeError: partialclass(HTTPException, status_code=400), + } +)( + wrap_in_class( + Transition, + transform=lambda d: { + **d, + "id": d["transition_id"], + "current": {"workflow": d["current"][0], "step": d["current"][1]}, + "next": d["next"] and {"workflow": d["next"][0], "step": d["next"][1]}, + }, + one=True, + _kind="inserted", + )( + cozo_query( + increase_counter("create_execution_transition")( + _create_execution_transition ) ) + ) +) - queries = [ - verify_developer_id_query(developer_id), - verify_developer_owns_resource_query( - developer_id, - "executions", - execution_id=execution_id, - parents=[("agents", "agent_id"), ("tasks", "task_id")], - ), - validate_status_query if not is_parallel else None, - update_execution_query if not is_parallel else None, - check_last_transition_query if not is_parallel else None, - insert_query, - ] - - return ( - queries, - { - "transition_values": transition_values, - "next_type": data.type, - "valid_transitions": valid_transitions, - **update_execution_params, +create_execution_transition_async = rewrap_exceptions( + { + QueryException: partialclass(HTTPException, status_code=400), + ValidationError: partialclass(HTTPException, status_code=400), + TypeError: partialclass(HTTPException, status_code=400), + } +)( + wrap_in_class( + Transition, + transform=lambda d: { + **d, + "id": d["transition_id"], + "current": {"workflow": d["current"][0], "step": d["current"][1]}, + "next": d["next"] and {"workflow": d["next"][0], "step": d["next"][1]}, }, + one=True, + _kind="inserted", + )( + cozo_query_async( + increase_counter("create_execution_transition_async")( + _create_execution_transition + ) + ) ) +) diff --git a/agents-api/agents_api/models/execution/get_execution.py b/agents-api/agents_api/models/execution/get_execution.py index db0279b1f..c4e30ab64 100644 --- a/agents-api/agents_api/models/execution/get_execution.py +++ b/agents-api/agents_api/models/execution/get_execution.py @@ -52,7 +52,7 @@ def get_execution( ?[id, task_id, status, input, output, error, session_id, metadata, created_at, updated_at] := input[execution_id], - *executions { + *executions:execution_id_status_idx { task_id, execution_id, status, diff --git a/agents-api/agents_api/models/execution/get_paused_execution_token.py b/agents-api/agents_api/models/execution/get_paused_execution_token.py index 44eb8a4da..526a6012c 100644 --- a/agents-api/agents_api/models/execution/get_paused_execution_token.py +++ b/agents-api/agents_api/models/execution/get_paused_execution_token.py @@ -38,7 +38,7 @@ def get_paused_execution_token( check_status_query = """ ?[execution_id, status] := - *executions { + *executions:execution_id_status_idx { execution_id, status, }, @@ -55,7 +55,7 @@ def get_paused_execution_token( *executions { execution_id, }, - *transitions { + *transitions:execution_id_type_created_at_idx { execution_id, created_at, task_token, diff --git a/agents-api/agents_api/models/execution/list_execution_transitions.py b/agents-api/agents_api/models/execution/list_execution_transitions.py index 8931676f6..d30e8595e 100644 --- a/agents-api/agents_api/models/execution/list_execution_transitions.py +++ b/agents-api/agents_api/models/execution/list_execution_transitions.py @@ -35,7 +35,7 @@ def list_execution_transitions( query = f""" ?[id, execution_id, type, current, next, output, metadata, updated_at, created_at] := - *transitions {{ + *transitions:execution_id_type_created_at_idx {{ execution_id, transition_id: id, type, diff --git a/agents-api/agents_api/models/execution/update_execution.py b/agents-api/agents_api/models/execution/update_execution.py index 35deab259..b485e34f8 100644 --- a/agents-api/agents_api/models/execution/update_execution.py +++ b/agents-api/agents_api/models/execution/update_execution.py @@ -79,9 +79,10 @@ def update_execution( validate_status_query = """ valid_status[count(status)] := - *executions { + *executions:execution_id_status_idx { status, execution_id: to_uuid($execution_id), + task_id: to_uuid($task_id), }, status in $valid_previous_statuses @@ -124,5 +125,6 @@ def update_execution( "values": values, "valid_previous_statuses": valid_previous_statuses, "execution_id": str(execution_id), + "task_id": task_id, }, ) diff --git a/agents-api/agents_api/models/task/get_task.py b/agents-api/agents_api/models/task/get_task.py index 460fdc38b..ab6edb2c3 100644 --- a/agents-api/agents_api/models/task/get_task.py +++ b/agents-api/agents_api/models/task/get_task.py @@ -63,7 +63,7 @@ def get_task( metadata, ] := input[task_id], - *tasks { + *tasks:task_id_agent_id_idx { agent_id, task_id, updated_at_ms, diff --git a/agents-api/agents_api/models/user/get_user.py b/agents-api/agents_api/models/user/get_user.py index 69b3da883..89f49dae1 100644 --- a/agents-api/agents_api/models/user/get_user.py +++ b/agents-api/agents_api/models/user/get_user.py @@ -85,7 +85,7 @@ def get_user( updated_at, metadata, ] := input[developer_id, id], - *users { + *users:developer_id_metadata_user_id_idx { user_id: id, developer_id, name, diff --git a/agents-api/agents_api/models/user/list_users.py b/agents-api/agents_api/models/user/list_users.py index f1e06adf4..cc857b1a1 100644 --- a/agents-api/agents_api/models/user/list_users.py +++ b/agents-api/agents_api/models/user/list_users.py @@ -88,7 +88,7 @@ def list_users( metadata, ] := input[developer_id], - *users {{ + *users:developer_id_metadata_user_id_idx {{ user_id: id, developer_id, name, diff --git a/agents-api/agents_api/models/user/patch_user.py b/agents-api/agents_api/models/user/patch_user.py index e091edc63..bd3fc0246 100644 --- a/agents-api/agents_api/models/user/patch_user.py +++ b/agents-api/agents_api/models/user/patch_user.py @@ -91,7 +91,7 @@ def patch_user( ?[{user_update_cols}, metadata] := input[{user_update_cols}], - *users {{ + *users:developer_id_metadata_user_id_idx {{ developer_id: to_uuid($developer_id), user_id: to_uuid($user_id), metadata: md, diff --git a/agents-api/migrations/migrate_1733755642_transition_indices.py b/agents-api/migrations/migrate_1733755642_transition_indices.py new file mode 100644 index 000000000..1b33f4646 --- /dev/null +++ b/agents-api/migrations/migrate_1733755642_transition_indices.py @@ -0,0 +1,42 @@ +# /usr/bin/env python3 + +MIGRATION_ID = "transition_indices" +CREATED_AT = 1733755642.881131 + + +create_transition_indices = dict( + up=[ + "::index create executions:execution_id_status_idx { execution_id, status }", + "::index create executions:execution_id_task_id_idx { execution_id, task_id }", + "::index create executions:task_id_execution_id_idx { task_id, execution_id }", + "::index create tasks:task_id_agent_id_idx { task_id, agent_id }", + "::index create agents:agent_id_developer_id_idx { agent_id, developer_id }", + "::index create sessions:session_id_developer_id_idx { session_id, developer_id }", + "::index create docs:owner_id_metadata_doc_id_idx { owner_id, metadata, doc_id }", + "::index create agents:developer_id_metadata_agent_id_idx { developer_id, metadata, agent_id }", + "::index create users:developer_id_metadata_user_id_idx { developer_id, metadata, user_id }", + "::index create transitions:execution_id_type_created_at_idx { execution_id, type, created_at }", + ], + down=[ + "::index drop executions:execution_id_status_idx", + "::index drop executions:execution_id_task_id_idx", + "::index drop executions:task_id_execution_id_idx", + "::index drop tasks:task_id_agent_id_idx", + "::index drop agents:agent_id_developer_id_idx", + "::index drop sessions:session_id_developer_id_idx", + "::index drop docs:owner_id_metadata_doc_id_idx", + "::index drop agents:developer_id_metadata_agent_id_idx", + "::index drop users:developer_id_metadata_user_id_idx", + "::index drop transitions:execution_id_type_created_at_idx", + ], +) + + +def up(client): + for q in create_transition_indices["up"]: + client.run(q) + + +def down(client): + for q in create_transition_indices["down"]: + client.run(q)