diff --git a/agents-api/agents_api/common/nlp.py b/agents-api/agents_api/common/nlp.py index a2f2f17ea..89b06a2c1 100644 --- a/agents-api/agents_api/common/nlp.py +++ b/agents-api/agents_api/common/nlp.py @@ -1,221 +1,238 @@ import re from collections import Counter, defaultdict +from functools import lru_cache +from typing import Dict, List, Set, Tuple import spacy +from spacy.matcher import PhraseMatcher +from spacy.tokens import Doc +from spacy.util import filter_spans -# Load spaCy English model +# Precompile regex patterns +WHITESPACE_RE = re.compile(r"\s+") +NON_ALPHANUM_RE = re.compile(r"[^\w\s\-_]+") + +# Initialize spaCy with minimal pipeline spacy.prefer_gpu() -nlp = spacy.load("en_core_web_sm") +nlp = spacy.load( + "en_core_web_sm", + disable=["lemmatizer", "textcat", "vector"], # Disable unused components +) +# Singleton PhraseMatcher for better performance -def extract_keywords(text: str, top_n: int = 10, clean: bool = True) -> list[str]: - """ - Extracts significant keywords and phrases from the text. - Args: - text (str): The input text to process. - top_n (int): Number of top keywords to extract based on frequency. - clean (bool): Strip non-alphanumeric characters from keywords. +class KeywordMatcher: + _instance = None - Returns: - List[str]: A list of extracted keywords/phrases. - """ - doc = nlp(text) - - # Extract named entities - entities = [ - ent.text.strip() - for ent in doc.ents - if ent.label_ - not in ["DATE", "TIME", "PERCENT", "MONEY", "QUANTITY", "ORDINAL", "CARDINAL"] - ] + def __new__(cls): + if cls._instance is None: + cls._instance = super().__new__(cls) + cls._instance.matcher = PhraseMatcher(nlp.vocab, attr="LOWER") + cls._instance.batch_size = 1000 + cls._instance.patterns_cache = {} + return cls._instance - # Extract nouns and proper nouns - nouns = [ - chunk.text.strip().lower() - for chunk in doc.noun_chunks - if not chunk.root.is_stop - ] + @lru_cache(maxsize=10000) + def _create_pattern(self, text: str) -> Doc: + return nlp.make_doc(text) - # Combine entities and nouns - combined = entities + nouns + def find_matches(self, doc: Doc, keywords: List[str]) -> Dict[str, List[int]]: + """Batch process keywords for better performance.""" + keyword_positions = defaultdict(list) - # Normalize and count frequency - normalized = [re.sub(r"\s+", " ", kw).strip() for kw in combined] - freq = Counter(normalized) + # Process keywords in batches to avoid memory issues + for i in range(0, len(keywords), self.batch_size): + batch = keywords[i : i + self.batch_size] + patterns = [self._create_pattern(kw) for kw in batch] - # Get top_n keywords - keywords = [item for item, count in freq.most_common(top_n)] + # Clear previous patterns and add new batch + if "KEYWORDS" in self.matcher: + self.matcher.remove("KEYWORDS") + self.matcher.add("KEYWORDS", patterns) - if clean: - keywords = [re.sub(r"[^\w\s\-_]+", "", kw) for kw in keywords] + # Find matches for this batch + matches = self.matcher(doc) + for match_id, start, end in matches: + span_text = doc[start:end].text + normalized = WHITESPACE_RE.sub(" ", span_text).lower().strip() + keyword_positions[normalized].append(start) - return keywords + return keyword_positions -def find_keyword_positions(doc, keyword: str) -> list[int]: - """ - Finds all start indices of the keyword in the tokenized doc. +# Initialize global matcher +keyword_matcher = KeywordMatcher() - Args: - doc (spacy.tokens.Doc): The tokenized document. - keyword (str): The keyword or phrase to search for. - Returns: - List[int]: List of starting token indices where the keyword appears. - """ - keyword_tokens = keyword.split() - n = len(keyword_tokens) - positions = [] - for i in range(len(doc) - n + 1): - window = doc[i : i + n] - window_text = " ".join([token.text.lower() for token in window]) - if window_text == keyword: - positions.append(i) - return positions +@lru_cache(maxsize=10000) +def clean_keyword(kw: str) -> str: + """Cache cleaned keywords for reuse.""" + return NON_ALPHANUM_RE.sub("", kw).strip() -def find_proximity_groups( - text: str, keywords: list[str], n: int = 10 -) -> list[set[str]]: - """ - Groups keywords that appear within n words of each other. +def extract_keywords(doc: Doc, top_n: int = 10, clean: bool = True) -> List[str]: + """Optimized keyword extraction with minimal behavior change.""" + excluded_labels = { + "DATE", + "TIME", + "PERCENT", + "MONEY", + "QUANTITY", + "ORDINAL", + "CARDINAL", + } - Args: - text (str): The input text. - keywords (List[str]): List of keywords to consider. - n (int): The proximity window in words. + # Extract and filter spans in a single pass + ent_spans = [ent for ent in doc.ents if ent.label_ not in excluded_labels] + chunk_spans = [chunk for chunk in doc.noun_chunks if not chunk.root.is_stop] + all_spans = filter_spans(ent_spans + chunk_spans) - Returns: - List[Set[str]]: List of sets, each containing keywords that are proximate. - """ - doc = nlp(text.lower()) - keyword_positions = defaultdict(list) + # Process spans efficiently + keywords = [] + seen_texts = set() - for kw in keywords: - positions = find_keyword_positions(doc, kw) - keyword_positions[kw].extend(positions) - - # Initialize Union-Find structure - parent = {} - - def find(u): - while parent[u] != u: - parent[u] = parent[parent[u]] - u = parent[u] - return u - - def union(u, v): - u_root = find(u) - v_root = find(v) - if u_root == v_root: - return - parent[v_root] = u_root - - # Initialize each keyword as its own parent - for kw in keywords: - parent[kw] = kw - - # Compare all pairs of keywords - for i in range(len(keywords)): - for j in range(i + 1, len(keywords)): - kw1 = keywords[i] - kw2 = keywords[j] - positions1 = keyword_positions[kw1] - positions2 = keyword_positions[kw2] - # Check if any positions are within n words - for pos1 in positions1: - for pos2 in positions2: - distance = abs(pos1 - pos2) - if distance <= n: - union(kw1, kw2) - break - else: - continue - break - - # Group keywords by their root parent + for span in all_spans: + text = span.text.strip() + lower_text = text.lower() + + # Skip empty or seen texts + if not text or lower_text in seen_texts: + continue + + seen_texts.add(lower_text) + keywords.append(text) + + # Normalize keywords by replacing multiple spaces with single space and stripping + normalized_keywords = [WHITESPACE_RE.sub(" ", kw).strip() for kw in keywords] + + # Count frequencies efficiently + freq = Counter(normalized_keywords) + top_keywords = [kw for kw, _ in freq.most_common(top_n)] + + if clean: + return [clean_keyword(kw) for kw in top_keywords] + return top_keywords + + +def find_proximity_groups( + keywords: List[str], keyword_positions: Dict[str, List[int]], n: int = 10 +) -> List[Set[str]]: + """Optimized proximity grouping using sorted positions.""" + # Early return for single or no keywords + if len(keywords) <= 1: + return [{kw} for kw in keywords] + + # Create flat list of positions for efficient processing + positions: List[Tuple[int, str]] = [ + (pos, kw) for kw in keywords for pos in keyword_positions[kw] + ] + + # Sort positions once + positions.sort() + + # Initialize Union-Find with path compression and union by rank + parent = {kw: kw for kw in keywords} + rank = {kw: 0 for kw in keywords} + + def find(u: str) -> str: + if parent[u] != u: + parent[u] = find(parent[u]) + return parent[u] + + def union(u: str, v: str) -> None: + u_root, v_root = find(u), find(v) + if u_root != v_root: + if rank[u_root] < rank[v_root]: + u_root, v_root = v_root, u_root + parent[v_root] = u_root + if rank[u_root] == rank[v_root]: + rank[u_root] += 1 + + # Use sliding window for proximity checking + window = [] + for pos, kw in positions: + # Remove positions outside window + while window and pos - window[0][0] > n: + window.pop(0) + + # Union with all keywords in window + for _, w_kw in window: + union(kw, w_kw) + + window.append((pos, kw)) + + # Group keywords efficiently groups = defaultdict(set) for kw in keywords: root = find(kw) groups[root].add(kw) - # Convert to list of sets - group_list = list(groups.values()) - - return group_list + return list(groups.values()) -def build_query(groups: list[set[str]], keywords: list[str], n: int = 10) -> str: - """ - Builds a query string using the custom query language. +@lru_cache(maxsize=100) +def build_query_pattern(group_size: int, n: int) -> str: + """Cache query patterns for common group sizes.""" + if group_size == 1: + return '"{}"' + return f"NEAR/{n}(" + " ".join('"{}"' for _ in range(group_size)) + ")" - Args: - groups (List[Set[str]]): List of keyword groups. - keywords (List[str]): Original list of keywords. - n (int): The proximity window for NEAR. - Returns: - str: The constructed query string. - """ - grouped_keywords = set() +def build_query(groups: List[Set[str]], n: int = 10) -> str: + """Build query with cached patterns.""" clauses = [] for group in groups: if len(group) == 1: - clauses.append(f'"{list(group)[0]}"') + clauses.append(f'"{next(iter(group))}"') else: - sorted_group = sorted( - group, key=lambda x: -len(x) - ) # Sort by length to prioritize phrases - escaped_keywords = [f'"{kw}"' for kw in sorted_group] - near_clause = f"NEAR/{n}(" + " ".join(escaped_keywords) + ")" - clauses.append(near_clause) - grouped_keywords.update(group) - - # Identify keywords not in any group (if any) - remaining = set(keywords) - grouped_keywords - for kw in remaining: - clauses.append(f'"{kw}"') - - # Combine all clauses with OR - query = " OR ".join(clauses) + # Sort by length descending to prioritize longer phrases + sorted_group = sorted(group, key=len, reverse=True) + # Get cached pattern and format with keywords + pattern = build_query_pattern(len(group), n) + clause = pattern.format(*sorted_group) + clauses.append(clause) - return query + return " OR ".join(clauses) -def text_to_custom_query(text: str, top_n: int = 10, proximity_n: int = 10) -> str: +def paragraph_to_custom_queries( + paragraph: str, top_n: int = 10, proximity_n: int = 10, min_keywords: int = 1 +) -> List[str]: """ - Converts arbitrary text to the custom query language. + Optimized paragraph processing with minimal behavior changes. + Added min_keywords parameter to filter out low-value queries. + """ + if not paragraph or not paragraph.strip(): + return [] - Args: - text (str): The input text to convert. - top_n (int): Number of top keywords to extract. - proximity_n (int): The proximity window for NEAR/n. + # Process entire paragraph once + doc = nlp(paragraph) + queries = [] - Returns: - str: The custom query string. - """ - keywords = extract_keywords(text, top_n) - if not keywords: - return "" - groups = find_proximity_groups(text, keywords, proximity_n) - query = build_query(groups, keywords, proximity_n) - return query + # Process sentences + for sent in doc.sents: + # Convert to doc for consistent API + sent_doc = sent.as_doc() + # Extract and clean keywords + keywords = extract_keywords(sent_doc, top_n) + if len(keywords) < min_keywords: + continue -def paragraph_to_custom_queries(paragraph: str) -> list[str]: - """ - Converts a paragraph to a list of custom query strings. + # Find keyword positions using matcher + keyword_positions = keyword_matcher.find_matches(sent_doc, keywords) - Args: - paragraph (str): The input paragraph to convert. + # Skip if no keywords found in positions + if not keyword_positions: + continue - Returns: - List[str]: The list of custom query strings. - """ + # Find proximity groups and build query + groups = find_proximity_groups(keywords, keyword_positions, proximity_n) + query = build_query(groups, proximity_n) - queries = [text_to_custom_query(sentence.text) for sentence in nlp(paragraph).sents] - queries = [q for q in queries if q] + if query: + queries.append(query) return queries