Skip to content

Commit

Permalink
fix: Apply various small fixes to task execution logic (#436)
Browse files Browse the repository at this point in the history
* fix: Apply various small fixes to task execution logic

* fix: Fix workflows serialization
  • Loading branch information
whiterabbit1983 authored Jul 26, 2024
1 parent c17617f commit 787ba7b
Show file tree
Hide file tree
Showing 12 changed files with 159 additions and 158 deletions.
41 changes: 24 additions & 17 deletions agents-api/agents_api/activities/task_steps/__init__.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
import asyncio

# import celpy
from simpleeval import simple_eval
from openai.types.chat.chat_completion import ChatCompletion
from temporalio import activity
from uuid import uuid4

from ...autogen.openapi_model import (
PromptWorkflowStep,
# EvaluateWorkflowStep,
# YieldWorkflowStep,
ToolCallWorkflowStep,
YieldWorkflowStep,
# ToolCallWorkflowStep,
# ErrorWorkflowStep,
IfElseWorkflowStep,
InputChatMLMessage,
Expand Down Expand Up @@ -79,26 +80,26 @@ async def prompt_step(context: StepContext) -> dict:
# return {"result": result}


# @activity.defn
# async def yield_step(context: StepContext) -> dict:
# if not isinstance(context.definition, YieldWorkflowStep):
# return {}

# # TODO: implement

# return {"test": "result"}


@activity.defn
async def tool_call_step(context: StepContext) -> dict:
if not isinstance(context.definition, ToolCallWorkflowStep):
async def yield_step(context: StepContext) -> dict:
if not isinstance(context.definition, YieldWorkflowStep):
return {}

# TODO: implement

return {"test": "result"}


# @activity.defn
# async def tool_call_step(context: StepContext) -> dict:
# assert isinstance(context.definition, ToolCallWorkflowStep)

# context.definition.tool_id
# context.definition.arguments
# # get tool by id
# # call tool


# @activity.defn
# async def error_step(context: StepContext) -> dict:
# if not isinstance(context.definition, ErrorWorkflowStep):
Expand All @@ -109,10 +110,16 @@ async def tool_call_step(context: StepContext) -> dict:

@activity.defn
async def if_else_step(context: StepContext) -> dict:
if not isinstance(context.definition, IfElseWorkflowStep):
return {}
assert isinstance(context.definition, IfElseWorkflowStep)

return {"test": "result"}
context_data: dict = context.model_dump()
next_workflow = (
context.definition.then
if simple_eval(context.definition.if_, names=context_data)
else context.definition.else_
)

return {"goto_workflow": next_workflow}


@activity.defn
Expand Down
112 changes: 57 additions & 55 deletions agents-api/agents_api/autogen/openapi_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# generated by datamodel-codegen:
# filename: openapi.yaml
# timestamp: 2024-07-10T09:10:51+00:00
# timestamp: 2024-07-17T06:45:55+00:00

from __future__ import annotations

Expand Down Expand Up @@ -858,13 +858,25 @@ class ChatMLImageContentPart(BaseModel):
"""


class ToolResponse(BaseModel):
id: UUID
"""
Optional Tool ID
"""
output: Dict[str, Any]


class CELObject(BaseModel):
model_config = ConfigDict(
extra="allow",
)
workflow: str
arguments: CELObject
arguments: Dict[str, Any]


class YieldWorkflowStep(CELObject):
pass
class YieldWorkflowStep(BaseModel):
workflow: str
arguments: Dict[str, Any]


class ToolCallWorkflowStep(BaseModel):
Expand All @@ -885,61 +897,59 @@ class IfElseWorkflowStep(BaseModel):
else_: Annotated[YieldWorkflowStep, Field(alias="else")]


class CreateExecution(BaseModel):
task_id: UUID
arguments: Dict[str, Any]
"""
JSON Schema of parameters
class TransitionType(str, Enum):
"""


class ToolResponse(BaseModel):
id: UUID
"""
Optional Tool ID
"""
output: Dict[str, Any]


class Type3(str, Enum):
"""
Transition type
Execution Status
"""

finish = "finish"
wait = "wait"
error = "error"
step = "step"
cancelled = "cancelled"


class UpdateExecutionTransitionRequest(BaseModel):
class ExecutionStatus(str, Enum):
"""
Update execution transition request schema
Execution Status
"""

type: Type3
"""
Transition type
"""
from_: Annotated[List[str | int], Field(alias="from", max_length=2, min_length=2)]
queued = "queued"
starting = "starting"
running = "running"
awaiting_input = "awaiting_input"
succeeded = "succeeded"
failed = "failed"
cancelled = "cancelled"


class CreateExecution(BaseModel):
task_id: UUID
arguments: Dict[str, Any]
"""
From state
JSON Schema of parameters
"""
to: Annotated[List[str | int] | None, Field(None, max_length=2, min_length=2)]


class StopExecution(BaseModel):
status: Literal["cancelled"] = "cancelled"
"""
To state
Stop Execution Status
"""
output: Dict[str, Any]


class ResumeExecutionTransitionRequest(BaseModel):
"""
Execution output
Update execution transition request schema
"""
task_token: str | None = None

task_token: str
"""
Task token
"""
metadata: Dict[str, Any] | None = None
output: Dict[str, Any]
"""
Custom metadata
Output of the execution
"""


Expand Down Expand Up @@ -1175,34 +1185,26 @@ class PatchToolRequest(BaseModel):
class Execution(BaseModel):
id: UUID
task_id: UUID
created_at: UUID
created_at: AwareDatetime
arguments: Dict[str, Any]
"""
JSON Schema of parameters
"""
status: Annotated[
str,
Field(pattern="^(queued|starting|running|awaiting_input|succeeded|failed)$"),
]
"""
Execution Status
"""
status: ExecutionStatus


class ExecutionTransition(BaseModel):
id: UUID
execution_id: UUID
created_at: AwareDatetime
updated_at: AwareDatetime
outputs: Dict[str, Any]
"""
Outputs from an Execution Transition
"""
from_: Annotated[List[str | int], Field(alias="from")]
to: List[str | int]
type: Annotated[str, Field(pattern="^(finish|wait|error|step)$")]
"""
Execution Status
"""
current: List[str | int]
next: List[str | int]
type: TransitionType


class PromptWorkflowStep(BaseModel):
Expand Down Expand Up @@ -1259,6 +1261,9 @@ class Task(BaseModel):
Describes a Task
"""

model_config = ConfigDict(
extra="allow",
)
name: str
"""
Name of the Task
Expand Down Expand Up @@ -1291,8 +1296,5 @@ class Task(BaseModel):
ID of the Task
"""
created_at: AwareDatetime
updated_at: AwareDatetime
agent_id: UUID


CELObject.model_rebuild()
YieldWorkflowStep.model_rebuild()
2 changes: 2 additions & 0 deletions agents-api/agents_api/clients/temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
temporal_private_key,
)
from ..common.protocol.tasks import ExecutionInput
from ..worker.codec import pydantic_data_converter


async def get_client():
Expand All @@ -22,6 +23,7 @@ async def get_client():
temporal_worker_url,
namespace=temporal_namespace,
tls=tls_config,
data_converter=pydantic_data_converter,
)


Expand Down
4 changes: 4 additions & 0 deletions agents-api/agents_api/common/utils/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import json
from uuid import UUID
from typing import Any
from pydantic import BaseModel


class CustomJSONEncoder(json.JSONEncoder):
Expand Down Expand Up @@ -37,6 +38,9 @@ def default(self, obj):
if isinstance(obj, UUID):
return str(obj)

if isinstance(obj, BaseModel):
return obj.model_dump()

return obj


Expand Down
4 changes: 2 additions & 2 deletions agents-api/agents_api/models/execution/create_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@ def create_execution_query(
execution_id: UUID,
session_id: UUID | None = None,
status: Literal[
"pending",
"queued",
"starting",
"running",
"awaiting_input",
"succeeded",
"failed",
] = "pending",
"cancelled",
] = "queued",
arguments: Dict[str, Any] = {},
) -> tuple[str, dict]:
# TODO: Check for agent in developer ID; Assert whether dev can access agent and by relation the task
Expand Down
18 changes: 12 additions & 6 deletions agents-api/agents_api/routers/agents/create_agent.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Annotated
from uuid import uuid4

from fastapi import Depends
from pydantic import UUID4
Expand All @@ -7,7 +8,6 @@
from ...dependencies.developer_id import get_developer_id
from ...models.agent.create_agent import create_agent_query
from ...autogen.openapi_model import CreateAgentRequest, ResourceCreatedResponse
from ...common.utils.datetime import utcnow

from .router import router

Expand All @@ -17,13 +17,19 @@ async def create_agent(
request: CreateAgentRequest,
x_developer_id: Annotated[UUID4, Depends(get_developer_id)],
) -> ResourceCreatedResponse:
agent_id = create_agent_query(
new_agent_id = uuid4()

resp = create_agent_query(
developer_id=x_developer_id,
agent_id=new_agent_id,
name=request.name,
about=request.about,
instructions=request.instructions,
instructions=request.instructions or [],
model=request.model,
default_settings=request.default_settings,
metadata=request.metadata,
default_settings=request.default_settings or {},
metadata=request.metadata or {},
)
return ResourceCreatedResponse(id=agent_id, created_at=utcnow())

resp.iterrows()

return ResourceCreatedResponse(id=new_agent_id, created_at=resp["created_at"])
Loading

0 comments on commit 787ba7b

Please sign in to comment.