Skip to content

Commit

Permalink
feat(agents-api): claude response bug fix in prompt step
Browse files Browse the repository at this point in the history
  • Loading branch information
Vedantsahai18 committed Oct 31, 2024
1 parent 60ab65b commit 5867ecb
Showing 1 changed file with 67 additions and 17 deletions.
84 changes: 67 additions & 17 deletions agents-api/agents_api/activities/task_steps/prompt_step.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import os
from datetime import datetime

from anthropic import Anthropic, AsyncAnthropic # Import AsyncAnthropic client
from anthropic.types.beta.beta_message import BetaMessage
from beartype import beartype
from litellm import ChatCompletionMessageToolCall, Function, Message
from litellm.types.utils import Choices, ModelResponse
from litellm.utils import CustomStreamWrapper, ModelResponse
from temporalio import activity
from temporalio.exceptions import ApplicationError
Expand All @@ -27,7 +31,7 @@ def format_agent_tool(tool: Tool) -> dict:
"function": {
"name": tool.name,
"description": tool.description,
"parameters": tool.function.parameters,
"parameters": tool.parameters,
},
}
elif tool.type == "computer_20241022":
Expand Down Expand Up @@ -122,23 +126,71 @@ async def prompt_step(context: StepContext) -> StepOutcome:
client = AsyncAnthropic(api_key=anthropic_api_key)
new_prompt = [{"role": "user", "content": prompt[0]["content"]}]
# Claude Response
response = await client.beta.messages.create(
claude_response: BetaMessage = await client.beta.messages.create(
model="claude-3-5-sonnet-20241022",
messages=new_prompt,
tools=formatted_agent_tools,
max_tokens=1024,
betas=betas,
)
print("-" * 100)
# print("Model Response: ", model_response)
print("Response: ", response)

return StepOutcome(
output=response.model_dump()
if hasattr(response, "model_dump")
else response,
next=None,

# Claude returns [ToolUse | TextBlock]
# We need to convert tool_use to tool_calls
# And set content = TextBlock.text
# But we need to ensure no more than one text block is returned
if (
len([block for block in claude_response.content if block.type == "text"])
> 1
):
raise ApplicationError("Claude should only return one message")

text_block = next(
(block for block in claude_response.content if block.type == "text"),
None,
)

stop_reason = claude_response.stop_reason

if stop_reason == "tool_use":
choice = Choices(
message=Message(
role="assistant",
content=text_block.text if text_block else None,
tool_calls=[
ChatCompletionMessageToolCall(
type="function",
function=Function(
name=block.name,
arguments=block.input,
),
)
for block in claude_response.content
if block.type == "tool_use"
],
),
finish_reason="tool_calls",
)
else:
assert (
text_block
), "Claude should always return a text block for stop_reason=stop"

choice = Choices(
message=Message(
role="assistant",
content=text_block.text,
),
finish_reason="stop",
)

response: ModelResponse = ModelResponse(
id=claude_response.id,
choices=[choice],
created=int(datetime.now().timestamp()),
model=claude_response.model,
object="text_completion",
)

else:
# Use litellm for other models
completion_data: dict = {
Expand All @@ -164,9 +216,7 @@ async def prompt_step(context: StepContext) -> StepOutcome:

response = response.choices[0].message.content

return StepOutcome(
output=response.model_dump()
if hasattr(response, "model_dump")
else response,
next=None,
)
return StepOutcome(
output=response.model_dump() if hasattr(response, "model_dump") else response,
next=None,
)

0 comments on commit 5867ecb

Please sign in to comment.