Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(agents-api): Tweak queries for search #685

Merged
merged 4 commits into from
Oct 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
216 changes: 216 additions & 0 deletions agents-api/agents_api/common/nlp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
import re
from collections import Counter, defaultdict

import spacy

# Load spaCy English model
spacy.prefer_gpu()
nlp = spacy.load("en_core_web_sm")


def extract_keywords(text: str, top_n: int = 10) -> list[str]:
"""
Extracts significant keywords and phrases from the text.

Args:
text (str): The input text to process.
top_n (int): Number of top keywords to extract based on frequency.

Returns:
List[str]: A list of extracted keywords/phrases.
"""
doc = nlp(text)

# Extract named entities
entities = [
ent.text.strip()
for ent in doc.ents
if ent.label_
not in ["DATE", "TIME", "PERCENT", "MONEY", "QUANTITY", "ORDINAL", "CARDINAL"]
]

# Extract nouns and proper nouns
nouns = [
chunk.text.strip().lower()
for chunk in doc.noun_chunks
if not chunk.root.is_stop
]

# Combine entities and nouns
combined = entities + nouns

# Normalize and count frequency
normalized = [re.sub(r"\s+", " ", kw).strip().lower() for kw in combined]
freq = Counter(normalized)

# Get top_n keywords
keywords = [item for item, count in freq.most_common(top_n)]

return keywords


def find_keyword_positions(doc, keyword: str) -> list[int]:
"""
Finds all start indices of the keyword in the tokenized doc.

Args:
doc (spacy.tokens.Doc): The tokenized document.
keyword (str): The keyword or phrase to search for.

Returns:
List[int]: List of starting token indices where the keyword appears.
"""
keyword_tokens = keyword.split()
n = len(keyword_tokens)
positions = []
for i in range(len(doc) - n + 1):
window = doc[i : i + n]
window_text = " ".join([token.text.lower() for token in window])
if window_text == keyword:
positions.append(i)
return positions


def find_proximity_groups(
text: str, keywords: list[str], n: int = 10
) -> list[set[str]]:
"""
Groups keywords that appear within n words of each other.

Args:
text (str): The input text.
keywords (List[str]): List of keywords to consider.
n (int): The proximity window in words.

Returns:
List[Set[str]]: List of sets, each containing keywords that are proximate.
"""
doc = nlp(text.lower())
keyword_positions = defaultdict(list)

for kw in keywords:
positions = find_keyword_positions(doc, kw)
keyword_positions[kw].extend(positions)

# Initialize Union-Find structure
parent = {}

def find(u):
while parent[u] != u:
parent[u] = parent[parent[u]]
u = parent[u]
return u

def union(u, v):
u_root = find(u)
v_root = find(v)
if u_root == v_root:
return
parent[v_root] = u_root

# Initialize each keyword as its own parent
for kw in keywords:
parent[kw] = kw

# Compare all pairs of keywords
for i in range(len(keywords)):
for j in range(i + 1, len(keywords)):
kw1 = keywords[i]
kw2 = keywords[j]
positions1 = keyword_positions[kw1]
positions2 = keyword_positions[kw2]
# Check if any positions are within n words
for pos1 in positions1:
for pos2 in positions2:
distance = abs(pos1 - pos2)
if distance <= n:
union(kw1, kw2)
break
else:
continue
break

# Group keywords by their root parent
groups = defaultdict(set)
for kw in keywords:
root = find(kw)
groups[root].add(kw)

# Convert to list of sets
group_list = list(groups.values())

return group_list


def build_query(groups: list[set[str]], keywords: list[str], n: int = 10) -> str:
"""
Builds a query string using the custom query language.

Args:
groups (List[Set[str]]): List of keyword groups.
keywords (List[str]): Original list of keywords.
n (int): The proximity window for NEAR.

Returns:
str: The constructed query string.
"""
grouped_keywords = set()
clauses = []

for group in groups:
if len(group) == 1:
clauses.append(f'"{list(group)[0]}"')
else:
sorted_group = sorted(
group, key=lambda x: -len(x)
) # Sort by length to prioritize phrases
escaped_keywords = [f'"{kw}"' for kw in sorted_group]
near_clause = f"NEAR/{n}(" + " ".join(escaped_keywords) + ")"
clauses.append(near_clause)
grouped_keywords.update(group)

# Identify keywords not in any group (if any)
remaining = set(keywords) - grouped_keywords
for kw in remaining:
clauses.append(f'"{kw}"')

# Combine all clauses with OR
query = " OR ".join(clauses)

return query


def text_to_custom_query(text: str, top_n: int = 10, proximity_n: int = 10) -> str:
"""
Converts arbitrary text to the custom query language.

Args:
text (str): The input text to convert.
top_n (int): Number of top keywords to extract.
proximity_n (int): The proximity window for NEAR/n.

Returns:
str: The custom query string.
"""
keywords = extract_keywords(text, top_n)
if not keywords:
return ""
groups = find_proximity_groups(text, keywords, proximity_n)
query = build_query(groups, keywords, proximity_n)
return query


def paragraph_to_custom_queries(paragraph: str) -> list[str]:
"""
Converts a paragraph to a list of custom query strings.

Args:
paragraph (str): The input paragraph to convert.

Returns:
List[str]: The list of custom query strings.
"""

queries = [text_to_custom_query(sentence.text) for sentence in nlp(paragraph).sents]

return queries
4 changes: 2 additions & 2 deletions agents-api/agents_api/models/docs/search_docs_by_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ def search_docs_by_embedding(
owners: list[tuple[Literal["user", "agent"], UUID]],
query_embedding: list[float],
k: int = 3,
confidence: float = 0.7,
ef: int = 128,
confidence: float = 0.5,
ef: int = 32,
mmr_lambda: float = 0.25,
embedding_size: int = 1024,
) -> tuple[list[str], dict]:
Expand Down
9 changes: 5 additions & 4 deletions agents-api/agents_api/models/docs/search_docs_by_text.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""This module contains functions for searching documents in the CozoDB based on embedding queries."""

import json
from typing import Any, Literal, TypeVar
from uuid import UUID

Expand All @@ -10,6 +9,7 @@
from pydantic import ValidationError

from ...autogen.openapi_model import DocReference
from ...common.nlp import paragraph_to_custom_queries
from ..utils import (
cozo_query,
partialclass,
Expand Down Expand Up @@ -64,7 +64,7 @@ def search_docs_by_text(

# Need to use NEAR/3($query) to search for arbitrary text within 3 words of each other
# See: https://docs.cozodb.org/en/latest/vector.html#full-text-search-fts
query = f"NEAR/3({json.dumps(query)})"
fts_queries = paragraph_to_custom_queries(query)

# Construct the datalog query for searching document snippets
search_query = f"""
Expand Down Expand Up @@ -112,11 +112,12 @@ def search_docs_by_text(
index,
content
|
query: $query,
query: query,
k: {k},
score_kind: 'tf_idf',
bind_score: score,
}},
query in $fts_queries,
distance = -score,
snippet_data = [index, content]

Expand Down Expand Up @@ -183,5 +184,5 @@ def search_docs_by_text(

return (
queries,
{"owners": owners, "query": query},
{"owners": owners, "query": query, "fts_queries": fts_queries},
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# /usr/bin/env python3

MIGRATION_ID = "tweak_proximity_indices"
CREATED_AT = 1729114011.022733


def run(client, *queries):
joiner = "}\n\n{"

query = joiner.join(queries)
query = f"{{\n{query}\n}}"
client.run(query)


drop_snippets_lsh_index = dict(
up="""
::lsh drop snippets:lsh
""",
down="""
::lsh create snippets:lsh {
extractor: content,
tokenizer: Simple,
filters: [Stopwords('en')],
n_perm: 200,
target_threshold: 0.9,
n_gram: 3,
false_positive_weight: 1.0,
false_negative_weight: 1.0,
}
""",
)

snippets_lsh_index = dict(
up="""
::lsh create snippets:lsh {
extractor: content,
tokenizer: Simple,
filters: [Lowercase, AsciiFolding, Stemmer('english'), Stopwords('en')],
n_perm: 200,
target_threshold: 0.5,
n_gram: 2,
false_positive_weight: 1.0,
false_negative_weight: 1.0,
}
""",
down="""
::lsh drop snippets:lsh
""",
)

# See: https://docs.cozodb.org/en/latest/vector.html#full-text-search-fts
drop_snippets_fts_index = dict(
down="""
::fts create snippets:fts {
extractor: content,
tokenizer: Simple,
filters: [Lowercase, Stemmer('english'), Stopwords('en')],
}
""",
up="""
::fts drop snippets:fts
""",
)

# See: https://docs.cozodb.org/en/latest/vector.html#full-text-search-fts
snippets_fts_index = dict(
up="""
::fts create snippets:fts {
extractor: content,
tokenizer: Simple,
filters: [Lowercase, AsciiFolding, Stemmer('english'), Stopwords('en')],
}
""",
down="""
::fts drop snippets:fts
""",
)

queries_to_run = [
drop_snippets_lsh_index,
drop_snippets_fts_index,
snippets_lsh_index,
snippets_fts_index,
]


def up(client):
run(client, *[q["up"] for q in queries_to_run])


def down(client):
run(client, *[q["down"] for q in reversed(queries_to_run)])
Loading
Loading