Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(agents-api): Add python expression support to prompt step #795

Merged
merged 7 commits into from
Oct 31, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion agents-api/agents_api/activities/execute_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
VectorDocSearchRequest,
)
from ..autogen.Tools import SystemDef
from ..common.protocol.developers import Developer
from ..common.protocol.tasks import StepContext
from ..common.storage_handler import auto_blob_store
from ..env import testing
Expand Down
55 changes: 42 additions & 13 deletions agents-api/agents_api/activities/task_steps/prompt_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from ...common.utils.template import render_template
from ...env import anthropic_api_key, debug
from ..utils import get_handler
from .base_evaluate import base_evaluate

COMPUTER_USE_BETA_FLAG = "computer-use-2024-10-22"

Expand Down Expand Up @@ -77,43 +78,67 @@ def format_tool(tool: Tool) -> dict:
return formatted


EVAL_PROMPT_PREFIX = "$_ "


@activity.defn
@auto_blob_store
@beartype
async def prompt_step(context: StepContext) -> StepOutcome:
# Get context data
prompt: str | list[dict] = context.current_step.model_dump()["prompt"]
context_data: dict = context.model_dump()
context_data: dict = context.model_dump(include_remote=True)

# Render template messages
prompt = await render_template(
prompt,
context_data,
skip_vars=["developer_id"],
# If the prompt is a string and starts with $_ then we need to evaluate it
should_evaluate_prompt = isinstance(prompt, str) and prompt.startswith(
EVAL_PROMPT_PREFIX
)

if should_evaluate_prompt:
prompt = await base_evaluate(
prompt[len(EVAL_PROMPT_PREFIX) :].strip(), context_data
)

if not isinstance(prompt, (str, list)):
raise ApplicationError(
"Invalid prompt expression, expected a string or list"
)

# Wrap the prompt in a list if it is not already
prompt = (
prompt if isinstance(prompt, list) else [{"role": "user", "content": prompt}]
)

# Render template messages if we didn't evaluate the prompt
if not should_evaluate_prompt:
# Render template messages
prompt = await render_template(
prompt,
context_data,
skip_vars=["developer_id"],
)

# Get settings and run llm
agent_default_settings: dict = (
context.execution_input.agent.default_settings.model_dump()
if context.execution_input.agent.default_settings
else {}
)

agent_model: str = (
context.execution_input.agent.model
if context.execution_input.agent.model
else "gpt-4o"
)

# Get passed settings
if context.current_step.settings:
passed_settings: dict = context.current_step.settings.model_dump(
exclude_unset=True
)
else:
passed_settings: dict = {}

# Wrap the prompt in a list if it is not already
if isinstance(prompt, str):
prompt = [{"role": "user", "content": prompt}]

# Format tools for litellm
formatted_tools = [format_tool(tool) for tool in context.tools]

Expand All @@ -132,11 +157,15 @@ async def prompt_step(context: StepContext) -> StepOutcome:
betas = [COMPUTER_USE_BETA_FLAG]
# Use Anthropic API directly
client = AsyncAnthropic(api_key=anthropic_api_key)
new_prompt = [{"role": "user", "content": prompt[0]["content"]}]

# Reformat the prompt for Anthropic
# Anthropic expects a list of messages with role and content (and no name etc)
prompt = [{"role": "user", "content": message["content"]} for message in prompt]

# Claude Response
claude_response: BetaMessage = await client.beta.messages.create(
model="claude-3-5-sonnet-20241022",
messages=new_prompt,
messages=prompt,
tools=formatted_tools,
max_tokens=1024,
betas=betas,
Expand Down Expand Up @@ -210,7 +239,7 @@ async def prompt_step(context: StepContext) -> StepOutcome:
}

extra_body = {
"cache": {"no-cache": debug},
"cache": {"no-cache": debug or context.current_step.disable_cache},
}

response: ModelResponse = await litellm.acompletion(
Expand Down
40 changes: 15 additions & 25 deletions agents-api/agents_api/autogen/Sessions.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,10 @@ class CreateSessionRequest(BaseModel):
"""
Action to start on context window overflow
"""
forward_tool_results: StrictBool | None = None
auto_run_tools: StrictBool = False
"""
Whether to forward the tool results to the model when available.
"true" => always forward
"false" => never forward
null => forward if applicable (default)
Whether to auto-run the tool and send the tool results to the model when available.
(default: false for sessions, true for tasks)

If a tool call is made, the tool's output will be sent back to the model as the model's input.
If a tool call is not made, the model's output will be returned as is.
Expand Down Expand Up @@ -80,12 +78,10 @@ class PatchSessionRequest(BaseModel):
"""
Action to start on context window overflow
"""
forward_tool_results: StrictBool | None = None
auto_run_tools: StrictBool = False
"""
Whether to forward the tool results to the model when available.
"true" => always forward
"false" => never forward
null => forward if applicable (default)
Whether to auto-run the tool and send the tool results to the model when available.
(default: false for sessions, true for tasks)

If a tool call is made, the tool's output will be sent back to the model as the model's input.
If a tool call is not made, the model's output will be returned as is.
Expand Down Expand Up @@ -117,12 +113,10 @@ class Session(BaseModel):
"""
Action to start on context window overflow
"""
forward_tool_results: StrictBool | None = None
auto_run_tools: StrictBool = False
"""
Whether to forward the tool results to the model when available.
"true" => always forward
"false" => never forward
null => forward if applicable (default)
Whether to auto-run the tool and send the tool results to the model when available.
(default: false for sessions, true for tasks)

If a tool call is made, the tool's output will be sent back to the model as the model's input.
If a tool call is not made, the model's output will be returned as is.
Expand Down Expand Up @@ -190,12 +184,10 @@ class UpdateSessionRequest(BaseModel):
"""
Action to start on context window overflow
"""
forward_tool_results: StrictBool | None = None
auto_run_tools: StrictBool = False
"""
Whether to forward the tool results to the model when available.
"true" => always forward
"false" => never forward
null => forward if applicable (default)
Whether to auto-run the tool and send the tool results to the model when available.
(default: false for sessions, true for tasks)

If a tool call is made, the tool's output will be sent back to the model as the model's input.
If a tool call is not made, the model's output will be returned as is.
Expand Down Expand Up @@ -234,12 +226,10 @@ class CreateOrUpdateSessionRequest(CreateSessionRequest):
"""
Action to start on context window overflow
"""
forward_tool_results: StrictBool | None = None
auto_run_tools: StrictBool = False
"""
Whether to forward the tool results to the model when available.
"true" => always forward
"false" => never forward
null => forward if applicable (default)
Whether to auto-run the tool and send the tool results to the model when available.
(default: false for sessions, true for tasks)

If a tool call is made, the tool's output will be sent back to the model as the model's input.
If a tool call is not made, the model's output will be returned as is.
Expand Down
28 changes: 16 additions & 12 deletions agents-api/agents_api/autogen/Tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -686,7 +686,7 @@ class PromptStep(BaseModel):
"""
The prompt to run
"""
tools: Literal["all"] | list[ToolRef | CreateToolRequest] = []
tools: Literal["all"] | list[ToolRef | CreateToolRequest] = "all"
"""
The tools to use for the prompt
"""
Expand All @@ -702,16 +702,18 @@ class PromptStep(BaseModel):
"""
Whether to unwrap the output of the prompt step, equivalent to `response.choices[0].message.content`
"""
forward_tool_results: StrictBool | None = None
auto_run_tools: StrictBool = True
"""
Whether to forward the tool results to the model when available.
"true" => always forward
"false" => never forward
null => forward if applicable (default)
Whether to auto-run the tool and send the tool results to the model when available.
(default: true for prompt steps, false for sessions)

If a tool call is made, the tool's output will be used as the model's input.
If a tool call is not made, the model's output will be used as the next step's input.
"""
disable_cache: StrictBool = False
"""
Whether to disable caching for the prompt step
"""


class PromptStepUpdateItem(BaseModel):
Expand All @@ -730,7 +732,7 @@ class PromptStepUpdateItem(BaseModel):
"""
The prompt to run
"""
tools: Literal["all"] | list[ToolRefUpdateItem | CreateToolRequest] = []
tools: Literal["all"] | list[ToolRefUpdateItem | CreateToolRequest] = "all"
"""
The tools to use for the prompt
"""
Expand All @@ -746,16 +748,18 @@ class PromptStepUpdateItem(BaseModel):
"""
Whether to unwrap the output of the prompt step, equivalent to `response.choices[0].message.content`
"""
forward_tool_results: StrictBool | None = None
auto_run_tools: StrictBool = True
"""
Whether to forward the tool results to the model when available.
"true" => always forward
"false" => never forward
null => forward if applicable (default)
Whether to auto-run the tool and send the tool results to the model when available.
(default: true for prompt steps, false for sessions)

If a tool call is made, the tool's output will be used as the model's input.
If a tool call is not made, the model's output will be used as the next step's input.
"""
disable_cache: StrictBool = False
"""
Whether to disable caching for the prompt step
"""


class ReturnStep(BaseModel):
Expand Down
1 change: 0 additions & 1 deletion agents-api/agents_api/autogen/Tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
BaseModel,
ConfigDict,
Field,
RootModel,
StrictBool,
)

Expand Down
18 changes: 17 additions & 1 deletion agents-api/agents_api/common/protocol/tasks.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from typing import Annotated, Any
from typing import Annotated, Any, Literal
from uuid import UUID

from beartype import beartype
from temporalio import activity, workflow
from temporalio.exceptions import ApplicationError

with workflow.unsafe.imports_passed_through():
from pydantic import BaseModel, Field, computed_field
Expand All @@ -23,6 +24,7 @@
TaskSpecDef,
TaskToolDef,
Tool,
ToolRef,
TransitionTarget,
TransitionType,
UpdateTaskRequest,
Expand Down Expand Up @@ -154,6 +156,20 @@ def tools(self) -> list[Tool | CreateToolRequest]:
task = execution_input.task
agent_tools = execution_input.agent_tools

step_tools: Literal["all"] | list[ToolRef | CreateToolRequest] = getattr(
self.current_step, "tools", "all"
)

if step_tools != "all":
if not all(
tool and isinstance(tool, CreateToolRequest) for tool in step_tools
):
raise ApplicationError(
"Invalid tools for step (ToolRef not supported yet)"
)

return step_tools

# Need to convert task.tools (list[TaskToolDef]) to list[Tool]
task_tools = []
for tool in task.tools:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def create_or_update_session(
data: CreateOrUpdateSessionRequest,
) -> tuple[list[str], dict]:
data.metadata = data.metadata or {}
session_data = data.model_dump()
session_data = data.model_dump(exclude={"auto_run_tools", "disable_cache"})

user = session_data.pop("user")
agent = session_data.pop("agent")
Expand Down
2 changes: 1 addition & 1 deletion agents-api/agents_api/models/session/create_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def create_session(
session_id = session_id or uuid4()

data.metadata = data.metadata or {}
session_data = data.model_dump()
session_data = data.model_dump(exclude={"auto_run_tools", "disable_cache"})

user = session_data.pop("user")
agent = session_data.pop("agent")
Expand Down
Loading
Loading