Skip to content

Commit

Permalink
refactor: Lint agents-api (CI)
Browse files Browse the repository at this point in the history
  • Loading branch information
Vedantsahai18 authored and github-actions[bot] committed Jan 13, 2025
1 parent 1b02a79 commit 25a2e65
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 67 deletions.
15 changes: 8 additions & 7 deletions agents-api/agents_api/common/nlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,8 @@ def extract_keywords(doc: Doc, top_n: int = 10, clean: bool = True) -> list[str]
ent_spans = [ent for ent in doc.ents if ent.label_ not in excluded_labels]
# Add more comprehensive stopword filtering for noun chunks
chunk_spans = [
chunk for chunk in doc.noun_chunks
chunk
for chunk in doc.noun_chunks
if not chunk.root.is_stop and not all(token.is_stop for token in chunk)
]
all_spans = filter_spans(ent_spans + chunk_spans)
Expand All @@ -109,7 +110,7 @@ def extract_keywords(doc: Doc, top_n: int = 10, clean: bool = True) -> list[str]
# Skip if all tokens in span are stopwords
if all(token.is_stop for token in span):
continue

text = span.text.strip()
lower_text = text.lower()

Expand Down Expand Up @@ -194,7 +195,7 @@ def text_to_tsvector_query(
) -> str:
"""
Extracts meaningful keywords/phrases from text and joins them with OR.
Example:
Input: "I like basketball especially Michael Jordan"
Output: "basketball OR Michael Jordan"
Expand All @@ -216,7 +217,7 @@ def text_to_tsvector_query(

for sent in doc.sents:
sent_doc = sent.as_doc()

# Extract keywords
keywords = extract_keywords(sent_doc, top_n)
if len(keywords) < min_keywords:
Expand All @@ -235,7 +236,7 @@ def text_to_tsvector_query(
if len(group) > 1:
# Sort by length descending to prioritize longer phrases
sorted_group = sorted(group, key=len, reverse=True)
# For truly proximate multi-word groups, group words
# For truly proximate multi-word groups, group words
queries.add(" OR ".join(sorted_group))
else:
# For non-proximate words or single words, add them separately
Expand Down Expand Up @@ -265,7 +266,7 @@ def batch_text_to_tsvector_queries(
results = []

for doc in nlp.pipe(paragraphs, disable=["lemmatizer", "textcat"], n_process=n_process):
queries = set() # Use set to avoid duplicates
queries = set() # Use set to avoid duplicates
for sent in doc.sents:
sent_doc = sent.as_doc()
keywords = extract_keywords(sent_doc, top_n)
Expand All @@ -280,7 +281,7 @@ def batch_text_to_tsvector_queries(
if len(group) > 1:
# Sort by length descending to prioritize longer phrases
sorted_group = sorted(group, key=len, reverse=True)
# For truly proximate multi-word groups, group words
# For truly proximate multi-word groups, group words
queries.add(" OR ".join(sorted_group))
else:
# For non-proximate words or single words, add them separately
Expand Down
1 change: 0 additions & 1 deletion agents-api/agents_api/queries/docs/search_docs_by_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from ...autogen.openapi_model import DocReference
from ...common.utils.db_exceptions import common_db_exceptions
from ..utils import pg_query, rewrap_exceptions, wrap_in_class
from ...common.nlp import text_to_tsvector_query
from .utils import transform_to_doc_reference

# Raw query for text search
Expand Down
4 changes: 2 additions & 2 deletions agents-api/tests/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,15 +187,15 @@ async def test_doc_with_embedding(dsn=pg_dsn, developer=test_developer, doc=test
f"[{', '.join([str(0.3 + 0.4 * (i % 3) / 2) for i in range(1024)])}]",
)

# Insert embedding with random values between -0.8 and 0.8
# Insert embedding with random values between -0.8 and 0.8
await pool.execute(
"""
INSERT INTO docs_embeddings_store (developer_id, doc_id, index, chunk_seq, chunk, embedding)
VALUES ($1, $2, 0, 2, $3, $4)
""",
developer.id,
doc.id,
"Test content 2",
"Test content 2",
f"[{', '.join([str(-0.8 + 1.6 * (i % 5) / 4) for i in range(1024)])}]",
)

Expand Down
96 changes: 39 additions & 57 deletions agents-api/tests/test_docs_queries.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from agents_api.autogen.openapi_model import CreateDocRequest
from agents_api.clients.pg import create_db_pool
from agents_api.common.nlp import text_to_tsvector_query
from agents_api.queries.docs.create_doc import create_doc
from agents_api.queries.docs.delete_doc import delete_doc
from agents_api.queries.docs.get_doc import get_doc
Expand All @@ -9,8 +10,6 @@
from agents_api.queries.docs.search_docs_hybrid import search_docs_hybrid
from ward import test

from agents_api.common.nlp import text_to_tsvector_query

from .fixtures import (
pg_dsn,
test_agent,
Expand All @@ -24,48 +23,50 @@

import math


def make_vector_with_similarity(n: int, d: float):
"""
Returns a list `v` of length `n` such that the cosine similarity
between `v` and the all-ones vector of length `n` is approximately d.
"""
if not -1.0 <= d <= 1.0:
raise ValueError("d must lie in [-1, 1].")

msg = "d must lie in [-1, 1]."
raise ValueError(msg)

# Handle special cases exactly:
if abs(d - 1.0) < 1e-12: # d ~ +1
return [1.0] * n
if abs(d + 1.0) < 1e-12: # d ~ -1
return [-1.0] * n
if abs(d) < 1e-12: # d ~ 0
v = [0.0]*n
if abs(d) < 1e-12: # d ~ 0
v = [0.0] * n
if n >= 2:
v[0] = 1.0
v[1] = -1.0
return v

sign_d = 1.0 if d >= 0 else -1.0

# Base part: sign(d)*[1,1,...,1]
base = [sign_d]*n
base = [sign_d] * n

# Orthogonal unit vector u with sum(u)=0; for simplicity:
# u = [1/sqrt(2), -1/sqrt(2), 0, 0, ..., 0]
u = [0.0]*n
u = [0.0] * n
if n >= 2:
u[0] = 1.0 / math.sqrt(2)
u[1] = -1.0 / math.sqrt(2)
# (if n=1, there's no truly orthogonal vector to [1], so skip)

# Solve for alpha:
# alpha^2 = n*(1 - d^2)/d^2
alpha = math.sqrt(n*(1 - d*d)) / abs(d)
alpha = math.sqrt(n * (1 - d * d)) / abs(d)

# Construct v
v = [0.0]*n
v = [0.0] * n
for i in range(n):
v[i] = base[i] + alpha * u[i]

return v


Expand Down Expand Up @@ -304,6 +305,7 @@ async def _(dsn=pg_dsn, agent=test_agent, developer=test_developer):
assert any(d.id == doc.id for d in result), f"Should find document {doc.id}"
assert result[0].metadata == {"test": "test"}, "Metadata should match"


@test("query: search docs by text with technical terms and phrases")
async def _(dsn=pg_dsn, developer=test_developer, agent=test_agent):
pool = await create_db_pool(dsn=dsn)
Expand Down Expand Up @@ -340,7 +342,7 @@ async def _(dsn=pg_dsn, developer=test_developer, agent=test_agent):
"API endpoints",
"REST architecture",
"database optimization",
"indexing"
"indexing",
]

for query in technical_queries:
Expand All @@ -357,9 +359,14 @@ async def _(dsn=pg_dsn, developer=test_developer, agent=test_agent):

# Verify appropriate document is found based on query
if "API" in query or "REST" in query:
assert any(doc.id == doc1.id for doc in results), f"Doc1 should be found with query '{query}'"
assert any(doc.id == doc1.id for doc in results), (
f"Doc1 should be found with query '{query}'"
)
if "database" in query.lower() or "indexing" in query:
assert any(doc.id == doc2.id for doc in results), f"Doc2 should be found with query '{query}'"
assert any(doc.id == doc2.id for doc in results), (
f"Doc2 should be found with query '{query}'"
)


@test("query: search docs by embedding")
async def _(
Expand Down Expand Up @@ -409,84 +416,59 @@ async def _(
assert len(result) >= 1
assert result[0].metadata is not None


@test("utility: test text_to_tsvector_query")
async def _():
test_cases = [
# Single words
(
"test",
"test"
),

("test", "test"),
# Multiple words in single sentence
(
"quick brown fox",
"quick brown fox" # Now kept as a single phrase due to proximity
"quick brown fox", # Now kept as a single phrase due to proximity
),

# Technical terms and phrases
(
"Machine Learning algorithm",
"machine learning algorithm" # Common technical phrase
"machine learning algorithm", # Common technical phrase
),
# Multiple sentences
(
"Machine learning is great. Data science rocks.",
"machine learning OR data science rocks"
"machine learning OR data science rocks",
),

# Quoted phrases
(
'"quick brown fox"',
"quick brown fox" # Quotes removed, phrase kept together
"quick brown fox", # Quotes removed, phrase kept together
),
(
'Find "machine learning" algorithms',
"machine learning"
),

('Find "machine learning" algorithms', "machine learning"),
# Multiple quoted phrases
(
'"data science" and "machine learning"',
"machine learning OR data science"
),

('"data science" and "machine learning"', "machine learning OR data science"),
# Edge cases
(
"",
""
),
("", ""),
(
"the and or",
"" # All stop words should result in empty string
"", # All stop words should result in empty string
),
(
"a",
"" # Single stop word should result in empty string
"", # Single stop word should result in empty string
),
(
"X",
"X"
),

("X", "X"),
# Empty quotes
(
'""',
""
),
(
'test "" phrase',
"phrase OR test"
),
('""', ""),
('test "" phrase', "phrase OR test"),
]

for input_text, expected_output in test_cases:
print(f"Input: '{input_text}'")
result = text_to_tsvector_query(input_text)
print(f"Generated query: '{result}'")
print(f"Expected: '{expected_output}'\n")
assert result.lower() == expected_output.lower(), \
assert result.lower() == expected_output.lower(), (
f"Expected '{expected_output}' but got '{result}' for input '{input_text}'"
)


# @test("query: search docs by embedding with different confidence levels")
Expand Down

0 comments on commit 25a2e65

Please sign in to comment.