Skip to content

Commit

Permalink
feat: Update execution transition
Browse files Browse the repository at this point in the history
  • Loading branch information
whiterabbit1983 committed Jul 2, 2024
1 parent d76768e commit 0a20b95
Show file tree
Hide file tree
Showing 7 changed files with 207 additions and 176 deletions.
96 changes: 63 additions & 33 deletions agents-api/agents_api/autogen/openapi_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# generated by datamodel-codegen:
# filename: openapi.yaml
# timestamp: 2024-06-27T12:10:13+00:00
# timestamp: 2024-07-02T07:16:16+00:00

from __future__ import annotations

Expand Down Expand Up @@ -413,8 +413,8 @@ class ResponseFormat(BaseModel):
"""


class Stop(RootModel[List[str]]):
root: Annotated[List[str], Field(max_length=4, min_length=1)]
class Stop(RootModel[List[Any]]):
root: Annotated[List[Any], Field(max_length=4, min_length=1)]
"""
Up to 4 sequences where the API will stop generating further tokens.
Expand Down Expand Up @@ -859,17 +859,14 @@ class ChatMLImageContentPart(BaseModel):


class CELObject(BaseModel):
pass
model_config = ConfigDict(
extra="allow",
)


class YieldWorkflowStep(BaseModel):
workflow: str
arguments: CELObject


class YieldWorkflowStep(CELObject):
pass


class ToolCallWorkflowStep(BaseModel):
tool_id: str
arguments: CELObject
Expand Down Expand Up @@ -911,6 +908,48 @@ class ToolResponse(BaseModel):
output: Dict[str, Any]


class Type3(str, Enum):
"""
Transition type
"""

finish = "finish"
wait = "wait"
error = "error"
step = "step"


class UpdateExecutionTransitionRequest(BaseModel):
"""
Update execution transition request schema
"""

type: Type3
"""
Transition type
"""
from_: Annotated[List[str | int], Field(alias="from", max_length=2, min_length=2)]
"""
From state
"""
to: Annotated[List[str | int] | None, Field(None, max_length=2, min_length=2)]
"""
To state
"""
output: Dict[str, Any]
"""
Execution output
"""
task_token: str | None = None
"""
Task token
"""
metadata: Dict[str, Any] | None = None
"""
Custom metadata
"""


class Agent(BaseModel):
name: str
"""
Expand Down Expand Up @@ -1143,43 +1182,34 @@ class PatchToolRequest(BaseModel):
class Execution(BaseModel):
id: UUID
task_id: UUID
created_at: UUID
arguments: Dict[str, Any]
"""
JSON Schema of parameters
"""
status: Annotated[
str,
Field(pattern="^(queued|starting|running|awaiting_input|succeeded|failed)$"),
]
"""
Execution Status
"""
arguments: Dict[str, Any]
"""
JSON of parameters
"""
user_id: UUID | None = None
session_id: UUID | None = None
created_at: AwareDatetime
updated_at: AwareDatetime


class ExecutionTransition(BaseModel):
id: UUID
execution_id: UUID
type: Annotated[str, Field(pattern="^(finish|wait|error|step)$")]
"""
Execution Status
"""
from_: Annotated[List[str | int], Field(alias="from")]
to: List[str | int]
task_token: str | None = None
created_at: AwareDatetime
outputs: Dict[str, Any]
"""
Outputs from an Execution Transition
"""
metadata: Dict[str, Any] | None = None
from_: Annotated[List[str | int], Field(alias="from")]
to: List[str | int]
type: Annotated[str, Field(pattern="^(finish|wait|error|step)$")]
"""
(Optional) metadata
Execution Status
"""
created_at: AwareDatetime
updated_at: AwareDatetime | None = None


class PromptWorkflowStep(BaseModel):
Expand Down Expand Up @@ -1236,9 +1266,6 @@ class Task(BaseModel):
Describes a Task
"""

model_config = ConfigDict(
extra="allow",
)
name: str
"""
Name of the Task
Expand Down Expand Up @@ -1271,5 +1298,8 @@ class Task(BaseModel):
ID of the Task
"""
created_at: AwareDatetime
updated_at: AwareDatetime | None = None
agent_id: UUID


CELObject.model_rebuild()
YieldWorkflowStep.model_rebuild()
10 changes: 8 additions & 2 deletions agents-api/agents_api/models/execution/create_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,14 @@ def create_execution_query(
execution_id: UUID,
session_id: UUID | None = None,
status: Literal[
"queued", "starting", "running", "awaiting_input", "succeeded", "failed"
] = "queued",
"pending",
"queued",
"starting",
"running",
"awaiting_input",
"succeeded",
"failed",
] = "pending",
arguments: Dict[str, Any] = {},
) -> tuple[str, dict]:
# TODO: Check for agent in developer ID; Assert whether dev can access agent and by relation the task
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,27 +7,31 @@

@cozo_query
@beartype
def list_execution_transitions_query(execution_id: UUID) -> tuple[str, dict]:

def list_execution_transitions_query(
execution_id: UUID, limit: int = 100, offset: int = 0
) -> tuple[str, dict]:
query = """
{
?[transition_id, type, from, to, output, updated_at, created_at] := *transitions {
execution_id: to_uuid($execution_id),
transition_id,
type,
from,
to,
output,
updated_at,
created_at,
{
?[transition_id, type, from, to, output, updated_at, created_at] := *transitions {
execution_id: to_uuid($execution_id),
transition_id,
type,
from,
to,
output,
updated_at,
created_at,
}
:limit $limit
:offset $offset
}
:limit 100
:offset 0
}
"""
"""

return (
query,
{
"execution_id": str(execution_id),
"limit": limit,
"offset": offset,
},
)
Original file line number Diff line number Diff line change
Expand Up @@ -28,22 +28,21 @@ def update_execution_transition_query(
}
)
query = f"""
{{
input[{transition_update_cols}] <- $transition_update_vals
?[{transition_update_cols}] := input[{transition_update_cols}],
*transitions {{
execution_id: to_uuid($execution_id),
transition_id: to_uuid($transition_id),
}}
:update transitions {{
{transition_update_cols}
{{
input[{transition_update_cols}] <- $transition_update_vals
?[{transition_update_cols}] := input[{transition_update_cols}],
*transitions {{
execution_id: to_uuid($execution_id),
transition_id: to_uuid($transition_id),
}}
:update transitions {{
{transition_update_cols}
}}
:returning
}}
:returning
}}
"""
"""

return (
query,
Expand Down
3 changes: 2 additions & 1 deletion agents-api/agents_api/models/task/get_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ def get_task_query(
@ 'NOW'
},
updated_at = to_int(updated_at_ms) / 1000,
id = to_uuid($task_id),
id = to_uuid($task_id)
:limit 1
"""

Expand Down
43 changes: 29 additions & 14 deletions agents-api/agents_api/routers/tasks/routers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
from agents_api.models.execution.list_execution_transitions import (
list_execution_transitions_query,
)
from agents_api.models.execution.update_execution_transition import (
update_execution_transition_query,
)
from agents_api.models.task.create_task import create_task_query
from agents_api.models.task.get_task import get_task_query
from agents_api.models.task.list_tasks import list_tasks_query
Expand All @@ -27,6 +30,7 @@
ExecutionTransition,
ResourceCreatedResponse,
ResourceUpdatedResponse,
UpdateExecutionTransitionRequest,
)
from agents_api.dependencies.developer_id import get_developer_id

Expand Down Expand Up @@ -150,17 +154,14 @@ async def create_task_execution(
x_developer_id: Annotated[UUID4, Depends(get_developer_id)],
) -> ResourceCreatedResponse:
# TODO: Do thorough validation of the input against task input schema
# DO NOT let the user specify the status
# status should be set as pending

resp = create_execution_query(
agent_id=agent_id,
task_id=task_id,
execution_id=uuid4(),
developer_id=x_developer_id,
status=request.status,
arguments=request.arguments,
)

return ResourceCreatedResponse(
id=resp["execution_id"][0], created_at=resp["created_at"][0]
)
Expand Down Expand Up @@ -193,7 +194,6 @@ async def get_execution(task_id: UUID4, execution_id: UUID4) -> Execution:
)



@router.get("/executions/{execution_id}/transitions/{transition_id}", tags=["tasks"])
async def get_execution_transition(
execution_id: UUID4,
Expand All @@ -218,20 +218,35 @@ async def get_execution_transition(
# TODO: Ask for a task token to resume a waiting transition
@router.put("/executions/{execution_id}/transitions/{transition_id}", tags=["tasks"])
async def update_execution_transition(
execution_id: UUID4, transition_id: UUID4, request
execution_id: UUID4,
transition_id: UUID4,
request: UpdateExecutionTransitionRequest,
) -> ResourceUpdatedResponse:
# try:
# resp = update_execution_transition_query(execution_id, transition_id, request)

# OpenAPI Model doesn't have update execution transition
try:
resp = update_execution_transition_query(
execution_id, transition_id, **request.model_dump()
)

raise NotImplementedError("Not implemented yet")
return ResourceUpdatedResponse(
id=resp["transition_id"][0],
updated_at=resp["updated_at"][0][0],
)
except (IndexError, KeyError):
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Transition not found",
)


@router.get("/executions/{execution_id}/transitions", tags=["tasks"])
async def list_execution_transitions(execution_id: UUID4) -> ExecutionTransitionList:
# lists out the execution transitions
res = list_execution_transitions_query(execution_id)
async def list_execution_transitions(
execution_id: UUID4,
limit: int = 100,
offset: int = 0,
) -> ExecutionTransitionList:
res = list_execution_transitions_query(
execution_id=execution_id, limit=limit, offset=offset
)
return ExecutionTransitionList(
items=[ExecutionTransition(**row.to_dict()) for _, row in res.iterrows()]
)
Loading

0 comments on commit 0a20b95

Please sign in to comment.