From 8c03d93701a454327d6b458177e0626ae165d2a0 Mon Sep 17 00:00:00 2001 From: Diwank Singh Tomer Date: Wed, 2 Oct 2024 21:57:20 -0400 Subject: [PATCH] fix(agents-api): Switch to monkeypatching because everything is shit Signed-off-by: Diwank Singh Tomer --- .../agents_api/autogen/openapi_model.py | 339 +++++++++++------- 1 file changed, 214 insertions(+), 125 deletions(-) diff --git a/agents-api/agents_api/autogen/openapi_model.py b/agents-api/agents_api/autogen/openapi_model.py index 48811eb20..bff2221eb 100644 --- a/agents-api/agents_api/autogen/openapi_model.py +++ b/agents-api/agents_api/autogen/openapi_model.py @@ -6,7 +6,14 @@ import jinja2 from litellm.utils import _select_tokenizer as select_tokenizer from litellm.utils import token_counter -from pydantic import AwareDatetime, Field, field_validator, model_validator, validator +from pydantic import ( + AwareDatetime, + Field, + computed_field, + field_validator, + model_validator, + validator, +) from ..common.utils.datetime import utcnow from .Agents import * @@ -155,7 +162,7 @@ def from_model_input( # Patch Task Workflow Steps -# -------------------------------------- +# ------------------------- def validate_python_expression(expr: str) -> tuple[bool, str]: @@ -186,145 +193,255 @@ def validate_jinja_template(template: str) -> tuple[bool, str]: return False, f"TemplateSyntaxError in '{template}': {str(e)}" -_EvaluateStep = EvaluateStep +@field_validator("evaluate") +def validate_evaluate_expressions(cls, v): + for key, expr in v.items(): + is_valid, error = validate_python_expression(expr) + if not is_valid: + raise ValueError(f"Invalid Python expression in key '{key}': {error}") + return v + + +EvaluateStep.validate_evaluate_expressions = validate_evaluate_expressions -class EvaluateStep(_EvaluateStep): - @field_validator("evaluate") - def validate_evaluate_expressions(cls, v): +@field_validator("arguments") +def validate_arguments(cls, v): + if isinstance(v, dict): for key, expr in v.items(): - is_valid, error = validate_python_expression(expr) - if not is_valid: - raise ValueError(f"Invalid Python expression in key '{key}': {error}") - return v + if isinstance(expr, str): + is_valid, error = validate_python_expression(expr) + if not is_valid: + raise ValueError( + f"Invalid Python expression in arguments key '{key}': {error}" + ) + return v -_ToolCallStep = ToolCallStep +ToolCallStep.validate_arguments = validate_arguments -class ToolCallStep(_ToolCallStep): - @field_validator("arguments") - def validate_arguments(cls, v): - if isinstance(v, dict): - for key, expr in v.items(): - if isinstance(expr, str): - is_valid, error = validate_python_expression(expr) - if not is_valid: - raise ValueError( - f"Invalid Python expression in arguments key '{key}': {error}" - ) - return v +# Add the new validator function +@field_validator("prompt") +def validate_prompt(cls, v): + if isinstance(v, str): + is_valid, error = validate_jinja_template(v) + if not is_valid: + raise ValueError(f"Invalid Jinja template in prompt: {error}") + elif isinstance(v, list): + for item in v: + if "content" in item: + is_valid, error = validate_jinja_template(item["content"]) + if not is_valid: + raise ValueError( + f"Invalid Jinja template in prompt content: {error}" + ) + return v -_PromptStep = PromptStep +# Patch the original PromptStep class to add the new validator +PromptStep.validate_prompt = validate_prompt -class PromptStep(_PromptStep): - @field_validator("prompt") - def validate_prompt(cls, v): - if isinstance(v, str): - is_valid, error = validate_jinja_template(v) - if not is_valid: - raise ValueError(f"Invalid Jinja template in prompt: {error}") - elif isinstance(v, list): - for item in v: - if "content" in item: - is_valid, error = validate_jinja_template(item["content"]) - if not is_valid: - raise ValueError( - f"Invalid Jinja template in prompt content: {error}" - ) - return v +@field_validator("set") +def validate_set_expressions(cls, v): + for key, expr in v.items(): + is_valid, error = validate_python_expression(expr) + if not is_valid: + raise ValueError(f"Invalid Python expression in set key '{key}': {error}") + return v -_SetStep = SetStep +SetStep.validate_set_expressions = validate_set_expressions -class SetStep(_SetStep): - @field_validator("set") - def validate_set_expressions(cls, v): - for key, expr in v.items(): - is_valid, error = validate_python_expression(expr) - if not is_valid: - raise ValueError( - f"Invalid Python expression in set key '{key}': {error}" - ) - return v +@field_validator("log") +def validate_log_template(cls, v): + is_valid, error = validate_jinja_template(v) + if not is_valid: + raise ValueError(f"Invalid Jinja template in log: {error}") + return v -_LogStep = LogStep +LogStep.validate_log_template = validate_log_template -class LogStep(_LogStep): - @field_validator("log") - def validate_log_template(cls, v): - is_valid, error = validate_jinja_template(v) +@field_validator("return_") +def validate_return_expressions(cls, v): + for key, expr in v.items(): + is_valid, error = validate_python_expression(expr) if not is_valid: - raise ValueError(f"Invalid Jinja template in log: {error}") - return v + raise ValueError( + f"Invalid Python expression in return key '{key}': {error}" + ) + return v -_ReturnStep = ReturnStep +ReturnStep.validate_return_expressions = validate_return_expressions -class ReturnStep(_ReturnStep): - @field_validator("return_") - def validate_return_expressions(cls, v): +@field_validator("arguments") +def validate_yield_arguments(cls, v): + if isinstance(v, dict): for key, expr in v.items(): is_valid, error = validate_python_expression(expr) if not is_valid: raise ValueError( - f"Invalid Python expression in return key '{key}': {error}" + f"Invalid Python expression in yield arguments key '{key}': {error}" ) - return v + return v -_YieldStep = YieldStep +YieldStep.validate_yield_arguments = validate_yield_arguments -class YieldStep(_YieldStep): - @field_validator("arguments") - def validate_yield_arguments(cls, v): - if isinstance(v, dict): - for key, expr in v.items(): - is_valid, error = validate_python_expression(expr) - if not is_valid: - raise ValueError( - f"Invalid Python expression in yield arguments key '{key}': {error}" - ) - return v +@field_validator("if_") +def validate_if_expression(cls, v): + is_valid, error = validate_python_expression(v) + if not is_valid: + raise ValueError(f"Invalid Python expression in if condition: {error}") + return v + +IfElseWorkflowStep.validate_if_expression = validate_if_expression -_IfElseWorkflowStep = IfElseWorkflowStep +@field_validator("over") +def validate_over_expression(cls, v): + is_valid, error = validate_python_expression(v) + if not is_valid: + raise ValueError(f"Invalid Python expression in over: {error}") + return v -class IfElseWorkflowStep(_IfElseWorkflowStep): - @field_validator("if_") - def validate_if_expression(cls, v): + +@field_validator("reduce") +def validate_reduce_expression(cls, v): + if v is not None: is_valid, error = validate_python_expression(v) if not is_valid: - raise ValueError(f"Invalid Python expression in if condition: {error}") - return v + raise ValueError(f"Invalid Python expression in reduce: {error}") + return v -_MapReduceStep = MapReduceStep +MapReduceStep.validate_over_expression = validate_over_expression +MapReduceStep.validate_reduce_expression = validate_reduce_expression -class MapReduceStep(_MapReduceStep): - @field_validator("over") - def validate_over_expression(cls, v): - is_valid, error = validate_python_expression(v) - if not is_valid: - raise ValueError(f"Invalid Python expression in over: {error}") - return v +# Patch workflow +# -------------- - @field_validator("reduce") - def validate_reduce_expression(cls, v): - if v is not None: - is_valid, error = validate_python_expression(v) - if not is_valid: - raise ValueError(f"Invalid Python expression in reduce: {error}") - return v +_CreateTaskRequest = CreateTaskRequest + +CreateTaskRequest.model_config = ConfigDict( + **{ + **_CreateTaskRequest.model_config, + "extra": "allow", + } +) + + +@model_validator(mode="after") +def validate_subworkflows(self): + subworkflows = { + k: v + for k, v in self.model_dump().items() + if k not in _CreateTaskRequest.model_fields + } + + for workflow_name, workflow_definition in subworkflows.items(): + try: + WorkflowType.model_validate(workflow_definition) + setattr(self, workflow_name, WorkflowType(workflow_definition)) + except Exception as e: + raise ValueError(f"Invalid subworkflow '{workflow_name}': {str(e)}") + return self + + +CreateTaskRequest.validate_subworkflows = validate_subworkflows + + +# Custom types (not generated correctly) +# -------------------------------------- + +ChatMLContent = ( + list[ChatMLTextContentPart | ChatMLImageContentPart] + | Tool + | ChosenToolCall + | str + | ToolResponse + | list[ + list[ChatMLTextContentPart | ChatMLImageContentPart] + | Tool + | ChosenToolCall + | str + | ToolResponse + ] +) + +# Extract ChatMLRole +ChatMLRole = BaseEntry.model_fields["role"].annotation + +# Extract ChatMLSource +ChatMLSource = BaseEntry.model_fields["source"].annotation + +# Extract ExecutionStatus +ExecutionStatus = Execution.model_fields["status"].annotation + +# Extract TransitionType +TransitionType = Transition.model_fields["type"].annotation + +# Assertions to ensure consistency (optional, but recommended for runtime checks) +assert ChatMLRole == BaseEntry.model_fields["role"].annotation +assert ChatMLSource == BaseEntry.model_fields["source"].annotation +assert ExecutionStatus == Execution.model_fields["status"].annotation +assert TransitionType == Transition.model_fields["type"].annotation + + +# Create models +# ------------- + + +class CreateTransitionRequest(Transition): + # The following fields are optional in this + + id: UUID | None = None + execution_id: UUID | None = None + created_at: AwareDatetime | None = None + updated_at: AwareDatetime | None = None + metadata: dict[str, Any] | None = None + task_token: str | None = None + + +class CreateEntryRequest(BaseEntry): + timestamp: Annotated[ + float, Field(ge=0.0, default_factory=lambda: utcnow().timestamp()) + ] + + @classmethod + def from_model_input( + cls: Type[Self], + model: str, + *, + role: ChatMLRole, + content: ChatMLContent, + name: str | None = None, + source: ChatMLSource, + **kwargs: dict, + ) -> Self: + tokenizer: dict = select_tokenizer(model=model) + token_count = token_counter( + model=model, messages=[{"role": role, "content": content, "name": name}] + ) + + return cls( + role=role, + content=content, + name=name, + source=source, + tokenizer=tokenizer["type"], + token_count=token_count, + **kwargs, + ) # Workflow related models @@ -427,34 +544,6 @@ class Task(_Task): ] -_CreateTaskRequest = CreateTaskRequest - - -class CreateTaskRequest(_CreateTaskRequest): - model_config = ConfigDict( - **{ - **_CreateTaskRequest.model_config, - "extra": "allow", - } - ) - - @model_validator(mode="after") - def validate_subworkflows(self) -> Self: - subworkflows = { - k: v - for k, v in self.model_dump().items() - if k not in _CreateTaskRequest.model_fields - } - - for workflow_name, workflow_definition in subworkflows.items(): - try: - WorkflowType.model_validate(workflow_definition) - setattr(self, workflow_name, WorkflowType(workflow_definition)) - except Exception as e: - raise ValueError(f"Invalid subworkflow '{workflow_name}': {str(e)}") - return self - - CreateOrUpdateTaskRequest = CreateTaskRequest _PatchTaskRequest = PatchTaskRequest