From d706ce8dd03fdb7ecc0fedf81a6feec81a45e011 Mon Sep 17 00:00:00 2001 From: Ishan Anand Date: Fri, 24 Jan 2025 16:23:34 +0530 Subject: [PATCH 1/7] add async streamed response to llm message conversion --- .../chat_model/anthropic_chat_model.py | 71 ++++++++-------- src/magentic/chat_model/openai_chat_model.py | 82 +++++++++++-------- src/magentic/utilities.py | 39 +++++++++ 3 files changed, 121 insertions(+), 71 deletions(-) create mode 100644 src/magentic/utilities.py diff --git a/src/magentic/chat_model/anthropic_chat_model.py b/src/magentic/chat_model/anthropic_chat_model.py index 426fd62..4c24090 100644 --- a/src/magentic/chat_model/anthropic_chat_model.py +++ b/src/magentic/chat_model/anthropic_chat_model.py @@ -1,3 +1,4 @@ +import asyncio import json from collections.abc import AsyncIterator, Callable, Iterable, Iterator, Sequence from enum import Enum @@ -6,7 +7,7 @@ from typing import Any, Generic, TypeVar, cast, overload from magentic._parsing import contains_parallel_function_call_type, contains_string_type -from magentic._streamed_response import StreamedResponse +from magentic._streamed_response import AsyncStreamedResponse, StreamedResponse from magentic.chat_model.base import ChatModel, aparse_stream, parse_stream from magentic.chat_model.function_schema import ( BaseFunctionSchema, @@ -35,7 +36,7 @@ StreamState, ) from magentic.function_call import FunctionCall, ParallelFunctionCall, _create_unique_id -from magentic.streaming import StreamedStr +from magentic.streaming import AsyncStreamedStr, StreamedStr from magentic.vision import UserImageMessage try: @@ -136,6 +137,16 @@ def _(message: UserImageMessage[Any]) -> MessageParam: } +def _function_call_to_tool_call_block(function_call: FunctionCall) -> ToolUseBlockParam: + function_schema = FunctionCallFunctionSchema(function_call.function) + return { + "type": "tool_use", + "id": function_call._unique_id, + "name": function_schema.name, + "input": json.loads(function_schema.serialize_args(function_call)), + } + + @message_to_anthropic_message.register(AssistantMessage) def _(message: AssistantMessage[Any]) -> MessageParam: if isinstance(message.content, str): @@ -144,38 +155,17 @@ def _(message: AssistantMessage[Any]) -> MessageParam: "content": message.content, } - function_schema: FunctionSchema[Any] - if isinstance(message.content, FunctionCall): - function_schema = FunctionCallFunctionSchema(message.content.function) return { "role": AnthropicMessageRole.ASSISTANT.value, - "content": [ - { - "type": "tool_use", - "id": message.content._unique_id, - "name": function_schema.name, - "input": json.loads( - function_schema.serialize_args(message.content) - ), - } - ], + "content": [_function_call_to_tool_call_block(message.content)], } if isinstance(message.content, ParallelFunctionCall): return { "role": AnthropicMessageRole.ASSISTANT.value, "content": [ - { - "type": "tool_use", - "id": function_call._unique_id, - "name": FunctionCallFunctionSchema(function_call.function).name, - "input": json.loads( - FunctionCallFunctionSchema( - function_call.function - ).serialize_args(function_call) - ), - } + _function_call_to_tool_call_block(function_call) for function_call in message.content ], } @@ -184,23 +174,34 @@ def _(message: AssistantMessage[Any]) -> MessageParam: content_blocks: list[TextBlockParam | ToolUseBlockParam] = [] for item in message.content: if isinstance(item, StreamedStr): - content_blocks.append({"type": "text", "text": str(item)}) + content_blocks.append({"type": "text", "text": item.to_string()}) elif isinstance(item, FunctionCall): - function_schema = FunctionCallFunctionSchema(item.function) - content_blocks.append( - { - "type": "tool_use", - "id": item._unique_id, - "name": function_schema.name, - "input": json.loads(function_schema.serialize_args(item)), - } - ) + content_blocks.append(_function_call_to_tool_call_block(item)) return { "role": AnthropicMessageRole.ASSISTANT.value, "content": content_blocks, } + if isinstance(message.content, AsyncStreamedResponse): + from magentic.utilities import ASYNC_RUNNER + + async def collect_content_blocks(): + content_blocks: list[TextBlockParam | ToolUseBlockParam] = [] + async for item in message.content: + if isinstance(item, AsyncStreamedStr): + content_blocks.append( + {"type": "text", "text": await item.to_string()} + ) + elif isinstance(item, FunctionCall): + content_blocks.append(_function_call_to_tool_call_block(item)) + return content_blocks + + return { + "role": AnthropicMessageRole.ASSISTANT.value, + "content": ASYNC_RUNNER.run_coroutine(collect_content_blocks()), + } + function_schema = function_schema_for_type(type(message.content)) return { "role": AnthropicMessageRole.ASSISTANT.value, diff --git a/src/magentic/chat_model/openai_chat_model.py b/src/magentic/chat_model/openai_chat_model.py index 0c505ab..970596d 100644 --- a/src/magentic/chat_model/openai_chat_model.py +++ b/src/magentic/chat_model/openai_chat_model.py @@ -9,6 +9,7 @@ ChatCompletionChunk, ChatCompletionContentPartParam, ChatCompletionMessageParam, + ChatCompletionMessageToolCallParam, ChatCompletionNamedToolChoiceParam, ChatCompletionStreamOptionsParam, ChatCompletionToolChoiceOptionParam, @@ -17,7 +18,7 @@ ) from magentic._parsing import contains_parallel_function_call_type, contains_string_type -from magentic._streamed_response import StreamedResponse +from magentic._streamed_response import AsyncStreamedResponse, StreamedResponse from magentic.chat_model.base import ChatModel, aparse_stream, parse_stream from magentic.chat_model.function_schema import ( BaseFunctionSchema, @@ -46,7 +47,7 @@ StreamState, ) from magentic.function_call import FunctionCall, ParallelFunctionCall, _create_unique_id -from magentic.streaming import StreamedStr +from magentic.streaming import AsyncStreamedStr, StreamedStr from magentic.vision import UserImageMessage @@ -122,43 +123,36 @@ def _(message: UserImageMessage[Any]) -> ChatCompletionUserMessageParam: } +def _function_call_to_tool_call_block( + function_call: FunctionCall[Any], +) -> ChatCompletionMessageToolCallParam: + function_schema = FunctionCallFunctionSchema(function_call.function) + return { + "id": function_call._unique_id, + "type": "function", + "function": { + "name": function_schema.name, + "arguments": function_schema.serialize_args(function_call), + }, + } + + @message_to_openai_message.register(AssistantMessage) def _(message: AssistantMessage[Any]) -> ChatCompletionMessageParam: if isinstance(message.content, str): return {"role": OpenaiMessageRole.ASSISTANT.value, "content": message.content} - function_schema: FunctionSchema[Any] - if isinstance(message.content, FunctionCall): - function_schema = FunctionCallFunctionSchema(message.content.function) return { "role": OpenaiMessageRole.ASSISTANT.value, - "tool_calls": [ - { - "id": message.content._unique_id, - "type": "function", - "function": { - "name": function_schema.name, - "arguments": function_schema.serialize_args(message.content), - }, - } - ], + "tool_calls": [_function_call_to_tool_call_block(message.content)], } if isinstance(message.content, ParallelFunctionCall): return { "role": OpenaiMessageRole.ASSISTANT.value, "tool_calls": [ - { - "id": function_call._unique_id, - "type": "function", - "function": { - "name": FunctionCallFunctionSchema(function_call.function).name, - "arguments": FunctionCallFunctionSchema( - function_call.function - ).serialize_args(function_call), - }, - } + _function_call_to_tool_call_block(function_call) for function_call in message.content ], } @@ -168,23 +162,39 @@ def _(message: AssistantMessage[Any]) -> ChatCompletionMessageParam: function_calls: list[FunctionCall[Any]] = [] for item in message.content: if isinstance(item, StreamedStr): - content.append(str(item)) + content.append(item.to_string()) elif isinstance(item, FunctionCall): function_calls.append(item) return { "role": OpenaiMessageRole.ASSISTANT.value, "content": " ".join(content), "tool_calls": [ - { - "id": function_call._unique_id, - "type": "function", - "function": { - "name": FunctionCallFunctionSchema(function_call.function).name, - "arguments": FunctionCallFunctionSchema( - function_call.function - ).serialize_args(function_call), - }, - } + _function_call_to_tool_call_block(function_call) + for function_call in function_calls + ], + } + + if isinstance(message.content, AsyncStreamedResponse): + from magentic.utilities import ASYNC_RUNNER + + async def collect_content_and_function_calls(): + content: list[str] = [] + function_calls: list[FunctionCall[Any]] = [] + async for item in message.content: + if isinstance(item, AsyncStreamedStr): + content.append(await item.to_string()) + elif isinstance(item, FunctionCall): + function_calls.append(item) + return content, function_calls + + content, function_calls = ASYNC_RUNNER.run_coroutine( + collect_content_and_function_calls() + ) + return { + "role": OpenaiMessageRole.ASSISTANT.value, + "content": " ".join(content), + "tool_calls": [ + _function_call_to_tool_call_block(function_call) for function_call in function_calls ], } diff --git a/src/magentic/utilities.py b/src/magentic/utilities.py new file mode 100644 index 0000000..39d9c76 --- /dev/null +++ b/src/magentic/utilities.py @@ -0,0 +1,39 @@ +"""Utilities for the magentic package.""" + +import asyncio +import atexit +import threading +from collections.abc import Coroutine +from concurrent.futures import ThreadPoolExecutor +from typing import Any + + +class _AsyncRunner: + """Manages thread pool and event loops for running async code in sync contexts.""" + + def __init__(self, max_workers: int = 2): + self._thread_pool = ThreadPoolExecutor(max_workers=max_workers) + self._thread_local = threading.local() + atexit.register(self._cleanup) + + def _cleanup(self) -> None: + """Cleanup the thread pool and event loops on exit.""" + self._thread_pool.shutdown(wait=False, cancel_futures=True) + + def _run_coro_in_thread(self, coro: Coroutine[Any, Any, Any]) -> Any: + """Run a coroutine in the thread pool's event loop.""" + if not hasattr(self._thread_local, "loop"): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + self._thread_local.loop = loop + else: + loop = self._thread_local.loop + return loop.run_until_complete(coro) + + def run_coroutine(self, coro: Coroutine[Any, Any, Any]) -> Any: + """Run the coroutine in a separate thread.""" + return self._thread_pool.submit(self._run_coro_in_thread, coro).result() + + +# Global instance for the package +ASYNC_RUNNER = _AsyncRunner() From 5e79b29987e5018b23f74b3f4c45ea9f5cb15f10 Mon Sep 17 00:00:00 2001 From: Ishan Anand Date: Fri, 24 Jan 2025 17:00:42 +0530 Subject: [PATCH 2/7] add tests --- .../chat_model/anthropic_chat_model.py | 41 +++++++------- src/magentic/chat_model/openai_chat_model.py | 53 +++++++++---------- src/magentic/utilities.py | 39 -------------- tests/chat_model/test_anthropic_chat_model.py | 26 +++++++++ tests/chat_model/test_openai_chat_model.py | 25 +++++++++ 5 files changed, 96 insertions(+), 88 deletions(-) delete mode 100644 src/magentic/utilities.py diff --git a/src/magentic/chat_model/anthropic_chat_model.py b/src/magentic/chat_model/anthropic_chat_model.py index 4c24090..b3db2ce 100644 --- a/src/magentic/chat_model/anthropic_chat_model.py +++ b/src/magentic/chat_model/anthropic_chat_model.py @@ -1,4 +1,3 @@ -import asyncio import json from collections.abc import AsyncIterator, Callable, Iterable, Iterator, Sequence from enum import Enum @@ -12,7 +11,6 @@ from magentic.chat_model.function_schema import ( BaseFunctionSchema, FunctionCallFunctionSchema, - FunctionSchema, function_schema_for_type, get_async_function_schemas, get_function_schemas, @@ -183,25 +181,6 @@ def _(message: AssistantMessage[Any]) -> MessageParam: "content": content_blocks, } - if isinstance(message.content, AsyncStreamedResponse): - from magentic.utilities import ASYNC_RUNNER - - async def collect_content_blocks(): - content_blocks: list[TextBlockParam | ToolUseBlockParam] = [] - async for item in message.content: - if isinstance(item, AsyncStreamedStr): - content_blocks.append( - {"type": "text", "text": await item.to_string()} - ) - elif isinstance(item, FunctionCall): - content_blocks.append(_function_call_to_tool_call_block(item)) - return content_blocks - - return { - "role": AnthropicMessageRole.ASSISTANT.value, - "content": ASYNC_RUNNER.run_coroutine(collect_content_blocks()), - } - function_schema = function_schema_for_type(type(message.content)) return { "role": AnthropicMessageRole.ASSISTANT.value, @@ -236,6 +215,24 @@ def _(message: ToolResultMessage[Any]) -> MessageParam: } +async def async_message_to_anthropic_message(message: Message[Any]) -> MessageParam: + """Convert a Message to an Anthropic message (async version).""" + if isinstance(message.content, AsyncStreamedResponse): + content_blocks: list[TextBlockParam | ToolUseBlockParam] = [] + async for item in message.content: + if isinstance(item, AsyncStreamedStr): + content_blocks.append({"type": "text", "text": await item.to_string()}) + elif isinstance(item, FunctionCall): + content_blocks.append(_function_call_to_tool_call_block(item)) + + return { + "role": AnthropicMessageRole.ASSISTANT.value, + "content": content_blocks, + } + else: # noqa: RET505 + return message_to_anthropic_message(message) + + # TODO: Move this to the magentic level by allowing `UserMessage` have a list of content def _combine_messages(messages: Iterable[MessageParam]) -> list[MessageParam]: """Combine messages with the same role, to get alternating roles. @@ -515,7 +512,7 @@ async def acomplete( ] = await self._async_client.messages.stream( model=self.model, messages=_combine_messages( - [message_to_anthropic_message(m) for m in messages] + [await async_message_to_anthropic_message(m) for m in messages] ), max_tokens=self.max_tokens, stop_sequences=_if_given(stop), diff --git a/src/magentic/chat_model/openai_chat_model.py b/src/magentic/chat_model/openai_chat_model.py index 970596d..985519e 100644 --- a/src/magentic/chat_model/openai_chat_model.py +++ b/src/magentic/chat_model/openai_chat_model.py @@ -23,7 +23,6 @@ from magentic.chat_model.function_schema import ( BaseFunctionSchema, FunctionCallFunctionSchema, - FunctionSchema, function_schema_for_type, get_async_function_schemas, get_function_schemas, @@ -174,31 +173,6 @@ def _(message: AssistantMessage[Any]) -> ChatCompletionMessageParam: ], } - if isinstance(message.content, AsyncStreamedResponse): - from magentic.utilities import ASYNC_RUNNER - - async def collect_content_and_function_calls(): - content: list[str] = [] - function_calls: list[FunctionCall[Any]] = [] - async for item in message.content: - if isinstance(item, AsyncStreamedStr): - content.append(await item.to_string()) - elif isinstance(item, FunctionCall): - function_calls.append(item) - return content, function_calls - - content, function_calls = ASYNC_RUNNER.run_coroutine( - collect_content_and_function_calls() - ) - return { - "role": OpenaiMessageRole.ASSISTANT.value, - "content": " ".join(content), - "tool_calls": [ - _function_call_to_tool_call_block(function_call) - for function_call in function_calls - ], - } - function_schema = function_schema_for_type(type(message.content)) return { "role": OpenaiMessageRole.ASSISTANT.value, @@ -230,6 +204,31 @@ def _(message: ToolResultMessage[Any]) -> ChatCompletionMessageParam: } +async def async_message_to_openai_message( + message: Message[Any], +) -> ChatCompletionMessageParam: + """Convert a Message to an OpenAI message (async version).""" + if isinstance(message.content, AsyncStreamedResponse): + content: list[str] = [] + function_calls: list[FunctionCall[Any]] = [] + async for item in message.content: + if isinstance(item, AsyncStreamedStr): + content.append(await item.to_string()) + elif isinstance(item, FunctionCall): + function_calls.append(item) + + return { + "role": OpenaiMessageRole.ASSISTANT.value, + "content": " ".join(content), + "tool_calls": [ + _function_call_to_tool_call_block(function_call) + for function_call in function_calls + ], + } + else: # noqa: RET505 + return message_to_openai_message(message) + + # TODO: Use ToolResultMessage to solve this at magentic level def _add_missing_tool_calls_responses( messages: list[ChatCompletionMessageParam], @@ -556,7 +555,7 @@ async def acomplete( ] = await self._async_client.chat.completions.create( model=self.model, messages=_add_missing_tool_calls_responses( - [message_to_openai_message(m) for m in messages] + [await async_message_to_openai_message(m) for m in messages] ), max_tokens=_if_given(self.max_tokens), seed=_if_given(self.seed), diff --git a/src/magentic/utilities.py b/src/magentic/utilities.py deleted file mode 100644 index 39d9c76..0000000 --- a/src/magentic/utilities.py +++ /dev/null @@ -1,39 +0,0 @@ -"""Utilities for the magentic package.""" - -import asyncio -import atexit -import threading -from collections.abc import Coroutine -from concurrent.futures import ThreadPoolExecutor -from typing import Any - - -class _AsyncRunner: - """Manages thread pool and event loops for running async code in sync contexts.""" - - def __init__(self, max_workers: int = 2): - self._thread_pool = ThreadPoolExecutor(max_workers=max_workers) - self._thread_local = threading.local() - atexit.register(self._cleanup) - - def _cleanup(self) -> None: - """Cleanup the thread pool and event loops on exit.""" - self._thread_pool.shutdown(wait=False, cancel_futures=True) - - def _run_coro_in_thread(self, coro: Coroutine[Any, Any, Any]) -> Any: - """Run a coroutine in the thread pool's event loop.""" - if not hasattr(self._thread_local, "loop"): - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - self._thread_local.loop = loop - else: - loop = self._thread_local.loop - return loop.run_until_complete(coro) - - def run_coroutine(self, coro: Coroutine[Any, Any, Any]) -> Any: - """Run the coroutine in a separate thread.""" - return self._thread_pool.submit(self._run_coro_in_thread, coro).result() - - -# Global instance for the package -ASYNC_RUNNER = _AsyncRunner() diff --git a/tests/chat_model/test_anthropic_chat_model.py b/tests/chat_model/test_anthropic_chat_model.py index 12b3ef5..8c1673a 100644 --- a/tests/chat_model/test_anthropic_chat_model.py +++ b/tests/chat_model/test_anthropic_chat_model.py @@ -8,6 +8,7 @@ from magentic._streamed_response import AsyncStreamedResponse, StreamedResponse from magentic.chat_model.anthropic_chat_model import ( AnthropicChatModel, + async_message_to_anthropic_message, message_to_anthropic_message, ) from magentic.chat_model.base import ToolSchemaParseError @@ -125,6 +126,31 @@ def test_message_to_anthropic_message(message, expected_anthropic_message): assert message_to_anthropic_message(message) == expected_anthropic_message +async def test_async_message_to_anthropic_message(): + async def generate_async_streamed_response(): + async def async_string_generator(): + yield "Hello" + yield "World" + + yield AsyncStreamedStr(async_string_generator()) + yield FunctionCall(plus, 1, 2) + + async_streamed_response = AsyncStreamedResponse(generate_async_streamed_response()) + message = AssistantMessage(async_streamed_response) + assert await async_message_to_anthropic_message(message) == { + "role": "assistant", + "content": [ + {"type": "text", "text": "HelloWorld"}, + { + "type": "tool_use", + "id": ANY, + "name": "plus", + "input": {"a": 1, "b": 2}, + }, + ], + } + + def test_message_to_anthropic_message_user_image_document_bytes_pdf(document_bytes_pdf): image_message = UserMessage([DocumentBytes(document_bytes_pdf)]) assert message_to_anthropic_message(image_message) == snapshot( diff --git a/tests/chat_model/test_openai_chat_model.py b/tests/chat_model/test_openai_chat_model.py index b5156e4..86e4863 100644 --- a/tests/chat_model/test_openai_chat_model.py +++ b/tests/chat_model/test_openai_chat_model.py @@ -24,6 +24,7 @@ ) from magentic.chat_model.openai_chat_model import ( OpenaiChatModel, + async_message_to_openai_message, message_to_openai_message, ) from magentic.function_call import FunctionCall, ParallelFunctionCall @@ -134,6 +135,30 @@ def test_message_to_openai_message(message, expected_openai_message): assert message_to_openai_message(message) == expected_openai_message +async def test_async_message_to_openai_message(): + async def generate_async_streamed_response(): + async def async_string_generator(): + yield "Hello" + yield "World" + + yield AsyncStreamedStr(async_string_generator()) + yield FunctionCall(plus, 1, 2) + + async_streamed_response = AsyncStreamedResponse(generate_async_streamed_response()) + message = AssistantMessage(async_streamed_response) + assert await async_message_to_openai_message(message) == { + "role": "assistant", + "content": "HelloWorld", + "tool_calls": [ + { + "id": ANY, + "type": "function", + "function": {"name": "plus", "arguments": '{"a":1,"b":2}'}, + }, + ], + } + + def test_message_to_openai_message_user_image_message_bytes_jpg(image_bytes_jpg): image_message = UserMessage([ImageBytes(image_bytes_jpg)]) assert message_to_openai_message(image_message) == snapshot( From c55332075d95dbef0ad26e78aa0800da647655e6 Mon Sep 17 00:00:00 2001 From: Ishan Anand Date: Fri, 24 Jan 2025 17:16:47 +0530 Subject: [PATCH 3/7] specify types --- src/magentic/chat_model/anthropic_chat_model.py | 4 +++- tests/chat_model/test_anthropic_chat_model.py | 4 ++-- tests/chat_model/test_openai_chat_model.py | 4 ++-- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/src/magentic/chat_model/anthropic_chat_model.py b/src/magentic/chat_model/anthropic_chat_model.py index b3db2ce..2a315bd 100644 --- a/src/magentic/chat_model/anthropic_chat_model.py +++ b/src/magentic/chat_model/anthropic_chat_model.py @@ -135,7 +135,9 @@ def _(message: UserImageMessage[Any]) -> MessageParam: } -def _function_call_to_tool_call_block(function_call: FunctionCall) -> ToolUseBlockParam: +def _function_call_to_tool_call_block( + function_call: FunctionCall[Any], +) -> ToolUseBlockParam: function_schema = FunctionCallFunctionSchema(function_call.function) return { "type": "tool_use", diff --git a/tests/chat_model/test_anthropic_chat_model.py b/tests/chat_model/test_anthropic_chat_model.py index 8c1673a..61646f6 100644 --- a/tests/chat_model/test_anthropic_chat_model.py +++ b/tests/chat_model/test_anthropic_chat_model.py @@ -132,10 +132,10 @@ async def async_string_generator(): yield "Hello" yield "World" - yield AsyncStreamedStr(async_string_generator()) + yield AsyncStreamedStr(async_string_generator()) # type: ignore # noqa: PGH003 yield FunctionCall(plus, 1, 2) - async_streamed_response = AsyncStreamedResponse(generate_async_streamed_response()) + async_streamed_response = AsyncStreamedResponse(generate_async_streamed_response()) # type: ignore # noqa: PGH003 message = AssistantMessage(async_streamed_response) assert await async_message_to_anthropic_message(message) == { "role": "assistant", diff --git a/tests/chat_model/test_openai_chat_model.py b/tests/chat_model/test_openai_chat_model.py index 86e4863..cd4d860 100644 --- a/tests/chat_model/test_openai_chat_model.py +++ b/tests/chat_model/test_openai_chat_model.py @@ -141,10 +141,10 @@ async def async_string_generator(): yield "Hello" yield "World" - yield AsyncStreamedStr(async_string_generator()) + yield AsyncStreamedStr(async_string_generator()) # type: ignore # noqa: PGH003 yield FunctionCall(plus, 1, 2) - async_streamed_response = AsyncStreamedResponse(generate_async_streamed_response()) + async_streamed_response = AsyncStreamedResponse(generate_async_streamed_response()) # type: ignore # noqa: PGH003 message = AssistantMessage(async_streamed_response) assert await async_message_to_openai_message(message) == { "role": "assistant", From 8e3c3ff5d79e22c80b921785dfaae553e262c734 Mon Sep 17 00:00:00 2001 From: Jack Collins <6640905+jackmpcollins@users.noreply.github.com> Date: Sun, 26 Jan 2025 22:57:05 -0800 Subject: [PATCH 4/7] Resolve RET505 --- src/magentic/chat_model/anthropic_chat_model.py | 3 +-- src/magentic/chat_model/openai_chat_model.py | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/src/magentic/chat_model/anthropic_chat_model.py b/src/magentic/chat_model/anthropic_chat_model.py index 2a315bd..488d3f1 100644 --- a/src/magentic/chat_model/anthropic_chat_model.py +++ b/src/magentic/chat_model/anthropic_chat_model.py @@ -231,8 +231,7 @@ async def async_message_to_anthropic_message(message: Message[Any]) -> MessagePa "role": AnthropicMessageRole.ASSISTANT.value, "content": content_blocks, } - else: # noqa: RET505 - return message_to_anthropic_message(message) + return message_to_anthropic_message(message) # TODO: Move this to the magentic level by allowing `UserMessage` have a list of content diff --git a/src/magentic/chat_model/openai_chat_model.py b/src/magentic/chat_model/openai_chat_model.py index 985519e..010841d 100644 --- a/src/magentic/chat_model/openai_chat_model.py +++ b/src/magentic/chat_model/openai_chat_model.py @@ -225,8 +225,7 @@ async def async_message_to_openai_message( for function_call in function_calls ], } - else: # noqa: RET505 - return message_to_openai_message(message) + return message_to_openai_message(message) # TODO: Use ToolResultMessage to solve this at magentic level From 30bca9d4cdcef52204de1469f0e3092cbdba90d2 Mon Sep 17 00:00:00 2001 From: Jack Collins <6640905+jackmpcollins@users.noreply.github.com> Date: Sun, 26 Jan 2025 23:16:42 -0800 Subject: [PATCH 5/7] Implement async_message_to_X_message using singledispatch --- .../chat_model/anthropic_chat_model.py | 40 +++++++++-------- src/magentic/chat_model/openai_chat_model.py | 43 +++++++++++-------- 2 files changed, 46 insertions(+), 37 deletions(-) diff --git a/src/magentic/chat_model/anthropic_chat_model.py b/src/magentic/chat_model/anthropic_chat_model.py index 488d3f1..9e9b68f 100644 --- a/src/magentic/chat_model/anthropic_chat_model.py +++ b/src/magentic/chat_model/anthropic_chat_model.py @@ -68,6 +68,12 @@ def message_to_anthropic_message(message: Message[Any]) -> MessageParam: raise NotImplementedError(type(message)) +@singledispatch +async def async_message_to_anthropic_message(message: Message[Any]) -> MessageParam: + """Async version of `message_to_anthropic_message`.""" + return message_to_anthropic_message(message) + + @message_to_anthropic_message.register(_RawMessage) def _(message: _RawMessage[Any]) -> MessageParam: # TODO: Validate the message content @@ -177,7 +183,6 @@ def _(message: AssistantMessage[Any]) -> MessageParam: content_blocks.append({"type": "text", "text": item.to_string()}) elif isinstance(item, FunctionCall): content_blocks.append(_function_call_to_tool_call_block(item)) - return { "role": AnthropicMessageRole.ASSISTANT.value, "content": content_blocks, @@ -198,6 +203,22 @@ def _(message: AssistantMessage[Any]) -> MessageParam: } +@async_message_to_anthropic_message.register(AssistantMessage) +async def _(message: AssistantMessage[Any]) -> MessageParam: + if isinstance(message.content, AsyncStreamedResponse): + content_blocks: list[TextBlockParam | ToolUseBlockParam] = [] + async for item in message.content: + if isinstance(item, AsyncStreamedStr): + content_blocks.append({"type": "text", "text": await item.to_string()}) + elif isinstance(item, FunctionCall): + content_blocks.append(_function_call_to_tool_call_block(item)) + return { + "role": AnthropicMessageRole.ASSISTANT.value, + "content": content_blocks, + } + return message_to_anthropic_message(message) + + @message_to_anthropic_message.register(ToolResultMessage) def _(message: ToolResultMessage[Any]) -> MessageParam: if isinstance(message.content, str): @@ -217,23 +238,6 @@ def _(message: ToolResultMessage[Any]) -> MessageParam: } -async def async_message_to_anthropic_message(message: Message[Any]) -> MessageParam: - """Convert a Message to an Anthropic message (async version).""" - if isinstance(message.content, AsyncStreamedResponse): - content_blocks: list[TextBlockParam | ToolUseBlockParam] = [] - async for item in message.content: - if isinstance(item, AsyncStreamedStr): - content_blocks.append({"type": "text", "text": await item.to_string()}) - elif isinstance(item, FunctionCall): - content_blocks.append(_function_call_to_tool_call_block(item)) - - return { - "role": AnthropicMessageRole.ASSISTANT.value, - "content": content_blocks, - } - return message_to_anthropic_message(message) - - # TODO: Move this to the magentic level by allowing `UserMessage` have a list of content def _combine_messages(messages: Iterable[MessageParam]) -> list[MessageParam]: """Combine messages with the same role, to get alternating roles. diff --git a/src/magentic/chat_model/openai_chat_model.py b/src/magentic/chat_model/openai_chat_model.py index 010841d..1cee0fb 100644 --- a/src/magentic/chat_model/openai_chat_model.py +++ b/src/magentic/chat_model/openai_chat_model.py @@ -64,6 +64,14 @@ def message_to_openai_message(message: Message[Any]) -> ChatCompletionMessagePar raise NotImplementedError(type(message)) +@singledispatch +async def async_message_to_openai_message( + message: Message[Any], +) -> ChatCompletionMessageParam: + """Async version of `message_to_openai_message`.""" + return message_to_openai_message(message) + + @message_to_openai_message.register(_RawMessage) def _(message: _RawMessage[Any]) -> ChatCompletionMessageParam: assert isinstance(message.content, dict) @@ -190,24 +198,8 @@ def _(message: AssistantMessage[Any]) -> ChatCompletionMessageParam: } -@message_to_openai_message.register(ToolResultMessage) -def _(message: ToolResultMessage[Any]) -> ChatCompletionMessageParam: - if isinstance(message.content, str): - content = message.content - else: - function_schema = function_schema_for_type(type(message.content)) - content = function_schema.serialize_args(message.content) - return { - "role": OpenaiMessageRole.TOOL.value, - "tool_call_id": message.tool_call_id, - "content": content, - } - - -async def async_message_to_openai_message( - message: Message[Any], -) -> ChatCompletionMessageParam: - """Convert a Message to an OpenAI message (async version).""" +@async_message_to_openai_message.register(AssistantMessage) +async def _(message: AssistantMessage[Any]) -> ChatCompletionMessageParam: if isinstance(message.content, AsyncStreamedResponse): content: list[str] = [] function_calls: list[FunctionCall[Any]] = [] @@ -216,7 +208,6 @@ async def async_message_to_openai_message( content.append(await item.to_string()) elif isinstance(item, FunctionCall): function_calls.append(item) - return { "role": OpenaiMessageRole.ASSISTANT.value, "content": " ".join(content), @@ -228,6 +219,20 @@ async def async_message_to_openai_message( return message_to_openai_message(message) +@message_to_openai_message.register(ToolResultMessage) +def _(message: ToolResultMessage[Any]) -> ChatCompletionMessageParam: + if isinstance(message.content, str): + content = message.content + else: + function_schema = function_schema_for_type(type(message.content)) + content = function_schema.serialize_args(message.content) + return { + "role": OpenaiMessageRole.TOOL.value, + "tool_call_id": message.tool_call_id, + "content": content, + } + + # TODO: Use ToolResultMessage to solve this at magentic level def _add_missing_tool_calls_responses( messages: list[ChatCompletionMessageParam], From b7c954d600046da21097a8c01210133af72066fa Mon Sep 17 00:00:00 2001 From: Jack Collins <6640905+jackmpcollins@users.noreply.github.com> Date: Sun, 26 Jan 2025 23:27:13 -0800 Subject: [PATCH 6/7] Simplify test setup using async_iter --- tests/chat_model/test_anthropic_chat_model.py | 17 ++++++----------- tests/chat_model/test_openai_chat_model.py | 17 ++++++----------- 2 files changed, 12 insertions(+), 22 deletions(-) diff --git a/tests/chat_model/test_anthropic_chat_model.py b/tests/chat_model/test_anthropic_chat_model.py index 61646f6..f3f0d90 100644 --- a/tests/chat_model/test_anthropic_chat_model.py +++ b/tests/chat_model/test_anthropic_chat_model.py @@ -26,7 +26,7 @@ FunctionCall, ParallelFunctionCall, ) -from magentic.streaming import AsyncStreamedStr, StreamedStr +from magentic.streaming import AsyncStreamedStr, StreamedStr, async_iter def plus(a: int, b: int) -> int: @@ -127,20 +127,15 @@ def test_message_to_anthropic_message(message, expected_anthropic_message): async def test_async_message_to_anthropic_message(): - async def generate_async_streamed_response(): - async def async_string_generator(): - yield "Hello" - yield "World" - - yield AsyncStreamedStr(async_string_generator()) # type: ignore # noqa: PGH003 - yield FunctionCall(plus, 1, 2) - - async_streamed_response = AsyncStreamedResponse(generate_async_streamed_response()) # type: ignore # noqa: PGH003 + async_streamed_str = AsyncStreamedStr(async_iter(["Hello", " World"])) + async_streamed_response = AsyncStreamedResponse( + async_iter([async_streamed_str, FunctionCall(plus, 1, 2)]) + ) message = AssistantMessage(async_streamed_response) assert await async_message_to_anthropic_message(message) == { "role": "assistant", "content": [ - {"type": "text", "text": "HelloWorld"}, + {"type": "text", "text": "Hello World"}, { "type": "tool_use", "id": ANY, diff --git a/tests/chat_model/test_openai_chat_model.py b/tests/chat_model/test_openai_chat_model.py index cd4d860..f8aa30a 100644 --- a/tests/chat_model/test_openai_chat_model.py +++ b/tests/chat_model/test_openai_chat_model.py @@ -28,7 +28,7 @@ message_to_openai_message, ) from magentic.function_call import FunctionCall, ParallelFunctionCall -from magentic.streaming import AsyncStreamedStr, StreamedStr +from magentic.streaming import AsyncStreamedStr, StreamedStr, async_iter def plus(a: int, b: int) -> int: @@ -136,19 +136,14 @@ def test_message_to_openai_message(message, expected_openai_message): async def test_async_message_to_openai_message(): - async def generate_async_streamed_response(): - async def async_string_generator(): - yield "Hello" - yield "World" - - yield AsyncStreamedStr(async_string_generator()) # type: ignore # noqa: PGH003 - yield FunctionCall(plus, 1, 2) - - async_streamed_response = AsyncStreamedResponse(generate_async_streamed_response()) # type: ignore # noqa: PGH003 + async_streamed_str = AsyncStreamedStr(async_iter(["Hello", " World"])) + async_streamed_response = AsyncStreamedResponse( + async_iter([async_streamed_str, FunctionCall(plus, 1, 2)]) + ) message = AssistantMessage(async_streamed_response) assert await async_message_to_openai_message(message) == { "role": "assistant", - "content": "HelloWorld", + "content": "Hello World", "tool_calls": [ { "id": ANY, From 6b9ee8f22781c3af5dbac9136d34e5128717bbc7 Mon Sep 17 00:00:00 2001 From: Jack Collins <6640905+jackmpcollins@users.noreply.github.com> Date: Sun, 26 Jan 2025 23:38:12 -0800 Subject: [PATCH 7/7] Parameterize test. Add non-async test cases --- tests/chat_model/test_anthropic_chat_model.py | 225 +++++++++-------- tests/chat_model/test_openai_chat_model.py | 236 ++++++++++-------- 2 files changed, 248 insertions(+), 213 deletions(-) diff --git a/tests/chat_model/test_anthropic_chat_model.py b/tests/chat_model/test_anthropic_chat_model.py index f3f0d90..e1f2f3f 100644 --- a/tests/chat_model/test_anthropic_chat_model.py +++ b/tests/chat_model/test_anthropic_chat_model.py @@ -33,117 +33,136 @@ def plus(a: int, b: int) -> int: return a + b -@pytest.mark.parametrize( - ("message", "expected_anthropic_message"), - [ - (UserMessage("Hello"), {"role": "user", "content": "Hello"}), - (AssistantMessage("Hello"), {"role": "assistant", "content": "Hello"}), - ( - AssistantMessage(42), - { - "role": "assistant", - "content": [ - { - "type": "tool_use", - "id": ANY, - "name": "return_int", - "input": {"value": 42}, - } - ], - }, - ), - ( - AssistantMessage(FunctionCall(plus, 1, 2)), - { - "role": "assistant", - "content": [ - { - "type": "tool_use", - "id": ANY, - "name": "plus", - "input": {"a": 1, "b": 2}, - } - ], - }, - ), - ( - AssistantMessage( - ParallelFunctionCall( - [FunctionCall(plus, 1, 2), FunctionCall(plus, 3, 4)] - ) - ), - { - "role": "assistant", - "content": [ - { - "type": "tool_use", - "id": ANY, - "name": "plus", - "input": {"a": 1, "b": 2}, - }, - { - "type": "tool_use", - "id": ANY, - "name": "plus", - "input": {"a": 3, "b": 4}, - }, - ], - }, - ), - ( - AssistantMessage( - StreamedResponse([StreamedStr(["Hello"]), FunctionCall(plus, 1, 2)]) - ), - { - "role": "assistant", - "content": [ - {"type": "text", "text": "Hello"}, - { - "type": "tool_use", - "id": ANY, - "name": "plus", - "input": {"a": 1, "b": 2}, - }, - ], - }, +message_to_anthropic_message_test_cases = [ + (UserMessage("Hello"), {"role": "user", "content": "Hello"}), + (AssistantMessage("Hello"), {"role": "assistant", "content": "Hello"}), + ( + AssistantMessage(42), + { + "role": "assistant", + "content": [ + { + "type": "tool_use", + "id": ANY, + "name": "return_int", + "input": {"value": 42}, + } + ], + }, + ), + ( + AssistantMessage(FunctionCall(plus, 1, 2)), + { + "role": "assistant", + "content": [ + { + "type": "tool_use", + "id": ANY, + "name": "plus", + "input": {"a": 1, "b": 2}, + } + ], + }, + ), + ( + AssistantMessage( + ParallelFunctionCall([FunctionCall(plus, 1, 2), FunctionCall(plus, 3, 4)]) ), - ( - FunctionResultMessage(3, FunctionCall(plus, 1, 2)), - { - "role": "user", - "content": [ - { - "type": "tool_result", - "tool_use_id": ANY, - "content": {"value": 3}, - } - ], - }, + { + "role": "assistant", + "content": [ + { + "type": "tool_use", + "id": ANY, + "name": "plus", + "input": {"a": 1, "b": 2}, + }, + { + "type": "tool_use", + "id": ANY, + "name": "plus", + "input": {"a": 3, "b": 4}, + }, + ], + }, + ), + ( + AssistantMessage( + StreamedResponse([StreamedStr(["Hello"]), FunctionCall(plus, 1, 2)]) ), - ], + { + "role": "assistant", + "content": [ + {"type": "text", "text": "Hello"}, + { + "type": "tool_use", + "id": ANY, + "name": "plus", + "input": {"a": 1, "b": 2}, + }, + ], + }, + ), + ( + FunctionResultMessage(3, FunctionCall(plus, 1, 2)), + { + "role": "user", + "content": [ + { + "type": "tool_result", + "tool_use_id": ANY, + "content": {"value": 3}, + } + ], + }, + ), +] + + +@pytest.mark.parametrize( + ("message", "expected_anthropic_message"), message_to_anthropic_message_test_cases ) def test_message_to_anthropic_message(message, expected_anthropic_message): assert message_to_anthropic_message(message) == expected_anthropic_message -async def test_async_message_to_anthropic_message(): - async_streamed_str = AsyncStreamedStr(async_iter(["Hello", " World"])) - async_streamed_response = AsyncStreamedResponse( - async_iter([async_streamed_str, FunctionCall(plus, 1, 2)]) +async_message_to_anthropic_message_test_cases = [ + *message_to_anthropic_message_test_cases, + ( + AssistantMessage( + AsyncStreamedResponse( + async_iter( + [ + AsyncStreamedStr(async_iter(["Hello", " World"])), + FunctionCall(plus, 1, 2), + ] + ) + ) + ), + { + "role": "assistant", + "content": [ + {"type": "text", "text": "Hello World"}, + { + "type": "tool_use", + "id": ANY, + "name": "plus", + "input": {"a": 1, "b": 2}, + }, + ], + }, + ), +] + + +@pytest.mark.parametrize( + ("message", "expected_anthropic_message"), + async_message_to_anthropic_message_test_cases, +) +async def test_async_message_to_anthropic_message(message, expected_anthropic_message): + assert ( + await async_message_to_anthropic_message(message) == expected_anthropic_message ) - message = AssistantMessage(async_streamed_response) - assert await async_message_to_anthropic_message(message) == { - "role": "assistant", - "content": [ - {"type": "text", "text": "Hello World"}, - { - "type": "tool_use", - "id": ANY, - "name": "plus", - "input": {"a": 1, "b": 2}, - }, - ], - } def test_message_to_anthropic_message_user_image_document_bytes_pdf(document_bytes_pdf): diff --git a/tests/chat_model/test_openai_chat_model.py b/tests/chat_model/test_openai_chat_model.py index f8aa30a..67be73d 100644 --- a/tests/chat_model/test_openai_chat_model.py +++ b/tests/chat_model/test_openai_chat_model.py @@ -35,123 +35,139 @@ def plus(a: int, b: int) -> int: return a + b -@pytest.mark.parametrize( - ("message", "expected_openai_message"), - [ - ( - _RawMessage({"role": "user", "content": "Hello"}), - {"role": "user", "content": "Hello"}, - ), - (SystemMessage("Hello"), {"role": "system", "content": "Hello"}), - (UserMessage("Hello"), {"role": "user", "content": "Hello"}), - ( - UserMessage([ImageUrl("https://example.com/image.jpg")]), - { - "role": "user", - "content": [ - { - "type": "image_url", - "image_url": {"url": "https://example.com/image.jpg"}, - } - ], - }, - ), - (AssistantMessage("Hello"), {"role": "assistant", "content": "Hello"}), - ( - AssistantMessage(42), - { - "role": "assistant", - "tool_calls": [ - { - "id": ANY, - "type": "function", - "function": {"name": "return_int", "arguments": '{"value":42}'}, - } - ], - }, - ), - ( - AssistantMessage(FunctionCall(plus, 1, 2)), - { - "role": "assistant", - "tool_calls": [ - { - "id": ANY, - "type": "function", - "function": {"name": "plus", "arguments": '{"a":1,"b":2}'}, - } - ], - }, - ), - ( - AssistantMessage( - ParallelFunctionCall( - [FunctionCall(plus, 1, 2), FunctionCall(plus, 3, 4)] - ) - ), - { - "role": "assistant", - "tool_calls": [ - { - "id": ANY, - "type": "function", - "function": {"name": "plus", "arguments": '{"a":1,"b":2}'}, - }, - { - "id": ANY, - "type": "function", - "function": {"name": "plus", "arguments": '{"a":3,"b":4}'}, - }, - ], - }, - ), - ( - AssistantMessage( - StreamedResponse([StreamedStr(["Hello"]), FunctionCall(plus, 1, 2)]) - ), - { - "role": "assistant", - "content": "Hello", - "tool_calls": [ - { - "id": ANY, - "type": "function", - "function": {"name": "plus", "arguments": '{"a":1,"b":2}'}, - }, - ], - }, +message_to_openai_message_test_cases = [ + ( + _RawMessage({"role": "user", "content": "Hello"}), + {"role": "user", "content": "Hello"}, + ), + (SystemMessage("Hello"), {"role": "system", "content": "Hello"}), + (UserMessage("Hello"), {"role": "user", "content": "Hello"}), + ( + UserMessage([ImageUrl("https://example.com/image.jpg")]), + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": {"url": "https://example.com/image.jpg"}, + } + ], + }, + ), + (AssistantMessage("Hello"), {"role": "assistant", "content": "Hello"}), + ( + AssistantMessage(42), + { + "role": "assistant", + "tool_calls": [ + { + "id": ANY, + "type": "function", + "function": {"name": "return_int", "arguments": '{"value":42}'}, + } + ], + }, + ), + ( + AssistantMessage(FunctionCall(plus, 1, 2)), + { + "role": "assistant", + "tool_calls": [ + { + "id": ANY, + "type": "function", + "function": {"name": "plus", "arguments": '{"a":1,"b":2}'}, + } + ], + }, + ), + ( + AssistantMessage( + ParallelFunctionCall([FunctionCall(plus, 1, 2), FunctionCall(plus, 3, 4)]) ), - ( - FunctionResultMessage(3, FunctionCall(plus, 1, 2)), - { - "role": "tool", - "tool_call_id": ANY, - "content": '{"value":3}', - }, + { + "role": "assistant", + "tool_calls": [ + { + "id": ANY, + "type": "function", + "function": {"name": "plus", "arguments": '{"a":1,"b":2}'}, + }, + { + "id": ANY, + "type": "function", + "function": {"name": "plus", "arguments": '{"a":3,"b":4}'}, + }, + ], + }, + ), + ( + AssistantMessage( + StreamedResponse([StreamedStr(["Hello"]), FunctionCall(plus, 1, 2)]) ), - ], + { + "role": "assistant", + "content": "Hello", + "tool_calls": [ + { + "id": ANY, + "type": "function", + "function": {"name": "plus", "arguments": '{"a":1,"b":2}'}, + }, + ], + }, + ), + ( + FunctionResultMessage(3, FunctionCall(plus, 1, 2)), + { + "role": "tool", + "tool_call_id": ANY, + "content": '{"value":3}', + }, + ), +] + + +@pytest.mark.parametrize( + ("message", "expected_openai_message"), message_to_openai_message_test_cases ) def test_message_to_openai_message(message, expected_openai_message): assert message_to_openai_message(message) == expected_openai_message -async def test_async_message_to_openai_message(): - async_streamed_str = AsyncStreamedStr(async_iter(["Hello", " World"])) - async_streamed_response = AsyncStreamedResponse( - async_iter([async_streamed_str, FunctionCall(plus, 1, 2)]) - ) - message = AssistantMessage(async_streamed_response) - assert await async_message_to_openai_message(message) == { - "role": "assistant", - "content": "Hello World", - "tool_calls": [ - { - "id": ANY, - "type": "function", - "function": {"name": "plus", "arguments": '{"a":1,"b":2}'}, - }, - ], - } +async_message_to_openai_message_test_cases = [ + *message_to_openai_message_test_cases, + ( + AssistantMessage( + AsyncStreamedResponse( + async_iter( + [ + AsyncStreamedStr(async_iter(["Hello", " World"])), + FunctionCall(plus, 1, 2), + ] + ) + ) + ), + { + "role": "assistant", + "content": "Hello World", + "tool_calls": [ + { + "id": ANY, + "type": "function", + "function": {"name": "plus", "arguments": '{"a":1,"b":2}'}, + }, + ], + }, + ), +] + + +@pytest.mark.parametrize( + ("message", "expected_openai_message"), async_message_to_openai_message_test_cases +) +async def test_async_message_to_openai_message(message, expected_openai_message): + assert await async_message_to_openai_message(message) == expected_openai_message def test_message_to_openai_message_user_image_message_bytes_jpg(image_bytes_jpg):