Skip to content

Commit

Permalink
fix(agents-api): updated the prompt step
Browse files Browse the repository at this point in the history
  • Loading branch information
Vedantsahai18 committed Nov 22, 2024
1 parent bea2681 commit dbe6ae9
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 128 deletions.
168 changes: 42 additions & 126 deletions agents-api/agents_api/activities/task_steps/prompt_step.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
from datetime import datetime
from typing import Callable

from anthropic import AsyncAnthropic # Import AsyncAnthropic client
from anthropic.types.beta.beta_message import BetaMessage
from beartype import beartype
from langchain_core.tools import BaseTool
from langchain_core.tools.convert import tool as tool_decorator
from litellm import ChatCompletionMessageToolCall, Function, Message
from litellm.types.utils import Choices, ModelResponse
from litellm.types.utils import ModelResponse
from temporalio import activity
from temporalio.exceptions import ApplicationError

Expand All @@ -18,30 +15,13 @@
from ...common.protocol.tasks import StepContext, StepOutcome
from ...common.storage_handler import auto_blob_store
from ...common.utils.template import render_template
from ...env import anthropic_api_key, debug
from ..utils import get_handler_with_filtered_params
from .base_evaluate import base_evaluate

COMPUTER_USE_BETA_FLAG = "computer-use-2024-10-22"


def format_tool(tool: Tool) -> dict:
# FIXME: Wrong format for computer_20241022 for litellm
if tool.type == "computer_20241022":
return {
"type": tool.type,
"name": tool.name,
"display_width_px": tool.computer_20241022
and tool.computer_20241022.display_width_px,
"display_height_px": tool.computer_20241022
and tool.computer_20241022.display_height_px,
"display_number": tool.computer_20241022
and tool.computer_20241022.display_number,
}

if tool.type in ["bash_20241022", "text_editor_20241022"]:
return tool.model_dump(include={"type", "name"})

if tool.type == "function":
return {
"type": "function",
Expand Down Expand Up @@ -163,114 +143,50 @@ async def prompt_step(context: StepContext) -> StepOutcome:
for fmt_tool, orig_tool in zip(formatted_tools, context.tools)
}

# Check if the model is Anthropic
if agent_model.lower().startswith("claude-3.5") and any(
tool["type"] in ["computer_20241022", "bash_20241022", "text_editor_20241022"]
for tool in formatted_tools
):
# Retrieve the API key from the environment variable
betas = [COMPUTER_USE_BETA_FLAG]
# Use Anthropic API directly
client = AsyncAnthropic(api_key=anthropic_api_key)

# Reformat the prompt for Anthropic
# Anthropic expects a list of messages with role and content (and no name etc)
prompt = [{"role": "user", "content": message["content"]} for message in prompt]

# Filter tools for specific types
filtered_tools = [
tool
for tool in formatted_tools
if tool["type"]
in ["computer_20241022", "bash_20241022", "text_editor_20241022"]
]

# Claude Response
claude_response: BetaMessage = await client.beta.messages.create(
model="claude-3-5-sonnet-20241022",
messages=prompt,
tools=filtered_tools,
max_tokens=1024,
betas=betas,
)

# Claude returns [ToolUse | TextBlock]
# We need to convert tool_use to tool_calls
# And set content = TextBlock.text
# But we need to ensure no more than one text block is returned
if (
len([block for block in claude_response.content if block.type == "text"])
> 1
):
raise ApplicationError("Claude should only return one message")

text_block = next(
(block for block in claude_response.content if block.type == "text"),
None,
)

stop_reason = claude_response.stop_reason

if stop_reason == "tool_use":
choice = Choices(
message=Message(
role="assistant",
content=text_block.text if text_block else None,
tool_calls=[
ChatCompletionMessageToolCall(
type="function",
function=Function(
name=block.name,
arguments=block.input,
),
)
for block in claude_response.content
if block.type == "tool_use"
],
),
finish_reason="tool_calls",
)
else:
assert (
text_block
), "Claude should always return a text block for stop_reason=stop"

choice = Choices(
message=Message(
role="assistant",
content=text_block.text,
),
finish_reason="stop",
)

response: ModelResponse = ModelResponse(
id=claude_response.id,
choices=[choice],
created=int(datetime.now().timestamp()),
model=claude_response.model,
object="text_completion",
)

else:
# FIXME: hardcoded tool to a None value as the tool calls are not implemented yet
# Check if using Claude model and has specific tool types
is_claude_model = agent_model.lower().startswith("claude-3.5")

# FIXME: Hack to make the computer use tools compatible with litellm
# Issue was: litellm expects type to be `computer_20241022` and spec to be
# `function` (see: https://docs.litellm.ai/docs/providers/anthropic#computer-tools)
# but we don't allow that (spec should match type).
formatted_tools = []
for i, tool in enumerate(context.tools):
if tool.type == "computer_20241022" and tool.computer_20241022:
function = tool.computer_20241022
tool = {
"type": tool.type,
"function": {
"name": tool.name,
"parameters": {
k: v
for k, v in function.model_dump().items()
if k not in ["name", "type"]
},
},
}
formatted_tools.append(tool)

if not is_claude_model:
formatted_tools = None
# Use litellm for other models
completion_data: dict = {
"model": agent_model,
"tools": formatted_tools or None,
"messages": prompt,
**agent_default_settings,
**passed_settings,
}

extra_body = {
"cache": {"no-cache": debug or context.current_step.disable_cache},
}
# Use litellm for other models
completion_data: dict = {
"model": agent_model,
"tools": formatted_tools or None,
"messages": prompt,
**agent_default_settings,
**passed_settings,
}

response: ModelResponse = await litellm.acompletion(
**completion_data,
extra_body=extra_body,
)
extra_body = {
"cache": {"no-cache": debug or context.current_step.disable_cache},
}

response: ModelResponse = await litellm.acompletion(
**completion_data,
extra_body=extra_body,
)

if context.current_step.unwrap:
if len(response.choices) != 1:
Expand Down
1 change: 0 additions & 1 deletion agents-api/agents_api/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
hostname: str = env.str("AGENTS_API_HOSTNAME", default="localhost")
public_port: int = env.int("AGENTS_API_PUBLIC_PORT", default=80)
api_prefix: str = env.str("AGENTS_API_PREFIX", default="")
anthropic_api_key: str = env.str("ANTHROPIC_API_KEY", default=None)

# Tasks
# -----
Expand Down
1 change: 0 additions & 1 deletion agents-api/docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ x--shared-environment: &shared-environment
S3_ENDPOINT: ${S3_ENDPOINT:-http://seaweedfs:8333}
S3_ACCESS_KEY: ${S3_ACCESS_KEY}
S3_SECRET_KEY: ${S3_SECRET_KEY}
ANTHROPIC_API_KEY: ${ANTHROPIC_API_KEY}

x--base-agents-api: &base-agents-api
image: julepai/agents-api:${TAG:-dev}
Expand Down

0 comments on commit dbe6ae9

Please sign in to comment.