Skip to content

Commit

Permalink
feat(agents-api): Improve nlp performance
Browse files Browse the repository at this point in the history
Signed-off-by: Diwank Singh Tomer <[email protected]>
  • Loading branch information
creatorrr committed Oct 24, 2024
1 parent d803ebb commit 0f4c4e0
Showing 1 changed file with 64 additions and 16 deletions.
80 changes: 64 additions & 16 deletions agents-api/agents_api/common/nlp.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import re
from collections import Counter, defaultdict
from functools import lru_cache
from typing import Dict, List, Set, Tuple

import spacy
from spacy.matcher import PhraseMatcher
Expand All @@ -13,31 +12,29 @@
NON_ALPHANUM_RE = re.compile(r"[^\w\s\-_]+")

# Initialize spaCy with minimal pipeline
spacy.prefer_gpu()
nlp = spacy.load(
"en_core_web_sm",
disable=["lemmatizer", "textcat", "vector"], # Disable unused components
)
nlp = spacy.load("en_core_web_sm", exclude=["lemmatizer", "textcat", "tok2vec"])

# Singleton PhraseMatcher for better performance
# Add sentencizer for faster sentence tokenization
sentencizer = nlp.add_pipe("sentencizer")


# Singleton PhraseMatcher for better performance
class KeywordMatcher:
_instance = None

def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance.matcher = PhraseMatcher(nlp.vocab, attr="LOWER")
cls._instance.batch_size = 1000
cls._instance.batch_size = 1000 # Adjust based on memory constraints
cls._instance.patterns_cache = {}
return cls._instance

@lru_cache(maxsize=10000)
def _create_pattern(self, text: str) -> Doc:
return nlp.make_doc(text)

def find_matches(self, doc: Doc, keywords: List[str]) -> Dict[str, List[int]]:
def find_matches(self, doc: Doc, keywords: list[str]) -> dict[str, list[int]]:
"""Batch process keywords for better performance."""
keyword_positions = defaultdict(list)

Expand Down Expand Up @@ -71,7 +68,7 @@ def clean_keyword(kw: str) -> str:
return NON_ALPHANUM_RE.sub("", kw).strip()


def extract_keywords(doc: Doc, top_n: int = 10, clean: bool = True) -> List[str]:
def extract_keywords(doc: Doc, top_n: int = 10, clean: bool = True) -> list[str]:
"""Optimized keyword extraction with minimal behavior change."""
excluded_labels = {
"DATE",
Expand Down Expand Up @@ -116,15 +113,15 @@ def extract_keywords(doc: Doc, top_n: int = 10, clean: bool = True) -> List[str]


def find_proximity_groups(
keywords: List[str], keyword_positions: Dict[str, List[int]], n: int = 10
) -> List[Set[str]]:
keywords: list[str], keyword_positions: dict[str, list[int]], n: int = 10
) -> list[set[str]]:
"""Optimized proximity grouping using sorted positions."""
# Early return for single or no keywords
if len(keywords) <= 1:
return [{kw} for kw in keywords]

# Create flat list of positions for efficient processing
positions: List[Tuple[int, str]] = [
positions: list[tuple[int, str]] = [
(pos, kw) for kw in keywords for pos in keyword_positions[kw]
]

Expand Down Expand Up @@ -171,15 +168,14 @@ def union(u: str, v: str) -> None:
return list(groups.values())


@lru_cache(maxsize=100)
def build_query_pattern(group_size: int, n: int) -> str:
"""Cache query patterns for common group sizes."""
if group_size == 1:
return '"{}"'
return f"NEAR/{n}(" + " ".join('"{}"' for _ in range(group_size)) + ")"


def build_query(groups: List[Set[str]], n: int = 10) -> str:
def build_query(groups: list[set[str]], n: int = 10) -> str:
"""Build query with cached patterns."""
clauses = []

Expand All @@ -197,12 +193,22 @@ def build_query(groups: List[Set[str]], n: int = 10) -> str:
return " OR ".join(clauses)


@lru_cache(maxsize=100)
def paragraph_to_custom_queries(
paragraph: str, top_n: int = 10, proximity_n: int = 10, min_keywords: int = 1
) -> List[str]:
) -> list[str]:
"""
Optimized paragraph processing with minimal behavior changes.
Added min_keywords parameter to filter out low-value queries.
Args:
paragraph (str): The input paragraph to convert.
top_n (int): Number of top keywords to extract per sentence.
proximity_n (int): The proximity window for NEAR/n.
min_keywords (int): Minimum number of keywords required to form a query.
Returns:
list[str]: The list of custom query strings.
"""
if not paragraph or not paragraph.strip():
return []
Expand Down Expand Up @@ -236,3 +242,45 @@ def paragraph_to_custom_queries(
queries.append(query)

return queries


def batch_paragraphs_to_custom_queries(
paragraphs: list[str],
top_n: int = 10,
proximity_n: int = 10,
min_keywords: int = 1,
n_process: int = 1,
) -> list[list[str]]:
"""
Processes multiple paragraphs using nlp.pipe for better performance.
Args:
paragraphs (list[str]): list of paragraphs to process.
top_n (int): Number of top keywords to extract per sentence.
proximity_n (int): The proximity window for NEAR/n.
min_keywords (int): Minimum number of keywords required to form a query.
n_process (int): Number of processes to use for multiprocessing.
Returns:
list[list[str]]: A list where each element is a list of queries for a paragraph.
"""
results = []
for doc in nlp.pipe(
paragraphs, disable=["lemmatizer", "textcat"], n_process=n_process
):
queries = []
for sent in doc.sents:
sent_doc = sent.as_doc()
keywords = extract_keywords(sent_doc, top_n)
if len(keywords) < min_keywords:
continue
keyword_positions = keyword_matcher.find_matches(sent_doc, keywords)
if not keyword_positions:
continue
groups = find_proximity_groups(keywords, keyword_positions, proximity_n)
query = build_query(groups, proximity_n)
if query:
queries.append(query)
results.append(queries)

return results

0 comments on commit 0f4c4e0

Please sign in to comment.