diff --git a/agents-api/agents_api/autogen/Tasks.py b/agents-api/agents_api/autogen/Tasks.py index 48dba4ad7..9dd531c47 100644 --- a/agents-api/agents_api/autogen/Tasks.py +++ b/agents-api/agents_api/autogen/Tasks.py @@ -35,10 +35,10 @@ class CaseThen(BaseModel): | LogStep | EmbedStep | SearchStep + | YieldStep | ReturnStep | SleepStep | ErrorWorkflowStep - | YieldStep | WaitForInputStep ) """ @@ -63,10 +63,10 @@ class CaseThenUpdateItem(BaseModel): | LogStep | EmbedStep | SearchStep + | YieldStep | ReturnStep | SleepStep | ErrorWorkflowStep - | YieldStep | WaitForInputStep ) """ @@ -130,10 +130,10 @@ class CreateTaskRequest(BaseModel): | LogStep | EmbedStep | SearchStep + | YieldStep | ReturnStep | SleepStep | ErrorWorkflowStep - | YieldStep | WaitForInputStep | IfElseWorkflowStep | SwitchStep @@ -227,6 +227,7 @@ class ForeachDo(BaseModel): | LogStep | EmbedStep | SearchStep + | YieldStep ) """ The steps to run for each iteration @@ -251,6 +252,7 @@ class ForeachDoUpdateItem(BaseModel): | LogStep | EmbedStep | SearchStep + | YieldStep ) """ The steps to run for each iteration @@ -324,10 +326,10 @@ class IfElseWorkflowStep(BaseModel): | LogStep | EmbedStep | SearchStep + | YieldStep | ReturnStep | SleepStep | ErrorWorkflowStep - | YieldStep | WaitForInputStep ) """ @@ -342,10 +344,10 @@ class IfElseWorkflowStep(BaseModel): | LogStep | EmbedStep | SearchStep + | YieldStep | ReturnStep | SleepStep | ErrorWorkflowStep - | YieldStep | WaitForInputStep | None, Field(None, alias="else"), @@ -376,10 +378,10 @@ class IfElseWorkflowStepUpdateItem(BaseModel): | LogStep | EmbedStep | SearchStep + | YieldStep | ReturnStep | SleepStep | ErrorWorkflowStep - | YieldStep | WaitForInputStep ) """ @@ -394,10 +396,10 @@ class IfElseWorkflowStepUpdateItem(BaseModel): | LogStep | EmbedStep | SearchStep + | YieldStep | ReturnStep | SleepStep | ErrorWorkflowStep - | YieldStep | WaitForInputStep | None, Field(None, alias="else"), @@ -462,6 +464,7 @@ class Main(BaseModel): | LogStep | EmbedStep | SearchStep + | YieldStep ) """ The steps to run for each iteration @@ -503,6 +506,7 @@ class MainModel(BaseModel): | LogStep | EmbedStep | SearchStep + | YieldStep ) """ The steps to run for each iteration @@ -543,6 +547,7 @@ class ParallelStep(BaseModel): | LogStep | EmbedStep | SearchStep + | YieldStep ], Field(max_length=100), ] @@ -569,6 +574,7 @@ class ParallelStepUpdateItem(BaseModel): | LogStep | EmbedStep | SearchStep + | YieldStep ], Field(max_length=100), ] @@ -596,10 +602,10 @@ class PatchTaskRequest(BaseModel): | LogStep | EmbedStep | SearchStep + | YieldStep | ReturnStep | SleepStep | ErrorWorkflowStep - | YieldStep | WaitForInputStep | IfElseWorkflowStepUpdateItem | SwitchStepUpdateItem @@ -874,10 +880,10 @@ class Task(BaseModel): | LogStep | EmbedStep | SearchStep + | YieldStep | ReturnStep | SleepStep | ErrorWorkflowStep - | YieldStep | WaitForInputStep | IfElseWorkflowStep | SwitchStep @@ -1009,10 +1015,10 @@ class UpdateTaskRequest(BaseModel): | LogStep | EmbedStep | SearchStep + | YieldStep | ReturnStep | SleepStep | ErrorWorkflowStep - | YieldStep | WaitForInputStep | IfElseWorkflowStep | SwitchStep diff --git a/agents-api/agents_api/autogen/openapi_model.py b/agents-api/agents_api/autogen/openapi_model.py index 5c5a8c86f..48811eb20 100644 --- a/agents-api/agents_api/autogen/openapi_model.py +++ b/agents-api/agents_api/autogen/openapi_model.py @@ -1,10 +1,12 @@ # ruff: noqa: F401, F403, F405 +import ast from typing import Annotated, Any, Generic, Literal, Self, Type, TypeVar, get_args from uuid import UUID +import jinja2 from litellm.utils import _select_tokenizer as select_tokenizer from litellm.utils import token_counter -from pydantic import AwareDatetime, Field +from pydantic import AwareDatetime, Field, field_validator, model_validator, validator from ..common.utils.datetime import utcnow from .Agents import * @@ -152,6 +154,179 @@ def from_model_input( ) +# Patch Task Workflow Steps +# -------------------------------------- + + +def validate_python_expression(expr: str) -> tuple[bool, str]: + try: + ast.parse(expr) + return True, "" + except SyntaxError as e: + return False, f"SyntaxError in '{expr}': {str(e)}" + + +def validate_jinja_template(template: str) -> tuple[bool, str]: + env = jinja2.Environment() + try: + parsed_template = env.parse(template) + for node in parsed_template.body: + if isinstance(node, jinja2.nodes.Output): + for child in node.nodes: + if isinstance(child, jinja2.nodes.Name): + # Check if the variable is a valid Python expression + is_valid, error = validate_python_expression(child.name) + if not is_valid: + return ( + False, + f"Invalid Python expression in Jinja template '{template}': {error}", + ) + return True, "" + except jinja2.exceptions.TemplateSyntaxError as e: + return False, f"TemplateSyntaxError in '{template}': {str(e)}" + + +_EvaluateStep = EvaluateStep + + +class EvaluateStep(_EvaluateStep): + @field_validator("evaluate") + def validate_evaluate_expressions(cls, v): + for key, expr in v.items(): + is_valid, error = validate_python_expression(expr) + if not is_valid: + raise ValueError(f"Invalid Python expression in key '{key}': {error}") + return v + + +_ToolCallStep = ToolCallStep + + +class ToolCallStep(_ToolCallStep): + @field_validator("arguments") + def validate_arguments(cls, v): + if isinstance(v, dict): + for key, expr in v.items(): + if isinstance(expr, str): + is_valid, error = validate_python_expression(expr) + if not is_valid: + raise ValueError( + f"Invalid Python expression in arguments key '{key}': {error}" + ) + return v + + +_PromptStep = PromptStep + + +class PromptStep(_PromptStep): + @field_validator("prompt") + def validate_prompt(cls, v): + if isinstance(v, str): + is_valid, error = validate_jinja_template(v) + if not is_valid: + raise ValueError(f"Invalid Jinja template in prompt: {error}") + elif isinstance(v, list): + for item in v: + if "content" in item: + is_valid, error = validate_jinja_template(item["content"]) + if not is_valid: + raise ValueError( + f"Invalid Jinja template in prompt content: {error}" + ) + return v + + +_SetStep = SetStep + + +class SetStep(_SetStep): + @field_validator("set") + def validate_set_expressions(cls, v): + for key, expr in v.items(): + is_valid, error = validate_python_expression(expr) + if not is_valid: + raise ValueError( + f"Invalid Python expression in set key '{key}': {error}" + ) + return v + + +_LogStep = LogStep + + +class LogStep(_LogStep): + @field_validator("log") + def validate_log_template(cls, v): + is_valid, error = validate_jinja_template(v) + if not is_valid: + raise ValueError(f"Invalid Jinja template in log: {error}") + return v + + +_ReturnStep = ReturnStep + + +class ReturnStep(_ReturnStep): + @field_validator("return_") + def validate_return_expressions(cls, v): + for key, expr in v.items(): + is_valid, error = validate_python_expression(expr) + if not is_valid: + raise ValueError( + f"Invalid Python expression in return key '{key}': {error}" + ) + return v + + +_YieldStep = YieldStep + + +class YieldStep(_YieldStep): + @field_validator("arguments") + def validate_yield_arguments(cls, v): + if isinstance(v, dict): + for key, expr in v.items(): + is_valid, error = validate_python_expression(expr) + if not is_valid: + raise ValueError( + f"Invalid Python expression in yield arguments key '{key}': {error}" + ) + return v + + +_IfElseWorkflowStep = IfElseWorkflowStep + + +class IfElseWorkflowStep(_IfElseWorkflowStep): + @field_validator("if_") + def validate_if_expression(cls, v): + is_valid, error = validate_python_expression(v) + if not is_valid: + raise ValueError(f"Invalid Python expression in if condition: {error}") + return v + + +_MapReduceStep = MapReduceStep + + +class MapReduceStep(_MapReduceStep): + @field_validator("over") + def validate_over_expression(cls, v): + is_valid, error = validate_python_expression(v) + if not is_valid: + raise ValueError(f"Invalid Python expression in over: {error}") + return v + + @field_validator("reduce") + def validate_reduce_expression(cls, v): + if v is not None: + is_valid, error = validate_python_expression(v) + if not is_valid: + raise ValueError(f"Invalid Python expression in reduce: {error}") + return v + + # Workflow related models # ----------------------- @@ -228,6 +403,29 @@ class Task(_Task): # Patch some models to allow extra fields # -------------------------------------- +WorkflowType = RootModel[ + list[ + EvaluateStep + | ToolCallStep + | PromptStep + | GetStep + | SetStep + | LogStep + | EmbedStep + | SearchStep + | ReturnStep + | SleepStep + | ErrorWorkflowStep + | YieldStep + | WaitForInputStep + | IfElseWorkflowStep + | SwitchStep + | ForeachStep + | ParallelStep + | MapReduceStep + ] +] + _CreateTaskRequest = CreateTaskRequest @@ -240,6 +438,22 @@ class CreateTaskRequest(_CreateTaskRequest): } ) + @model_validator(mode="after") + def validate_subworkflows(self) -> Self: + subworkflows = { + k: v + for k, v in self.model_dump().items() + if k not in _CreateTaskRequest.model_fields + } + + for workflow_name, workflow_definition in subworkflows.items(): + try: + WorkflowType.model_validate(workflow_definition) + setattr(self, workflow_name, WorkflowType(workflow_definition)) + except Exception as e: + raise ValueError(f"Invalid subworkflow '{workflow_name}': {str(e)}") + return self + CreateOrUpdateTaskRequest = CreateTaskRequest diff --git a/typespec/tasks/steps.tsp b/typespec/tasks/steps.tsp index 3495def1b..2267ae320 100644 --- a/typespec/tasks/steps.tsp +++ b/typespec/tasks/steps.tsp @@ -49,7 +49,8 @@ alias MappableWorkflowStep = | SetStep | LogStep | EmbedStep - | SearchStep; + | SearchStep + | YieldStep; alias NonConditionalWorkflowStep = | MappableWorkflowStep