From 46bba753685848a4fdca0405db115661bbbc68a5 Mon Sep 17 00:00:00 2001 From: anakin87 Date: Tue, 7 Jan 2025 16:38:46 +0100 Subject: [PATCH] make chatmessage retriever and writer compatible with the new API --- .../test_in_memory_chat_message_store.py | 30 ++++++------- .../retrievers/test_chat_message_retriever.py | 44 +++++++++---------- .../writers/test_chat_message_writer.py | 10 ++--- 3 files changed, 42 insertions(+), 42 deletions(-) diff --git a/test/chat_message_stores/test_in_memory_chat_message_store.py b/test/chat_message_stores/test_in_memory_chat_message_store.py index 47ef9e85..e89a3be2 100644 --- a/test/chat_message_stores/test_in_memory_chat_message_store.py +++ b/test/chat_message_stores/test_in_memory_chat_message_store.py @@ -42,11 +42,11 @@ def test_count_messages(self): """ store = InMemoryChatMessageStore() assert store.count_messages() == 0 - store.write_messages(messages=[ChatMessage.from_user(content="Hello, how can I help you?")]) + store.write_messages(messages=[ChatMessage.from_user("Hello, how can I help you?")]) assert store.count_messages() == 1 - store.write_messages(messages=[ChatMessage.from_user(content="Hallo, wie kann ich Ihnen helfen?")]) + store.write_messages(messages=[ChatMessage.from_user("Hallo, wie kann ich Ihnen helfen?")]) assert store.count_messages() == 2 - store.write_messages(messages=[ChatMessage.from_user(content="Hola, ¿cómo puedo ayudarte?")]) + store.write_messages(messages=[ChatMessage.from_user("Hola, ¿cómo puedo ayudarte?")]) assert store.count_messages() == 3 def test_retrieve(self): @@ -55,18 +55,18 @@ def test_retrieve(self): """ store = InMemoryChatMessageStore() assert store.retrieve() == [] - store.write_messages(messages=[ChatMessage.from_user(content="Hello, how can I help you?")]) - assert store.retrieve() == [ChatMessage.from_user(content="Hello, how can I help you?")] - store.write_messages(messages=[ChatMessage.from_user(content="Hallo, wie kann ich Ihnen helfen?")]) + store.write_messages(messages=[ChatMessage.from_user("Hello, how can I help you?")]) + assert store.retrieve() == [ChatMessage.from_user("Hello, how can I help you?")] + store.write_messages(messages=[ChatMessage.from_user("Hallo, wie kann ich Ihnen helfen?")]) assert store.retrieve() == [ - ChatMessage.from_user(content="Hello, how can I help you?"), - ChatMessage.from_user(content="Hallo, wie kann ich Ihnen helfen?"), + ChatMessage.from_user("Hello, how can I help you?"), + ChatMessage.from_user("Hallo, wie kann ich Ihnen helfen?"), ] - store.write_messages(messages=[ChatMessage.from_user(content="Hola, ¿cómo puedo ayudarte?")]) + store.write_messages(messages=[ChatMessage.from_user("Hola, ¿cómo puedo ayudarte?")]) assert store.retrieve() == [ - ChatMessage.from_user(content="Hello, how can I help you?"), - ChatMessage.from_user(content="Hallo, wie kann ich Ihnen helfen?"), - ChatMessage.from_user(content="Hola, ¿cómo puedo ayudarte?"), + ChatMessage.from_user("Hello, how can I help you?"), + ChatMessage.from_user("Hallo, wie kann ich Ihnen helfen?"), + ChatMessage.from_user("Hola, ¿cómo puedo ayudarte?"), ] def test_delete_messages(self): @@ -75,12 +75,12 @@ def test_delete_messages(self): """ store = InMemoryChatMessageStore() assert store.count_messages() == 0 - store.write_messages(messages=[ChatMessage.from_user(content="Hello, how can I help you?")]) + store.write_messages(messages=[ChatMessage.from_user("Hello, how can I help you?")]) assert store.count_messages() == 1 store.delete_messages() assert store.count_messages() == 0 - store.write_messages(messages=[ChatMessage.from_user(content="Hallo, wie kann ich Ihnen helfen?")]) - store.write_messages(messages=[ChatMessage.from_user(content="Hola, ¿cómo puedo ayudarte?")]) + store.write_messages(messages=[ChatMessage.from_user("Hallo, wie kann ich Ihnen helfen?")]) + store.write_messages(messages=[ChatMessage.from_user("Hola, ¿cómo puedo ayudarte?")]) assert store.count_messages() == 2 store.delete_messages() assert store.count_messages() == 0 diff --git a/test/components/retrievers/test_chat_message_retriever.py b/test/components/retrievers/test_chat_message_retriever.py index a378c762..1368ef25 100644 --- a/test/components/retrievers/test_chat_message_retriever.py +++ b/test/components/retrievers/test_chat_message_retriever.py @@ -24,8 +24,8 @@ def test_retrieve_messages(self): Test that the ChatMessageRetriever component can retrieve messages from the message store. """ messages = [ - ChatMessage.from_user(content="Hello, how can I help you?"), - ChatMessage.from_user(content="Hallo, wie kann ich Ihnen helfen?") + ChatMessage.from_user("Hello, how can I help you?"), + ChatMessage.from_user("Hallo, wie kann ich Ihnen helfen?") ] message_store = InMemoryChatMessageStore() @@ -40,10 +40,10 @@ def test_retrieve_messages_last_k(self): Test that the ChatMessageRetriever component can retrieve last_k messages from the message store. """ messages = [ - ChatMessage.from_user(content="Hello, how can I help you?"), - ChatMessage.from_user(content="Hallo, wie kann ich Ihnen helfen?"), - ChatMessage.from_user(content="Hola, como puedo ayudarte?"), - ChatMessage.from_user(content="Bonjour, comment puis-je vous aider?") + ChatMessage.from_user("Hello, how can I help you?"), + ChatMessage.from_user("Hallo, wie kann ich Ihnen helfen?"), + ChatMessage.from_user("Hola, como puedo ayudarte?"), + ChatMessage.from_user("Bonjour, comment puis-je vous aider?") ] message_store = InMemoryChatMessageStore() @@ -52,19 +52,19 @@ def test_retrieve_messages_last_k(self): assert retriever.message_store == message_store assert retriever.run(last_k=1) == { - "messages": [ChatMessage.from_user(content="Bonjour, comment puis-je vous aider?")]} + "messages": [ChatMessage.from_user("Bonjour, comment puis-je vous aider?")]} assert retriever.run(last_k=2) == { - "messages": [ChatMessage.from_user(content="Hola, como puedo ayudarte?"), - ChatMessage.from_user(content="Bonjour, comment puis-je vous aider?") + "messages": [ChatMessage.from_user("Hola, como puedo ayudarte?"), + ChatMessage.from_user("Bonjour, comment puis-je vous aider?") ]} # outliers assert retriever.run(last_k=10) == { - "messages": [ChatMessage.from_user(content="Hello, how can I help you?"), - ChatMessage.from_user(content="Hallo, wie kann ich Ihnen helfen?"), - ChatMessage.from_user(content="Hola, como puedo ayudarte?"), - ChatMessage.from_user(content="Bonjour, comment puis-je vous aider?") + "messages": [ChatMessage.from_user("Hello, how can I help you?"), + ChatMessage.from_user("Hallo, wie kann ich Ihnen helfen?"), + ChatMessage.from_user("Hola, como puedo ayudarte?"), + ChatMessage.from_user("Bonjour, comment puis-je vous aider?") ]} with pytest.raises(ValueError): @@ -79,10 +79,10 @@ def test_retrieve_messages_last_k_init(self): by testing the init last_k parameter and the run last_k parameter logic """ messages = [ - ChatMessage.from_user(content="Hello, how can I help you?"), - ChatMessage.from_user(content="Hallo, wie kann ich Ihnen helfen?"), - ChatMessage.from_user(content="Hola, como puedo ayudarte?"), - ChatMessage.from_user(content="Bonjour, comment puis-je vous aider?") + ChatMessage.from_user("Hello, how can I help you?"), + ChatMessage.from_user("Hallo, wie kann ich Ihnen helfen?"), + ChatMessage.from_user("Hola, como puedo ayudarte?"), + ChatMessage.from_user("Bonjour, comment puis-je vous aider?") ] message_store = InMemoryChatMessageStore() @@ -93,12 +93,12 @@ def test_retrieve_messages_last_k_init(self): # last_k is 1 here from run parameter, overrides init of 2 assert retriever.run(last_k=1) == { - "messages": [ChatMessage.from_user(content="Bonjour, comment puis-je vous aider?")]} + "messages": [ChatMessage.from_user("Bonjour, comment puis-je vous aider?")]} # last_k is 2 here from init assert retriever.run() == { - "messages": [ChatMessage.from_user(content="Hola, como puedo ayudarte?"), - ChatMessage.from_user(content="Bonjour, comment puis-je vous aider?") + "messages": [ChatMessage.from_user("Hola, como puedo ayudarte?"), + ChatMessage.from_user("Bonjour, comment puis-je vous aider?") ]} def test_to_dict(self): @@ -157,7 +157,7 @@ def test_chat_message_retriever_pipeline(self): Context: {% for memory in memories %} - {{ memory.content }} + {{ memory.text }} {% endfor %} Question: {{ query }} @@ -166,7 +166,7 @@ def test_chat_message_retriever_pipeline(self): question = "What is the capital of France?" res = pipe.run(data={"prompt_builder": {"template": [ChatMessage.from_user(user_prompt)], "query": question}}) - resulting_prompt = res["prompt_builder"]["prompt"][0].content + resulting_prompt = res["prompt_builder"]["prompt"][0].text assert "France" in resulting_prompt assert "how can I help you" in resulting_prompt diff --git a/test/components/writers/test_chat_message_writer.py b/test/components/writers/test_chat_message_writer.py index 27d4652f..5bf9ff24 100644 --- a/test/components/writers/test_chat_message_writer.py +++ b/test/components/writers/test_chat_message_writer.py @@ -12,8 +12,8 @@ def test_init(self): Test that the ChatMessageWriter component can be initialized with a valid message store. """ messages = [ - ChatMessage.from_user(content="Hello, how can I help you?"), - ChatMessage.from_user(content="Hallo, wie kann ich Ihnen helfen?") + ChatMessage.from_user("Hello, how can I help you?"), + ChatMessage.from_user("Hallo, wie kann ich Ihnen helfen?") ] message_store = InMemoryChatMessageStore() @@ -42,7 +42,7 @@ def test_to_dict(self): } # write again and serialize - writer.run(messages=[ChatMessage.from_user(content="Hello, how can I help you?")]) + writer.run(messages=[ChatMessage.from_user("Hello, how can I help you?")]) data = writer.to_dict() assert data == { "type": "haystack_experimental.components.writers.chat_message_writer.ChatMessageWriter", @@ -74,7 +74,7 @@ def test_from_dict(self): } # write to verify that everything is still working - results = writer.run(messages=[ChatMessage.from_user(content="Hello, how can I help you?")]) + results = writer.run(messages=[ChatMessage.from_user("Hello, how can I help you?")]) assert results["messages_written"] == 1 def test_chat_message_writer_pipeline(self): @@ -97,7 +97,7 @@ def test_chat_message_writer_pipeline(self): res = pipe.run(data={"prompt_builder": {"template": [ChatMessage.from_user(user_prompt)], "query": question}}) assert res["writer"]["messages_written"] == 1 # only one message is written assert len(store.retrieve()) == 1 # only one message is written - assert store.retrieve()[0].content == """ + assert store.retrieve()[0].text == """ Given the following information, answer the question. Question: What is the capital of France? Answer: