From 277c168f27b2c4c07b6cac508630383f8aaac59a Mon Sep 17 00:00:00 2001 From: Dmitry Paramonov Date: Tue, 23 Jul 2024 20:10:37 +0300 Subject: [PATCH] fix: Apply various small fixes to task execution logic --- .../activities/task_steps/__init__.py | 41 ++++--- .../agents_api/autogen/openapi_model.py | 112 +++++++++--------- agents-api/agents_api/clients/temporal.py | 2 + .../models/execution/create_execution.py | 4 +- .../agents_api/routers/agents/create_agent.py | 17 ++- .../agents_api/routers/tasks/routers.py | 12 +- agents-api/agents_api/worker/__main__.py | 4 +- .../agents_api/workflows/task_execution.py | 51 ++++---- agents-api/poetry.lock | 62 ++-------- agents-api/pyproject.toml | 2 +- openapi.yaml | 9 +- 11 files changed, 160 insertions(+), 156 deletions(-) diff --git a/agents-api/agents_api/activities/task_steps/__init__.py b/agents-api/agents_api/activities/task_steps/__init__.py index e376c4782..78a889e76 100644 --- a/agents-api/agents_api/activities/task_steps/__init__.py +++ b/agents-api/agents_api/activities/task_steps/__init__.py @@ -1,6 +1,7 @@ import asyncio # import celpy +from simpleeval import simple_eval from openai.types.chat.chat_completion import ChatCompletion from temporalio import activity from uuid import uuid4 @@ -8,10 +9,10 @@ from ...autogen.openapi_model import ( PromptWorkflowStep, # EvaluateWorkflowStep, - # YieldWorkflowStep, + YieldWorkflowStep, # ToolCallWorkflowStep, # ErrorWorkflowStep, - # IfElseWorkflowStep, + IfElseWorkflowStep, InputChatMLMessage, ) @@ -79,20 +80,24 @@ async def prompt_step(context: StepContext) -> dict: # return {"result": result} -# @activity.defn -# async def yield_step(context: StepContext) -> dict: -# if not isinstance(context.definition, YieldWorkflowStep): -# return {} +@activity.defn +async def yield_step(context: StepContext) -> dict: + if not isinstance(context.definition, YieldWorkflowStep): + return {} -# # TODO: implement + # TODO: implement -# return {"test": "result"} + return {"test": "result"} # @activity.defn # async def tool_call_step(context: StepContext) -> dict: -# if not isinstance(context.definition, ToolCallWorkflowStep): -# return {} +# assert isinstance(context.definition, ToolCallWorkflowStep) + +# context.definition.tool_id +# context.definition.arguments +# # get tool by id +# # call tool # # TODO: implement @@ -107,12 +112,18 @@ async def prompt_step(context: StepContext) -> dict: # return {"error": context.definition.error} -# @activity.defn -# async def if_else_step(context: StepContext) -> dict: -# if not isinstance(context.definition, IfElseWorkflowStep): -# return {} +@activity.defn +async def if_else_step(context: StepContext) -> dict: + assert isinstance(context.definition, IfElseWorkflowStep) -# return {"test": "result"} + context_data: dict = context.model_dump() + next_workflow = ( + context.definition.then + if simple_eval(context.definition.if_, names=context_data) + else context.definition.else_ + ) + + return {"workflow": next_workflow} @activity.defn diff --git a/agents-api/agents_api/autogen/openapi_model.py b/agents-api/agents_api/autogen/openapi_model.py index 142f9886a..77d50710f 100644 --- a/agents-api/agents_api/autogen/openapi_model.py +++ b/agents-api/agents_api/autogen/openapi_model.py @@ -1,6 +1,6 @@ # generated by datamodel-codegen: # filename: openapi.yaml -# timestamp: 2024-07-10T09:10:51+00:00 +# timestamp: 2024-07-17T06:45:55+00:00 from __future__ import annotations @@ -858,13 +858,25 @@ class ChatMLImageContentPart(BaseModel): """ +class ToolResponse(BaseModel): + id: UUID + """ + Optional Tool ID + """ + output: Dict[str, Any] + + class CELObject(BaseModel): + model_config = ConfigDict( + extra="allow", + ) workflow: str - arguments: CELObject + arguments: Dict[str, Any] -class YieldWorkflowStep(CELObject): - pass +class YieldWorkflowStep(BaseModel): + workflow: str + arguments: Dict[str, Any] class ToolCallWorkflowStep(BaseModel): @@ -885,61 +897,59 @@ class IfElseWorkflowStep(BaseModel): else_: Annotated[YieldWorkflowStep, Field(alias="else")] -class CreateExecution(BaseModel): - task_id: UUID - arguments: Dict[str, Any] - """ - JSON Schema of parameters +class TransitionType(str, Enum): """ - - -class ToolResponse(BaseModel): - id: UUID - """ - Optional Tool ID - """ - output: Dict[str, Any] - - -class Type3(str, Enum): - """ - Transition type + Execution Status """ finish = "finish" wait = "wait" error = "error" step = "step" + cancelled = "cancelled" -class UpdateExecutionTransitionRequest(BaseModel): +class ExecutionStatus(str, Enum): """ - Update execution transition request schema + Execution Status """ - type: Type3 - """ - Transition type - """ - from_: Annotated[List[str | int], Field(alias="from", max_length=2, min_length=2)] + queued = "queued" + starting = "starting" + running = "running" + awaiting_input = "awaiting_input" + succeeded = "succeeded" + failed = "failed" + cancelled = "cancelled" + + +class CreateExecution(BaseModel): + task_id: UUID + arguments: Dict[str, Any] """ - From state + JSON Schema of parameters """ - to: Annotated[List[str | int] | None, Field(None, max_length=2, min_length=2)] + + +class StopExecution(BaseModel): + status: Literal["cancelled"] = "cancelled" """ - To state + Stop Execution Status """ - output: Dict[str, Any] + + +class ResumeExecutionTransitionRequest(BaseModel): """ - Execution output + Update execution transition request schema """ - task_token: str | None = None + + task_token: str """ Task token """ - metadata: Dict[str, Any] | None = None + output: Dict[str, Any] """ - Custom metadata + Output of the execution """ @@ -1175,34 +1185,26 @@ class PatchToolRequest(BaseModel): class Execution(BaseModel): id: UUID task_id: UUID - created_at: UUID + created_at: AwareDatetime arguments: Dict[str, Any] """ JSON Schema of parameters """ - status: Annotated[ - str, - Field(pattern="^(queued|starting|running|awaiting_input|succeeded|failed)$"), - ] - """ - Execution Status - """ + status: ExecutionStatus class ExecutionTransition(BaseModel): id: UUID execution_id: UUID created_at: AwareDatetime + updated_at: AwareDatetime outputs: Dict[str, Any] """ Outputs from an Execution Transition """ - from_: Annotated[List[str | int], Field(alias="from")] - to: List[str | int] - type: Annotated[str, Field(pattern="^(finish|wait|error|step)$")] - """ - Execution Status - """ + current: List[str | int] + next: List[str | int] + type: TransitionType class PromptWorkflowStep(BaseModel): @@ -1259,6 +1261,9 @@ class Task(BaseModel): Describes a Task """ + model_config = ConfigDict( + extra="allow", + ) name: str """ Name of the Task @@ -1291,8 +1296,5 @@ class Task(BaseModel): ID of the Task """ created_at: AwareDatetime + updated_at: AwareDatetime agent_id: UUID - - -CELObject.model_rebuild() -YieldWorkflowStep.model_rebuild() diff --git a/agents-api/agents_api/clients/temporal.py b/agents-api/agents_api/clients/temporal.py index 9ccac2938..8c8dfa7d5 100644 --- a/agents-api/agents_api/clients/temporal.py +++ b/agents-api/agents_api/clients/temporal.py @@ -7,6 +7,7 @@ temporal_private_key, ) from ..common.protocol.tasks import ExecutionInput +from ..worker.codec import pydantic_data_converter async def get_client(): @@ -22,6 +23,7 @@ async def get_client(): temporal_worker_url, namespace=temporal_namespace, tls=tls_config, + data_converter=pydantic_data_converter, ) diff --git a/agents-api/agents_api/models/execution/create_execution.py b/agents-api/agents_api/models/execution/create_execution.py index ce84ea2b2..c759990d8 100644 --- a/agents-api/agents_api/models/execution/create_execution.py +++ b/agents-api/agents_api/models/execution/create_execution.py @@ -15,14 +15,14 @@ def create_execution_query( execution_id: UUID, session_id: UUID | None = None, status: Literal[ - "pending", "queued", "starting", "running", "awaiting_input", "succeeded", "failed", - ] = "pending", + "cancelled", + ] = "queued", arguments: Dict[str, Any] = {}, ) -> tuple[str, dict]: # TODO: Check for agent in developer ID; Assert whether dev can access agent and by relation the task diff --git a/agents-api/agents_api/routers/agents/create_agent.py b/agents-api/agents_api/routers/agents/create_agent.py index 991a6a437..f3ea9d1ab 100644 --- a/agents-api/agents_api/routers/agents/create_agent.py +++ b/agents-api/agents_api/routers/agents/create_agent.py @@ -1,4 +1,5 @@ from typing import Annotated +from uuid import uuid4 from fastapi import Depends from pydantic import UUID4 @@ -17,13 +18,19 @@ async def create_agent( request: CreateAgentRequest, x_developer_id: Annotated[UUID4, Depends(get_developer_id)], ) -> ResourceCreatedResponse: - agent_id = create_agent_query( + new_agent_id = uuid4() + + resp = create_agent_query( developer_id=x_developer_id, + agent_id=new_agent_id, name=request.name, about=request.about, - instructions=request.instructions, + instructions=request.instructions or [], model=request.model, - default_settings=request.default_settings, - metadata=request.metadata, + default_settings=request.default_settings or {}, + metadata=request.metadata or {}, ) - return ResourceCreatedResponse(id=agent_id, created_at=utcnow()) + + resp.iterrows() + + return ResourceCreatedResponse(id=new_agent_id, created_at=resp["created_at"]) diff --git a/agents-api/agents_api/routers/tasks/routers.py b/agents-api/agents_api/routers/tasks/routers.py index 1efb68cd0..e26c71054 100644 --- a/agents-api/agents_api/routers/tasks/routers.py +++ b/agents-api/agents_api/routers/tasks/routers.py @@ -1,3 +1,4 @@ +import logging from typing import Annotated from uuid import uuid4 from jsonschema import validate @@ -34,7 +35,6 @@ ExecutionTransition, ResourceCreatedResponse, ResourceUpdatedResponse, - UpdateExecutionTransitionRequest, CreateExecution, ) from agents_api.dependencies.developer_id import get_developer_id @@ -43,6 +43,10 @@ from agents_api.clients.cozo import client as cozo_client +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) + + class TaskList(BaseModel): items: list[Task] @@ -206,7 +210,9 @@ async def create_task_execution( execution_input=execution_input, job_id=uuid4(), ) - except Exception: + except Exception as e: + logger.exception(e) + update_execution_status_query( task_id=task_id, execution_id=execution_id, @@ -276,7 +282,7 @@ async def get_execution_transition( async def update_execution_transition( execution_id: UUID4, transition_id: UUID4, - request: UpdateExecutionTransitionRequest, + request: ExecutionTransition, ) -> ResourceUpdatedResponse: try: resp = update_execution_transition_query( diff --git a/agents-api/agents_api/worker/__main__.py b/agents-api/agents_api/worker/__main__.py index cb57490bb..f4daea8e4 100644 --- a/agents-api/agents_api/worker/__main__.py +++ b/agents-api/agents_api/worker/__main__.py @@ -21,7 +21,7 @@ from ..activities.task_steps import ( prompt_step, # evaluate_step, - # yield_step, + yield_step, # tool_call_step, # error_step, # if_else_step, @@ -77,7 +77,7 @@ async def main(): task_activities = [ prompt_step, # evaluate_step, - # yield_step, + yield_step, # tool_call_step, # error_step, # if_else_step, diff --git a/agents-api/agents_api/workflows/task_execution.py b/agents-api/agents_api/workflows/task_execution.py index 5daa53d95..44e4b44cc 100644 --- a/agents-api/agents_api/workflows/task_execution.py +++ b/agents-api/agents_api/workflows/task_execution.py @@ -9,16 +9,17 @@ from ..activities.task_steps import ( prompt_step, transition_step, + if_else_step, ) from ..common.protocol.tasks import ( ExecutionInput, PromptWorkflowStep, # EvaluateWorkflowStep, - # YieldWorkflowStep, + YieldWorkflowStep, # ToolCallWorkflowStep, # ErrorWorkflowStep, - # IfElseWorkflowStep, + IfElseWorkflowStep, StepContext, TransitionInfo, ) @@ -44,10 +45,10 @@ async def run( developer_id=execution_input.developer_id, execution=execution_input.execution, task=execution_input.task, - # agent=execution_input.agent, - # user=execution_input.user, - # session=execution_input.session, - # tools=execution_input.tools, + agent=execution_input.agent, + user=execution_input.user, + session=execution_input.session, + tools=execution_input.tools, arguments=execution_input.arguments, definition=step, inputs=previous_inputs, @@ -67,22 +68,19 @@ async def run( # if outputs.tool_calls is not None: # should_wait = True - is_last = step_idx + 1 == len(current_workflow) - # case EvaluateWorkflowStep(): # result = await workflow.execute_activity( # evaluate_step, # context, # schedule_to_close_timeout=timedelta(seconds=600), # ) - # case YieldWorkflowStep(): - # result = await workflow.execute_activity( - # yield_step, - # context, - # schedule_to_close_timeout=timedelta(seconds=600), - # ) + case YieldWorkflowStep(): + outputs = await workflow.execute_child_workflow( + TaskExecutionWorkflow.run, + args=[execution_input, (step.workflow, 0), previous_inputs], + ) # case ToolCallWorkflowStep(): - # result = await workflow.execute_activity( + # outputs = await workflow.execute_activity( # tool_call_step, # context, # schedule_to_close_timeout=timedelta(seconds=600), @@ -93,13 +91,24 @@ async def run( # context, # schedule_to_close_timeout=timedelta(seconds=600), # ) - # case IfElseWorkflowStep(): - # result = await workflow.execute_activity( - # if_else_step, - # context, - # schedule_to_close_timeout=timedelta(seconds=600), - # ) + case IfElseWorkflowStep(): + outputs = await workflow.execute_activity( + if_else_step, + context, + schedule_to_close_timeout=timedelta(seconds=600), + ) + workflow_step: YieldWorkflowStep = outputs["workflow"] + + outputs = await workflow.execute_child_workflow( + TaskExecutionWorkflow.run, + args=[ + execution_input, + (workflow_step.workflow, 0), + previous_inputs, + ], + ) + is_last = step_idx + 1 == len(current_workflow) # Transition type transition_type = ( "awaiting_input" if should_wait else ("finish" if is_last else "step") diff --git a/agents-api/poetry.lock b/agents-api/poetry.lock index b2302ca02..33e7364c0 100644 --- a/agents-api/poetry.lock +++ b/agents-api/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. [[package]] name = "aiohttp" @@ -452,26 +452,6 @@ files = [ {file = "cachetools-5.3.3.tar.gz", hash = "sha256:ba29e2dfa0b8b556606f097407ed1aa62080ee108ab0dc5ec9d6a723a007d105"}, ] -[[package]] -name = "cel-python" -version = "0.1.5" -description = "Pure Python CEL Implementation" -optional = false -python-versions = ">=3.7, <4" -files = [ - {file = "cel-python-0.1.5.tar.gz", hash = "sha256:d3911bb046bc3ed12792bd88ab453f72d98c66923b72a2fa016bcdffd96e2f98"}, - {file = "cel_python-0.1.5-py3-none-any.whl", hash = "sha256:ac81fab8ba08b633700a45d84905be2863529c6a32935c9da7ef53fc06844f1a"}, -] - -[package.dependencies] -babel = ">=2.9.0" -jmespath = ">=0.10.0" -lark-parser = ">=0.10.1" -python-dateutil = ">=2.8.1" -pyyaml = ">=5.4.1" -requests = ">=2.25.1" -urllib3 = ">=1.26.4" - [[package]] name = "certifi" version = "2024.6.2" @@ -2090,17 +2070,6 @@ files = [ [package.dependencies] Jinja2 = ">=2.2" -[[package]] -name = "jmespath" -version = "1.0.1" -description = "JSON Matching Expressions" -optional = false -python-versions = ">=3.7" -files = [ - {file = "jmespath-1.0.1-py3-none-any.whl", hash = "sha256:02e2e4cc71b5bcab88332eebf907519190dd9e6e82107fa7f83b1003a6252980"}, - {file = "jmespath-1.0.1.tar.gz", hash = "sha256:90261b206d6defd58fdd5e85f478bf633a2901798906be2ad389150c5c60edbe"}, -] - [[package]] name = "json5" version = "0.9.25" @@ -2619,22 +2588,6 @@ orjson = ">=3.9.14,<4.0.0" pydantic = ">=1,<3" requests = ">=2,<3" -[[package]] -name = "lark-parser" -version = "0.12.0" -description = "a modern parsing library" -optional = false -python-versions = "*" -files = [ - {file = "lark-parser-0.12.0.tar.gz", hash = "sha256:15967db1f1214013dca65b1180745047b9be457d73da224fcda3d9dd4e96a138"}, - {file = "lark_parser-0.12.0-py2.py3-none-any.whl", hash = "sha256:0eaf30cb5ba787fe404d73a7d6e61df97b21d5a63ac26c5008c78a494373c675"}, -] - -[package.extras] -atomic-cache = ["atomicwrites"] -nearley = ["js2py"] -regex = ["regex"] - [[package]] name = "libcst" version = "1.4.0" @@ -4834,6 +4787,17 @@ files = [ {file = "shellingham-1.5.4.tar.gz", hash = "sha256:8dbca0739d487e5bd35ab3ca4b36e11c4078f3a234bfce294b0a0291363404de"}, ] +[[package]] +name = "simpleeval" +version = "0.9.13" +description = "A simple, safe single expression evaluator library." +optional = false +python-versions = "*" +files = [ + {file = "simpleeval-0.9.13-py2.py3-none-any.whl", hash = "sha256:22a2701a5006e4188d125d34accf2405c2c37c93f6b346f2484b6422415ae54a"}, + {file = "simpleeval-0.9.13.tar.gz", hash = "sha256:4a30f9cc01825fe4c719c785e3762623e350c4840d5e6855c2a8496baaa65fac"}, +] + [[package]] name = "six" version = "1.16.0" @@ -5902,4 +5866,4 @@ test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools", [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.11" -content-hash = "a90aa6e665c0bdd8cfca719a93ff27adbb86f66ed14b81a554e1aa6629fba86d" +content-hash = "c4658426f4e83c6905cc8656e04efdfb389b930771e4fbf541e12bc18ef08f8f" diff --git a/agents-api/pyproject.toml b/agents-api/pyproject.toml index 7f12f971e..13a1065b9 100644 --- a/agents-api/pyproject.toml +++ b/agents-api/pyproject.toml @@ -34,7 +34,7 @@ tenacity = "^8.3.0" beartype = "^0.18.5" -cel-python = "^0.1.5" +simpleeval = "^0.9.13" [tool.poetry.group.dev.dependencies] ipython = "^8.18.1" black = "^24.4.0" diff --git a/openapi.yaml b/openapi.yaml index ff277180b..756ec1ee2 100644 --- a/openapi.yaml +++ b/openapi.yaml @@ -2791,6 +2791,9 @@ components: created_at: type: string format: date-time + updated_at: + type: string + format: date-time agent_id: type: string format: uuid @@ -2859,7 +2862,7 @@ components: format: uuid created_at: type: string - format: uuid + format: date-time arguments: type: object properties: {} @@ -2944,7 +2947,7 @@ components: workflow: type: string arguments: - $ref: '#/components/schemas/CELObject' + type: object additionalProperties: true required: @@ -2956,7 +2959,7 @@ components: workflow: type: string arguments: - $ref: '#/components/schemas/CELObject' + type: object required: - workflow - arguments