Skip to content

Commit

Permalink
refactor: Lint agents-api (CI)
Browse files Browse the repository at this point in the history
  • Loading branch information
Vedantsahai18 committed Jan 9, 2025
1 parent de7621b commit f697394
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -78,4 +78,3 @@ async def search_docs_by_embedding(
metadata_filter,
],
)

34 changes: 20 additions & 14 deletions agents-api/tests/test_docs_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,50 +22,53 @@

import math


def make_vector_with_similarity(n: int, d: float):
"""
Returns a list `v` of length `n` such that the cosine similarity
between `v` and the all-ones vector of length `n` is approximately d.
"""
if not -1.0 <= d <= 1.0:
raise ValueError("d must lie in [-1, 1].")

msg = "d must lie in [-1, 1]."
raise ValueError(msg)

# Handle special cases exactly:
if abs(d - 1.0) < 1e-12: # d ~ +1
return [1.0] * n
if abs(d + 1.0) < 1e-12: # d ~ -1
return [-1.0] * n
if abs(d) < 1e-12: # d ~ 0
v = [0.0]*n
if abs(d) < 1e-12: # d ~ 0
v = [0.0] * n
if n >= 2:
v[0] = 1.0
v[1] = -1.0
return v

sign_d = 1.0 if d >= 0 else -1.0

# Base part: sign(d)*[1,1,...,1]
base = [sign_d]*n
base = [sign_d] * n

# Orthogonal unit vector u with sum(u)=0; for simplicity:
# u = [1/sqrt(2), -1/sqrt(2), 0, 0, ..., 0]
u = [0.0]*n
u = [0.0] * n
if n >= 2:
u[0] = 1.0 / math.sqrt(2)
u[1] = -1.0 / math.sqrt(2)
# (if n=1, there's no truly orthogonal vector to [1], so skip)

# Solve for alpha:
# alpha^2 = n*(1 - d^2)/d^2
alpha = math.sqrt(n*(1 - d*d)) / abs(d)
alpha = math.sqrt(n * (1 - d * d)) / abs(d)

# Construct v
v = [0.0]*n
v = [0.0] * n
for i in range(n):
v[i] = base[i] + alpha * u[i]

return v


@test("query: create user doc")
async def _(dsn=pg_dsn, developer=test_developer, user=test_user):
pool = await create_db_pool(dsn=dsn)
Expand Down Expand Up @@ -329,7 +332,10 @@ async def _(

@test("query: search docs by embedding with confidence")
async def _(
dsn=pg_dsn, agent=test_agent, developer=test_developer, doc=test_doc_with_embedding,
dsn=pg_dsn,
agent=test_agent,
developer=test_developer,
doc=test_doc_with_embedding,
):
pool = await create_db_pool(dsn=dsn)

Expand All @@ -344,7 +350,7 @@ async def _(
developer_id=developer.id,
owners=[("agent", agent.id)],
embedding=query_embedding,
confidence=confidence*0.9,
confidence=confidence * 0.9,
k=3, # Add k parameter
metadata_filter={"test": "test"}, # Add metadata filter
connection_pool=pool,
Expand All @@ -358,7 +364,7 @@ async def _(
developer_id=developer.id,
owners=[("agent", agent.id)],
embedding=query_embedding,
confidence=confidence*1.1,
confidence=confidence * 1.1,
k=3, # Add k parameter
metadata_filter={"test": "test"}, # Add metadata filter
connection_pool=pool,
Expand Down

0 comments on commit f697394

Please sign in to comment.