Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: make tests for ChatMessage retriever and writer compatible with the new API #163

Merged
merged 1 commit into from
Jan 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 15 additions & 15 deletions test/chat_message_stores/test_in_memory_chat_message_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,11 @@ def test_count_messages(self):
"""
store = InMemoryChatMessageStore()
assert store.count_messages() == 0
store.write_messages(messages=[ChatMessage.from_user(content="Hello, how can I help you?")])
store.write_messages(messages=[ChatMessage.from_user("Hello, how can I help you?")])
assert store.count_messages() == 1
store.write_messages(messages=[ChatMessage.from_user(content="Hallo, wie kann ich Ihnen helfen?")])
store.write_messages(messages=[ChatMessage.from_user("Hallo, wie kann ich Ihnen helfen?")])
assert store.count_messages() == 2
store.write_messages(messages=[ChatMessage.from_user(content="Hola, ¿cómo puedo ayudarte?")])
store.write_messages(messages=[ChatMessage.from_user("Hola, ¿cómo puedo ayudarte?")])
assert store.count_messages() == 3

def test_retrieve(self):
Expand All @@ -55,18 +55,18 @@ def test_retrieve(self):
"""
store = InMemoryChatMessageStore()
assert store.retrieve() == []
store.write_messages(messages=[ChatMessage.from_user(content="Hello, how can I help you?")])
assert store.retrieve() == [ChatMessage.from_user(content="Hello, how can I help you?")]
store.write_messages(messages=[ChatMessage.from_user(content="Hallo, wie kann ich Ihnen helfen?")])
store.write_messages(messages=[ChatMessage.from_user("Hello, how can I help you?")])
assert store.retrieve() == [ChatMessage.from_user("Hello, how can I help you?")]
store.write_messages(messages=[ChatMessage.from_user("Hallo, wie kann ich Ihnen helfen?")])
assert store.retrieve() == [
ChatMessage.from_user(content="Hello, how can I help you?"),
ChatMessage.from_user(content="Hallo, wie kann ich Ihnen helfen?"),
ChatMessage.from_user("Hello, how can I help you?"),
ChatMessage.from_user("Hallo, wie kann ich Ihnen helfen?"),
]
store.write_messages(messages=[ChatMessage.from_user(content="Hola, ¿cómo puedo ayudarte?")])
store.write_messages(messages=[ChatMessage.from_user("Hola, ¿cómo puedo ayudarte?")])
assert store.retrieve() == [
ChatMessage.from_user(content="Hello, how can I help you?"),
ChatMessage.from_user(content="Hallo, wie kann ich Ihnen helfen?"),
ChatMessage.from_user(content="Hola, ¿cómo puedo ayudarte?"),
ChatMessage.from_user("Hello, how can I help you?"),
ChatMessage.from_user("Hallo, wie kann ich Ihnen helfen?"),
ChatMessage.from_user("Hola, ¿cómo puedo ayudarte?"),
]

def test_delete_messages(self):
Expand All @@ -75,12 +75,12 @@ def test_delete_messages(self):
"""
store = InMemoryChatMessageStore()
assert store.count_messages() == 0
store.write_messages(messages=[ChatMessage.from_user(content="Hello, how can I help you?")])
store.write_messages(messages=[ChatMessage.from_user("Hello, how can I help you?")])
assert store.count_messages() == 1
store.delete_messages()
assert store.count_messages() == 0
store.write_messages(messages=[ChatMessage.from_user(content="Hallo, wie kann ich Ihnen helfen?")])
store.write_messages(messages=[ChatMessage.from_user(content="Hola, ¿cómo puedo ayudarte?")])
store.write_messages(messages=[ChatMessage.from_user("Hallo, wie kann ich Ihnen helfen?")])
store.write_messages(messages=[ChatMessage.from_user("Hola, ¿cómo puedo ayudarte?")])
assert store.count_messages() == 2
store.delete_messages()
assert store.count_messages() == 0
44 changes: 22 additions & 22 deletions test/components/retrievers/test_chat_message_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ def test_retrieve_messages(self):
Test that the ChatMessageRetriever component can retrieve messages from the message store.
"""
messages = [
ChatMessage.from_user(content="Hello, how can I help you?"),
ChatMessage.from_user(content="Hallo, wie kann ich Ihnen helfen?")
ChatMessage.from_user("Hello, how can I help you?"),
ChatMessage.from_user("Hallo, wie kann ich Ihnen helfen?")
]

message_store = InMemoryChatMessageStore()
Expand All @@ -40,10 +40,10 @@ def test_retrieve_messages_last_k(self):
Test that the ChatMessageRetriever component can retrieve last_k messages from the message store.
"""
messages = [
ChatMessage.from_user(content="Hello, how can I help you?"),
ChatMessage.from_user(content="Hallo, wie kann ich Ihnen helfen?"),
ChatMessage.from_user(content="Hola, como puedo ayudarte?"),
ChatMessage.from_user(content="Bonjour, comment puis-je vous aider?")
ChatMessage.from_user("Hello, how can I help you?"),
ChatMessage.from_user("Hallo, wie kann ich Ihnen helfen?"),
ChatMessage.from_user("Hola, como puedo ayudarte?"),
ChatMessage.from_user("Bonjour, comment puis-je vous aider?")
]

message_store = InMemoryChatMessageStore()
Expand All @@ -52,19 +52,19 @@ def test_retrieve_messages_last_k(self):

assert retriever.message_store == message_store
assert retriever.run(last_k=1) == {
"messages": [ChatMessage.from_user(content="Bonjour, comment puis-je vous aider?")]}
"messages": [ChatMessage.from_user("Bonjour, comment puis-je vous aider?")]}

assert retriever.run(last_k=2) == {
"messages": [ChatMessage.from_user(content="Hola, como puedo ayudarte?"),
ChatMessage.from_user(content="Bonjour, comment puis-je vous aider?")
"messages": [ChatMessage.from_user("Hola, como puedo ayudarte?"),
ChatMessage.from_user("Bonjour, comment puis-je vous aider?")
]}

# outliers
assert retriever.run(last_k=10) == {
"messages": [ChatMessage.from_user(content="Hello, how can I help you?"),
ChatMessage.from_user(content="Hallo, wie kann ich Ihnen helfen?"),
ChatMessage.from_user(content="Hola, como puedo ayudarte?"),
ChatMessage.from_user(content="Bonjour, comment puis-je vous aider?")
"messages": [ChatMessage.from_user("Hello, how can I help you?"),
ChatMessage.from_user("Hallo, wie kann ich Ihnen helfen?"),
ChatMessage.from_user("Hola, como puedo ayudarte?"),
ChatMessage.from_user("Bonjour, comment puis-je vous aider?")
]}

with pytest.raises(ValueError):
Expand All @@ -79,10 +79,10 @@ def test_retrieve_messages_last_k_init(self):
by testing the init last_k parameter and the run last_k parameter logic
"""
messages = [
ChatMessage.from_user(content="Hello, how can I help you?"),
ChatMessage.from_user(content="Hallo, wie kann ich Ihnen helfen?"),
ChatMessage.from_user(content="Hola, como puedo ayudarte?"),
ChatMessage.from_user(content="Bonjour, comment puis-je vous aider?")
ChatMessage.from_user("Hello, how can I help you?"),
ChatMessage.from_user("Hallo, wie kann ich Ihnen helfen?"),
ChatMessage.from_user("Hola, como puedo ayudarte?"),
ChatMessage.from_user("Bonjour, comment puis-je vous aider?")
]

message_store = InMemoryChatMessageStore()
Expand All @@ -93,12 +93,12 @@ def test_retrieve_messages_last_k_init(self):

# last_k is 1 here from run parameter, overrides init of 2
assert retriever.run(last_k=1) == {
"messages": [ChatMessage.from_user(content="Bonjour, comment puis-je vous aider?")]}
"messages": [ChatMessage.from_user("Bonjour, comment puis-je vous aider?")]}

# last_k is 2 here from init
assert retriever.run() == {
"messages": [ChatMessage.from_user(content="Hola, como puedo ayudarte?"),
ChatMessage.from_user(content="Bonjour, comment puis-je vous aider?")
"messages": [ChatMessage.from_user("Hola, como puedo ayudarte?"),
ChatMessage.from_user("Bonjour, comment puis-je vous aider?")
]}

def test_to_dict(self):
Expand Down Expand Up @@ -157,7 +157,7 @@ def test_chat_message_retriever_pipeline(self):

Context:
{% for memory in memories %}
{{ memory.content }}
{{ memory.text }}
{% endfor %}

Question: {{ query }}
Expand All @@ -166,7 +166,7 @@ def test_chat_message_retriever_pipeline(self):
question = "What is the capital of France?"

res = pipe.run(data={"prompt_builder": {"template": [ChatMessage.from_user(user_prompt)], "query": question}})
resulting_prompt = res["prompt_builder"]["prompt"][0].content
resulting_prompt = res["prompt_builder"]["prompt"][0].text
assert "France" in resulting_prompt
assert "how can I help you" in resulting_prompt

Expand Down
10 changes: 5 additions & 5 deletions test/components/writers/test_chat_message_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ def test_init(self):
Test that the ChatMessageWriter component can be initialized with a valid message store.
"""
messages = [
ChatMessage.from_user(content="Hello, how can I help you?"),
ChatMessage.from_user(content="Hallo, wie kann ich Ihnen helfen?")
ChatMessage.from_user("Hello, how can I help you?"),
ChatMessage.from_user("Hallo, wie kann ich Ihnen helfen?")
]

message_store = InMemoryChatMessageStore()
Expand Down Expand Up @@ -42,7 +42,7 @@ def test_to_dict(self):
}

# write again and serialize
writer.run(messages=[ChatMessage.from_user(content="Hello, how can I help you?")])
writer.run(messages=[ChatMessage.from_user("Hello, how can I help you?")])
data = writer.to_dict()
assert data == {
"type": "haystack_experimental.components.writers.chat_message_writer.ChatMessageWriter",
Expand Down Expand Up @@ -74,7 +74,7 @@ def test_from_dict(self):
}

# write to verify that everything is still working
results = writer.run(messages=[ChatMessage.from_user(content="Hello, how can I help you?")])
results = writer.run(messages=[ChatMessage.from_user("Hello, how can I help you?")])
assert results["messages_written"] == 1

def test_chat_message_writer_pipeline(self):
Expand All @@ -97,7 +97,7 @@ def test_chat_message_writer_pipeline(self):
res = pipe.run(data={"prompt_builder": {"template": [ChatMessage.from_user(user_prompt)], "query": question}})
assert res["writer"]["messages_written"] == 1 # only one message is written
assert len(store.retrieve()) == 1 # only one message is written
assert store.retrieve()[0].content == """
assert store.retrieve()[0].text == """
Given the following information, answer the question.
Question: What is the capital of France?
Answer:
Expand Down
Loading