Skip to content

Commit

Permalink
feat(agents-api): Add python expression support to prompt step
Browse files Browse the repository at this point in the history
Signed-off-by: Diwank Singh Tomer <[email protected]>
  • Loading branch information
creatorrr committed Oct 31, 2024
1 parent dc4d8ea commit d13724b
Show file tree
Hide file tree
Showing 7 changed files with 115 additions and 36 deletions.
1 change: 0 additions & 1 deletion agents-api/agents_api/activities/execute_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
VectorDocSearchRequest,
)
from ..autogen.Tools import SystemDef
from ..common.protocol.developers import Developer
from ..common.protocol.tasks import StepContext
from ..common.storage_handler import auto_blob_store
from ..env import testing
Expand Down
45 changes: 35 additions & 10 deletions agents-api/agents_api/activities/task_steps/prompt_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from ...common.utils.template import render_template
from ...env import anthropic_api_key, debug
from ..utils import get_handler
from .base_evaluate import base_evaluate

COMPUTER_USE_BETA_FLAG = "computer-use-2024-10-22"

Expand Down Expand Up @@ -77,43 +78,67 @@ def format_tool(tool: Tool) -> dict:
return formatted


EVAL_PROMPT_PREFIX = "$_ "


@activity.defn
@auto_blob_store
@beartype
async def prompt_step(context: StepContext) -> StepOutcome:
# Get context data
prompt: str | list[dict] = context.current_step.model_dump()["prompt"]
context_data: dict = context.model_dump()
context_data: dict = context.model_dump(include_remote=True)

# Render template messages
prompt = await render_template(
prompt,
context_data,
skip_vars=["developer_id"],
# If the prompt is a string and starts with $_ then we need to evaluate it
should_evaluate_prompt = isinstance(prompt, str) and prompt.startswith(
EVAL_PROMPT_PREFIX
)

if should_evaluate_prompt:
prompt = await base_evaluate(
prompt[len(EVAL_PROMPT_PREFIX) :].strip(), context_data
)

if not isinstance(prompt, (str, list)):
raise ApplicationError(
"Invalid prompt expression, expected a string or list"
)

# Wrap the prompt in a list if it is not already
prompt = (
prompt if isinstance(prompt, list) else [{"role": "user", "content": prompt}]
)

# Render template messages if we didn't evaluate the prompt
if not should_evaluate_prompt:
# Render template messages
prompt = await render_template(
prompt,
context_data,
skip_vars=["developer_id"],
)

# Get settings and run llm
agent_default_settings: dict = (
context.execution_input.agent.default_settings.model_dump()
if context.execution_input.agent.default_settings
else {}
)

agent_model: str = (
context.execution_input.agent.model
if context.execution_input.agent.model
else "gpt-4o"
)

# Get passed settings
if context.current_step.settings:
passed_settings: dict = context.current_step.settings.model_dump(
exclude_unset=True
)
else:
passed_settings: dict = {}

# Wrap the prompt in a list if it is not already
if isinstance(prompt, str):
prompt = [{"role": "user", "content": prompt}]

# Format tools for litellm
formatted_tools = [format_tool(tool) for tool in context.tools]

Expand Down
1 change: 0 additions & 1 deletion agents-api/agents_api/autogen/Tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
BaseModel,
ConfigDict,
Field,
RootModel,
StrictBool,
)

Expand Down
18 changes: 17 additions & 1 deletion agents-api/agents_api/common/protocol/tasks.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from typing import Annotated, Any
from typing import Annotated, Any, Literal
from uuid import UUID

from beartype import beartype
from temporalio import activity, workflow
from temporalio.exceptions import ApplicationError

with workflow.unsafe.imports_passed_through():
from pydantic import BaseModel, Field, computed_field
Expand All @@ -23,6 +24,7 @@
TaskSpecDef,
TaskToolDef,
Tool,
ToolRef,
TransitionTarget,
TransitionType,
UpdateTaskRequest,
Expand Down Expand Up @@ -154,6 +156,20 @@ def tools(self) -> list[Tool | CreateToolRequest]:
task = execution_input.task
agent_tools = execution_input.agent_tools

step_tools: Literal["all"] | list[ToolRef | CreateToolRequest] = getattr(
self.current_step, "tools", "all"
)

if step_tools != "all":
if not all(
tool and isinstance(tool, CreateToolRequest) for tool in step_tools
):
raise ApplicationError(
"Invalid tools for step (ToolRef not supported yet)"
)

return step_tools

# Need to convert task.tools (list[TaskToolDef]) to list[Tool]
task_tools = []
for tool in task.tools:
Expand Down
28 changes: 6 additions & 22 deletions agents-api/agents_api/workflows/task_execution/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,9 @@ async def run(
retry_policy=DEFAULT_RETRY_POLICY,
)

# Init state
state = None

match context.current_step, outcome:
# Handle errors (activity returns None)
case step, StepOutcome(error=error) if error is not None:
Expand Down Expand Up @@ -371,22 +374,9 @@ async def run(

state = PartialTransition(type="resume", output=result)

case PromptStep(unwrap=True), StepOutcome(output=response):
finish_reason = response["choices"][0]["finish_reason"]
if finish_reason == "tool_calls":
workflow.logger.error(
"Prompt step: Tool calls not supported in unwrap mode"
)

state = PartialTransition(
type="error", output="Tool calls not supported in unwrap mode"
)
await transition(context, state)

raise ApplicationError("Tool calls not supported in unwrap mode")

workflow.logger.debug(f"Prompt step: Received response: {response}")
state = PartialTransition(output=response)
case PromptStep(unwrap=True), StepOutcome(output=message):
workflow.logger.debug(f"Prompt step: Received response: {message}")
state = PartialTransition(output=message)

case PromptStep(auto_run_tools=False, unwrap=False), StepOutcome(
output=response
Expand Down Expand Up @@ -493,12 +483,6 @@ async def run(
)
state = PartialTransition(output=response)

case SetStep(), StepOutcome(output=evaluated_output):
workflow.logger.info("Set step: Updating user state")

case SetStep(), StepOutcome(output=evaluated_output):
workflow.logger.info("Set step: Updating user state")

case SetStep(), StepOutcome(output=evaluated_output):
workflow.logger.info("Set step: Updating user state")

Expand Down
56 changes: 56 additions & 0 deletions agents-api/tests/test_execution_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1133,6 +1133,62 @@ async def _(
]


@test("workflow: prompt step (python expression)")
async def _(
client=cozo_client,
developer_id=test_developer_id,
agent=test_agent,
):
mock_model_response = ModelResponse(
id="fake_id",
choices=[Choices(message={"role": "assistant", "content": "Hello, world!"})],
created=0,
object="text_completion",
)

with patch("agents_api.clients.litellm.acompletion") as acompletion:
acompletion.return_value = mock_model_response
data = CreateExecutionRequest(input={"test": "input"})

task = create_task(
developer_id=developer_id,
agent_id=agent.id,
data=CreateTaskRequest(
**{
"name": "test task",
"description": "test task about",
"input_schema": {"type": "object", "additionalProperties": True},
"main": [
{
"prompt": "$_ [{'role': 'user', 'content': _.test}]",
"settings": {},
},
],
}
),
client=client,
)

async with patch_testing_temporal() as (_, mock_run_task_execution_workflow):
execution, handle = await start_execution(
developer_id=developer_id,
task_id=task.id,
data=data,
client=client,
)

assert handle is not None
assert execution.task_id == task.id
assert execution.input == data.input

mock_run_task_execution_workflow.assert_called_once()

result = await handle.result()
result = result["choices"][0]["message"]
assert result["content"] == "Hello, world!"
assert result["role"] == "assistant"


@test("workflow: prompt step")
async def _(
client=cozo_client,
Expand Down
2 changes: 1 addition & 1 deletion typespec/tasks/steps.tsp
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ model PromptStepDef {
prompt: JinjaTemplate | InputChatMLMessage<JinjaTemplate>[];

/** The tools to use for the prompt */
tools: "all" | (ToolRef | CreateToolRequest)[] = #[];
tools: "all" | (ToolRef | CreateToolRequest)[] = "all";

/** The tool choice for the prompt */
tool_choice?: ToolChoiceOption;
Expand Down

0 comments on commit d13724b

Please sign in to comment.