From d13724bba3ff3885bb729e3a20afcd45be3a4cd4 Mon Sep 17 00:00:00 2001 From: Diwank Singh Tomer Date: Thu, 31 Oct 2024 17:04:13 -0400 Subject: [PATCH] feat(agents-api): Add python expression support to prompt step Signed-off-by: Diwank Singh Tomer --- .../agents_api/activities/execute_system.py | 1 - .../activities/task_steps/prompt_step.py | 45 +++++++++++---- agents-api/agents_api/autogen/Tools.py | 1 - .../agents_api/common/protocol/tasks.py | 18 +++++- .../workflows/task_execution/__init__.py | 28 ++-------- agents-api/tests/test_execution_workflow.py | 56 +++++++++++++++++++ typespec/tasks/steps.tsp | 2 +- 7 files changed, 115 insertions(+), 36 deletions(-) diff --git a/agents-api/agents_api/activities/execute_system.py b/agents-api/agents_api/activities/execute_system.py index a40e02f7e..abc4f1865 100644 --- a/agents-api/agents_api/activities/execute_system.py +++ b/agents-api/agents_api/activities/execute_system.py @@ -15,7 +15,6 @@ VectorDocSearchRequest, ) from ..autogen.Tools import SystemDef -from ..common.protocol.developers import Developer from ..common.protocol.tasks import StepContext from ..common.storage_handler import auto_blob_store from ..env import testing diff --git a/agents-api/agents_api/activities/task_steps/prompt_step.py b/agents-api/agents_api/activities/task_steps/prompt_step.py index 7a4605ee0..55ca3d140 100644 --- a/agents-api/agents_api/activities/task_steps/prompt_step.py +++ b/agents-api/agents_api/activities/task_steps/prompt_step.py @@ -19,6 +19,7 @@ from ...common.utils.template import render_template from ...env import anthropic_api_key, debug from ..utils import get_handler +from .base_evaluate import base_evaluate COMPUTER_USE_BETA_FLAG = "computer-use-2024-10-22" @@ -77,32 +78,60 @@ def format_tool(tool: Tool) -> dict: return formatted +EVAL_PROMPT_PREFIX = "$_ " + + @activity.defn @auto_blob_store @beartype async def prompt_step(context: StepContext) -> StepOutcome: # Get context data prompt: str | list[dict] = context.current_step.model_dump()["prompt"] - context_data: dict = context.model_dump() + context_data: dict = context.model_dump(include_remote=True) - # Render template messages - prompt = await render_template( - prompt, - context_data, - skip_vars=["developer_id"], + # If the prompt is a string and starts with $_ then we need to evaluate it + should_evaluate_prompt = isinstance(prompt, str) and prompt.startswith( + EVAL_PROMPT_PREFIX ) + + if should_evaluate_prompt: + prompt = await base_evaluate( + prompt[len(EVAL_PROMPT_PREFIX) :].strip(), context_data + ) + + if not isinstance(prompt, (str, list)): + raise ApplicationError( + "Invalid prompt expression, expected a string or list" + ) + + # Wrap the prompt in a list if it is not already + prompt = ( + prompt if isinstance(prompt, list) else [{"role": "user", "content": prompt}] + ) + + # Render template messages if we didn't evaluate the prompt + if not should_evaluate_prompt: + # Render template messages + prompt = await render_template( + prompt, + context_data, + skip_vars=["developer_id"], + ) + # Get settings and run llm agent_default_settings: dict = ( context.execution_input.agent.default_settings.model_dump() if context.execution_input.agent.default_settings else {} ) + agent_model: str = ( context.execution_input.agent.model if context.execution_input.agent.model else "gpt-4o" ) + # Get passed settings if context.current_step.settings: passed_settings: dict = context.current_step.settings.model_dump( exclude_unset=True @@ -110,10 +139,6 @@ async def prompt_step(context: StepContext) -> StepOutcome: else: passed_settings: dict = {} - # Wrap the prompt in a list if it is not already - if isinstance(prompt, str): - prompt = [{"role": "user", "content": prompt}] - # Format tools for litellm formatted_tools = [format_tool(tool) for tool in context.tools] diff --git a/agents-api/agents_api/autogen/Tools.py b/agents-api/agents_api/autogen/Tools.py index 52007fcea..88c4764fe 100644 --- a/agents-api/agents_api/autogen/Tools.py +++ b/agents-api/agents_api/autogen/Tools.py @@ -12,7 +12,6 @@ BaseModel, ConfigDict, Field, - RootModel, StrictBool, ) diff --git a/agents-api/agents_api/common/protocol/tasks.py b/agents-api/agents_api/common/protocol/tasks.py index 66ffd9632..87fd51b33 100644 --- a/agents-api/agents_api/common/protocol/tasks.py +++ b/agents-api/agents_api/common/protocol/tasks.py @@ -1,8 +1,9 @@ -from typing import Annotated, Any +from typing import Annotated, Any, Literal from uuid import UUID from beartype import beartype from temporalio import activity, workflow +from temporalio.exceptions import ApplicationError with workflow.unsafe.imports_passed_through(): from pydantic import BaseModel, Field, computed_field @@ -23,6 +24,7 @@ TaskSpecDef, TaskToolDef, Tool, + ToolRef, TransitionTarget, TransitionType, UpdateTaskRequest, @@ -154,6 +156,20 @@ def tools(self) -> list[Tool | CreateToolRequest]: task = execution_input.task agent_tools = execution_input.agent_tools + step_tools: Literal["all"] | list[ToolRef | CreateToolRequest] = getattr( + self.current_step, "tools", "all" + ) + + if step_tools != "all": + if not all( + tool and isinstance(tool, CreateToolRequest) for tool in step_tools + ): + raise ApplicationError( + "Invalid tools for step (ToolRef not supported yet)" + ) + + return step_tools + # Need to convert task.tools (list[TaskToolDef]) to list[Tool] task_tools = [] for tool in task.tools: diff --git a/agents-api/agents_api/workflows/task_execution/__init__.py b/agents-api/agents_api/workflows/task_execution/__init__.py index 51318248e..de3c1189e 100644 --- a/agents-api/agents_api/workflows/task_execution/__init__.py +++ b/agents-api/agents_api/workflows/task_execution/__init__.py @@ -202,6 +202,9 @@ async def run( retry_policy=DEFAULT_RETRY_POLICY, ) + # Init state + state = None + match context.current_step, outcome: # Handle errors (activity returns None) case step, StepOutcome(error=error) if error is not None: @@ -371,22 +374,9 @@ async def run( state = PartialTransition(type="resume", output=result) - case PromptStep(unwrap=True), StepOutcome(output=response): - finish_reason = response["choices"][0]["finish_reason"] - if finish_reason == "tool_calls": - workflow.logger.error( - "Prompt step: Tool calls not supported in unwrap mode" - ) - - state = PartialTransition( - type="error", output="Tool calls not supported in unwrap mode" - ) - await transition(context, state) - - raise ApplicationError("Tool calls not supported in unwrap mode") - - workflow.logger.debug(f"Prompt step: Received response: {response}") - state = PartialTransition(output=response) + case PromptStep(unwrap=True), StepOutcome(output=message): + workflow.logger.debug(f"Prompt step: Received response: {message}") + state = PartialTransition(output=message) case PromptStep(auto_run_tools=False, unwrap=False), StepOutcome( output=response @@ -493,12 +483,6 @@ async def run( ) state = PartialTransition(output=response) - case SetStep(), StepOutcome(output=evaluated_output): - workflow.logger.info("Set step: Updating user state") - - case SetStep(), StepOutcome(output=evaluated_output): - workflow.logger.info("Set step: Updating user state") - case SetStep(), StepOutcome(output=evaluated_output): workflow.logger.info("Set step: Updating user state") diff --git a/agents-api/tests/test_execution_workflow.py b/agents-api/tests/test_execution_workflow.py index d41aa4a6d..dbfa2f8bd 100644 --- a/agents-api/tests/test_execution_workflow.py +++ b/agents-api/tests/test_execution_workflow.py @@ -1133,6 +1133,62 @@ async def _( ] +@test("workflow: prompt step (python expression)") +async def _( + client=cozo_client, + developer_id=test_developer_id, + agent=test_agent, +): + mock_model_response = ModelResponse( + id="fake_id", + choices=[Choices(message={"role": "assistant", "content": "Hello, world!"})], + created=0, + object="text_completion", + ) + + with patch("agents_api.clients.litellm.acompletion") as acompletion: + acompletion.return_value = mock_model_response + data = CreateExecutionRequest(input={"test": "input"}) + + task = create_task( + developer_id=developer_id, + agent_id=agent.id, + data=CreateTaskRequest( + **{ + "name": "test task", + "description": "test task about", + "input_schema": {"type": "object", "additionalProperties": True}, + "main": [ + { + "prompt": "$_ [{'role': 'user', 'content': _.test}]", + "settings": {}, + }, + ], + } + ), + client=client, + ) + + async with patch_testing_temporal() as (_, mock_run_task_execution_workflow): + execution, handle = await start_execution( + developer_id=developer_id, + task_id=task.id, + data=data, + client=client, + ) + + assert handle is not None + assert execution.task_id == task.id + assert execution.input == data.input + + mock_run_task_execution_workflow.assert_called_once() + + result = await handle.result() + result = result["choices"][0]["message"] + assert result["content"] == "Hello, world!" + assert result["role"] == "assistant" + + @test("workflow: prompt step") async def _( client=cozo_client, diff --git a/typespec/tasks/steps.tsp b/typespec/tasks/steps.tsp index 5de975247..943dfbc7c 100644 --- a/typespec/tasks/steps.tsp +++ b/typespec/tasks/steps.tsp @@ -103,7 +103,7 @@ model PromptStepDef { prompt: JinjaTemplate | InputChatMLMessage[]; /** The tools to use for the prompt */ - tools: "all" | (ToolRef | CreateToolRequest)[] = #[]; + tools: "all" | (ToolRef | CreateToolRequest)[] = "all"; /** The tool choice for the prompt */ tool_choice?: ToolChoiceOption;