Skip to content

Commit

Permalink
fix: embedding search confidence fixed
Browse files Browse the repository at this point in the history
  • Loading branch information
Vedantsahai18 committed Jan 9, 2025
1 parent 131f594 commit 15e9097
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@ async def search_docs_by_embedding(
owner_types,
owner_ids,
k,
1.0 - confidence,
confidence,
metadata_filter,
],
)

2 changes: 1 addition & 1 deletion agents-api/agents_api/queries/docs/search_docs_hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ async def search_docs_hybrid(
owner_ids,
k,
alpha,
1.0 - confidence,
confidence,
metadata_filter,
search_language,
],
Expand Down
87 changes: 86 additions & 1 deletion agents-api/tests/test_docs_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,51 @@

EMBEDDING_SIZE: int = 1024

import math

def make_vector_with_similarity(n: int, d: float):
"""
Returns a list `v` of length `n` such that the cosine similarity
between `v` and the all-ones vector of length `n` is approximately d.
"""
if not -1.0 <= d <= 1.0:
raise ValueError("d must lie in [-1, 1].")

# Handle special cases exactly:
if abs(d - 1.0) < 1e-12: # d ~ +1
return [1.0] * n
if abs(d + 1.0) < 1e-12: # d ~ -1
return [-1.0] * n
if abs(d) < 1e-12: # d ~ 0
v = [0.0]*n
if n >= 2:
v[0] = 1.0
v[1] = -1.0
return v

sign_d = 1.0 if d >= 0 else -1.0

# Base part: sign(d)*[1,1,...,1]
base = [sign_d]*n

# Orthogonal unit vector u with sum(u)=0; for simplicity:
# u = [1/sqrt(2), -1/sqrt(2), 0, 0, ..., 0]
u = [0.0]*n
if n >= 2:
u[0] = 1.0 / math.sqrt(2)
u[1] = -1.0 / math.sqrt(2)
# (if n=1, there's no truly orthogonal vector to [1], so skip)

# Solve for alpha:
# alpha^2 = n*(1 - d^2)/d^2
alpha = math.sqrt(n*(1 - d*d)) / abs(d)

# Construct v
v = [0.0]*n
for i in range(n):
v[i] = base[i] + alpha * u[i]

return v

@test("query: create user doc")
async def _(dsn=pg_dsn, developer=test_developer, user=test_user):
Expand Down Expand Up @@ -257,7 +302,7 @@ async def _(dsn=pg_dsn, agent=test_agent, developer=test_developer):
assert result[0].metadata == {"test": "test"}, "Metadata should match"


@test("query: search docs by embedding")
@test("query: search docs by embedding without confidence")
async def _(
dsn=pg_dsn, agent=test_agent, developer=test_developer, doc=test_doc_with_embedding
):
Expand All @@ -282,6 +327,46 @@ async def _(
assert result[0].metadata is not None


@test("query: search docs by embedding with confidence")
async def _(
dsn=pg_dsn, agent=test_agent, developer=test_developer, doc=test_doc_with_embedding,
):
pool = await create_db_pool(dsn=dsn)

assert doc.embeddings is not None

# Get query embedding by averaging the embeddings (list of floats)
confidence = 0.9
query_embedding = make_vector_with_similarity(len(doc.embeddings[0]), confidence)

# Search using the correct parameter types
result = await search_docs_by_embedding(
developer_id=developer.id,
owners=[("agent", agent.id)],
embedding=query_embedding,
confidence=confidence*0.9,
k=3, # Add k parameter
metadata_filter={"test": "test"}, # Add metadata filter
connection_pool=pool,
)

assert len(result) >= 1
assert result[0].metadata is not None

# Search using the correct parameter types
result = await search_docs_by_embedding(
developer_id=developer.id,
owners=[("agent", agent.id)],
embedding=query_embedding,
confidence=confidence*1.1,
k=3, # Add k parameter
metadata_filter={"test": "test"}, # Add metadata filter
connection_pool=pool,
)

assert len(result) == 0


@test("query: search docs by hybrid")
async def _(
dsn=pg_dsn, agent=test_agent, developer=test_developer, doc=test_doc_with_embedding
Expand Down
14 changes: 7 additions & 7 deletions memory-store/migrations/000018_doc_search.up.sql
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ OR REPLACE FUNCTION search_by_vector (
metadata_filter jsonb DEFAULT NULL
) RETURNS SETOF doc_search_result LANGUAGE plpgsql AS $$
DECLARE
search_threshold float;
distance_threshold float;
owner_filter_sql text;
metadata_filter_sql text;
BEGIN
Expand All @@ -114,7 +114,7 @@ BEGIN
END IF;

-- Calculate search threshold from confidence
search_threshold := 1.0 - confidence;
distance_threshold := 1.0 - confidence;

-- Build owner filter SQL
owner_filter_sql := '
Expand All @@ -138,30 +138,30 @@ BEGIN
d.index,
d.title,
d.content,
(1 - (d.embedding <=> $1)) as distance,
(d.embedding <=> $1) as distance,
d.embedding,
d.metadata,
doc_owners.owner_type,
doc_owners.owner_id
FROM docs_embeddings d
LEFT JOIN doc_owners ON d.doc_id = doc_owners.doc_id
WHERE d.developer_id = $7
AND 1 - (d.embedding <=> $1) >= $2
AND distance >= $2
%s
%s
ORDER BY (1 - (d.embedding <=> $1)) DESC
ORDER BY distance ASC
LIMIT ($3 * 4) -- Get more candidates than needed
)
SELECT DISTINCT ON (doc_id) *
FROM ranked_docs
ORDER BY doc_id, distance DESC
ORDER BY doc_id, distance ASC
LIMIT $3',
owner_filter_sql,
metadata_filter_sql
)
USING
query_embedding,
search_threshold,
distance_threshold,
k,
owner_types,
owner_ids,
Expand Down

0 comments on commit 15e9097

Please sign in to comment.