diff --git a/agents-api/agents_api/activities/task_steps/prompt_step.py b/agents-api/agents_api/activities/task_steps/prompt_step.py index 3eecc2b55..e9b4daeb3 100644 --- a/agents-api/agents_api/activities/task_steps/prompt_step.py +++ b/agents-api/agents_api/activities/task_steps/prompt_step.py @@ -1,5 +1,6 @@ from beartype import beartype from temporalio import activity +from temporalio.exceptions import ApplicationError from ...clients import ( litellm, # We dont directly import `acompletion` so we can mock it @@ -63,18 +64,29 @@ async def prompt_step(context: StepContext) -> StepOutcome: else: passed_settings: dict = {} + # Wrap the prompt in a list if it is not already + if isinstance(prompt, str): + prompt = [{"role": "user", "content": prompt}] + completion_data: dict = { "model": agent_model, "tools": formatted_agent_tools or None, - ("messages" if isinstance(prompt, list) else "prompt"): prompt, + "messages": prompt, **agent_default_settings, **passed_settings, } + response = await litellm.acompletion( **completion_data, ) + if context.current_step.unwrap: + if response.choices[0].finish_reason == "tool_calls": + raise ApplicationError("Tool calls cannot be unwrapped") + + response = response.choices[0].message.content + return StepOutcome( - output=response.model_dump(), + output=response.model_dump() if hasattr(response, "model_dump") else response, next=None, ) diff --git a/agents-api/agents_api/workflows/task_execution/__init__.py b/agents-api/agents_api/workflows/task_execution/__init__.py index d1f13585f..8b7316f3b 100644 --- a/agents-api/agents_api/workflows/task_execution/__init__.py +++ b/agents-api/agents_api/workflows/task_execution/__init__.py @@ -389,9 +389,15 @@ async def run( state = PartialTransition(type="resume", output=result) - case PromptStep(), StepOutcome(output=response): + case PromptStep( + forward_tool_results=forward_tool_results, unwrap=False + ), StepOutcome(output=response): workflow.logger.debug(f"Prompt step: Received response: {response}") - if response["choices"][0]["finish_reason"] != "tool_calls": + + if ( + response["choices"][0]["finish_reason"] != "tool_calls" + or not forward_tool_results + ): workflow.logger.debug("Prompt step: Received response") state = PartialTransition(output=response) else: @@ -405,8 +411,8 @@ async def run( args=[context, tool_calls_input], schedule_to_close_timeout=timedelta(days=31), ) + # Feed the tool call results back to the model - # context.inputs.append(tool_calls_results) context.current_step.prompt.append(message) context.current_step.prompt.append(tool_calls_results) new_response = await workflow.execute_activity( @@ -418,6 +424,10 @@ async def run( ) state = PartialTransition(output=new_response.output, type="resume") + case PromptStep(unwrap=True), StepOutcome(output=response): + workflow.logger.debug(f"Prompt step: Received response: {response}") + state = PartialTransition(output=response) + case SetStep(), StepOutcome(output=evaluated_output): workflow.logger.info("Set step: Updating user state") self.update_user_state(evaluated_output) @@ -468,6 +478,7 @@ async def run( # 4. Transition to the next step workflow.logger.info(f"Transitioning after step {context.cursor.step}") + # The returned value is the transition finally created final_state = await transition(context, state) diff --git a/agents-api/tests/test_execution_workflow.py b/agents-api/tests/test_execution_workflow.py index 44f6d6ed4..26f8c3edf 100644 --- a/agents-api/tests/test_execution_workflow.py +++ b/agents-api/tests/test_execution_workflow.py @@ -338,7 +338,7 @@ async def _( assert result["value"] == data.input["test"] -# @test("workflow: log step") +@test("workflow: log step") async def _( client=cozo_client, developer_id=test_developer_id, @@ -441,7 +441,8 @@ async def _( assert result["hello"] == data.input["test"] -@test("workflow: wait for input step start") +# FIXME: This test is not working. It gets stuck +# @test("workflow: wait for input step start") async def _( client=cozo_client, developer_id=test_developer_id, @@ -818,6 +819,66 @@ async def _( assert result["role"] == "assistant" +@test("workflow: prompt step unwrap") +async def _( + client=cozo_client, + developer_id=test_developer_id, + agent=test_agent, +): + mock_model_response = ModelResponse( + id="fake_id", + choices=[Choices(message={"role": "assistant", "content": "Hello, world!"})], + created=0, + object="text_completion", + ) + + with patch("agents_api.clients.litellm.acompletion") as acompletion: + acompletion.return_value = mock_model_response + data = CreateExecutionRequest(input={"test": "input"}) + + task = create_task( + developer_id=developer_id, + agent_id=agent.id, + data=CreateTaskRequest( + **{ + "name": "test task", + "description": "test task about", + "input_schema": {"type": "object", "additionalProperties": True}, + "main": [ + { + "prompt": [ + { + "role": "user", + "content": "message", + }, + ], + "unwrap": True, + "settings": {}, + }, + ], + } + ), + client=client, + ) + + async with patch_testing_temporal() as (_, mock_run_task_execution_workflow): + execution, handle = await start_execution( + developer_id=developer_id, + task_id=task.id, + data=data, + client=client, + ) + + assert handle is not None + assert execution.task_id == task.id + assert execution.input == data.input + + mock_run_task_execution_workflow.assert_called_once() + + result = await handle.result() + assert result == "Hello, world!" + + @test("workflow: set and get steps") async def _( client=cozo_client,