diff --git a/libs/core/langchain_core/prompts/chat.py b/libs/core/langchain_core/prompts/chat.py index 11924e3328d68..80d427c8be6c7 100644 --- a/libs/core/langchain_core/prompts/chat.py +++ b/libs/core/langchain_core/prompts/chat.py @@ -828,6 +828,7 @@ def pretty_print(self) -> None: Union[str, list[dict], list[object]], ], str, + dict, ] @@ -1461,7 +1462,15 @@ def _convert_to_message( _message = _create_template_from_message_type( "human", message, template_format=template_format ) - elif isinstance(message, tuple): + elif isinstance(message, (tuple, dict)): + if isinstance(message, dict): + if set(message.keys()) != {"content", "role"}: + msg = ( + "Expected dict to have exact keys 'role' and 'content'." + f" Got: {message}" + ) + raise ValueError(msg) + message = (message["role"], message["content"]) if len(message) != 2: msg = f"Expected 2-tuple of (role, template), got {message}" raise ValueError(msg) diff --git a/libs/core/tests/unit_tests/prompts/test_chat.py b/libs/core/tests/unit_tests/prompts/test_chat.py index abc0ac4f32628..32ad050077990 100644 --- a/libs/core/tests/unit_tests/prompts/test_chat.py +++ b/libs/core/tests/unit_tests/prompts/test_chat.py @@ -824,6 +824,41 @@ def test_chat_prompt_message_placeholder_tuple() -> None: assert optional_prompt.format_messages() == [] +def test_chat_prompt_message_placeholder_dict() -> None: + prompt = ChatPromptTemplate([{"role": "placeholder", "content": "{convo}"}]) + assert prompt.format_messages(convo=[("user", "foo")]) == [ + HumanMessage(content="foo") + ] + + assert prompt.format_messages() == [] + + # Is optional = True + optional_prompt = ChatPromptTemplate( + [{"role": "placeholder", "content": ["{convo}", False]}] + ) + assert optional_prompt.format_messages(convo=[("user", "foo")]) == [ + HumanMessage(content="foo") + ] + with pytest.raises(KeyError): + assert optional_prompt.format_messages() == [] + + +def test_chat_prompt_message_dict() -> None: + prompt = ChatPromptTemplate( + [{"role": "system", "content": "foo"}, {"role": "user", "content": "bar"}] + ) + assert prompt.format_messages() == [ + SystemMessage(content="foo"), + HumanMessage(content="bar"), + ] + + with pytest.raises(ValueError): + ChatPromptTemplate([{"role": "system", "content": False}]) + + with pytest.raises(ValueError): + ChatPromptTemplate([{"role": "foo", "content": "foo"}]) + + async def test_messages_prompt_accepts_list() -> None: prompt = ChatPromptTemplate([MessagesPlaceholder("history")]) value = prompt.invoke([("user", "Hi there")]) # type: ignore