From af121a2929cad3a7dbb0b408e3fce488a3a5f63f Mon Sep 17 00:00:00 2001 From: Dmitry Paramonov Date: Sat, 24 Aug 2024 23:37:30 +0300 Subject: [PATCH] fix: Fix templates rendering for different kind of messages --- .../activities/task_steps/prompt_step.py | 71 ++++++++++++++++--- agents-api/agents_api/activities/utils.py | 1 + .../agents_api/common/utils/template.py | 6 +- .../routers/tasks/create_task_execution.py | 2 +- .../tests/sample_tasks/test_find_selector.py | 12 +--- agents-api/tests/test_execution_workflow.py | 2 +- 6 files changed, 70 insertions(+), 24 deletions(-) diff --git a/agents-api/agents_api/activities/task_steps/prompt_step.py b/agents-api/agents_api/activities/task_steps/prompt_step.py index 003809909..84a569ee1 100644 --- a/agents-api/agents_api/activities/task_steps/prompt_step.py +++ b/agents-api/agents_api/activities/task_steps/prompt_step.py @@ -3,7 +3,12 @@ from beartype import beartype from temporalio import activity -from ...autogen.openapi_model import ChatSettings, InputChatMLMessage +from ...autogen.openapi_model import ( + ChatSettings, + Content, + ContentModel, + InputChatMLMessage, +) from ...clients import ( litellm, # We dont directly import `acompletion` so we can mock it ) @@ -11,6 +16,32 @@ from ...common.utils.template import render_template +def _content_to_dict( + content: str | list[str] | list[Content | ContentModel], role: str +) -> str | list[dict]: + if isinstance(content, str): + return content + + result = [] + for s in content: + if isinstance(s, str): + result.append({"content": {"type": "text", "text": s, "role": role}}) + elif isinstance(s, Content): + result.append({"content": {"type": s.type, "text": s.text, "role": role}}) + elif isinstance(s, ContentModel): + result.append( + { + "content": { + "type": s.type, + "image_url": {"url": s.image_url.url}, + "role": role, + } + } + ) + + return result + + @activity.defn @beartype async def prompt_step(context: StepContext) -> StepOutcome: @@ -27,26 +58,44 @@ async def prompt_step(context: StepContext) -> StepOutcome: template_messages: list[InputChatMLMessage] = prompt messages = await asyncio.gather( *[ - render_template(msg.content, context_data, skip_vars=["developer_id"]) + render_template( + _content_to_dict(msg.content, msg.role), + context_data, + skip_vars=["developer_id"], + ) for msg in template_messages ] ) - messages = [ - ( - InputChatMLMessage(role="user", content=m) - if isinstance(m, str) - else InputChatMLMessage(**m) - ) - for m in messages - ] + result_messages = [] + for m in messages: + if isinstance(m, str): + msg = InputChatMLMessage(role="user", content=m) + else: + msg = [] + for d in m: + role = d["content"].get("role") + d["content"] = [d["content"]] + d["role"] = role + msg.append(InputChatMLMessage(**d)) + + result_messages.append(msg) + + # messages = [ + # ( + # InputChatMLMessage(role="user", content=m) + # if isinstance(m, str) + # else [InputChatMLMessage(**d) for d in m] + # ) + # for m in messages + # ] # Get settings and run llm settings: ChatSettings = context.current_step.settings or ChatSettings() settings_data: dict = settings.model_dump() response = await litellm.acompletion( - messages=messages, + messages=result_messages, **settings_data, ) diff --git a/agents-api/agents_api/activities/utils.py b/agents-api/agents_api/activities/utils.py index 3a7c74f1c..231dee595 100644 --- a/agents-api/agents_api/activities/utils.py +++ b/agents-api/agents_api/activities/utils.py @@ -15,6 +15,7 @@ "load_json": json.loads, } + @beartype def get_evaluator(names: dict[str, Any]) -> SimpleEval: evaluator = EvalWithCompoundTypes(names=names, functions=ALLOWED_FUNCTIONS) diff --git a/agents-api/agents_api/common/utils/template.py b/agents-api/agents_api/common/utils/template.py index 1fcc143b3..35ae2c350 100644 --- a/agents-api/agents_api/common/utils/template.py +++ b/agents-api/agents_api/common/utils/template.py @@ -70,7 +70,11 @@ async def render_template_parts( # Parse template # FIXME: should template_strings contain a list of ChatMLTextContentPart? Should we handle it somehow? templates = [ - (jinja_env.from_string(msg["text"]) if msg["type"] == "text" else None) + ( + jinja_env.from_string(msg["content"]["text"]) + if msg["content"]["type"] == "text" + else None + ) for msg in template_strings ] diff --git a/agents-api/agents_api/routers/tasks/create_task_execution.py b/agents-api/agents_api/routers/tasks/create_task_execution.py index ec17146af..0497777bf 100644 --- a/agents-api/agents_api/routers/tasks/create_task_execution.py +++ b/agents-api/agents_api/routers/tasks/create_task_execution.py @@ -59,7 +59,7 @@ async def start_execution( client=client, ) - job_id=uuid4() + job_id = uuid4() try: handle = await run_task_execution_workflow( diff --git a/agents-api/tests/sample_tasks/test_find_selector.py b/agents-api/tests/sample_tasks/test_find_selector.py index 40b649ca3..caebd7547 100644 --- a/agents-api/tests/sample_tasks/test_find_selector.py +++ b/agents-api/tests/sample_tasks/test_find_selector.py @@ -85,7 +85,6 @@ async def _( agent_id = str(agent.id) task_id = str(uuid4()) - with patch_embed_acompletion(), open( f"{this_dir}/find_selector.yaml", "r" ) as sample_file: @@ -106,12 +105,7 @@ async def _( input = dict( screenshot_base64="iVBORw0KGgoAAAANSUhEUgAAABAAAAAQCAYAAAAf8/9hAAAACXBIWXMAAAsTAAALEwEAmpwYAAAA", - network_requests=[{ - "request": {}, - "response": { - "body": "Lady Gaga" - } - }], + network_requests=[{"request": {}, "response": {"body": "Lady Gaga"}}], parameters=["name"], ) execution_data = dict(input=input) @@ -122,8 +116,6 @@ async def _( json=execution_data, ).json() - handle = temporal_client.get_workflow_handle( - execution_created["jobs"][0] - ) + handle = temporal_client.get_workflow_handle(execution_created["jobs"][0]) await handle.result() diff --git a/agents-api/tests/test_execution_workflow.py b/agents-api/tests/test_execution_workflow.py index d55b62362..395f77de9 100644 --- a/agents-api/tests/test_execution_workflow.py +++ b/agents-api/tests/test_execution_workflow.py @@ -729,7 +729,7 @@ async def _( ) as task_file: task_definition = yaml.safe_load(task_file) acompletion.return_value = mock_model_response - data = CreateExecutionRequest(input={"test": "input"}) + data = CreateExecutionRequest(input={"parameters": ["param1", "param2"]}) task = create_task( developer_id=developer_id,