Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support images directly in UserMessage #387

Merged
merged 48 commits into from
Jan 6, 2025
Merged
Changes from 1 commit
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
bc1e7cf
Move OpenaiChatModel UserImageMessage support into openai_chat_model.py
jackmpcollins Aug 13, 2024
27370b1
Implement message_to_anthropic_message for UserImageMessage
jackmpcollins Aug 13, 2024
6f4ae58
Get test passing using _combine_messages
jackmpcollins Aug 13, 2024
2fbd2b2
WIP: Add ImageBytes, ImageUrl. Expand UserMessage types.
jackmpcollins Aug 13, 2024
1eb89b5
Merge branch 'main' into allow-images-in-user-message
jackmpcollins Dec 3, 2024
186876d
Remove duplicate message_to_x_message for UserImageMessage
jackmpcollins Dec 3, 2024
f225d1a
Add make test-fix-snapshots
jackmpcollins Dec 3, 2024
925c20b
Move image_bytes fixtures into top-level conftest
jackmpcollins Dec 3, 2024
0c7f307
Fix typing for ImageBytes.mime_type. Add tests
jackmpcollins Dec 3, 2024
6c4d872
Fix mypy errors on UserMessage conversion typing
jackmpcollins Dec 3, 2024
eba51a9
Validate bytes are valid image
jackmpcollins Dec 3, 2024
3fa37c9
Use ImageBytes in UserImageMessage conversion functions
jackmpcollins Dec 3, 2024
a8dbb40
Add typevar UserMessageContentT
jackmpcollins Dec 3, 2024
b0c97c3
Fix mypy errors due to UserMessage now generic
jackmpcollins Dec 3, 2024
882b238
Attempt to coerce type in Placeholder.format
jackmpcollins Dec 4, 2024
59249bd
Fix: list -> Iterable in UserMessage serialization
jackmpcollins Dec 4, 2024
31a47f6
Fix: return message in message_to_openai_message
jackmpcollins Dec 4, 2024
e07f364
Merge branch 'main' into allow-images-in-user-message
jackmpcollins Jan 1, 2025
df575d4
Make Placeholder a BaseModel
jackmpcollins Jan 1, 2025
ce1334f
Make ContentT covariant
jackmpcollins Jan 1, 2025
9ddd4cd
Add covariant PlaceholderT
jackmpcollins Jan 1, 2025
dfc8a93
Fix type checking for UserMessage format
jackmpcollins Jan 1, 2025
87b81ca
Remove unused type ignores
jackmpcollins Jan 1, 2025
026e436
Require pydantic 2.10 to fix generic in BaseModel
jackmpcollins Jan 1, 2025
c8c7bfa
Revert "Remove unused type ignores"
jackmpcollins Jan 1, 2025
71e411e
Remove pydantic url from error messages in tests
jackmpcollins Jan 1, 2025
20e9146
Add TypeAlias UserMessageContentBlock
jackmpcollins Jan 1, 2025
19346e6
Add trailing .0 to pydantic version in pyproject
jackmpcollins Jan 1, 2025
d2e9b80
Use TypeAdapter for Placeholder coercion
jackmpcollins Jan 1, 2025
2cb7379
Add typing-extensions as dependency
jackmpcollins Jan 1, 2025
34cf578
Remove todo for testing AssistantMessage with FunctionCall
jackmpcollins Jan 1, 2025
17e4558
Ignore logfire not configured warnings
jackmpcollins Jan 1, 2025
fd636b0
Deprecate UserImageMessage
jackmpcollins Jan 1, 2025
0be8bf9
Add make test-snapshots-create and improve naming
jackmpcollins Jan 2, 2025
cf8a55c
Add tests for UserMessage with ImageBytes/Url
jackmpcollins Jan 2, 2025
f35fdf1
Upgrade mypy version
jackmpcollins Jan 2, 2025
6f42619
Add tests for ImageBytes with ChatModel
jackmpcollins Jan 2, 2025
75df0a8
Improve AssistantMessage typing, add tests
jackmpcollins Jan 2, 2025
26f3282
Handle Literal string in AssistantMessage.format typing
jackmpcollins Jan 2, 2025
5bf4412
Add NotPlaceholder Protocol
jackmpcollins Jan 2, 2025
ccb30c0
Fix type hints for UserMessage
jackmpcollins Jan 5, 2025
8e04974
Add github issue link for failing type tests
jackmpcollins Jan 5, 2025
a6ad13c
Improve handling of Literal in AssistantMessage typing
jackmpcollins Jan 6, 2025
0855b8a
Rename to PlaceholderTypeT
jackmpcollins Jan 6, 2025
f5477bd
Remove done todo re name not in kwargs error
jackmpcollins Jan 6, 2025
5012ad2
Add top-level imports for ImageBytes, ImageUrl
jackmpcollins Jan 6, 2025
225479d
Update docs for vision
jackmpcollins Jan 6, 2025
b443e09
Add note about Placeholder coercion
jackmpcollins Jan 6, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Fix type hints for UserMessage
jackmpcollins committed Jan 5, 2025
commit ccb30c004f5ef3b73e19c28096d79646e04d4129
34 changes: 19 additions & 15 deletions src/magentic/chat_model/message.py
Original file line number Diff line number Diff line change
@@ -31,6 +31,9 @@

from magentic.function_call import FunctionCall

if TYPE_CHECKING:
from magentic.typing import NonStringSequence

PlaceholderT = TypeVar("PlaceholderT", covariant=True)


@@ -162,16 +165,22 @@ def format(self, **kwargs: Any) -> Self:
return self


UserMessageContentBlock: TypeAlias = str | ImageBytes | ImageUrl
UserMessageContentBlockT = TypeVar(
"UserMessageContentBlockT", bound=UserMessageContentBlock, covariant=True
)
UserMessageContentBlock: TypeAlias = ImageBytes | ImageUrl
UserMessageContentT = TypeVar(
"UserMessageContentT",
bound=str
| Sequence[UserMessageContentBlock | Placeholder[UserMessageContentBlock]],
| Sequence[str | UserMessageContentBlock | Placeholder[UserMessageContentBlock]],
covariant=True,
)
UserMessageContentBlockT = TypeVar(
"UserMessageContentBlockT", bound=UserMessageContentBlock
)
UserMessageContentBlockT2 = TypeVar(
"UserMessageContentBlockT2", bound=UserMessageContentBlock
)
# These allow type hinting that a `str` _might_ be part of the union
StrT = TypeVar("StrT", str, str)
StrT2 = TypeVar("StrT2", str, str)


class UserMessage(Message[UserMessageContentT], Generic[UserMessageContentT]):
@@ -187,20 +196,15 @@ def format(self: "UserMessage[str]", **kwargs: Any) -> "UserMessage[str]": ...

@overload
def format(
self: "UserMessage[Sequence[UserMessageContentBlockT]]", **kwargs: Any
) -> "UserMessage[Sequence[UserMessageContentBlockT]]": ...

@overload
def format(
self: "UserMessage[Sequence[Placeholder[UserMessageContentBlockT]]]",
self: "UserMessage[StrT | NonStringSequence[StrT2 | UserMessageContentBlockT | Placeholder[UserMessageContentBlockT2]]]",
**kwargs: Any,
) -> "UserMessage[Sequence[UserMessageContentBlockT]]": ...
) -> "UserMessage[StrT | Sequence[StrT2 | UserMessageContentBlockT | UserMessageContentBlockT2]]": ...

def format(
self: "UserMessage[str | Sequence[UserMessageContentBlockT | Placeholder[UserMessageContentBlockT]]]",
self: "UserMessage[StrT | NonStringSequence[StrT2 | UserMessageContentBlockT | Placeholder[UserMessageContentBlockT2]]]",
**kwargs: Any,
) -> "UserMessage[str | Sequence[UserMessageContentBlockT]]":
if isinstance(self.content, str | Placeholder):
) -> "UserMessage[StrT | Sequence[StrT2 | UserMessageContentBlockT | UserMessageContentBlockT2]]":
if isinstance(self.content, str):
return UserMessage(self.content.format(**kwargs))
if isinstance(self.content, Iterable):
return UserMessage([block.format(**kwargs) for block in self.content]) # type: ignore[misc]
21 changes: 20 additions & 1 deletion src/magentic/typing.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,26 @@
import inspect
import types
from collections.abc import Iterable, Mapping, Sequence
from typing import Any, TypeGuard, TypeVar, Union, get_args, get_origin
from typing import (
TYPE_CHECKING,
Any,
Protocol,
TypeGuard,
TypeVar,
Union,
get_args,
get_origin,
)

T_co = TypeVar("T_co", covariant=True)

if TYPE_CHECKING:
# Cannot be defined at runtime because Protocol cannot inherit from non-Protocol
class NonStringSequence(Sequence[T_co], Protocol[T_co]): # type: ignore[misc]
"""Protocol that matches Sequences except for `str`."""

# HACK: Works because `__contains__` method of `str` does not match `Sequence`
# See: https://github.com/python/typing/issues/256#issuecomment-1442633430


def is_union_type(type_: type) -> bool:
82 changes: 82 additions & 0 deletions tests/chat_model/test_message.py
Original file line number Diff line number Diff line change
@@ -11,6 +11,7 @@
AssistantMessage,
FunctionResultMessage,
ImageBytes,
ImageUrl,
Placeholder,
SystemMessage,
ToolResultMessage,
@@ -19,6 +20,9 @@
)
from magentic.function_call import FunctionCall

if TYPE_CHECKING:
from collections.abc import Sequence


def test_placeholder_format():
class Country(BaseModel):
@@ -111,6 +115,82 @@ def test_user_message_format():
assert user_message_formatted == UserMessage("Hello world")


def test_user_message_format_type_hints():
if TYPE_CHECKING: # Avoid runtime error for None missing `format` method
assert_type(cast(UserMessage[Literal["x"]], None).format(), UserMessage[str])
# mypy does not convert `Literal` to `str` in these cases but pyright does
# assert_type(cast(UserMessage[Sequence[Literal["x"]]], None).format(), UserMessage[Sequence[str]]) # noqa: ERA001
# assert_type(cast(UserMessage[Sequence[Literal["x"] | ImageBytes]], None).format(), UserMessage[Sequence[str | ImageBytes]]) # noqa: ERA001

assert_type(cast(UserMessage[str], None).format(), UserMessage[str])

assert_type(
cast(UserMessage[Sequence[str]], None).format(), UserMessage[Sequence[str]]
)
assert_type(
cast(UserMessage[str | Sequence[str]], None).format(),
UserMessage[str | Sequence[str]],
)

assert_type(
cast(UserMessage[Sequence[ImageBytes]], None).format(),
UserMessage[Sequence[ImageBytes]],
)
assert_type(
cast(UserMessage[Sequence[str | ImageBytes]], None).format(),
UserMessage[Sequence[str | ImageBytes]],
)
assert_type(
cast(UserMessage[str | Sequence[str | ImageBytes]], None).format(),
UserMessage[str | Sequence[str | ImageBytes]],
)

assert_type(
cast(UserMessage[Sequence[Placeholder[ImageUrl]]], None).format(),
UserMessage[Sequence[ImageUrl]],
)
assert_type(
cast(UserMessage[Sequence[str | Placeholder[ImageUrl]]], None).format(),
UserMessage[Sequence[str | ImageUrl]],
)
assert_type(
cast(
UserMessage[Sequence[ImageBytes | Placeholder[ImageUrl]]], None
).format(),
UserMessage[Sequence[ImageBytes | ImageUrl]],
)
assert_type(
cast(
UserMessage[Sequence[str | ImageBytes | Placeholder[ImageUrl]]], None
).format(),
UserMessage[Sequence[str | ImageBytes | ImageUrl]],
)

assert_type(
cast(UserMessage[str | Sequence[Placeholder[ImageUrl]]], None).format(),
UserMessage[str | Sequence[ImageUrl]],
)
assert_type(
cast(
UserMessage[str | Sequence[str | Placeholder[ImageUrl]]], None
).format(),
UserMessage[str | Sequence[str | ImageUrl]],
)
assert_type(
cast(
UserMessage[str | Sequence[ImageBytes | Placeholder[ImageUrl]]], None
).format(),
UserMessage[str | Sequence[ImageBytes | ImageUrl]],
)
assert_type(
cast(
UserMessage[str | Sequence[str | ImageBytes | Placeholder[ImageUrl]]],
None,
).format(),
UserMessage[str | Sequence[str | ImageBytes | ImageUrl]],
)


def test_assistant_message_usage():
assistant_message = AssistantMessage("Hello")
assert assistant_message.usage is None
@@ -140,6 +220,8 @@ class Country(BaseModel):
assert_type(assistant_message_formatted.content, Country)
assert assistant_message_formatted == AssistantMessage(Country(name="USA"))


def test_assistant_message_format_type_hints():
if TYPE_CHECKING: # Avoid runtime error for None missing `format` method
assert_type(
cast(AssistantMessage[Literal["x"]], None).format(), AssistantMessage[str]