Skip to content

Commit

Permalink
fix(agents-api): add split chunks option + nlp tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Ahmad-mtos committed Jan 15, 2025
1 parent 890880b commit 68a7a05
Show file tree
Hide file tree
Showing 3 changed files with 165 additions and 62 deletions.
23 changes: 18 additions & 5 deletions agents-api/agents_api/common/nlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
# Precompile regex patterns
WHITESPACE_RE = re.compile(r"\s+")
NON_ALPHANUM_RE = re.compile(r"[^\w\s\-_]+")
LONE_HYPHEN_RE = re.compile(r'\s*-\s*(?!\w)|(?<!\w)\s*-\s*')

# Initialize spaCy with minimal pipeline
nlp = spacy.load("en_core_web_sm", exclude=["lemmatizer", "textcat"])
Expand All @@ -32,10 +33,16 @@
@lru_cache(maxsize=10000)
def clean_keyword(kw: str) -> str:
"""Cache cleaned keywords for reuse."""
return NON_ALPHANUM_RE.sub("", kw).strip()
# First remove non-alphanumeric chars (except whitespace, hyphens, underscores)
cleaned = NON_ALPHANUM_RE.sub("", kw).strip()
# Replace lone hyphens with spaces
cleaned = LONE_HYPHEN_RE.sub(" ", cleaned)
# Clean up any resulting multiple spaces
cleaned = WHITESPACE_RE.sub(" ", cleaned).strip()
return cleaned


def extract_keywords(doc: Doc, top_n: int = 25, clean: bool = True) -> list[str]:
def extract_keywords(doc: Doc, top_n: int = 25, clean: bool = True, split_chunks: bool = False) -> list[str]:
"""Optimized keyword extraction with minimal behavior change."""
excluded_labels = {
"DATE", # Absolute or relative dates or periods.
Expand Down Expand Up @@ -95,6 +102,9 @@ def extract_keywords(doc: Doc, top_n: int = 25, clean: bool = True) -> list[str]
normalized_ent_keywords = [WHITESPACE_RE.sub(" ", kw).strip() for kw in ent_keywords]
normalized_keywords = [WHITESPACE_RE.sub(" ", kw).strip() for kw in keywords]

if split_chunks:
normalized_keywords = [word for kw in normalized_keywords for word in kw.split()]

# Count frequencies efficiently
ent_freq = Counter(normalized_ent_keywords)
freq = Counter(normalized_keywords)
Expand All @@ -109,7 +119,9 @@ def extract_keywords(doc: Doc, top_n: int = 25, clean: bool = True) -> list[str]


@lru_cache(maxsize=1000)
def text_to_tsvector_query(paragraph: str, top_n: int = 25, min_keywords: int = 1) -> str:
def text_to_tsvector_query(
paragraph: str, top_n: int = 25, min_keywords: int = 1, split_chunks: bool = False
) -> str:
"""
Extracts meaningful keywords/phrases from text and joins them with OR.
Expand All @@ -121,6 +133,7 @@ def text_to_tsvector_query(paragraph: str, top_n: int = 25, min_keywords: int =
paragraph (str): The input text to process
top_n (int): Number of top keywords to extract per sentence
min_keywords (int): Minimum number of keywords required
split_chunks (bool): If True, breaks multi-word noun chunks into individual words
Returns:
str: Keywords/phrases joined by OR
Expand All @@ -135,11 +148,11 @@ def text_to_tsvector_query(paragraph: str, top_n: int = 25, min_keywords: int =
sent_doc = sent.as_doc()

# Extract keywords
keywords = extract_keywords(sent_doc, top_n)
keywords = extract_keywords(sent_doc, top_n, split_chunks=split_chunks)
if len(keywords) < min_keywords:
continue

queries.add(" OR ".join(keywords))
queries.update(keywords)

# Join all terms with " OR "
return " OR ".join(queries) if queries else ""
Expand Down
57 changes: 0 additions & 57 deletions agents-api/tests/test_docs_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,63 +369,6 @@ async def _(
assert result[0].metadata is not None


@test("utility: test for text_to_tsvector_query")
async def _():
test_cases = [
# Single words
("test", "test"),
# Multiple words in single sentence
(
"quick brown fox",
"quick brown fox", # Now kept as a single phrase due to proximity
),
# Technical terms and phrases
(
"Machine Learning algorithm",
"machine learning algorithm", # Common technical phrase
),
# Multiple sentences
(
"I love basketball especially Michael Jordan. LeBron James is also great.",
"basketball OR lebron james OR michael jordan",
),
# Quoted phrases
(
'"quick brown fox"',
"quick brown fox", # Quotes removed, phrase kept together
),
('Find "machine learning" algorithms', "machine learning"),
# Multiple quoted phrases
('"data science" and "machine learning"', "machine learning OR data science"),
# Edge cases
("", ""),
(
"the and or",
"", # All stop words should result in empty string
),
(
"a",
"", # Single stop word should result in empty string
),
("X", "X"),
# Empty quotes
('""', ""),
('test "" phrase', "phrase OR test"),
]

for input_text, expected_output in test_cases:
print(f"Input: '{input_text}'")
result = text_to_tsvector_query(input_text)
print(f"Generated query: '{result}'")
print(f"Expected: '{expected_output}'\n")

result_terms = {term.lower() for term in result.split(" OR ") if term}
expected_terms = {term.lower() for term in expected_output.split(" OR ") if term}
assert result_terms == expected_terms, (
f"Expected terms {expected_terms} but got {result_terms} for input '{input_text}'"
)


# @test("query: search docs by embedding with different confidence levels")
# async def _(
# dsn=pg_dsn, agent=test_agent, developer=test_developer, doc=test_doc_with_embedding
Expand Down
147 changes: 147 additions & 0 deletions agents-api/tests/test_nlp_utilities.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
from agents_api.common.nlp import text_to_tsvector_query, clean_keyword, extract_keywords
import spacy

from ward import test

@test("utility: clean_keyword")
async def _():
assert clean_keyword("Hello, World!") == "Hello World"

# Basic cleaning
# assert clean_keyword("[email protected]") == "test example com"
assert clean_keyword("user-name_123") == "user-name_123"
assert clean_keyword(" spaces ") == "spaces"

# Special characters
assert clean_keyword("$price: 100%") == "price 100"
assert clean_keyword("#hashtag!") == "hashtag"

# Multiple spaces and punctuation
assert clean_keyword("multiple, spaces...") == "multiple spaces"

# Empty and whitespace
assert clean_keyword("") == ""
assert clean_keyword(" ") == ""

assert clean_keyword("- try") == "try"

@test("utility: extract_keywords")
async def _():
nlp = spacy.load("en_core_web_sm", exclude=["lemmatizer", "textcat"])
doc = nlp("John Doe is a software engineer at Google.")
assert set(extract_keywords(doc)) == {"John Doe", "a software engineer", "Google"}

@test("utility: text_to_tsvector_query - split_chunks=False")
async def _():
test_cases = [
# Single words
("test", "test"),
# Multiple words in single sentence
(
"quick brown fox",
"quick brown fox", # Now kept as a single phrase due to proximity
),
# Technical terms and phrases
(
"Machine Learning algorithm",
"machine learning algorithm", # Common technical phrase
),
# Multiple sentences
(
"I love basketball especially Michael Jordan. LeBron James is also great.",
"basketball OR lebron james OR michael jordan",
),
# Quoted phrases
(
'"quick brown fox"',
"quick brown fox", # Quotes removed, phrase kept together
),
('Find "machine learning" algorithms', "machine learning"),
# Multiple quoted phrases
('"data science" and "machine learning"', "machine learning OR data science"),
# Edge cases
("", ""),
(
"the and or",
"", # All stop words should result in empty string
),
(
"a",
"", # Single stop word should result in empty string
),
("X", "X"),
# Empty quotes
('""', ""),
('test "" phrase', "phrase OR test"),
("John Doe is a software engineer at Google.", "google OR john doe OR a software engineer"),
("- google", "google"),
]

for input_text, expected_output in test_cases:
print(f"Input: '{input_text}'")
result = text_to_tsvector_query(input_text, split_chunks=False)
print(f"Generated query: '{result}'")
print(f"Expected: '{expected_output}'\n")

result_terms = set(term.lower() for term in result.split(" OR ") if term)
expected_terms = set(term.lower() for term in expected_output.split(" OR ") if term)
assert result_terms == expected_terms, (
f"Expected terms {expected_terms} but got {result_terms} for input '{input_text}'"
)

@test("utility: text_to_tsvector_query - split_chunks=True")
async def _():
test_cases = [
# Single words
("test", "test"),
# Multiple words in single sentence
(
"quick brown fox",
"quick OR brown OR fox", # Now kept as a single phrase due to proximity
),
# Technical terms and phrases
(
"Machine Learning algorithm",
"machine OR learning OR algorithm", # Common technical phrase
),
# Multiple sentences
(
"I love basketball especially Michael Jordan. LeBron James is also great.",
"basketball OR lebron james OR michael jordan",
),
# Quoted phrases
(
'"quick brown fox"',
"quick OR brown OR fox", # Quotes removed, phrase kept together
),
('Find "machine learning" algorithms', "machine OR learning"),
# Multiple quoted phrases
('"data science" and "machine learning"', "machine OR learning OR data OR science"),
# Edge cases
("", ""),
(
"the and or",
"", # All stop words should result in empty string
),
(
"a",
"", # Single stop word should result in empty string
),
("X", "X"),
# Empty quotes
('""', ""),
('test "" phrase', "phrase OR test"),
("John Doe is a software engineer at Google.", "google OR john doe OR a OR software OR engineer"),
]

for input_text, expected_output in test_cases:
print(f"Input: '{input_text}'")
result = text_to_tsvector_query(input_text, split_chunks=True)
print(f"Generated query: '{result}'")
print(f"Expected: '{expected_output}'\n")

result_terms = set(term.lower() for term in result.split(" OR ") if term)
expected_terms = set(term.lower() for term in expected_output.split(" OR ") if term)
assert result_terms == expected_terms, (
f"Expected terms {expected_terms} but got {result_terms} for input '{input_text}'"
)

0 comments on commit 68a7a05

Please sign in to comment.