Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(cli): Improve cli wrapper and fix init command behavior #1214

Draft
wants to merge 5 commits into
base: dev
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 14 additions & 12 deletions agents-api/agents_api/common/nlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,15 @@ def clean_keyword(kw: str) -> str:

def extract_keywords(doc: Doc, top_n: int = 25, split_chunks: bool = True) -> list[str]:
"""Optimized keyword extraction with minimal behavior change."""

excluded_labels = {
"DATE", # Absolute or relative dates or periods.
"TIME", # Times smaller than a day.
"PERCENT", # Percentage, including ”%“.
"MONEY", # Monetary values, including unit.
"QUANTITY", # Measurements, as of weight or distance.
"ORDINAL", # “first”, “second”, etc.
"CARDINAL", # Numerals that do not fall under another type.
# "DATE", # Absolute or relative dates or periods.
# "MONEY", # Monetary values, including unit.
# "PERSON", # People, including fictional.
# "NORP", # Nationalities or religious or political groups.
# "FAC", # Buildings, airports, highways, bridges, etc.
Expand All @@ -66,6 +67,7 @@ def extract_keywords(doc: Doc, top_n: int = 25, split_chunks: bool = True) -> li

# Extract and filter spans in a single pass
ent_spans = [ent for ent in doc.ents if ent.label_ not in excluded_labels]

# Add more comprehensive stopword filtering for noun chunks
chunk_spans = [
chunk
Expand Down Expand Up @@ -116,18 +118,18 @@ def extract_keywords(doc: Doc, top_n: int = 25, split_chunks: bool = True) -> li


@lru_cache(maxsize=1000)
def text_to_tsvector_query(
def text_to_keywords(
paragraph: str,
top_n: int = 25,
min_keywords: int = 1,
split_chunks: bool = True,
) -> str:
) -> set[str]:
"""
Extracts meaningful keywords/phrases from text and joins them with OR.
Extracts meaningful keywords/phrases from text.

Example:
Input: "I like basketball especially Michael Jordan"
Output: "basketball OR Michael Jordan"
Output: {"basketball", "Michael Jordan"}

Args:
paragraph (str): The input text to process
Expand All @@ -136,26 +138,26 @@ def text_to_tsvector_query(
split_chunks (bool): If True, breaks multi-word noun chunks into individual words

Returns:
str: Keywords/phrases joined by OR
set[str]: Set of keywords/phrases
"""
if not paragraph or not paragraph.strip():
return ""
return set()

doc = nlp(paragraph)
queries = set() # Use set to avoid duplicates
all_keywords = set() # Use set to avoid duplicates

for sent in doc.sents:
sent_doc = sent.as_doc()

# Extract keywords
keywords = extract_keywords(sent_doc, top_n, split_chunks=split_chunks)
keywords = [kw for kw in keywords if len(kw) > 1]
if len(keywords) < min_keywords:
continue
Comment on lines 153 to 156
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Filtering keywords with len(kw) > 1 after extraction could return fewer keywords than min_keywords, making the minimum check ineffective. Move length check before the minimum check.

📝 Committable Code Suggestion

‼️ Ensure you review the code suggestion before committing it to the branch. Make sure it replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
keywords = extract_keywords(sent_doc, top_n, split_chunks=split_chunks)
keywords = [kw for kw in keywords if len(kw) > 1]
if len(keywords) < min_keywords:
continue
keywords = [kw for kw in extract_keywords(sent_doc, top_n, split_chunks=split_chunks) if len(kw) > 1]
if len(keywords) < min_keywords:
continue


queries.update(keywords)
all_keywords.update(keywords)

# Join all terms with " OR "
return " OR ".join(queries) if queries else ""
return all_keywords


# def batch_text_to_tsvector_queries(
Expand Down
4 changes: 4 additions & 0 deletions agents-api/agents_api/common/utils/get_doc_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ def get_language(lang: str) -> str:

def get_search_fn_and_params(
search_params,
*,
extract_keywords: bool = False,
) -> tuple[Any, dict[str, float | int | str | dict[str, float] | list[float]] | None]:
search_fn, params = None, None

Expand All @@ -63,6 +65,7 @@ def get_search_fn_and_params(
"k": k,
"metadata_filter": metadata_filter,
"search_language": search_language,
"extract_keywords": extract_keywords,
}

case VectorDocSearchRequest(
Expand Down Expand Up @@ -99,6 +102,7 @@ def get_search_fn_and_params(
"alpha": alpha,
"metadata_filter": metadata_filter,
"search_language": search_language,
"extract_keywords": extract_keywords,
}

# Note: connection_pool will be passed separately by the caller
Expand Down
8 changes: 6 additions & 2 deletions agents-api/agents_api/queries/chat/gather_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,12 @@ async def gather_messages(
# Invalid mode, return early
return past_messages, []

# Execute search
search_fn, params = get_search_fn_and_params(search_params)
# Execute search (extract keywords for FTS because the query is a conversation snippet)
extract_keywords: bool = True
search_fn, params = get_search_fn_and_params(
search_params, extract_keywords=extract_keywords
)

doc_references: list[DocReference] = await search_fn(
developer_id=developer.id,
owners=owners,
Expand Down
36 changes: 22 additions & 14 deletions agents-api/agents_api/queries/docs/search_docs_by_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,23 @@
from fastapi import HTTPException

from ...autogen.openapi_model import DocReference
from ...common.nlp import text_to_tsvector_query
from ...common.nlp import text_to_keywords
from ...common.utils.db_exceptions import common_db_exceptions
from ..utils import pg_query, rewrap_exceptions, wrap_in_class
from .utils import transform_to_doc_reference

# Raw query for text search
# SQL query for searching docs by text
search_docs_text_query = """
SELECT * FROM search_by_text(
$1, -- developer_id
$2, -- query
$3, -- owner_types
$4, -- owner_ids
$5, -- search_language
$6, -- k
$7 -- metadata_filter
)
$1::uuid, -- developer_id
$2::text, -- query_text
$3::text[], -- owner_types
$4::uuid[], -- owner_ids
$5::text, -- search_language
$6::int, -- k
$7::jsonb, -- metadata_filter
$8::float -- similarity_threshold (default value)
);
"""


Expand All @@ -38,7 +39,9 @@ async def search_docs_by_text(
query: str,
k: int = 3,
metadata_filter: dict[str, Any] = {},
search_language: str | None = "english",
search_language: str | None = "english_unaccent",
trigram_similarity_threshold: float = 0.3,
extract_keywords: bool = False,
) -> tuple[str, list]:
"""
Full-text search on docs using the search_tsv column.
Expand All @@ -55,14 +58,18 @@ async def search_docs_by_text(
Returns:
tuple[str, list]: SQL query and parameters for searching the documents.
"""

if k < 1:
raise HTTPException(status_code=400, detail="k must be >= 1")

# Extract owner types and IDs
owner_types: list[str] = [owner[0] for owner in owners]
owner_ids: list[str] = [str(owner[1]) for owner in owners]
# Pre-process rawtext query
query = text_to_tsvector_query(query, split_chunks=True)
owner_ids: list[UUID] = [owner[1] for owner in owners]

# Pre-process rawtext query if extract_keywords is True
if extract_keywords:
keywords = text_to_keywords(query, split_chunks=True)
query = " OR ".join(keywords)

return (
search_docs_text_query,
Expand All @@ -74,5 +81,6 @@ async def search_docs_by_text(
search_language,
k,
metadata_filter,
trigram_similarity_threshold,
],
)
14 changes: 10 additions & 4 deletions agents-api/agents_api/queries/docs/search_docs_hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from fastapi import HTTPException

from ...autogen.openapi_model import DocReference
from ...common.nlp import text_to_tsvector_query
from ...common.nlp import text_to_keywords
from ...common.utils.db_exceptions import common_db_exceptions
from ..utils import (
pg_query,
Expand All @@ -26,7 +26,8 @@
$7, -- alpha
$8, -- confidence
$9, -- metadata_filter
$10 -- search_language
$10, -- search_language
$11 -- trigram_similarity_threshold
)
"""

Expand All @@ -48,6 +49,8 @@ async def search_docs_hybrid(
metadata_filter: dict[str, Any] = {},
search_language: str = "english",
confidence: int | float = 0.5,
trigram_similarity_threshold: float = 0.3,
extract_keywords: bool = False,
) -> tuple[str, list]:
"""
Hybrid text-and-embedding doc search. We get top-K from each approach,
Expand Down Expand Up @@ -82,8 +85,10 @@ async def search_docs_hybrid(
owner_types: list[str] = [owner[0] for owner in owners]
owner_ids: list[str] = [str(owner[1]) for owner in owners]

# Pre-process rawtext query
text_query = text_to_tsvector_query(text_query, split_chunks=True)
# Pre-process rawtext query if extract_keywords is True
if extract_keywords:
keywords = text_to_keywords(text_query, split_chunks=True)
text_query = " OR ".join(keywords)

return (
search_docs_hybrid_query,
Expand All @@ -98,5 +103,6 @@ async def search_docs_hybrid(
confidence,
metadata_filter,
search_language,
trigram_similarity_threshold,
],
)
3 changes: 3 additions & 0 deletions agents-api/tests/test_get_doc_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def _():
"k": 10,
"metadata_filter": {"field": "value"},
"search_language": "english",
"extract_keywords": False,
}


Expand Down Expand Up @@ -132,6 +133,7 @@ def _():
"alpha": 0.5,
"metadata_filter": {"field": "value"},
"search_language": "english",
"extract_keywords": False,
}


Expand Down Expand Up @@ -159,6 +161,7 @@ def _():
"alpha": 0.5,
"metadata_filter": {"field": "value"},
"search_language": "english",
"extract_keywords": False,
}


Expand Down
Loading