Skip to content

Commit

Permalink
fix(agents-api): Fix tool conversion issues
Browse files Browse the repository at this point in the history
Signed-off-by: Diwank Singh Tomer <[email protected]>
  • Loading branch information
creatorrr committed Oct 31, 2024
1 parent 6e9e8aa commit c281c8c
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 21 deletions.
2 changes: 1 addition & 1 deletion agents-api/agents_api/activities/execute_system.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import asyncio
from typing import Any, Callable
from typing import Any
from uuid import UUID

from beartype import beartype
Expand Down
66 changes: 52 additions & 14 deletions agents-api/agents_api/activities/task_steps/prompt_step.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
import os
from datetime import datetime

from anthropic import Anthropic, AsyncAnthropic # Import AsyncAnthropic client
from anthropic import AsyncAnthropic # Import AsyncAnthropic client
from anthropic.types.beta.beta_message import BetaMessage
from beartype import beartype
from langchain_core.tools import BaseTool
from langchain_core.tools.convert import tool as tool_decorator
from litellm import ChatCompletionMessageToolCall, Function, Message
from litellm.types.utils import Choices, ModelResponse
from litellm.utils import CustomStreamWrapper, ModelResponse
from temporalio import activity
from temporalio.exceptions import ApplicationError

Expand All @@ -20,7 +18,6 @@
from ...common.storage_handler import auto_blob_store
from ...common.utils.template import render_template
from ...env import anthropic_api_key, debug
from ...models.tools.list_tools import list_tools
from ..utils import get_handler

COMPUTER_USE_BETA_FLAG = "computer-use-2024-10-22"
Expand All @@ -31,9 +28,12 @@ def format_tool(tool: Tool) -> dict:
return {
"type": tool.type,
"name": tool.name,
"display_width_px": tool.computer_20241022.display_width_px,
"display_height_px": tool.computer_20241022.display_height_px,
"display_number": tool.computer_20241022.display_number,
"display_width_px": tool.computer_20241022
and tool.computer_20241022.display_width_px,
"display_height_px": tool.computer_20241022
and tool.computer_20241022.display_height_px,
"display_number": tool.computer_20241022
and tool.computer_20241022.display_number,
}

if tool.type in ["bash_20241022", "text_editor_20241022"]:
Expand All @@ -45,7 +45,7 @@ def format_tool(tool: Tool) -> dict:
"function": {
"name": tool.name,
"description": tool.description,
"parameters": tool.function.parameters,
"parameters": tool.function and tool.function.parameters,
},
}

Expand Down Expand Up @@ -117,6 +117,12 @@ async def prompt_step(context: StepContext) -> StepOutcome:
# Format tools for litellm
formatted_tools = [format_tool(tool) for tool in context.tools]

# Map tools to their original objects
tools_mapping: dict[str, Tool] = {
fmt_tool["function"]["name"]: orig_tool
for fmt_tool, orig_tool in zip(formatted_tools, context.tools)
}

# Check if the model is Anthropic
if agent_model.lower().startswith("claude-3.5") and any(
tool["type"] in ["computer_20241022", "bash_20241022", "text_editor_20241022"]
Expand Down Expand Up @@ -207,18 +213,50 @@ async def prompt_step(context: StepContext) -> StepOutcome:
"cache": {"no-cache": debug},
}

response = await litellm.acompletion(
response: ModelResponse = await litellm.acompletion(
**completion_data,
extra_body=extra_body,
)

if context.current_step.unwrap:
if response.choices[0].finish_reason == "tool_calls":
raise ApplicationError("Tool calls cannot be unwrapped")
if context.current_step.unwrap:
if len(response.choices) != 1:
raise ApplicationError("Only one choice is supported")

choice = response.choices[0]
if choice.finish_reason == "tool_calls":
raise ApplicationError("Tool calls cannot be unwrapped")

return StepOutcome(
output=choice.message.content,
next=None,
)

# Re-convert tool-calls to integration/system calls if needed
response_as_dict = response.model_dump()

for choice in response_as_dict["choices"]:
if choice["finish_reason"] == "tool_calls":
calls = choice["message"]["tool_calls"]

for call in calls:
call_name = call["function"]["name"]
call_args = call["function"]["arguments"]

original_tool = tools_mapping.get(call_name)
if not original_tool:
raise ApplicationError(f"Tool {call_name} not found")

if original_tool.type == "function":
continue

response = response.choices[0].message.content
call.pop("function")
call["type"] = original_tool.type
call[original_tool.type] = {
"name": call_name,
"arguments": call_args,
}

return StepOutcome(
output=response.model_dump() if hasattr(response, "model_dump") else response,
output=response_as_dict,
next=None,
)
1 change: 0 additions & 1 deletion agents-api/agents_api/autogen/Tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
BaseModel,
ConfigDict,
Field,
RootModel,
StrictBool,
)

Expand Down
4 changes: 0 additions & 4 deletions agents-api/agents_api/common/utils/template.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,12 @@
import json
import re
from typing import List, TypeVar

import arrow
import re2
from beartype import beartype
from jinja2.sandbox import ImmutableSandboxedEnvironment
from jinja2schema import infer, to_json_schema
from jsonschema import validate

from ...activities.utils import ALLOWED_FUNCTIONS, constants, stdlib
from . import yaml

__all__: List[str] = [
"render_template",
Expand Down
1 change: 0 additions & 1 deletion agents-api/agents_api/workflows/task_execution/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
SleepFor,
SleepStep,
SwitchStep,
TaskToolDef,
ToolCallStep,
TransitionTarget,
WaitForInputStep,
Expand Down

0 comments on commit c281c8c

Please sign in to comment.