Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(agents-api): Add support for claude computer-use #787

Merged
merged 4 commits into from
Oct 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions agents-api/agents_api/activities/execute_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,8 +225,9 @@ async def execute_system(
return create_session_query(**arguments)
elif system.operation == "update":
return update_session_query(**arguments)
elif system.operation == "delete":
return update_session_query(**arguments)
# FIXME: @hamada needs to be fixed
# elif system.operation == "delete":
# return update_session_query(**arguments)
elif system.operation == "delete":
return delete_session_query(**arguments)
# TASKS
Expand Down
124 changes: 94 additions & 30 deletions agents-api/agents_api/activities/task_steps/prompt_step.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
import os
from datetime import datetime

from anthropic import Anthropic # Import Anthropic client
from anthropic import Anthropic, AsyncAnthropic # Import AsyncAnthropic client
from anthropic.types.beta.beta_message import BetaMessage
from beartype import beartype
from litellm import ChatCompletionMessageToolCall, Function, Message
from litellm.types.utils import Choices, ModelResponse
from litellm.utils import CustomStreamWrapper, ModelResponse
from temporalio import activity
from temporalio.exceptions import ApplicationError

Expand All @@ -20,29 +25,29 @@

# FIXME: This shouldn't be here.
def format_agent_tool(tool: Tool) -> dict:
if tool.function:
if tool.type == "function":
return {
"type": "function",
"function": {
"name": tool.name,
"description": tool.description,
"parameters": tool.function.parameters,
"parameters": tool.parameters,
},
}
elif tool.computer_20241022:
elif tool.type == "computer_20241022":
return {
"type": tool.type,
"name": tool.name,
"display_width_px": tool.display_width_px,
"display_height_px": tool.display_height_px,
"display_number": tool.display_number,
"display_width_px": tool.spec["display_width_px"],
"display_height_px": tool.spec["display_height_px"],
"display_number": tool.spec["display_number"],
}
elif tool.bash_20241022:
elif tool.type == "bash_20241022":
return {
"type": tool.type,
"name": tool.name,
}
elif tool.text_editor_20241022:
elif tool.type == "text_editor_20241022":
return {
"type": tool.type,
"name": tool.name,
Expand Down Expand Up @@ -86,7 +91,7 @@ async def prompt_step(context: StepContext) -> StepOutcome:
sort_by="created_at",
direction="desc",
)

# grab the tools from context.current_step.tools and then append it to the agent_tools
if context.current_step.settings:
passed_settings: dict = context.current_step.settings.model_dump(
exclude_unset=True
Expand All @@ -99,32 +104,93 @@ async def prompt_step(context: StepContext) -> StepOutcome:
prompt = [{"role": "user", "content": prompt}]

# Format agent_tools for litellm
formatted_agent_tools = [
format_agent_tool(tool) for tool in agent_tools if format_agent_tool(tool)
]
# Initialize the formatted_agent_tools with context tools
task_tools = context.tools
formatted_agent_tools = [format_agent_tool(tool) for tool in task_tools]
# Add agent_tools if they are not already in formatted_agent_tools
for agent_tool in agent_tools:
if (
format_agent_tool(agent_tool)
and format_agent_tool(agent_tool) not in formatted_agent_tools
):
formatted_agent_tools.append(format_agent_tool(agent_tool))

# Check if the model is Anthropic
if "claude-3.5-sonnet-20241022" == agent_model.lower():
if agent_model.lower().startswith("claude-3.5") and any(
tool["type"] in ["computer_20241022", "bash_20241022", "text_editor_20241022"]
for tool in formatted_agent_tools
):
# Retrieve the API key from the environment variable
betas = [COMPUTER_USE_BETA_FLAG]
# Use Anthropic API directly
client = Anthropic(api_key=anthropic_api_key)

client = AsyncAnthropic(api_key=anthropic_api_key)
new_prompt = [{"role": "user", "content": prompt[0]["content"]}]
# Claude Response
response = await client.beta.messages.create(
model=agent_model,
messages=prompt,
claude_response: BetaMessage = await client.beta.messages.create(
model="claude-3-5-sonnet-20241022",
messages=new_prompt,
tools=formatted_agent_tools,
max_tokens=1024,
betas=betas,
)

return StepOutcome(
output=response.model_dump()
if hasattr(response, "model_dump")
else response,
next=None,
# Claude returns [ToolUse | TextBlock]
# We need to convert tool_use to tool_calls
# And set content = TextBlock.text
# But we need to ensure no more than one text block is returned
if (
len([block for block in claude_response.content if block.type == "text"])
> 1
):
raise ApplicationError("Claude should only return one message")

text_block = next(
(block for block in claude_response.content if block.type == "text"),
None,
)

stop_reason = claude_response.stop_reason

if stop_reason == "tool_use":
choice = Choices(
message=Message(
role="assistant",
content=text_block.text if text_block else None,
tool_calls=[
ChatCompletionMessageToolCall(
type="function",
function=Function(
name=block.name,
arguments=block.input,
),
)
for block in claude_response.content
if block.type == "tool_use"
],
),
finish_reason="tool_calls",
)
else:
assert (
text_block
), "Claude should always return a text block for stop_reason=stop"

choice = Choices(
message=Message(
role="assistant",
content=text_block.text,
),
finish_reason="stop",
)

response: ModelResponse = ModelResponse(
id=claude_response.id,
choices=[choice],
created=int(datetime.now().timestamp()),
model=claude_response.model,
object="text_completion",
)

else:
# Use litellm for other models
completion_data: dict = {
Expand All @@ -150,9 +216,7 @@ async def prompt_step(context: StepContext) -> StepOutcome:

response = response.choices[0].message.content

return StepOutcome(
output=response.model_dump()
if hasattr(response, "model_dump")
else response,
next=None,
)
return StepOutcome(
output=response.model_dump() if hasattr(response, "model_dump") else response,
next=None,
)
51 changes: 51 additions & 0 deletions agents-api/agents_api/autogen/Tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,6 +687,18 @@ class CreateToolRequest(BaseModel):
"""
Name of the tool (must be unique for this agent and a valid python identifier string )
"""
type: Literal[
"function",
"integration",
"system",
"api_call",
"computer_20241022",
"text_editor_20241022",
"bash_20241022",
]
"""
Type of the tool
"""
description: str | None = None
"""
Description of the tool
Expand Down Expand Up @@ -953,6 +965,21 @@ class PatchToolRequest(BaseModel):
"""
Name of the tool (must be unique for this agent and a valid python identifier string )
"""
type: (
Literal[
"function",
"integration",
"system",
"api_call",
"computer_20241022",
"text_editor_20241022",
"bash_20241022",
]
| None
) = None
"""
Type of the tool
"""
description: str | None = None
"""
Description of the tool
Expand Down Expand Up @@ -1380,6 +1407,18 @@ class Tool(BaseModel):
"""
Name of the tool (must be unique for this agent and a valid python identifier string )
"""
type: Literal[
"function",
"integration",
"system",
"api_call",
"computer_20241022",
"text_editor_20241022",
"bash_20241022",
]
"""
Type of the tool
"""
description: str | None = None
"""
Description of the tool
Expand Down Expand Up @@ -1457,6 +1496,18 @@ class UpdateToolRequest(BaseModel):
"""
Name of the tool (must be unique for this agent and a valid python identifier string )
"""
type: Literal[
"function",
"integration",
"system",
"api_call",
"computer_20241022",
"text_editor_20241022",
"bash_20241022",
]
"""
Type of the tool
"""
description: str | None = None
"""
Description of the tool
Expand Down
11 changes: 5 additions & 6 deletions agents-api/agents_api/common/protocol/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,13 +235,12 @@ def task_to_spec(
for tool in task.tools:
tool_spec = getattr(tool, tool.type)

tools.append(
TaskToolDef(
type=tool.type,
spec=tool_spec.model_dump(),
**tool.model_dump(exclude={"type"}),
)
tool_obj = dict(
type=tool.type,
spec=tool_spec.model_dump(),
**tool.model_dump(exclude={"type"}),
)
tools.append(TaskToolDef(**tool_obj))

workflows = [Workflow(name="main", steps=task_data.pop("main"))]

Expand Down
51 changes: 51 additions & 0 deletions integrations-service/integrations/autogen/Tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,6 +687,18 @@ class CreateToolRequest(BaseModel):
"""
Name of the tool (must be unique for this agent and a valid python identifier string )
"""
type: Literal[
"function",
"integration",
"system",
"api_call",
"computer_20241022",
"text_editor_20241022",
"bash_20241022",
]
"""
Type of the tool
"""
description: str | None = None
"""
Description of the tool
Expand Down Expand Up @@ -953,6 +965,21 @@ class PatchToolRequest(BaseModel):
"""
Name of the tool (must be unique for this agent and a valid python identifier string )
"""
type: (
Literal[
"function",
"integration",
"system",
"api_call",
"computer_20241022",
"text_editor_20241022",
"bash_20241022",
]
| None
) = None
"""
Type of the tool
"""
description: str | None = None
"""
Description of the tool
Expand Down Expand Up @@ -1380,6 +1407,18 @@ class Tool(BaseModel):
"""
Name of the tool (must be unique for this agent and a valid python identifier string )
"""
type: Literal[
"function",
"integration",
"system",
"api_call",
"computer_20241022",
"text_editor_20241022",
"bash_20241022",
]
"""
Type of the tool
"""
description: str | None = None
"""
Description of the tool
Expand Down Expand Up @@ -1457,6 +1496,18 @@ class UpdateToolRequest(BaseModel):
"""
Name of the tool (must be unique for this agent and a valid python identifier string )
"""
type: Literal[
"function",
"integration",
"system",
"api_call",
"computer_20241022",
"text_editor_20241022",
"bash_20241022",
]
"""
Type of the tool
"""
description: str | None = None
"""
Description of the tool
Expand Down
3 changes: 3 additions & 0 deletions typespec/tools/models.tsp
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,9 @@ model Tool {
/** Name of the tool (must be unique for this agent and a valid python identifier string )*/
name: validPythonIdentifier;

/** Type of the tool */
type: ToolType;

/** Description of the tool */
description?: string;

Expand Down
Loading
Loading