From a51441ac209aa6e4f2ac7e936b892c4f5a7a169d Mon Sep 17 00:00:00 2001 From: Diwank Singh Tomer Date: Fri, 11 Oct 2024 05:37:04 -0400 Subject: [PATCH] Misc fixes (#627) - **fix(agents-api): Allow url etc to be overridden by arguments** - **fix(agents-api): base64 encode http content** - **feat(agents-api): Extend stdlib** - **fix(agents-api): Fix interceptors and simplify get/set steps** - **fix(agents-api): Fix the bug where execution.output was not being set** ---- > [!IMPORTANT] > This PR enhances API call execution, extends the standard library, improves error handling, refines execution transitions, and updates workflow execution in the `agents-api` module. > > - **API Call Execution**: > - `execute_api_call` in `excecute_api_call.py` now supports overriding `url` and `headers` via `RequestArgs`. > - HTTP content is base64 encoded before being returned. > - **Standard Library Extension**: > - `utils.py` extended with classes for `re`, `json`, `yaml`, `time`, `random`, `itertools`, `functools`, `base64`, `urllib`, `string`, `zoneinfo`, `datetime`, `math`, and `statistics`. > - **Error Handling**: > - Updated `is_non_retryable_error` in `exceptions/tasks.py` to accept `BaseException`. > - `CustomActivityInterceptor` and `CustomWorkflowInterceptor` in `interceptors.py` now handle `BaseException`. > - **Execution Transition**: > - Fixed bug in `create_execution_transition.py` where `execution.output` was not set correctly for non-error transitions. > - **Workflow Execution**: > - Removed user state management from `TaskExecutionWorkflow` in `task_execution/__init__.py`. > - Updated `continue_as_child` in `helpers.py` to handle user state via workflow memo. > - **Testing**: > - Updated tests in `test_execution_workflow.py` to cover new functionalities and ensure correct behavior. > > This description was created by [Ellipsis](https://www.ellipsis.dev?ref=julep-ai%2Fjulep&utm_source=github&utm_medium=referral) for 2b785597f98c66380d79a79d249ba95ddf79dcf0. It will automatically update as commits are pushed. --------- Signed-off-by: Diwank Singh Tomer Co-authored-by: creatorrr --- .../activities/excecute_api_call.py | 14 +- agents-api/agents_api/activities/utils.py | 189 ++++++++++++++++-- .../agents_api/common/exceptions/tasks.py | 8 +- agents-api/agents_api/common/interceptors.py | 4 +- .../agents_api/common/protocol/tasks.py | 3 +- .../execution/create_execution_transition.py | 2 +- .../workflows/task_execution/__init__.py | 43 +--- .../workflows/task_execution/helpers.py | 24 ++- agents-api/tests/test_execution_workflow.py | 6 +- 9 files changed, 218 insertions(+), 75 deletions(-) diff --git a/agents-api/agents_api/activities/excecute_api_call.py b/agents-api/agents_api/activities/excecute_api_call.py index 88fabce89..e7752aa06 100644 --- a/agents-api/agents_api/activities/excecute_api_call.py +++ b/agents-api/agents_api/activities/excecute_api_call.py @@ -1,3 +1,4 @@ +import base64 from typing import Annotated, Any, Optional, TypedDict, Union import httpx @@ -20,6 +21,8 @@ class RequestArgs(TypedDict): json_: Optional[dict[str, Any]] cookies: Optional[dict[str, str]] params: Optional[Union[str, dict[str, Any]]] + url: Optional[str] + headers: Optional[dict[str, str]] @beartype @@ -29,18 +32,23 @@ async def execute_api_call( ) -> Any: try: async with httpx.AsyncClient() as client: + arg_url = request_args.pop("url", None) + arg_headers = request_args.pop("headers", None) + response = await client.request( method=api_call.method, - url=str(api_call.url), - headers=api_call.headers, + url=arg_url or str(api_call.url), + headers=arg_headers or api_call.headers, follow_redirects=api_call.follow_redirects, **request_args, ) + content_base64 = base64.b64encode(response.content).decode("ascii") + response_dict = { "status_code": response.status_code, "headers": dict(response.headers), - "content": response.content, + "content": content_base64, "json": response.json(), } diff --git a/agents-api/agents_api/activities/utils.py b/agents-api/agents_api/activities/utils.py index f9f7ded12..fca62578a 100644 --- a/agents-api/agents_api/activities/utils.py +++ b/agents-api/agents_api/activities/utils.py @@ -1,20 +1,33 @@ +import base64 +import datetime as dt +import functools +import itertools import json -from functools import reduce -from itertools import accumulate -from random import random -from time import time -from typing import Any, Callable +import math +import random +import statistics +import string +import time +import urllib.parse +from typing import Any, Callable, ParamSpec, Type, TypeVar, cast import re2 import yaml +import zoneinfo from beartype import beartype from simpleeval import EvalWithCompoundTypes, SimpleEval -from yaml import CSafeLoader +from yaml import CSafeDumper, CSafeLoader + +T = TypeVar("T") + + +P = ParamSpec("P") +R = TypeVar("R") + # TODO: We need to make sure that we dont expose any security issues ALLOWED_FUNCTIONS = { "abs": abs, - "accumulate": accumulate, "all": all, "any": any, "bool": bool, @@ -25,23 +38,169 @@ "int": int, "len": len, "list": list, - "load_json": json.loads, - "load_yaml": lambda string: yaml.load(string, Loader=CSafeLoader), "map": map, - "match_regex": lambda pattern, string: bool(re2.fullmatch(pattern, string)), "max": max, "min": min, - "random": random, "range": range, - "reduce": reduce, "round": round, - "search_regex": lambda pattern, string: re2.search(pattern, string), "set": set, "str": str, "sum": sum, - "time": time, "tuple": tuple, + "reduce": functools.reduce, "zip": zip, + "search_regex": lambda pattern, string: re2.search(pattern, string), + "load_json": json.loads, + "load_yaml": lambda string: yaml.load(string, Loader=CSafeLoader), + "match_regex": lambda pattern, string: bool(re2.fullmatch(pattern, string)), +} + + +class stdlib_re: + fullmatch = re2.fullmatch + search = re2.search + escape = re2.escape + findall = re2.findall + finditer = re2.finditer + match = re2.match + split = re2.split + sub = re2.sub + subn = re2.subn + + +class stdlib_json: + loads = json.loads + dumps = json.dumps + + +class stdlib_yaml: + load = lambda string: yaml.load(string, Loader=CSafeLoader) # noqa: E731 + dump = lambda value: yaml.dump(value, Dumper=CSafeDumper) # noqa: E731 + + +class stdlib_time: + strftime = time.strftime + strptime = time.strptime + time = time + + +class stdlib_random: + choice = random.choice + choices = random.choices + sample = random.sample + shuffle = random.shuffle + randrange = random.randrange + randint = random.randint + random = random.random + + +class stdlib_itertools: + accumulate = itertools.accumulate + + +class stdlib_functools: + partial = functools.partial + reduce = functools.reduce + + +class stdlib_base64: + b64encode = base64.b64encode + b64decode = base64.b64decode + + +class stdlib_urllib: + class parse: + urlparse = urllib.parse.urlparse + urlencode = urllib.parse.urlencode + unquote = urllib.parse.unquote + quote = urllib.parse.quote + parse_qs = urllib.parse.parse_qs + parse_qsl = urllib.parse.parse_qsl + urlsplit = urllib.parse.urlsplit + urljoin = urllib.parse.urljoin + unwrap = urllib.parse.unwrap + + +class stdlib_string: + ascii_letters = string.ascii_letters + ascii_lowercase = string.ascii_lowercase + ascii_uppercase = string.ascii_uppercase + digits = string.digits + hexdigits = string.hexdigits + octdigits = string.octdigits + punctuation = string.punctuation + whitespace = string.whitespace + printable = string.printable + + +class stdlib_zoneinfo: + ZoneInfo = zoneinfo.ZoneInfo + + +class stdlib_datetime: + class timezone: + class utc: + utc = dt.timezone.utc + + class datetime: + now = dt.datetime.now + datetime = dt.datetime + timedelta = dt.timedelta + date = dt.date + time = dt.time + + timedelta = dt.timedelta + + +class stdlib_math: + sqrt = math.sqrt + exp = math.exp + ceil = math.ceil + floor = math.floor + isinf = math.isinf + isnan = math.isnan + log = math.log + log10 = math.log10 + log2 = math.log2 + pow = math.pow + sin = math.sin + cos = math.cos + tan = math.tan + asin = math.asin + acos = math.acos + atan = math.atan + atan2 = math.atan2 + + pi = math.pi + e = math.e + + +class stdlib_statistics: + mean = statistics.mean + stdev = statistics.stdev + geometric_mean = statistics.geometric_mean + median = statistics.median + median_low = statistics.median_low + median_high = statistics.median_high + mode = statistics.mode + quantiles = statistics.quantiles + + +stdlib = { + "re": stdlib_re, + "json": stdlib_json, + "yaml": stdlib_yaml, + "time": stdlib_time, + "random": stdlib_random, + "itertools": stdlib_itertools, + "functools": stdlib_functools, + "base64": stdlib_base64, + "urllib": stdlib_urllib, + "string": stdlib_string, + "zoneinfo": stdlib_zoneinfo, + "datetime": stdlib_datetime, + "math": stdlib_math, + "statistics": stdlib_statistics, } @@ -50,7 +209,7 @@ def get_evaluator( names: dict[str, Any], extra_functions: dict[str, Callable] | None = None ) -> SimpleEval: evaluator = EvalWithCompoundTypes( - names=names, functions=ALLOWED_FUNCTIONS | (extra_functions or {}) + names=names | stdlib, functions=ALLOWED_FUNCTIONS | (extra_functions or {}) ) return evaluator diff --git a/agents-api/agents_api/common/exceptions/tasks.py b/agents-api/agents_api/common/exceptions/tasks.py index 8ead1e7e2..81331234c 100644 --- a/agents-api/agents_api/common/exceptions/tasks.py +++ b/agents-api/agents_api/common/exceptions/tasks.py @@ -20,7 +20,7 @@ import temporalio.exceptions # List of error types that should not be retried -NON_RETRYABLE_ERROR_TYPES = [ +NON_RETRYABLE_ERROR_TYPES = ( # Temporal-specific errors temporalio.exceptions.WorkflowAlreadyStartedError, temporalio.exceptions.TerminatedError, @@ -99,10 +99,10 @@ litellm.exceptions.ServiceUnavailableError, litellm.exceptions.OpenAIError, litellm.exceptions.APIError, -] +) -def is_non_retryable_error(error: Exception) -> bool: +def is_non_retryable_error(error: BaseException) -> bool: """ Determines if the given error is non-retryable. @@ -115,4 +115,4 @@ def is_non_retryable_error(error: Exception) -> bool: Returns: bool: True if the error is non-retryable, False otherwise. """ - return isinstance(error, tuple(NON_RETRYABLE_ERROR_TYPES)) + return isinstance(error, NON_RETRYABLE_ERROR_TYPES) diff --git a/agents-api/agents_api/common/interceptors.py b/agents-api/agents_api/common/interceptors.py index 2fb077c45..c6e8e2eaf 100644 --- a/agents-api/agents_api/common/interceptors.py +++ b/agents-api/agents_api/common/interceptors.py @@ -31,7 +31,7 @@ class CustomActivityInterceptor(ActivityInboundInterceptor): async def execute_activity(self, input: ExecuteActivityInput): try: return await super().execute_activity(input) - except Exception as e: + except BaseException as e: if is_non_retryable_error(e): raise ApplicationError( str(e), @@ -53,7 +53,7 @@ class CustomWorkflowInterceptor(WorkflowInboundInterceptor): async def execute_workflow(self, input: ExecuteWorkflowInput): try: return await super().execute_workflow(input) - except Exception as e: + except BaseException as e: if is_non_retryable_error(e): raise ApplicationError( str(e), diff --git a/agents-api/agents_api/common/protocol/tasks.py b/agents-api/agents_api/common/protocol/tasks.py index bbb5c28d3..bd4aaa5a2 100644 --- a/agents-api/agents_api/common/protocol/tasks.py +++ b/agents-api/agents_api/common/protocol/tasks.py @@ -118,7 +118,8 @@ } # type: ignore -PartialTransition: Type[BaseModel] = create_partial_model(CreateTransitionRequest) +class PartialTransition(create_partial_model(CreateTransitionRequest)): + user_state: dict[str, Any] = Field(default_factory=dict) class ExecutionInput(BaseModel): diff --git a/agents-api/agents_api/models/execution/create_execution_transition.py b/agents-api/agents_api/models/execution/create_execution_transition.py index f40395126..2b1c09ae8 100644 --- a/agents-api/agents_api/models/execution/create_execution_transition.py +++ b/agents-api/agents_api/models/execution/create_execution_transition.py @@ -176,7 +176,7 @@ def create_execution_transition( data=UpdateExecutionRequest( status=transition_to_execution_status[data.type] ), - output=data.output if data.type == "finish" else None, + output=data.output if data.type != "error" else None, error=str(data.output) if data.type == "error" and data.output else None, diff --git a/agents-api/agents_api/workflows/task_execution/__init__.py b/agents-api/agents_api/workflows/task_execution/__init__.py index edf54fb12..155b49397 100644 --- a/agents-api/agents_api/workflows/task_execution/__init__.py +++ b/agents-api/agents_api/workflows/task_execution/__init__.py @@ -118,30 +118,6 @@ # Main workflow definition @workflow.defn class TaskExecutionWorkflow: - user_state: dict[str, Any] = {} - - def __init__(self) -> None: - self.user_state = {} - - # TODO: Add endpoints for getting and setting user state for an execution - # Query methods for user state - @workflow.query - def get_user_state(self) -> dict[str, Any]: - return self.user_state - - @workflow.query - def get_user_state_by_key(self, key: str) -> Any: - return self.user_state.get(key) - - # Signal methods for updating user state - @workflow.signal - def set_user_state(self, key: str, value: Any) -> None: - self.user_state[key] = value - - @workflow.signal - def update_user_state(self, values: dict[str, Any]) -> None: - self.user_state.update(values) - # Main workflow run method @workflow.run async def run( @@ -149,11 +125,7 @@ async def run( execution_input: ExecutionInput, start: TransitionTarget = TransitionTarget(workflow="main", step=0), previous_inputs: list[Any] = [], - user_state: dict[str, Any] = {}, ) -> Any: - # Set the initial user state - self.user_state = user_state - workflow.logger.info( f"TaskExecutionWorkflow for task {execution_input.task.id}" f" [LOC {start.workflow}.{start.step}]" @@ -258,7 +230,6 @@ async def run( switch=switch, index=index, previous_inputs=previous_inputs, - user_state=self.user_state, ) state = PartialTransition(output=result) @@ -276,7 +247,6 @@ async def run( else_branch=else_branch, condition=condition, previous_inputs=previous_inputs, - user_state=self.user_state, ) state = PartialTransition(output=result) @@ -288,7 +258,6 @@ async def run( do_step=do_step, items=items, previous_inputs=previous_inputs, - user_state=self.user_state, ) state = PartialTransition(output=result) @@ -303,7 +272,6 @@ async def run( reduce=reduce, initial=initial, previous_inputs=previous_inputs, - user_state=self.user_state, ) state = PartialTransition(output=result) @@ -316,7 +284,6 @@ async def run( map_defn=map_defn, items=items, previous_inputs=previous_inputs, - user_state=self.user_state, initial=initial, reduce=reduce, parallelism=parallelism, @@ -376,7 +343,6 @@ async def run( context, start=yield_next_target, previous_inputs=[output], - user_state=self.user_state, ) state = PartialTransition(output=result) @@ -439,14 +405,15 @@ async def run( case SetStep(), StepOutcome(output=evaluated_output): workflow.logger.info("Set step: Updating user state") - self.update_user_state(evaluated_output) # Pass along the previous output unchanged - state = PartialTransition(output=context.current_input) + state = PartialTransition( + output=context.current_input, user_state=evaluated_output + ) case GetStep(get=key), _: workflow.logger.info(f"Get step: Fetching '{key}' from user state") - value = self.get_user_state_by_key(key) + value = workflow.memo_value(key, default=None) workflow.logger.debug(f"Retrieved value: {value}") state = PartialTransition(output=value) @@ -596,5 +563,5 @@ def model_dump(obj): context.execution_input, start=final_state.next, previous_inputs=previous_inputs + [final_state.output], - user_state=self.user_state, + user_state=state.user_state, ) diff --git a/agents-api/agents_api/workflows/task_execution/helpers.py b/agents-api/agents_api/workflows/task_execution/helpers.py index 04449db58..271f33dbf 100644 --- a/agents-api/agents_api/workflows/task_execution/helpers.py +++ b/agents-api/agents_api/workflows/task_execution/helpers.py @@ -27,15 +27,23 @@ async def continue_as_child( previous_inputs: list[Any], user_state: dict[str, Any] = {}, ) -> Any: - return await workflow.execute_child_workflow( - "TaskExecutionWorkflow", + info = workflow.info() + + if info.is_continue_as_new_suggested(): + run = workflow.continue_as_new + else: + run = lambda *args, **kwargs: workflow.execute_child_workflow( # noqa: E731 + info.workflow_type, *args, **kwargs + ) + + return await run( args=[ execution_input, start, previous_inputs, - user_state, ], retry_policy=DEFAULT_RETRY_POLICY, + memo=workflow.memo() | user_state, ) @@ -46,7 +54,7 @@ async def execute_switch_branch( switch: list, index: int, previous_inputs: list[Any], - user_state: dict[str, Any], + user_state: dict[str, Any] = {}, ) -> Any: workflow.logger.info(f"Switch step: Chose branch {index}") chosen_branch = switch[index] @@ -77,7 +85,7 @@ async def execute_if_else_branch( else_branch: WorkflowStep, condition: bool, previous_inputs: list[Any], - user_state: dict[str, Any], + user_state: dict[str, Any] = {}, ) -> Any: workflow.logger.info(f"If-Else step: Condition evaluated to {condition}") chosen_branch = then_branch if condition else else_branch @@ -108,7 +116,7 @@ async def execute_foreach_step( do_step: WorkflowStep, items: list[Any], previous_inputs: list[Any], - user_state: dict[str, Any], + user_state: dict[str, Any] = {}, ) -> Any: workflow.logger.info(f"Foreach step: Iterating over {len(items)} items") results = [] @@ -142,7 +150,7 @@ async def execute_map_reduce_step( map_defn: WorkflowStep, items: list[Any], previous_inputs: list[Any], - user_state: dict[str, Any], + user_state: dict[str, Any] = {}, reduce: str | None = None, initial: Any = [], ) -> Any: @@ -185,7 +193,7 @@ async def execute_map_reduce_step_parallel( map_defn: WorkflowStep, items: list[Any], previous_inputs: list[Any], - user_state: dict[str, Any], + user_state: dict[str, Any] = {}, initial: Any = [], reduce: str | None = None, parallelism: int = task_max_parallelism, diff --git a/agents-api/tests/test_execution_workflow.py b/agents-api/tests/test_execution_workflow.py index e5ef7110a..f8a89cb62 100644 --- a/agents-api/tests/test_execution_workflow.py +++ b/agents-api/tests/test_execution_workflow.py @@ -819,9 +819,9 @@ async def _( "input_schema": {"type": "object", "additionalProperties": True}, "main": [ { - "if": "True", + "if": "False", "then": {"evaluate": {"hello": '"world"'}}, - "else": {"evaluate": {"hello": '"nope"'}}, + "else": {"evaluate": {"hello": "random.randint(0, 10)"}}, }, ], } @@ -849,7 +849,7 @@ async def _( mock_run_task_execution_workflow.assert_called_once() result = await handle.result() - assert result["hello"] == "world" + assert result["hello"] in [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] @test("workflow: switch step")