diff --git a/tests/chat_model/test_anthropic_chat_model.py b/tests/chat_model/test_anthropic_chat_model.py index 4a9e3db..12b3ef5 100644 --- a/tests/chat_model/test_anthropic_chat_model.py +++ b/tests/chat_model/test_anthropic_chat_model.py @@ -1,4 +1,5 @@ from typing import Annotated +from unittest.mock import ANY import pytest from inline_snapshot import snapshot @@ -11,7 +12,9 @@ ) from magentic.chat_model.base import ToolSchemaParseError from magentic.chat_model.message import ( + AssistantMessage, DocumentBytes, + FunctionResultMessage, ImageBytes, Message, Usage, @@ -25,6 +28,103 @@ from magentic.streaming import AsyncStreamedStr, StreamedStr +def plus(a: int, b: int) -> int: + return a + b + + +@pytest.mark.parametrize( + ("message", "expected_anthropic_message"), + [ + (UserMessage("Hello"), {"role": "user", "content": "Hello"}), + (AssistantMessage("Hello"), {"role": "assistant", "content": "Hello"}), + ( + AssistantMessage(42), + { + "role": "assistant", + "content": [ + { + "type": "tool_use", + "id": ANY, + "name": "return_int", + "input": {"value": 42}, + } + ], + }, + ), + ( + AssistantMessage(FunctionCall(plus, 1, 2)), + { + "role": "assistant", + "content": [ + { + "type": "tool_use", + "id": ANY, + "name": "plus", + "input": {"a": 1, "b": 2}, + } + ], + }, + ), + ( + AssistantMessage( + ParallelFunctionCall( + [FunctionCall(plus, 1, 2), FunctionCall(plus, 3, 4)] + ) + ), + { + "role": "assistant", + "content": [ + { + "type": "tool_use", + "id": ANY, + "name": "plus", + "input": {"a": 1, "b": 2}, + }, + { + "type": "tool_use", + "id": ANY, + "name": "plus", + "input": {"a": 3, "b": 4}, + }, + ], + }, + ), + ( + AssistantMessage( + StreamedResponse([StreamedStr(["Hello"]), FunctionCall(plus, 1, 2)]) + ), + { + "role": "assistant", + "content": [ + {"type": "text", "text": "Hello"}, + { + "type": "tool_use", + "id": ANY, + "name": "plus", + "input": {"a": 1, "b": 2}, + }, + ], + }, + ), + ( + FunctionResultMessage(3, FunctionCall(plus, 1, 2)), + { + "role": "user", + "content": [ + { + "type": "tool_result", + "tool_use_id": ANY, + "content": {"value": 3}, + } + ], + }, + ), + ], +) +def test_message_to_anthropic_message(message, expected_anthropic_message): + assert message_to_anthropic_message(message) == expected_anthropic_message + + def test_message_to_anthropic_message_user_image_document_bytes_pdf(document_bytes_pdf): image_message = UserMessage([DocumentBytes(document_bytes_pdf)]) assert message_to_anthropic_message(image_message) == snapshot(