Skip to content

Commit

Permalink
fix(agents-api,memory-store): Fix docs tests and related migration
Browse files Browse the repository at this point in the history
Signed-off-by: Diwank Singh Tomer <[email protected]>
  • Loading branch information
creatorrr committed Jan 2, 2025
1 parent 3fa2272 commit 38b4acb
Show file tree
Hide file tree
Showing 8 changed files with 178 additions and 143 deletions.
62 changes: 31 additions & 31 deletions agents-api/agents_api/queries/docs/get_doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,37 +8,37 @@

# Update the query to use DISTINCT ON to prevent duplicates
doc_with_embedding_query = """
WITH doc_data AS (
SELECT
d.doc_id,
d.developer_id,
d.title,
array_agg(d.content ORDER BY d.index) as content,
array_agg(d.index ORDER BY d.index) as indices,
array_agg(e.embedding ORDER BY d.index) as embeddings,
d.modality,
d.embedding_model,
d.embedding_dimensions,
d.language,
d.metadata,
d.created_at
FROM docs d
LEFT JOIN docs_embeddings e
ON d.doc_id = e.doc_id
WHERE d.developer_id = $1
AND d.doc_id = $2
GROUP BY
d.doc_id,
d.developer_id,
d.title,
d.modality,
d.embedding_model,
d.embedding_dimensions,
d.language,
d.metadata,
d.created_at
)
SELECT * FROM doc_data;
SELECT
d.doc_id,
d.developer_id,
d.title,
array_agg(d.content ORDER BY d.index) as content,
array_agg(d.index ORDER BY d.index) as indices,
array_agg(e.embedding ORDER BY d.index) as embeddings,
d.modality,
d.embedding_model,
d.embedding_dimensions,
d.language,
d.metadata,
d.created_at
FROM docs d
LEFT JOIN docs_embeddings e
ON d.doc_id = e.doc_id
AND e.embedding IS NOT NULL
WHERE d.developer_id = $1
AND d.doc_id = $2
GROUP BY
d.doc_id,
d.developer_id,
d.title,
d.modality,
d.embedding_model,
d.embedding_dimensions,
d.language,
d.metadata,
d.created_at
ORDER BY d.created_at DESC
LIMIT 1;
"""


Expand Down
74 changes: 32 additions & 42 deletions agents-api/agents_api/queries/docs/list_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,41 +15,38 @@

# Base query for listing docs with aggregated content and embeddings
base_docs_query = """
WITH doc_data AS (
SELECT
d.doc_id,
d.developer_id,
d.title,
array_agg(d.content ORDER BY d.index) as content,
array_agg(d.index ORDER BY d.index) as indices,
array_agg(CASE WHEN $2 THEN NULL ELSE e.embedding END ORDER BY d.index) as embeddings,
d.modality,
d.embedding_model,
d.embedding_dimensions,
d.language,
d.metadata,
d.created_at
FROM docs d
JOIN doc_owners doc_own
ON d.developer_id = doc_own.developer_id
AND d.doc_id = doc_own.doc_id
LEFT JOIN docs_embeddings e
ON d.doc_id = e.doc_id
WHERE d.developer_id = $1
AND doc_own.owner_type = $3
AND doc_own.owner_id = $4
GROUP BY
d.doc_id,
d.developer_id,
d.title,
d.modality,
d.embedding_model,
d.embedding_dimensions,
d.language,
d.metadata,
d.created_at
)
SELECT * FROM doc_data
SELECT
d.doc_id,
d.developer_id,
d.title,
array_agg(d.content ORDER BY d.index) as content,
array_agg(d.index ORDER BY d.index) as indices,
array_agg(CASE WHEN $2 THEN NULL ELSE e.embedding END ORDER BY d.index) as embeddings,
d.modality,
d.embedding_model,
d.embedding_dimensions,
d.language,
d.metadata,
d.created_at
FROM docs d
JOIN doc_owners doc_own
ON d.developer_id = doc_own.developer_id
AND d.doc_id = doc_own.doc_id
LEFT JOIN docs_embeddings e
ON d.doc_id = e.doc_id
WHERE d.developer_id = $1
AND doc_own.owner_type = $3
AND doc_own.owner_id = $4
GROUP BY
d.doc_id,
d.developer_id,
d.title,
d.modality,
d.embedding_model,
d.embedding_dimensions,
d.language,
d.metadata,
d.created_at
"""


Expand All @@ -58,13 +55,6 @@ def transform_list_docs(d: dict) -> dict:

embeddings = d["embeddings"][0] if len(d["embeddings"]) == 1 else d["embeddings"]

# try:
# # Embeddings are retreived as a string, so we need to evaluate it
# embeddings = ast.literal_eval(embeddings)
# except Exception as e:
# msg = f"Error evaluating embeddings: {e}"
# raise ValueError(msg)

if embeddings and all((e is None) for e in embeddings):
embeddings = None

Expand Down
18 changes: 10 additions & 8 deletions agents-api/agents_api/queries/docs/utils.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,21 @@
import ast
import json


def transform_to_doc_reference(d: dict) -> dict:
id = d.pop("doc_id")
content = d.pop("content")
index = d.pop("index")

embedding = d.pop("embedding")
# Convert embedding array string to list of floats if present
if d["embedding"] is not None:
try:
embedding = json.loads(d["embedding"])
except Exception as e:
msg = f"Error evaluating embeddings: {e}"
raise ValueError(msg)

try:
# Embeddings are retreived as a string, so we need to evaluate it
embedding = ast.literal_eval(embedding)
except Exception as e:
msg = f"Error evaluating embeddings: {e}"
raise ValueError(msg)
else:
embedding = None

owner = {
"id": d.pop("owner_id"),
Expand Down
88 changes: 47 additions & 41 deletions agents-api/tests/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,53 @@ async def test_doc(dsn=pg_dsn, developer=test_developer, agent=test_agent):
owner_id=agent.id,
connection_pool=pool,
)
return await get_doc(developer_id=developer.id, doc_id=resp.id, connection_pool=pool)

# Explicitly Refresh Indices: After inserting data, run a command to refresh the index,
# ensuring it's up-to-date before executing queries.
# This can be achieved by executing a REINDEX command
await pool.execute("REINDEX DATABASE")

yield await get_doc(developer_id=developer.id, doc_id=resp.id, connection_pool=pool)

# await delete_doc(
# developer_id=developer.id,
# doc_id=resp.id,
# owner_type="agent",
# owner_id=agent.id,
# connection_pool=pool,
# )


@fixture(scope="test")
async def test_user_doc(dsn=pg_dsn, developer=test_developer, user=test_user):
pool = await create_db_pool(dsn=dsn)
resp = await create_doc(
developer_id=developer.id,
data=CreateDocRequest(
title="Hello",
content=["World", "World2", "World3"],
metadata={"test": "test"},
embed_instruction="Embed the document",
),
owner_type="user",
owner_id=user.id,
connection_pool=pool,
)

# Explicitly Refresh Indices: After inserting data, run a command to refresh the index,
# ensuring it's up-to-date before executing queries.
# This can be achieved by executing a REINDEX command
await pool.execute("REINDEX DATABASE")

yield await get_doc(developer_id=developer.id, doc_id=resp.id, connection_pool=pool)

# await delete_doc(
# developer_id=developer.id,
# doc_id=resp.id,
# owner_type="user",
# owner_id=user.id,
# connection_pool=pool,
# )


@fixture(scope="test")
Expand Down Expand Up @@ -209,46 +255,6 @@ async def test_session(
)


@fixture(scope="global")
async def test_user_doc(
dsn=pg_dsn,
developer_id=test_developer_id,
user=test_user,
):
pool = await create_db_pool(dsn=dsn)
doc = await create_doc(
developer_id=developer_id,
owner_type="user",
owner_id=user.id,
data=CreateDocRequest(title="Hello", content=["World"]),
connection_pool=pool,
)
yield doc


# @fixture(scope="global")
# async def test_task(
# dsn=pg_dsn,
# developer_id=test_developer_id,
# agent=test_agent,
# ):
# async with get_pg_client(dsn=dsn) as client:
# task = await create_task(
# developer_id=developer_id,
# agent_id=agent.id,
# data=CreateTaskRequest(
# **{
# "name": "test task",
# "description": "test task about",
# "input_schema": {"type": "object", "additionalProperties": True},
# "main": [{"evaluate": {"hello": '"world"'}}],
# }
# ),
# client=client,
# )
# yield task


@fixture(scope="global")
async def test_execution(
dsn=pg_dsn,
Expand Down
26 changes: 16 additions & 10 deletions agents-api/tests/test_docs_queries.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import asyncio

from agents_api.autogen.openapi_model import CreateDocRequest
from agents_api.clients.pg import create_db_pool
from agents_api.queries.docs.create_doc import create_doc
Expand All @@ -11,7 +9,7 @@
from agents_api.queries.docs.search_docs_hybrid import search_docs_hybrid
from ward import skip, test

from tests.fixtures import pg_dsn, test_agent, test_developer, test_doc, test_user
from .fixtures import pg_dsn, test_agent, test_developer, test_doc, test_user

EMBEDDING_SIZE: int = 1024

Expand Down Expand Up @@ -215,7 +213,6 @@ async def _(dsn=pg_dsn, developer=test_developer, agent=test_agent):
assert not any(d.id == doc_agent.id for d in docs_list)


@skip("text search: test container not vectorizing")
@test("query: search docs by text")
async def _(dsn=pg_dsn, agent=test_agent, developer=test_developer):
pool = await create_db_pool(dsn=dsn)
Expand All @@ -234,9 +231,6 @@ async def _(dsn=pg_dsn, agent=test_agent, developer=test_developer):
connection_pool=pool,
)

# Add a longer delay to ensure the search index is updated
await asyncio.sleep(3)

# Search using simpler terms first
result = await search_docs_by_text(
developer_id=developer.id,
Expand All @@ -262,7 +256,7 @@ async def _(dsn=pg_dsn, agent=test_agent, developer=test_developer):
pool = await create_db_pool(dsn=dsn)

# Create a test document
await create_doc(
doc = await create_doc(
developer_id=developer.id,
owner_type="agent",
owner_id=agent.id,
Expand All @@ -275,11 +269,23 @@ async def _(dsn=pg_dsn, agent=test_agent, developer=test_developer):
connection_pool=pool,
)

# Create a test document
doc = await get_doc(
developer_id=developer.id,
doc_id=doc.id,
connection_pool=pool,
)

assert doc.embeddings is not None

# Get query embedding by averaging the embeddings (list of floats)
query_embedding = [sum(k) / len(k) for k in zip(*doc.embeddings)]

# Search using the correct parameter types
result = await search_docs_by_embedding(
developer_id=developer.id,
owners=[("agent", agent.id)],
query_embedding=[1.0] * 1024,
query_embedding=query_embedding,
k=3, # Add k parameter
metadata_filter={"test": "test"}, # Add metadata filter
connection_pool=pool,
Expand All @@ -289,7 +295,7 @@ async def _(dsn=pg_dsn, agent=test_agent, developer=test_developer):
assert result[0].metadata is not None


@skip("hybrid search: test container not vectorizing")
@skip("embedding search: test container not vectorizing")
@test("query: search docs by hybrid")
async def _(dsn=pg_dsn, agent=test_agent, developer=test_developer):
pool = await create_db_pool(dsn=dsn)
Expand Down
Loading

0 comments on commit 38b4acb

Please sign in to comment.