Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refine Message model and encapsulate tool calls for AI provider integration #2

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions aisuite/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .client import Client
from .framework.message import Message
1 change: 1 addition & 0 deletions aisuite/framework/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .provider_interface import ProviderInterface
from .chat_completion_response import ChatCompletionResponse
from .message import Message
6 changes: 5 additions & 1 deletion aisuite/framework/choice.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
from aisuite.framework.message import Message
from typing import Literal, Optional


class Choice:
def __init__(self):
self.message = Message()
self.finish_reason: Optional[Literal["stop", "tool_calls"]] = None
self.message = Message(
content=None, tool_calls=None, role="assistant", refusal=None
)
22 changes: 19 additions & 3 deletions aisuite/framework/message.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,22 @@
"""Interface to hold contents of api responses when they do not conform to the OpenAI style response"""

from pydantic import BaseModel
from typing import Literal, Optional

class Message:
def __init__(self):
self.content = None

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Utilize enums to enforce constraints on 'role' and type fields, enhancing robustness and maintain type safety.

Use Enumerations to enforce valid values and enhance type safety in role and type fields.

Suggested change
+ content: Optional[str]
+ tool_calls: Optional[list[ChatCompletionMessageToolCall]]
+ role: Optional[Literal["user", "assistant", "system"]] = Field(..., description="Role of the message sender", enum=["user", "assistant", "system"])
+ refusal: Optional[str]

class Function(BaseModel):
arguments: str
name: str


class ChatCompletionMessageToolCall(BaseModel):
id: str
function: Function
type: Literal["function"]


class Message(BaseModel):
content: Optional[str]
tool_calls: Optional[list[ChatCompletionMessageToolCall]]
role: Optional[Literal["user", "assistant", "system"]]
refusal: Optional[str]
176 changes: 171 additions & 5 deletions aisuite/providers/anthropic_provider.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
import anthropic
import json
from aisuite.provider import Provider
from aisuite.framework import ChatCompletionResponse
from aisuite.framework.message import Message, ChatCompletionMessageToolCall, Function

# Define a constant for the default max_tokens value
DEFAULT_MAX_TOKENS = 4096

# Links:
# Tool calling docs - https://docs.anthropic.com/en/docs/build-with-claude/tool-use


class AnthropicProvider(Provider):
def __init__(self, **config):
Expand All @@ -27,14 +32,175 @@ def chat_completions_create(self, model, messages, **kwargs):
if "max_tokens" not in kwargs:
kwargs["max_tokens"] = DEFAULT_MAX_TOKENS

return self.normalize_response(
self.client.messages.create(
model=model, system=system_message, messages=messages, **kwargs
)
# Handle tool calls. Convert from OpenAI tool calls to Anthropic tool calls.
if "tools" in kwargs:
kwargs["tools"] = convert_openai_tools_to_anthropic(kwargs["tools"])

# Convert tool results from OpenAI format to Anthropic format
converted_messages = []
for msg in messages:
if isinstance(msg, dict):
if msg["role"] == "tool":
# Convert tool result message
converted_msg = {
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": msg["tool_call_id"],
"content": msg["content"],
}
],
}
converted_messages.append(converted_msg)
elif msg["role"] == "assistant" and "tool_calls" in msg:
# Handle assistant messages with tool calls
content = []
if msg.get("content"):
content.append({"type": "text", "text": msg["content"]})
for tool_call in msg["tool_calls"]:
content.append(
{
"type": "tool_use",
"id": tool_call["id"],
"name": tool_call["function"]["name"],
"input": json.loads(tool_call["function"]["arguments"]),
}
)
converted_messages.append({"role": "assistant", "content": content})
else:
# Keep other messages as is
converted_messages.append(
{"role": msg["role"], "content": msg["content"]}
)
else:
# Handle Message objects
if msg.role == "tool":
converted_msg = {
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": msg.tool_call_id,
"content": msg.content,
}
],
}
converted_messages.append(converted_msg)
elif msg.role == "assistant" and msg.tool_calls:
# Handle Message objects with tool calls
content = []
if msg.content:
content.append({"type": "text", "text": msg.content})
for tool_call in msg.tool_calls:
content.append(
{
"type": "tool_use",
"id": tool_call.id,
"name": tool_call.function.name,
"input": json.loads(tool_call.function.arguments),
}
)
converted_messages.append({"role": "assistant", "content": content})
else:
converted_messages.append(
{"role": msg.role, "content": msg.content}
)

print(converted_messages)
response = self.client.messages.create(
model=model, system=system_message, messages=converted_messages, **kwargs
)
print(response)
return self.normalize_response(response)

def normalize_response(self, response):
"""Normalize the response from the Anthropic API to match OpenAI's response format."""
normalized_response = ChatCompletionResponse()
normalized_response.choices[0].message.content = response.content[0].text

# Map Anthropic stop_reason to OpenAI finish_reason
finish_reason_mapping = {
"end_turn": "stop",
"max_tokens": "length",
"tool_use": "tool_calls",
# Add more mappings as needed
}
normalized_response.choices[0].finish_reason = finish_reason_mapping.get(
response.stop_reason, "stop"
)

# Add usage information
normalized_response.usage = {
"prompt_tokens": response.usage.input_tokens,
"completion_tokens": response.usage.output_tokens,
"total_tokens": response.usage.input_tokens + response.usage.output_tokens,
Comment on lines +133 to +136

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add null checks to tool call properties to avoid access errors.

Handling conversion for tools should include a null check before accessing the tool's name and arguments to prevent the application from crashing.

Suggested change
normalized_response.usage = {
"prompt_tokens": response.usage.input_tokens,
"completion_tokens": response.usage.output_tokens,
"total_tokens": response.usage.input_tokens + response.usage.output_tokens,
+ "type": "tool_use",
+ "id": tool_call["id"],
+ "name": tool_call.get("function", {}).get("name", ""),
+ "input": json.loads(tool_call.get("function", {}).get("arguments", "{}")),

}

# Check if the response contains tool usage
if response.stop_reason == "tool_use":
# Find the tool_use content
tool_call = next(
(content for content in response.content if content.type == "tool_use"),
None,
)

if tool_call:
function = Function(
name=tool_call.name, arguments=json.dumps(tool_call.input)
)
tool_call_obj = ChatCompletionMessageToolCall(
id=tool_call.id, function=function, type="function"
)
# Get the text content if any
text_content = next(
(
content.text
for content in response.content
if content.type == "text"
),
"",
)

message = Message(
content=text_content or None,
tool_calls=[tool_call_obj] if tool_call else None,
role="assistant",
refusal=None,
)
normalized_response.choices[0].message = message
return normalized_response

# Handle regular text response
message = Message(
content=response.content[0].text,
Comment on lines 32 to +175

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reduces code duplication and enhances maintainability by refactoring repetitive conversion code into a separate function.

Abstract the repetitive conversion logic into separate functions for clarity and reuse.

Suggested change
if "max_tokens" not in kwargs:
kwargs["max_tokens"] = DEFAULT_MAX_TOKENS
return self.normalize_response(
self.client.messages.create(
model=model, system=system_message, messages=messages, **kwargs
)
# Handle tool calls. Convert from OpenAI tool calls to Anthropic tool calls.
if "tools" in kwargs:
kwargs["tools"] = convert_openai_tools_to_anthropic(kwargs["tools"])
# Convert tool results from OpenAI format to Anthropic format
converted_messages = []
for msg in messages:
if isinstance(msg, dict):
if msg["role"] == "tool":
# Convert tool result message
converted_msg = {
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": msg["tool_call_id"],
"content": msg["content"],
}
],
}
converted_messages.append(converted_msg)
elif msg["role"] == "assistant" and "tool_calls" in msg:
# Handle assistant messages with tool calls
content = []
if msg.get("content"):
content.append({"type": "text", "text": msg["content"]})
for tool_call in msg["tool_calls"]:
content.append(
{
"type": "tool_use",
"id": tool_call["id"],
"name": tool_call["function"]["name"],
"input": json.loads(tool_call["function"]["arguments"]),
}
)
converted_messages.append({"role": "assistant", "content": content})
else:
# Keep other messages as is
converted_messages.append(
{"role": msg["role"], "content": msg["content"]}
)
else:
# Handle Message objects
if msg.role == "tool":
converted_msg = {
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": msg.tool_call_id,
"content": msg.content,
}
],
}
converted_messages.append(converted_msg)
elif msg.role == "assistant" and msg.tool_calls:
# Handle Message objects with tool calls
content = []
if msg.content:
content.append({"type": "text", "text": msg.content})
for tool_call in msg.tool_calls:
content.append(
{
"type": "tool_use",
"id": tool_call.id,
"name": tool_call.function.name,
"input": json.loads(tool_call.function.arguments),
}
)
converted_messages.append({"role": "assistant", "content": content})
else:
converted_messages.append(
{"role": msg.role, "content": msg.content}
)
print(converted_messages)
response = self.client.messages.create(
model=model, system=system_message, messages=converted_messages, **kwargs
)
print(response)
return self.normalize_response(response)
def normalize_response(self, response):
"""Normalize the response from the Anthropic API to match OpenAI's response format."""
normalized_response = ChatCompletionResponse()
normalized_response.choices[0].message.content = response.content[0].text
# Map Anthropic stop_reason to OpenAI finish_reason
finish_reason_mapping = {
"end_turn": "stop",
"max_tokens": "length",
"tool_use": "tool_calls",
# Add more mappings as needed
}
normalized_response.choices[0].finish_reason = finish_reason_mapping.get(
response.stop_reason, "stop"
)
# Add usage information
normalized_response.usage = {
"prompt_tokens": response.usage.input_tokens,
"completion_tokens": response.usage.output_tokens,
"total_tokens": response.usage.input_tokens + response.usage.output_tokens,
}
# Check if the response contains tool usage
if response.stop_reason == "tool_use":
# Find the tool_use content
tool_call = next(
(content for content in response.content if content.type == "tool_use"),
None,
)
if tool_call:
function = Function(
name=tool_call.name, arguments=json.dumps(tool_call.input)
)
tool_call_obj = ChatCompletionMessageToolCall(
id=tool_call.id, function=function, type="function"
)
# Get the text content if any
text_content = next(
(
content.text
for content in response.content
if content.type == "text"
),
"",
)
message = Message(
content=text_content or None,
tool_calls=[tool_call_obj] if tool_call else None,
role="assistant",
refusal=None,
)
normalized_response.choices[0].message = message
return normalized_response
# Handle regular text response
message = Message(
content=response.content[0].text,
+ converted_messages = self.refactor_conversion_logic(messages)

role="assistant",
tool_calls=None,
refusal=None,
)
normalized_response.choices[0].message = message
return normalized_response


def convert_openai_tools_to_anthropic(openai_tools):
anthropic_tools = []

for tool in openai_tools:
# Only handle function-type tools from OpenAI
if tool.get("type") != "function":
continue

function = tool["function"]

anthropic_tool = {
"name": function["name"],
"description": function["description"],
"input_schema": {
"type": "object",
"properties": function["parameters"]["properties"],
"required": function["parameters"].get("required", []),
},
}

anthropic_tools.append(anthropic_tool)

return anthropic_tools
Loading