diff --git a/agents-api/agents_api/activities/task_steps/base_evaluate.py b/agents-api/agents_api/activities/task_steps/base_evaluate.py index 4b98cd13a..e65f1fe66 100644 --- a/agents-api/agents_api/activities/task_steps/base_evaluate.py +++ b/agents-api/agents_api/activities/task_steps/base_evaluate.py @@ -1,6 +1,8 @@ from typing import Any from beartype import beartype +from box import Box +from openai import BaseModel from temporalio import activity from ...env import testing @@ -15,6 +17,15 @@ async def base_evaluate( input_len = 1 if isinstance(exprs, str) else len(exprs) assert input_len > 0, "exprs must be a non-empty string, list or dict" + # Turn the nested dict values from pydantic to dicts where possible + values = { + k: v.model_dump() if isinstance(v, BaseModel) else v for k, v in values.items() + } + + # TODO: We should make this frozen_box=True, but we need to make sure that + # we don't break anything + values = Box(values, frozen_box=False, conversion_box=False) + evaluator = get_evaluator(names=values) try: diff --git a/agents-api/agents_api/activities/utils.py b/agents-api/agents_api/activities/utils.py index 231dee595..21c6b3675 100644 --- a/agents-api/agents_api/activities/utils.py +++ b/agents-api/agents_api/activities/utils.py @@ -8,6 +8,7 @@ from yaml import CSafeLoader ALLOWED_FUNCTIONS = { + "zip": zip, "len": len, "load_yaml": lambda string: yaml.load(string, Loader=CSafeLoader), "match_regex": lambda pattern, string: bool(re2.fullmatch(pattern, string)), diff --git a/agents-api/agents_api/workflows/task_execution.py b/agents-api/agents_api/workflows/task_execution.py index b331a52c0..72c7393ed 100644 --- a/agents-api/agents_api/workflows/task_execution.py +++ b/agents-api/agents_api/workflows/task_execution.py @@ -294,15 +294,13 @@ async def transition(**kwargs) -> None: previous_inputs + [item], ] + # TODO: We should parallelize this # Execute the chosen branch and come back here output = await workflow.execute_child_workflow( TaskExecutionWorkflow.run, args=map_reduce_args, ) - if hasattr(output, "model_dump"): - output = output.model_dump() - initial = await execute_activity( task_steps.base_evaluate, args=[ @@ -365,7 +363,7 @@ async def transition(**kwargs) -> None: ) case PromptStep(), StepOutcome(output=response): - state.output = response.get("choices", [{}])[0].get("message") + state.output = response case _: raise ApplicationError("Not implemented") diff --git a/agents-api/poetry.lock b/agents-api/poetry.lock index e9eba34b2..b703a0814 100644 --- a/agents-api/poetry.lock +++ b/agents-api/poetry.lock @@ -3076,6 +3076,41 @@ files = [ [package.extras] diagrams = ["jinja2", "railroad-diagrams"] +[[package]] +name = "python-box" +version = "7.2.0" +description = "Advanced Python dictionaries with dot notation access" +optional = false +python-versions = ">=3.8" +files = [ + {file = "python_box-7.2.0-cp310-cp310-macosx_11_0_x86_64.whl", hash = "sha256:6bdeec791e25258351388b3029a3ec5da302bb9ed3be175493c43cdc6c47f5e3"}, + {file = "python_box-7.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c449f7b3756a71479fa9c61a86e344ac00ed782a66d7662590f0afa294249d18"}, + {file = "python_box-7.2.0-cp310-cp310-win_amd64.whl", hash = "sha256:6b0d61f182d394106d963232854e495b51edc178faa5316a797be1178212d7e0"}, + {file = "python_box-7.2.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:e2d752de8c1204255bf7b0c814c59ef48293c187a7e9fdcd2fefa28024b72032"}, + {file = "python_box-7.2.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e8a6c35ea356a386077935958a5debcd5b229b9a1b3b26287a52dfe1a7e65d99"}, + {file = "python_box-7.2.0-cp311-cp311-win_amd64.whl", hash = "sha256:32ed58ec4d9e5475efe69f9c7d773dfea90a6a01979e776da93fd2b0a5d04429"}, + {file = "python_box-7.2.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:2a2d664c6a27f7515469b6f1e461935a2038ee130b7d194b4b4db4e85d363618"}, + {file = "python_box-7.2.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e8a5a7365db1aaf600d3e8a2747fcf6833beb5d45439a54318548f02e302e3ec"}, + {file = "python_box-7.2.0-cp312-cp312-win_amd64.whl", hash = "sha256:739f827056ea148cbea3122d4617c994e829b420b1331183d968b175304e3a4f"}, + {file = "python_box-7.2.0-cp38-cp38-macosx_11_0_x86_64.whl", hash = "sha256:2617ef3c3d199f55f63c908f540a4dc14ced9b18533a879e6171c94a6a436f23"}, + {file = "python_box-7.2.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ffd866bed03087b1d8340014da8c3aaae19135767580641df1b4ae6fff6ac0aa"}, + {file = "python_box-7.2.0-cp38-cp38-win_amd64.whl", hash = "sha256:9681f059e7e92bdf20782cd9ea6e533d4711fc7b8c57a462922a025d46add4d0"}, + {file = "python_box-7.2.0-cp39-cp39-macosx_11_0_x86_64.whl", hash = "sha256:6b59b1e2741c9ceecdf5a5bd9b90502c24650e609cd824d434fed3b6f302b7bb"}, + {file = "python_box-7.2.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e23fae825d809ae7520fdeac88bb52be55a3b63992120a00e381783669edf589"}, + {file = "python_box-7.2.0-cp39-cp39-win_amd64.whl", hash = "sha256:573b1abdcb7bd745fa404444f060ee62fc35a74f067181e55dcb43cfe92f2827"}, + {file = "python_box-7.2.0-py3-none-any.whl", hash = "sha256:a3c90832dd772cb0197fdb5bc06123b6e1b846899a1b53d9c39450d27a584829"}, + {file = "python_box-7.2.0.tar.gz", hash = "sha256:551af20bdab3a60a2a21e3435120453c4ca32f7393787c3a5036e1d9fc6a0ede"}, +] + +[package.extras] +all = ["msgpack", "ruamel.yaml (>=0.17)", "toml"] +msgpack = ["msgpack"] +pyyaml = ["PyYAML"] +ruamel-yaml = ["ruamel.yaml (>=0.17)"] +toml = ["toml"] +tomli = ["tomli", "tomli-w"] +yaml = ["ruamel.yaml (>=0.17)"] + [[package]] name = "python-dateutil" version = "2.9.0.post0" @@ -4544,4 +4579,4 @@ test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools", [metadata] lock-version = "2.0" python-versions = ">=3.11,<3.12" -content-hash = "f2e6680de8e96b10ec7a9e5edf15223fd4842342db950778ff1d4bb1940e7ae0" +content-hash = "ce6fb9e63ff83a508ad9660217393b84e7174493f6b30f21f6a23fe9b4262fc8" diff --git a/agents-api/pyproject.toml b/agents-api/pyproject.toml index d74d9dba9..eceeba341 100644 --- a/agents-api/pyproject.toml +++ b/agents-api/pyproject.toml @@ -34,6 +34,7 @@ lz4 = "^4.3.3" pyyaml = "^6.0.2" google-re2 = "^1.1.20240702" +python-box = "^7.2.0" [tool.poetry.group.dev.dependencies] ipython = "^8.26.0" ruff = "^0.5.5" diff --git a/agents-api/tests/sample_tasks/find_selector.yaml b/agents-api/tests/sample_tasks/find_selector.yaml index 2c94df23f..465a2b578 100644 --- a/agents-api/tests/sample_tasks/find_selector.yaml +++ b/agents-api/tests/sample_tasks/find_selector.yaml @@ -63,11 +63,11 @@ main: image_url: url: "{{inputs[0].screenshot_base64}}" - over: _.parameters + over: _["parameters"] reduce: >- results + [ - yaml.safe_load(_["choices"][0]["message"]["content"].trim()) + load_yaml(_["choices"][0]["message"].content.strip()) ] - evaluate: @@ -75,16 +75,12 @@ main: [ {"value": result["value"], "network_request": request} for request in inputs[0]["network_requests"] - if result["value"] in nr["response"]["body"] for result in _ - if result["found"] + if result["found"] and result["value"] in request["response"]["body"] ] - if: len(_["result"]) > 0 then: - workflow: find_selectors - arguments: - results: list(zip(_, execution.input.network_requests)) - parameters: execution.input.parameters + log: list(zip(_, inputs[0]["network_requests"])) else: - error: "Could not find the selector in any of the network requests" + error: "Could not find the selector in any of the network requests" \ No newline at end of file diff --git a/agents-api/tests/sample_tasks/test_find_selector.py b/agents-api/tests/sample_tasks/test_find_selector.py index caebd7547..67ad88607 100644 --- a/agents-api/tests/sample_tasks/test_find_selector.py +++ b/agents-api/tests/sample_tasks/test_find_selector.py @@ -85,9 +85,9 @@ async def _( agent_id = str(agent.id) task_id = str(uuid4()) - with patch_embed_acompletion(), open( - f"{this_dir}/find_selector.yaml", "r" - ) as sample_file: + with patch_embed_acompletion( + output={"role": "assistant", "content": "found: true\nvalue: 'Gaga'"} + ), open(f"{this_dir}/find_selector.yaml", "r") as sample_file: task_def = sample_file.read() async with patch_http_client_with_temporal(