Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Async streamed response to api message conversion #405

Merged
merged 7 commits into from
Jan 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 41 additions & 38 deletions src/magentic/chat_model/anthropic_chat_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,11 @@
from typing import Any, Generic, TypeVar, cast, overload

from magentic._parsing import contains_parallel_function_call_type, contains_string_type
from magentic._streamed_response import StreamedResponse
from magentic._streamed_response import AsyncStreamedResponse, StreamedResponse
from magentic.chat_model.base import ChatModel, aparse_stream, parse_stream
from magentic.chat_model.function_schema import (
BaseFunctionSchema,
FunctionCallFunctionSchema,
FunctionSchema,
function_schema_for_type,
get_async_function_schemas,
get_function_schemas,
Expand All @@ -35,7 +34,7 @@
StreamState,
)
from magentic.function_call import FunctionCall, ParallelFunctionCall, _create_unique_id
from magentic.streaming import StreamedStr
from magentic.streaming import AsyncStreamedStr, StreamedStr
from magentic.vision import UserImageMessage

try:
Expand Down Expand Up @@ -69,6 +68,12 @@ def message_to_anthropic_message(message: Message[Any]) -> MessageParam:
raise NotImplementedError(type(message))


@singledispatch
async def async_message_to_anthropic_message(message: Message[Any]) -> MessageParam:
"""Async version of `message_to_anthropic_message`."""
return message_to_anthropic_message(message)


@message_to_anthropic_message.register(_RawMessage)
def _(message: _RawMessage[Any]) -> MessageParam:
# TODO: Validate the message content
Expand Down Expand Up @@ -136,6 +141,18 @@ def _(message: UserImageMessage[Any]) -> MessageParam:
}


def _function_call_to_tool_call_block(
function_call: FunctionCall[Any],
) -> ToolUseBlockParam:
function_schema = FunctionCallFunctionSchema(function_call.function)
return {
"type": "tool_use",
"id": function_call._unique_id,
"name": function_schema.name,
"input": json.loads(function_schema.serialize_args(function_call)),
}


@message_to_anthropic_message.register(AssistantMessage)
def _(message: AssistantMessage[Any]) -> MessageParam:
if isinstance(message.content, str):
Expand All @@ -144,38 +161,17 @@ def _(message: AssistantMessage[Any]) -> MessageParam:
"content": message.content,
}

function_schema: FunctionSchema[Any]

if isinstance(message.content, FunctionCall):
function_schema = FunctionCallFunctionSchema(message.content.function)
return {
"role": AnthropicMessageRole.ASSISTANT.value,
"content": [
{
"type": "tool_use",
"id": message.content._unique_id,
"name": function_schema.name,
"input": json.loads(
function_schema.serialize_args(message.content)
),
}
],
"content": [_function_call_to_tool_call_block(message.content)],
}

if isinstance(message.content, ParallelFunctionCall):
return {
"role": AnthropicMessageRole.ASSISTANT.value,
"content": [
{
"type": "tool_use",
"id": function_call._unique_id,
"name": FunctionCallFunctionSchema(function_call.function).name,
"input": json.loads(
FunctionCallFunctionSchema(
function_call.function
).serialize_args(function_call)
),
}
_function_call_to_tool_call_block(function_call)
for function_call in message.content
],
}
Expand All @@ -184,18 +180,9 @@ def _(message: AssistantMessage[Any]) -> MessageParam:
content_blocks: list[TextBlockParam | ToolUseBlockParam] = []
for item in message.content:
if isinstance(item, StreamedStr):
content_blocks.append({"type": "text", "text": str(item)})
content_blocks.append({"type": "text", "text": item.to_string()})
elif isinstance(item, FunctionCall):
function_schema = FunctionCallFunctionSchema(item.function)
content_blocks.append(
{
"type": "tool_use",
"id": item._unique_id,
"name": function_schema.name,
"input": json.loads(function_schema.serialize_args(item)),
}
)

content_blocks.append(_function_call_to_tool_call_block(item))
return {
"role": AnthropicMessageRole.ASSISTANT.value,
"content": content_blocks,
Expand All @@ -216,6 +203,22 @@ def _(message: AssistantMessage[Any]) -> MessageParam:
}


@async_message_to_anthropic_message.register(AssistantMessage)
async def _(message: AssistantMessage[Any]) -> MessageParam:
if isinstance(message.content, AsyncStreamedResponse):
content_blocks: list[TextBlockParam | ToolUseBlockParam] = []
async for item in message.content:
if isinstance(item, AsyncStreamedStr):
content_blocks.append({"type": "text", "text": await item.to_string()})
elif isinstance(item, FunctionCall):
content_blocks.append(_function_call_to_tool_call_block(item))
return {
"role": AnthropicMessageRole.ASSISTANT.value,
"content": content_blocks,
}
return message_to_anthropic_message(message)


@message_to_anthropic_message.register(ToolResultMessage)
def _(message: ToolResultMessage[Any]) -> MessageParam:
if isinstance(message.content, str):
Expand Down Expand Up @@ -514,7 +517,7 @@ async def acomplete(
] = await self._async_client.messages.stream(
model=self.model,
messages=_combine_messages(
[message_to_anthropic_message(m) for m in messages]
[await async_message_to_anthropic_message(m) for m in messages]
),
max_tokens=self.max_tokens,
stop_sequences=_if_given(stop),
Expand Down
89 changes: 51 additions & 38 deletions src/magentic/chat_model/openai_chat_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
ChatCompletionChunk,
ChatCompletionContentPartParam,
ChatCompletionMessageParam,
ChatCompletionMessageToolCallParam,
ChatCompletionNamedToolChoiceParam,
ChatCompletionStreamOptionsParam,
ChatCompletionToolChoiceOptionParam,
Expand All @@ -17,12 +18,11 @@
)

from magentic._parsing import contains_parallel_function_call_type, contains_string_type
from magentic._streamed_response import StreamedResponse
from magentic._streamed_response import AsyncStreamedResponse, StreamedResponse
from magentic.chat_model.base import ChatModel, aparse_stream, parse_stream
from magentic.chat_model.function_schema import (
BaseFunctionSchema,
FunctionCallFunctionSchema,
FunctionSchema,
function_schema_for_type,
get_async_function_schemas,
get_function_schemas,
Expand All @@ -46,7 +46,7 @@
StreamState,
)
from magentic.function_call import FunctionCall, ParallelFunctionCall, _create_unique_id
from magentic.streaming import StreamedStr
from magentic.streaming import AsyncStreamedStr, StreamedStr
from magentic.vision import UserImageMessage


Expand All @@ -64,6 +64,14 @@ def message_to_openai_message(message: Message[Any]) -> ChatCompletionMessagePar
raise NotImplementedError(type(message))


@singledispatch
async def async_message_to_openai_message(
message: Message[Any],
) -> ChatCompletionMessageParam:
"""Async version of `message_to_openai_message`."""
return message_to_openai_message(message)


@message_to_openai_message.register(_RawMessage)
def _(message: _RawMessage[Any]) -> ChatCompletionMessageParam:
assert isinstance(message.content, dict)
Expand Down Expand Up @@ -122,43 +130,36 @@ def _(message: UserImageMessage[Any]) -> ChatCompletionUserMessageParam:
}


def _function_call_to_tool_call_block(
function_call: FunctionCall[Any],
) -> ChatCompletionMessageToolCallParam:
function_schema = FunctionCallFunctionSchema(function_call.function)
return {
"id": function_call._unique_id,
"type": "function",
"function": {
"name": function_schema.name,
"arguments": function_schema.serialize_args(function_call),
},
}


@message_to_openai_message.register(AssistantMessage)
def _(message: AssistantMessage[Any]) -> ChatCompletionMessageParam:
if isinstance(message.content, str):
return {"role": OpenaiMessageRole.ASSISTANT.value, "content": message.content}

function_schema: FunctionSchema[Any]

if isinstance(message.content, FunctionCall):
function_schema = FunctionCallFunctionSchema(message.content.function)
return {
"role": OpenaiMessageRole.ASSISTANT.value,
"tool_calls": [
{
"id": message.content._unique_id,
"type": "function",
"function": {
"name": function_schema.name,
"arguments": function_schema.serialize_args(message.content),
},
}
],
"tool_calls": [_function_call_to_tool_call_block(message.content)],
}

if isinstance(message.content, ParallelFunctionCall):
return {
"role": OpenaiMessageRole.ASSISTANT.value,
"tool_calls": [
{
"id": function_call._unique_id,
"type": "function",
"function": {
"name": FunctionCallFunctionSchema(function_call.function).name,
"arguments": FunctionCallFunctionSchema(
function_call.function
).serialize_args(function_call),
},
}
_function_call_to_tool_call_block(function_call)
for function_call in message.content
],
}
Expand All @@ -168,23 +169,14 @@ def _(message: AssistantMessage[Any]) -> ChatCompletionMessageParam:
function_calls: list[FunctionCall[Any]] = []
for item in message.content:
if isinstance(item, StreamedStr):
content.append(str(item))
content.append(item.to_string())
elif isinstance(item, FunctionCall):
function_calls.append(item)
return {
"role": OpenaiMessageRole.ASSISTANT.value,
"content": " ".join(content),
"tool_calls": [
{
"id": function_call._unique_id,
"type": "function",
"function": {
"name": FunctionCallFunctionSchema(function_call.function).name,
"arguments": FunctionCallFunctionSchema(
function_call.function
).serialize_args(function_call),
},
}
_function_call_to_tool_call_block(function_call)
for function_call in function_calls
],
}
Expand All @@ -206,6 +198,27 @@ def _(message: AssistantMessage[Any]) -> ChatCompletionMessageParam:
}


@async_message_to_openai_message.register(AssistantMessage)
async def _(message: AssistantMessage[Any]) -> ChatCompletionMessageParam:
if isinstance(message.content, AsyncStreamedResponse):
content: list[str] = []
function_calls: list[FunctionCall[Any]] = []
async for item in message.content:
if isinstance(item, AsyncStreamedStr):
content.append(await item.to_string())
elif isinstance(item, FunctionCall):
function_calls.append(item)
return {
"role": OpenaiMessageRole.ASSISTANT.value,
"content": " ".join(content),
"tool_calls": [
_function_call_to_tool_call_block(function_call)
for function_call in function_calls
],
}
return message_to_openai_message(message)


@message_to_openai_message.register(ToolResultMessage)
def _(message: ToolResultMessage[Any]) -> ChatCompletionMessageParam:
if isinstance(message.content, str):
Expand Down Expand Up @@ -546,7 +559,7 @@ async def acomplete(
] = await self._async_client.chat.completions.create(
model=self.model,
messages=_add_missing_tool_calls_responses(
[message_to_openai_message(m) for m in messages]
[await async_message_to_openai_message(m) for m in messages]
),
max_tokens=_if_given(self.max_tokens),
seed=_if_given(self.seed),
Expand Down
Loading
Loading