Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix[chat_bubble]: always place bubble at the end of the sentence #38

Merged
merged 6 commits into from
Oct 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions backend/app/api/v1/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
user_repository,
)
from app.requests import chat_query
from app.utils import clean_text
from app.utils import clean_text, find_following_sentence_ending, find_sentence_endings
from app.vectorstore.chroma import ChromaDB
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel
Expand Down Expand Up @@ -94,6 +94,7 @@ def chat(project_id: int, chat_request: ChatRequest, db: Session = Depends(get_d
content = response["response"]
content_length = len(content)
clean_content = clean_text(content)
context_sentence_endings = find_sentence_endings(content)
text_references = []
not_exact_matched_refs = []

Expand Down Expand Up @@ -121,8 +122,12 @@ def chat(project_id: int, chat_request: ChatRequest, db: Session = Depends(get_d
metadata = doc_metadata[best_match_index]
sent = doc_sent[best_match_index]

# find sentence start index of reference in the context
index = clean_content.find(clean_text(sentence))

# Find the following sentence end from the end index
reference_ending_index = find_following_sentence_ending(context_sentence_endings, index + len(sentence))

if index != -1:
text_reference = {
"asset_id": metadata["asset_id"],
Expand All @@ -131,7 +136,7 @@ def chat(project_id: int, chat_request: ChatRequest, db: Session = Depends(get_d
"filename": original_filename,
"source": [sent],
"start": index,
"end": index + len(sentence),
"end": reference_ending_index,
}
text_references.append(text_reference)
else:
Expand Down
27 changes: 27 additions & 0 deletions backend/app/utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import hashlib
from typing import List
from urllib.parse import urlparse
import uuid
import requests
import re
import string

from bisect import bisect_right


def generate_unique_filename(url, extension=".html"):
url_hash = hashlib.sha256(url.encode("utf-8")).hexdigest()
Expand Down Expand Up @@ -61,3 +64,27 @@ def fetch_html_and_save(url, file_path):
# Save the content to a file
with open(file_path, "wb") as file:
file.write(response.content)


def find_sentence_endings(text: str) -> List[int]:
# Regex to find periods, exclamation marks, and question marks followed by a space or the end of the text
sentence_endings = [match.end() for match in re.finditer(r'[.!?](?:\s|$)', text)]

# Add the last index of the text as an additional sentence ending
sentence_endings.append(len(text))

return sentence_endings

def find_following_sentence_ending(sentence_endings: List[int], index: int) -> int:
"""
Find the closest sentence ending that follows the given index.

Args:
sentence_endings: Sorted list of sentence ending positions
index: Current position in text

Returns:
Next sentence ending position or original index if none found
"""
pos = bisect_right(sentence_endings, index)
return sentence_endings[pos] if pos < len(sentence_endings) else index
148 changes: 145 additions & 3 deletions backend/tests/api/v1/test_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def mock_db():

@pytest.fixture
def mock_vectorstore():
with patch("app.vectorstore.chroma.ChromaDB") as mock:
with patch("app.api.v1.chat.ChromaDB") as mock:
yield mock


Expand All @@ -24,15 +24,38 @@ def mock_chat_query():
with patch("app.api.v1.chat.chat_query") as mock:
yield mock


@pytest.fixture
def mock_user_repository():
with patch("app.api.v1.chat.user_repository.get_user_api_key") as mock:
yield mock


@pytest.fixture
def mock_get_users_repository():
with patch("app.api.v1.chat.user_repository.get_users") as mock:
yield mock

@pytest.fixture
def mock_conversation_repository():
with patch("app.api.v1.chat.conversation_repository") as mock:
yield mock

@pytest.fixture
def mock_get_assets_filename():
with patch("app.api.v1.chat.project_repository.get_assets_filename") as mock:
yield mock


def test_chat_status_endpoint(mock_db):
# Arrange
project_id = 1

with patch("app.repositories.project_repository.get_assets_without_content", return_value=[]):
with patch("app.repositories.project_repository.get_assets_content_pending", return_value=[]):
# Act
response = client.get(f"/v1/chat/project/{project_id}/status")

# Assert
assert response.status_code == 200
assert response.json()["status"] == "success"
Expand Down Expand Up @@ -281,3 +304,122 @@ def test_group_by_start_end_large_values():
assert len(result) == 1
assert result[0]["start"] == 1000000 and result[0]["end"] == 1000010
assert len(result[0]["references"]) == 2


def test_chat_endpoint_success(mock_db, mock_get_users_repository, mock_vectorstore, mock_chat_query, mock_user_repository, mock_conversation_repository, mock_get_assets_filename):
# Arrange
project_id = 1
chat_request = {
"query": "Tell me about sustainability.",
"conversation_id": None
}

# Mocking dependencies
mock_vectorstore.return_value.get_relevant_segments.return_value = (["Quote 1", "Quote 2"], [1, 2], {})
mock_user_repository.return_value.key = "test_api_key"
mock_chat_query.return_value = {"response": "Here's a response", "references": []}
mock_conversation_repository.create_new_conversation.return_value = MagicMock(id=123)
mock_get_users_repository.return_value = MagicMock(id=1)
mock_get_assets_filename.return_value = ["file1.pdf", "file2.pdf"]

# Act
response = client.post(f"/v1/chat/project/{project_id}", json=chat_request)

# Assert
assert response.status_code == 200
assert response.json()["status"] == "success"
assert response.json()["data"]["conversation_id"] == "123"
assert response.json()["data"]["response"] == "Here's a response"

def test_chat_endpoint_creates_conversation(mock_db, mock_get_users_repository, mock_vectorstore, mock_chat_query, mock_user_repository, mock_conversation_repository, mock_get_assets_filename):
# Arrange
project_id = 1
chat_request = {
"query": "What's the latest on climate change?",
"conversation_id": None
}

# Set up mock responses
mock_vectorstore.return_value.get_relevant_segments.return_value = (["Quote 1"], [1], {})
mock_user_repository.return_value.key = "test_api_key"
mock_chat_query.return_value = {"response": "Latest news on climate change", "references": []}

# Explicitly set the mock to return 456 as the conversation ID
mock_conversation_repository.create_new_conversation.return_value = MagicMock(id=456)
mock_get_users_repository.return_value = MagicMock(id=1)
mock_get_assets_filename.return_value = ["file1.pdf"]

# Act
response = client.post(f"/v1/chat/project/{project_id}", json=chat_request)

# Assert
assert response.status_code == 200
assert response.json()["data"]["conversation_id"] == "456"
assert mock_conversation_repository.create_new_conversation.called

def test_chat_endpoint_error_handling(mock_db, mock_vectorstore, mock_chat_query):
# Arrange
project_id = 1
chat_request = {
"query": "An error should occur.",
"conversation_id": None
}

mock_vectorstore.return_value.get_relevant_segments.side_effect = Exception("Database error")

# Act
response = client.post(f"/v1/chat/project/{project_id}", json=chat_request)

# Assert
assert response.status_code == 400
assert "Unable to process the chat query" in response.json()["detail"]

def test_chat_endpoint_reference_processing(mock_db, mock_get_users_repository, mock_vectorstore, mock_chat_query, mock_user_repository, mock_conversation_repository, mock_get_assets_filename):
# Arrange
project_id = 1
chat_request = {
"query": "Reference query.",
"conversation_id": None
}

mock_vectorstore.return_value.get_relevant_segments.return_value = (["Reference Quote"], [1], [{"asset_id":1, "project_id": project_id,"filename": "test.pdf","page_number": 1}])
mock_user_repository.return_value.key = "test_api_key"
mock_chat_query.return_value = {
"response": "Response with references",
"references": [
{
"sentence": "Reference Quote",
"references": [{"file": "file1.pdf", "sentence": "Original sentence"}]
}
]
}
mock_conversation_repository.create_new_conversation.return_value.id = 789
mock_get_users_repository.return_value = MagicMock(id=1)
mock_get_assets_filename.return_value = ["file1.pdf"]

# Act
response = client.post(f"/v1/chat/project/{project_id}", json=chat_request)

# Assert
assert response.status_code == 200
assert len(response.json()["data"]["response_references"]) > 0

def test_chat_endpoint_with_conversation_id(mock_db, mock_vectorstore, mock_chat_query, mock_user_repository, mock_conversation_repository, mock_get_assets_filename):
# Arrange
project_id = 1
chat_request = {
"query": "Chat with conversation.",
"conversation_id": "existing_convo_id"
}

mock_vectorstore.return_value.get_relevant_segments.return_value = (["Quote"], [1], {})
mock_user_repository.return_value.key = "test_api_key"
mock_chat_query.return_value = {"response": "Response with existing conversation", "references": []}
mock_get_assets_filename.return_value = ["file1.pdf"]

# Act
response = client.post(f"/v1/chat/project/{project_id}", json=chat_request)

# Assert
assert response.status_code == 200
assert response.json()["data"]["conversation_id"] == "existing_convo_id"
45 changes: 45 additions & 0 deletions backend/tests/utils/test_following_sentence_ending.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import unittest

from app.utils import find_following_sentence_ending


class TestFindFollowingSentenceEnding(unittest.TestCase):
def test_basic_case(self):
sentence_endings = [10, 20, 30, 40]
index = 15
expected = 20 # Closest ending greater than 15 is 20
self.assertEqual(find_following_sentence_ending(sentence_endings, index), expected)

def test_no_greater_ending(self):
sentence_endings = [10, 20, 30]
index = 35
expected = 35 # No greater ending than 35, so it returns the index itself
self.assertEqual(find_following_sentence_ending(sentence_endings, index), expected)

def test_at_ending_boundary(self):
sentence_endings = [10, 20, 30, 40]
index = 30
expected = 40 # The next greater ending after 30 is 40
self.assertEqual(find_following_sentence_ending(sentence_endings, index), expected)

def test_first_sentence(self):
sentence_endings = [10, 20, 30, 40]
index = 5
expected = 10 # The closest ending greater than 5 is 10
self.assertEqual(find_following_sentence_ending(sentence_endings, index), expected)

def test_empty_sentence_endings(self):
sentence_endings = []
index = 5
expected = 5 # No sentence endings, so return the index itself
self.assertEqual(find_following_sentence_ending(sentence_endings, index), expected)

def test_same_index_as_last_ending(self):
sentence_endings = [10, 20, 30]
index = 30
expected = 30 # At the last sentence ending, return the index itself
self.assertEqual(find_following_sentence_ending(sentence_endings, index), expected)

# Run the tests
if __name__ == "__main__":
unittest.main()
39 changes: 39 additions & 0 deletions backend/tests/utils/test_sentence_endings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import unittest

from app.utils import find_sentence_endings


class TestFindSentenceEndings(unittest.TestCase):
def test_basic_sentences(self):
text = "This is a sentence. This is another one!"
expected = [20, 40, len(text)] # Sentence endings at ".", "!", and the last index
self.assertEqual(find_sentence_endings(text), expected)

def test_text_without_punctuation(self):
text = "This is a sentence without punctuation"
expected = [len(text)] # Only the last index is expected
self.assertEqual(find_sentence_endings(text), expected)

def test_multiple_punctuation(self):
text = "Is this working? Yes! It seems so."
expected = [17, 22, 34, len(text)] # Endings after "?", "!", ".", and the last index
self.assertEqual(find_sentence_endings(text), expected)

def test_trailing_whitespace(self):
text = "Trailing whitespace should be ignored. "
expected = [39, len(text)] # End at the period and the final index
self.assertEqual(find_sentence_endings(text), expected)

def test_punctuation_in_middle_of_text(self):
text = "Sentence. Followed by an abbreviation e.g. and another sentence."
expected = [10, 43, 64, len(text)] # Endings after ".", abbreviation ".", and sentence "."
self.assertEqual(find_sentence_endings(text), expected)

def test_empty_string(self):
text = ""
expected = [0] # Empty string should only have the 0th index as an "ending"
self.assertEqual(find_sentence_endings(text), expected)

# Run the tests
if __name__ == "__main__":
unittest.main()
Loading