diff --git a/agents-api/agents_api/queries/docs/get_doc.py b/agents-api/agents_api/queries/docs/get_doc.py index c742a3054..bef941d1f 100644 --- a/agents-api/agents_api/queries/docs/get_doc.py +++ b/agents-api/agents_api/queries/docs/get_doc.py @@ -8,37 +8,37 @@ # Update the query to use DISTINCT ON to prevent duplicates doc_with_embedding_query = """ -WITH doc_data AS ( - SELECT - d.doc_id, - d.developer_id, - d.title, - array_agg(d.content ORDER BY d.index) as content, - array_agg(d.index ORDER BY d.index) as indices, - array_agg(e.embedding ORDER BY d.index) as embeddings, - d.modality, - d.embedding_model, - d.embedding_dimensions, - d.language, - d.metadata, - d.created_at - FROM docs d - LEFT JOIN docs_embeddings e - ON d.doc_id = e.doc_id - WHERE d.developer_id = $1 - AND d.doc_id = $2 - GROUP BY - d.doc_id, - d.developer_id, - d.title, - d.modality, - d.embedding_model, - d.embedding_dimensions, - d.language, - d.metadata, - d.created_at -) -SELECT * FROM doc_data; +SELECT + d.doc_id, + d.developer_id, + d.title, + array_agg(d.content ORDER BY d.index) as content, + array_agg(d.index ORDER BY d.index) as indices, + array_agg(e.embedding ORDER BY d.index) as embeddings, + d.modality, + d.embedding_model, + d.embedding_dimensions, + d.language, + d.metadata, + d.created_at +FROM docs d +LEFT JOIN docs_embeddings e + ON d.doc_id = e.doc_id + AND e.embedding IS NOT NULL +WHERE d.developer_id = $1 + AND d.doc_id = $2 +GROUP BY + d.doc_id, + d.developer_id, + d.title, + d.modality, + d.embedding_model, + d.embedding_dimensions, + d.language, + d.metadata, + d.created_at +ORDER BY d.created_at DESC +LIMIT 1; """ diff --git a/agents-api/agents_api/queries/docs/list_docs.py b/agents-api/agents_api/queries/docs/list_docs.py index 60c0118a8..87893854a 100644 --- a/agents-api/agents_api/queries/docs/list_docs.py +++ b/agents-api/agents_api/queries/docs/list_docs.py @@ -15,41 +15,38 @@ # Base query for listing docs with aggregated content and embeddings base_docs_query = """ -WITH doc_data AS ( - SELECT - d.doc_id, - d.developer_id, - d.title, - array_agg(d.content ORDER BY d.index) as content, - array_agg(d.index ORDER BY d.index) as indices, - array_agg(CASE WHEN $2 THEN NULL ELSE e.embedding END ORDER BY d.index) as embeddings, - d.modality, - d.embedding_model, - d.embedding_dimensions, - d.language, - d.metadata, - d.created_at - FROM docs d - JOIN doc_owners doc_own - ON d.developer_id = doc_own.developer_id - AND d.doc_id = doc_own.doc_id - LEFT JOIN docs_embeddings e - ON d.doc_id = e.doc_id - WHERE d.developer_id = $1 - AND doc_own.owner_type = $3 - AND doc_own.owner_id = $4 - GROUP BY - d.doc_id, - d.developer_id, - d.title, - d.modality, - d.embedding_model, - d.embedding_dimensions, - d.language, - d.metadata, - d.created_at -) -SELECT * FROM doc_data +SELECT + d.doc_id, + d.developer_id, + d.title, + array_agg(d.content ORDER BY d.index) as content, + array_agg(d.index ORDER BY d.index) as indices, + array_agg(CASE WHEN $2 THEN NULL ELSE e.embedding END ORDER BY d.index) as embeddings, + d.modality, + d.embedding_model, + d.embedding_dimensions, + d.language, + d.metadata, + d.created_at +FROM docs d +JOIN doc_owners doc_own + ON d.developer_id = doc_own.developer_id + AND d.doc_id = doc_own.doc_id +LEFT JOIN docs_embeddings e + ON d.doc_id = e.doc_id +WHERE d.developer_id = $1 + AND doc_own.owner_type = $3 + AND doc_own.owner_id = $4 +GROUP BY + d.doc_id, + d.developer_id, + d.title, + d.modality, + d.embedding_model, + d.embedding_dimensions, + d.language, + d.metadata, + d.created_at """ @@ -58,13 +55,6 @@ def transform_list_docs(d: dict) -> dict: embeddings = d["embeddings"][0] if len(d["embeddings"]) == 1 else d["embeddings"] - # try: - # # Embeddings are retreived as a string, so we need to evaluate it - # embeddings = ast.literal_eval(embeddings) - # except Exception as e: - # msg = f"Error evaluating embeddings: {e}" - # raise ValueError(msg) - if embeddings and all((e is None) for e in embeddings): embeddings = None diff --git a/agents-api/agents_api/queries/docs/utils.py b/agents-api/agents_api/queries/docs/utils.py index 4d1cbaf45..423d14fcc 100644 --- a/agents-api/agents_api/queries/docs/utils.py +++ b/agents-api/agents_api/queries/docs/utils.py @@ -1,4 +1,4 @@ -import ast +import json def transform_to_doc_reference(d: dict) -> dict: @@ -6,14 +6,16 @@ def transform_to_doc_reference(d: dict) -> dict: content = d.pop("content") index = d.pop("index") - embedding = d.pop("embedding") + # Convert embedding array string to list of floats if present + if d["embedding"] is not None: + try: + embedding = json.loads(d["embedding"]) + except Exception as e: + msg = f"Error evaluating embeddings: {e}" + raise ValueError(msg) - try: - # Embeddings are retreived as a string, so we need to evaluate it - embedding = ast.literal_eval(embedding) - except Exception as e: - msg = f"Error evaluating embeddings: {e}" - raise ValueError(msg) + else: + embedding = None owner = { "id": d.pop("owner_id"), diff --git a/agents-api/tests/fixtures.py b/agents-api/tests/fixtures.py index 72e8f4d7e..00e307cf4 100644 --- a/agents-api/tests/fixtures.py +++ b/agents-api/tests/fixtures.py @@ -143,7 +143,53 @@ async def test_doc(dsn=pg_dsn, developer=test_developer, agent=test_agent): owner_id=agent.id, connection_pool=pool, ) - return await get_doc(developer_id=developer.id, doc_id=resp.id, connection_pool=pool) + + # Explicitly Refresh Indices: After inserting data, run a command to refresh the index, + # ensuring it's up-to-date before executing queries. + # This can be achieved by executing a REINDEX command + await pool.execute("REINDEX DATABASE") + + yield await get_doc(developer_id=developer.id, doc_id=resp.id, connection_pool=pool) + + # await delete_doc( + # developer_id=developer.id, + # doc_id=resp.id, + # owner_type="agent", + # owner_id=agent.id, + # connection_pool=pool, + # ) + + +@fixture(scope="test") +async def test_user_doc(dsn=pg_dsn, developer=test_developer, user=test_user): + pool = await create_db_pool(dsn=dsn) + resp = await create_doc( + developer_id=developer.id, + data=CreateDocRequest( + title="Hello", + content=["World", "World2", "World3"], + metadata={"test": "test"}, + embed_instruction="Embed the document", + ), + owner_type="user", + owner_id=user.id, + connection_pool=pool, + ) + + # Explicitly Refresh Indices: After inserting data, run a command to refresh the index, + # ensuring it's up-to-date before executing queries. + # This can be achieved by executing a REINDEX command + await pool.execute("REINDEX DATABASE") + + yield await get_doc(developer_id=developer.id, doc_id=resp.id, connection_pool=pool) + + # await delete_doc( + # developer_id=developer.id, + # doc_id=resp.id, + # owner_type="user", + # owner_id=user.id, + # connection_pool=pool, + # ) @fixture(scope="test") @@ -209,46 +255,6 @@ async def test_session( ) -@fixture(scope="global") -async def test_user_doc( - dsn=pg_dsn, - developer_id=test_developer_id, - user=test_user, -): - pool = await create_db_pool(dsn=dsn) - doc = await create_doc( - developer_id=developer_id, - owner_type="user", - owner_id=user.id, - data=CreateDocRequest(title="Hello", content=["World"]), - connection_pool=pool, - ) - yield doc - - -# @fixture(scope="global") -# async def test_task( -# dsn=pg_dsn, -# developer_id=test_developer_id, -# agent=test_agent, -# ): -# async with get_pg_client(dsn=dsn) as client: -# task = await create_task( -# developer_id=developer_id, -# agent_id=agent.id, -# data=CreateTaskRequest( -# **{ -# "name": "test task", -# "description": "test task about", -# "input_schema": {"type": "object", "additionalProperties": True}, -# "main": [{"evaluate": {"hello": '"world"'}}], -# } -# ), -# client=client, -# ) -# yield task - - @fixture(scope="global") async def test_execution( dsn=pg_dsn, diff --git a/agents-api/tests/test_docs_queries.py b/agents-api/tests/test_docs_queries.py index 2c49de891..fa17e6861 100644 --- a/agents-api/tests/test_docs_queries.py +++ b/agents-api/tests/test_docs_queries.py @@ -1,5 +1,3 @@ -import asyncio - from agents_api.autogen.openapi_model import CreateDocRequest from agents_api.clients.pg import create_db_pool from agents_api.queries.docs.create_doc import create_doc @@ -11,7 +9,7 @@ from agents_api.queries.docs.search_docs_hybrid import search_docs_hybrid from ward import skip, test -from tests.fixtures import pg_dsn, test_agent, test_developer, test_doc, test_user +from .fixtures import pg_dsn, test_agent, test_developer, test_doc, test_user EMBEDDING_SIZE: int = 1024 @@ -215,7 +213,6 @@ async def _(dsn=pg_dsn, developer=test_developer, agent=test_agent): assert not any(d.id == doc_agent.id for d in docs_list) -@skip("text search: test container not vectorizing") @test("query: search docs by text") async def _(dsn=pg_dsn, agent=test_agent, developer=test_developer): pool = await create_db_pool(dsn=dsn) @@ -234,9 +231,6 @@ async def _(dsn=pg_dsn, agent=test_agent, developer=test_developer): connection_pool=pool, ) - # Add a longer delay to ensure the search index is updated - await asyncio.sleep(3) - # Search using simpler terms first result = await search_docs_by_text( developer_id=developer.id, @@ -262,7 +256,7 @@ async def _(dsn=pg_dsn, agent=test_agent, developer=test_developer): pool = await create_db_pool(dsn=dsn) # Create a test document - await create_doc( + doc = await create_doc( developer_id=developer.id, owner_type="agent", owner_id=agent.id, @@ -275,11 +269,23 @@ async def _(dsn=pg_dsn, agent=test_agent, developer=test_developer): connection_pool=pool, ) + # Create a test document + doc = await get_doc( + developer_id=developer.id, + doc_id=doc.id, + connection_pool=pool, + ) + + assert doc.embeddings is not None + + # Get query embedding by averaging the embeddings (list of floats) + query_embedding = [sum(k) / len(k) for k in zip(*doc.embeddings)] + # Search using the correct parameter types result = await search_docs_by_embedding( developer_id=developer.id, owners=[("agent", agent.id)], - query_embedding=[1.0] * 1024, + query_embedding=query_embedding, k=3, # Add k parameter metadata_filter={"test": "test"}, # Add metadata filter connection_pool=pool, @@ -289,7 +295,7 @@ async def _(dsn=pg_dsn, agent=test_agent, developer=test_developer): assert result[0].metadata is not None -@skip("hybrid search: test container not vectorizing") +@skip("embedding search: test container not vectorizing") @test("query: search docs by hybrid") async def _(dsn=pg_dsn, agent=test_agent, developer=test_developer): pool = await create_db_pool(dsn=dsn) diff --git a/agents-api/tests/test_docs_routes.py b/agents-api/tests/test_docs_routes.py index e62da6c42..4d9d0ba93 100644 --- a/agents-api/tests/test_docs_routes.py +++ b/agents-api/tests/test_docs_routes.py @@ -2,7 +2,7 @@ from ward import skip, test -from tests.fixtures import ( +from .fixtures import ( make_request, patch_embed_acompletion, test_agent, @@ -10,7 +10,6 @@ test_user, test_user_doc, ) - from .utils import patch_testing_temporal @@ -173,10 +172,9 @@ def _(make_request=make_request, agent=test_agent): assert isinstance(docs, list) -@skip("Fails due to FTS not working in Test Container") @test("route: search agent docs") async def _(make_request=make_request, agent=test_agent, doc=test_doc): - await asyncio.sleep(0.5) + await asyncio.sleep(1) search_params = { "text": doc.content[0], "limit": 1, @@ -196,10 +194,9 @@ async def _(make_request=make_request, agent=test_agent, doc=test_doc): assert len(docs) >= 1 -@skip("Fails due to FTS not working in Test Container") @test("route: search user docs") async def _(make_request=make_request, user=test_user, doc=test_user_doc): - await asyncio.sleep(0.5) + await asyncio.sleep(1) search_params = { "text": doc.content[0], "limit": 1, @@ -220,10 +217,10 @@ async def _(make_request=make_request, user=test_user, doc=test_user_doc): assert len(docs) >= 1 -@skip("Fails due to Vectorizer and FTS not working in Test Container") +@skip("embedding search: test container not vectorizing") @test("route: search agent docs hybrid with mmr") async def _(make_request=make_request, agent=test_agent, doc=test_doc): - await asyncio.sleep(0.5) + await asyncio.sleep(1) EMBEDDING_SIZE = 1024 search_params = { diff --git a/agents-api/tests/utils.py b/agents-api/tests/utils.py index 2049b4689..884d27632 100644 --- a/agents-api/tests/utils.py +++ b/agents-api/tests/utils.py @@ -1,5 +1,6 @@ import asyncio import logging +import os import subprocess from contextlib import asynccontextmanager, contextmanager from unittest.mock import patch @@ -9,6 +10,8 @@ from fastapi.testclient import TestClient from litellm.types.utils import ModelResponse from temporalio.testing import WorkflowEnvironment +from testcontainers.core.container import DockerContainer +from testcontainers.core.waiting_utils import wait_for_logs from testcontainers.localstack import LocalStackContainer from testcontainers.postgres import PostgresContainer @@ -108,7 +111,7 @@ def patch_integration_service(output: dict = {"result": "ok"}): @contextmanager -def get_pg_dsn(): +def get_pg_dsn(start_vectorizer: bool = False): with PostgresContainer("timescale/timescaledb-ha:pg17") as postgres: test_psql_url = postgres.get_connection_url() pg_dsn = f"postgres://{test_psql_url[22:]}?sslmode=disable" @@ -116,6 +119,33 @@ def get_pg_dsn(): process = subprocess.Popen(command, shell=True) process.wait() + if not start_vectorizer: + yield pg_dsn + return + + # ELSE: + with ( + DockerContainer("timescale/pgai-vectorizer-worker:v0.3.0") + .with_network(postgres._network) # noqa: SLF001 + .with_env( + "PGAI_VECTORIZER_WORKER_DB_URL", + pg_dsn.replace("localhost", postgres.get_container_host_ip()), + ) + .with_env( + "VOYAGE_API_KEY", + os.environ.get("VOYAGE_API_KEY"), + ) + ) as vectorizer: + wait_for_logs( + vectorizer, + "finished processing vectorizer", + predicate_streams_and=True, + raise_on_exit=True, + timeout=10, + ) + + print("Vectorizer worker started") + yield pg_dsn diff --git a/memory-store/migrations/000018_doc_search.up.sql b/memory-store/migrations/000018_doc_search.up.sql index 8fde5e9bb..e120b78c6 100644 --- a/memory-store/migrations/000018_doc_search.up.sql +++ b/memory-store/migrations/000018_doc_search.up.sql @@ -271,11 +271,15 @@ BEGIN d.title, d.content, ts_rank_cd(d.search_tsv, $1, 32)::double precision as distance, - d.embedding, + e.embedding, d.metadata, doc_owners.owner_type, doc_owners.owner_id - FROM docs_embeddings d + FROM docs d + LEFT JOIN docs_embeddings e + ON e.developer_id = d.developer_id + AND e.doc_id = d.doc_id + AND e.index = d.index LEFT JOIN doc_owners ON d.doc_id = doc_owners.doc_id WHERE d.developer_id = $6 AND d.search_tsv @@ $1