Skip to content

Commit

Permalink
Merge pull request #942 from julep-ai/f/cozo-indices
Browse files Browse the repository at this point in the history
feat: Add indices migration
  • Loading branch information
whiterabbit1983 authored Dec 10, 2024
2 parents 10b2d86 + 2f88b8c commit 875943c
Show file tree
Hide file tree
Showing 13 changed files with 129 additions and 218 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def search_docs_by_embedding(
snippet_counter[count(item)] :=
owners[owner_type, owner_id_str],
owner_id = to_uuid(owner_id_str),
*docs {{
*docs:owner_id_metadata_doc_id_idx {{
owner_type,
owner_id,
doc_id: item,
Expand Down
2 changes: 1 addition & 1 deletion agents-api/agents_api/models/docs/search_docs_by_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def search_docs_by_text(
candidate[doc_id] :=
input[owner_type, owner_id],
*docs {{
*docs:owner_id_metadata_doc_id_idx {{
owner_type,
owner_id,
doc_id,
Expand Down
2 changes: 1 addition & 1 deletion agents-api/agents_api/models/execution/count_executions.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def count_executions(
counter[count(id)] :=
input[task_id],
*executions {
*executions:task_id_execution_id_idx {
task_id,
execution_id: id,
}
Expand Down
279 changes: 73 additions & 206 deletions agents-api/agents_api/models/execution/create_execution_transition.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,62 +25,8 @@
from .update_execution import update_execution


def validate_transition_targets(data: CreateTransitionRequest) -> None:
# Make sure the current/next targets are valid
match data.type:
case "finish_branch":
pass # TODO: Implement
case "finish" | "error" | "cancelled":
pass

### FIXME: HACK: Fix this and uncomment

### assert (
### data.next is None
### ), "Next target must be None for finish/finish_branch/error/cancelled"

case "init_branch" | "init":
assert (
data.next and data.current.step == data.next.step == 0
), "Next target must be same as current for init_branch/init and step 0"

case "wait":
assert data.next is None, "Next target must be None for wait"

case "resume" | "step":
assert data.next is not None, "Next target must be provided for resume/step"

if data.next.workflow == data.current.workflow:
assert (
data.next.step > data.current.step
), "Next step must be greater than current"

case _:
raise ValueError(f"Invalid transition type: {data.type}")


@rewrap_exceptions(
{
QueryException: partialclass(HTTPException, status_code=400),
ValidationError: partialclass(HTTPException, status_code=400),
TypeError: partialclass(HTTPException, status_code=400),
}
)
@wrap_in_class(
Transition,
transform=lambda d: {
**d,
"id": d["transition_id"],
"current": {"workflow": d["current"][0], "step": d["current"][1]},
"next": d["next"] and {"workflow": d["next"][0], "step": d["next"][1]},
},
one=True,
_kind="inserted",
)
@cozo_query
@increase_counter("create_execution_transition")
@beartype
def create_execution_transition(
def _create_execution_transition(
*,
developer_id: UUID,
execution_id: UUID,
Expand Down Expand Up @@ -140,7 +86,7 @@ def create_execution_transition(
]
last_transition_type[min_cost(type_created_at)] :=
*transitions {{
*transitions:execution_id_type_created_at_idx {{
execution_id: to_uuid("{str(execution_id)}"),
type,
created_at,
Expand Down Expand Up @@ -225,167 +171,88 @@ def create_execution_transition(
)


@rewrap_exceptions(
{
QueryException: partialclass(HTTPException, status_code=400),
ValidationError: partialclass(HTTPException, status_code=400),
TypeError: partialclass(HTTPException, status_code=400),
}
)
@wrap_in_class(
Transition,
transform=lambda d: {
**d,
"id": d["transition_id"],
"current": {"workflow": d["current"][0], "step": d["current"][1]},
"next": d["next"] and {"workflow": d["next"][0], "step": d["next"][1]},
},
one=True,
_kind="inserted",
)
@cozo_query_async
@increase_counter("create_execution_transition_async")
@beartype
async def create_execution_transition_async(
*,
developer_id: UUID,
execution_id: UUID,
data: CreateTransitionRequest,
# Only one of these needed
transition_id: UUID | None = None,
task_token: str | None = None,
# Only required for updating the execution status as well
update_execution_status: bool = False,
task_id: UUID | None = None,
) -> tuple[list[str | None], dict]:
transition_id = transition_id or uuid4()
data.metadata = data.metadata or {}
data.execution_id = execution_id

# Dump to json
if isinstance(data.output, list):
data.output = [
item.model_dump(mode="json") if hasattr(item, "model_dump") else item
for item in data.output
]

elif hasattr(data.output, "model_dump"):
data.output = data.output.model_dump(mode="json")

# TODO: This is a hack to make sure the transition is valid
# (parallel transitions are whack, we should do something better)
is_parallel = data.current.workflow.startswith("PAR:")

# Prepare the transition data
transition_data = data.model_dump(exclude_unset=True, exclude={"id"})

# Parse the current and next targets
validate_transition_targets(data)
current_target = transition_data.pop("current")
next_target = transition_data.pop("next")

transition_data["current"] = (current_target["workflow"], current_target["step"])
transition_data["next"] = next_target and (
next_target["workflow"],
next_target["step"],
)

columns, transition_values = cozo_process_mutate_data(
{
**transition_data,
"task_token": str(task_token), # Converting to str for JSON serialisation
"transition_id": str(transition_id),
"execution_id": str(execution_id),
}
)

# Make sure the transition is valid
check_last_transition_query = f"""
valid_transition[start, end] <- [
{", ".join(f'["{start}", "{end}"]' for start, ends in valid_transitions.items() for end in ends)}
]
def validate_transition_targets(data: CreateTransitionRequest) -> None:
# Make sure the current/next targets are valid
match data.type:
case "finish_branch":
pass # TODO: Implement
case "finish" | "error" | "cancelled":
pass

last_transition_type[min_cost(type_created_at)] :=
*transitions {{
execution_id: to_uuid("{str(execution_id)}"),
type,
created_at,
}},
type_created_at = [type, -created_at]
### FIXME: HACK: Fix this and uncomment

matched[collect(last_type)] :=
last_transition_type[data],
last_type_data = first(data),
last_type = if(is_null(last_type_data), "init", last_type_data),
valid_transition[last_type, $next_type]
### assert (
### data.next is None
### ), "Next target must be None for finish/finish_branch/error/cancelled"

?[valid] :=
matched[prev_transitions],
found = length(prev_transitions),
valid = if($next_type == "init", found == 0, found > 0),
assert(valid, "Invalid transition"),
case "init_branch" | "init":
assert (
data.next and data.current.step == data.next.step == 0
), "Next target must be same as current for init_branch/init and step 0"

:limit 1
"""
case "wait":
assert data.next is None, "Next target must be None for wait"

# Prepare the insert query
insert_query = f"""
?[{columns}] <- $transition_values
case "resume" | "step":
assert data.next is not None, "Next target must be provided for resume/step"

:insert transitions {{
{columns}
}}
:returning
"""
if data.next.workflow == data.current.workflow:
assert (
data.next.step > data.current.step
), "Next step must be greater than current"

validate_status_query, update_execution_query, update_execution_params = (
"",
"",
{},
)
case _:
raise ValueError(f"Invalid transition type: {data.type}")

if update_execution_status:
assert (
task_id is not None
), "task_id is required for updating the execution status"

# Prepare the execution update query
[*_, validate_status_query, update_execution_query], update_execution_params = (
update_execution.__wrapped__(
developer_id=developer_id,
task_id=task_id,
execution_id=execution_id,
data=UpdateExecutionRequest(
status=transition_to_execution_status[data.type]
),
output=data.output if data.type != "error" else None,
error=str(data.output)
if data.type == "error" and data.output
else None,
create_execution_transition = rewrap_exceptions(
{
QueryException: partialclass(HTTPException, status_code=400),
ValidationError: partialclass(HTTPException, status_code=400),
TypeError: partialclass(HTTPException, status_code=400),
}
)(
wrap_in_class(
Transition,
transform=lambda d: {
**d,
"id": d["transition_id"],
"current": {"workflow": d["current"][0], "step": d["current"][1]},
"next": d["next"] and {"workflow": d["next"][0], "step": d["next"][1]},
},
one=True,
_kind="inserted",
)(
cozo_query(
increase_counter("create_execution_transition")(
_create_execution_transition
)
)
)
)

queries = [
verify_developer_id_query(developer_id),
verify_developer_owns_resource_query(
developer_id,
"executions",
execution_id=execution_id,
parents=[("agents", "agent_id"), ("tasks", "task_id")],
),
validate_status_query if not is_parallel else None,
update_execution_query if not is_parallel else None,
check_last_transition_query if not is_parallel else None,
insert_query,
]

return (
queries,
{
"transition_values": transition_values,
"next_type": data.type,
"valid_transitions": valid_transitions,
**update_execution_params,
create_execution_transition_async = rewrap_exceptions(
{
QueryException: partialclass(HTTPException, status_code=400),
ValidationError: partialclass(HTTPException, status_code=400),
TypeError: partialclass(HTTPException, status_code=400),
}
)(
wrap_in_class(
Transition,
transform=lambda d: {
**d,
"id": d["transition_id"],
"current": {"workflow": d["current"][0], "step": d["current"][1]},
"next": d["next"] and {"workflow": d["next"][0], "step": d["next"][1]},
},
one=True,
_kind="inserted",
)(
cozo_query_async(
increase_counter("create_execution_transition_async")(
_create_execution_transition
)
)
)
)
2 changes: 1 addition & 1 deletion agents-api/agents_api/models/execution/get_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def get_execution(
?[id, task_id, status, input, output, error, session_id, metadata, created_at, updated_at] :=
input[execution_id],
*executions {
*executions:execution_id_status_idx {
task_id,
execution_id,
status,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def get_paused_execution_token(

check_status_query = """
?[execution_id, status] :=
*executions {
*executions:execution_id_status_idx {
execution_id,
status,
},
Expand All @@ -55,7 +55,7 @@ def get_paused_execution_token(
*executions {
execution_id,
},
*transitions {
*transitions:execution_id_type_created_at_idx {
execution_id,
created_at,
task_token,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def list_execution_transitions(

query = f"""
?[id, execution_id, type, current, next, output, metadata, updated_at, created_at] :=
*transitions {{
*transitions:execution_id_type_created_at_idx {{
execution_id,
transition_id: id,
type,
Expand Down
4 changes: 3 additions & 1 deletion agents-api/agents_api/models/execution/update_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,10 @@ def update_execution(

validate_status_query = """
valid_status[count(status)] :=
*executions {
*executions:execution_id_status_idx {
status,
execution_id: to_uuid($execution_id),
task_id: to_uuid($task_id),
},
status in $valid_previous_statuses
Expand Down Expand Up @@ -124,5 +125,6 @@ def update_execution(
"values": values,
"valid_previous_statuses": valid_previous_statuses,
"execution_id": str(execution_id),
"task_id": task_id,
},
)
Loading

0 comments on commit 875943c

Please sign in to comment.