diff --git a/agents-api/agents_api/queries/docs/create_doc.py b/agents-api/agents_api/queries/docs/create_doc.py index d8bcce7d3..d3c2fe3c1 100644 --- a/agents-api/agents_api/queries/docs/create_doc.py +++ b/agents-api/agents_api/queries/docs/create_doc.py @@ -153,7 +153,7 @@ async def create_doc( if isinstance(data.content, list): final_params_doc = [] final_params_owner = [] - + for idx, content in enumerate(data.content): doc_params = [ developer_id, @@ -185,7 +185,6 @@ async def create_doc( queries.append((doc_owner_query, final_params_owner, "fetchmany")) else: - # Create the doc record doc_params = [ developer_id, diff --git a/agents-api/agents_api/queries/docs/get_doc.py b/agents-api/agents_api/queries/docs/get_doc.py index 3f071cf87..1cee8f354 100644 --- a/agents-api/agents_api/queries/docs/get_doc.py +++ b/agents-api/agents_api/queries/docs/get_doc.py @@ -51,7 +51,9 @@ "id": d["doc_id"], "index": d["indices"][0], "content": d["content"][0] if len(d["content"]) == 1 else d["content"], - "embeddings": d["embeddings"][0] if len(d["embeddings"]) == 1 else d["embeddings"], + "embeddings": d["embeddings"][0] + if len(d["embeddings"]) == 1 + else d["embeddings"], **d, }, ) @@ -64,7 +66,7 @@ async def get_doc( ) -> tuple[str, list]: """ Fetch a single doc with its embedding, grouping all content chunks and embeddings. - + Parameters: developer_id (UUID): The ID of the developer. doc_id (UUID): The ID of the document. diff --git a/agents-api/agents_api/queries/docs/list_docs.py b/agents-api/agents_api/queries/docs/list_docs.py index 2b31df250..9788b0daa 100644 --- a/agents-api/agents_api/queries/docs/list_docs.py +++ b/agents-api/agents_api/queries/docs/list_docs.py @@ -75,7 +75,9 @@ "id": d["doc_id"], "index": d["indices"][0], "content": d["content"][0] if len(d["content"]) == 1 else d["content"], - "embedding": d["embeddings"][0] if d.get("embeddings") and len(d["embeddings"]) == 1 else d.get("embeddings"), + "embedding": d["embeddings"][0] + if d.get("embeddings") and len(d["embeddings"]) == 1 + else d.get("embeddings"), **d, }, ) diff --git a/agents-api/agents_api/queries/docs/search_docs_by_text.py b/agents-api/agents_api/queries/docs/search_docs_by_text.py index 79f9ac305..9c22a60ce 100644 --- a/agents-api/agents_api/queries/docs/search_docs_by_text.py +++ b/agents-api/agents_api/queries/docs/search_docs_by_text.py @@ -1,17 +1,16 @@ -from typing import Any, Literal, List +import json +from typing import Any, List, Literal from uuid import UUID +import asyncpg from beartype import beartype from fastapi import HTTPException from sqlglot import parse_one -import asyncpg -import json from ...autogen.openapi_model import DocReference -from ..utils import pg_query, wrap_in_class, rewrap_exceptions, partialclass +from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class -search_docs_text_query = ( - """ +search_docs_text_query = """ SELECT * FROM search_by_text( $1, -- developer_id $2, -- query @@ -19,7 +18,6 @@ ( SELECT array_agg(*)::UUID[] FROM jsonb_array_elements($4) ) ) """ -) @rewrap_exceptions( @@ -74,10 +72,10 @@ async def search_docs_by_text( # Extract owner types and IDs owner_types = [owner[0] for owner in owners] owner_ids = [owner[1] for owner in owners] - + return ( search_docs_text_query, - [ + [ developer_id, query, owner_types, diff --git a/agents-api/tests/fixtures.py b/agents-api/tests/fixtures.py index a34c7e2aa..2ad6bfeeb 100644 --- a/agents-api/tests/fixtures.py +++ b/agents-api/tests/fixtures.py @@ -63,6 +63,7 @@ def test_developer_id(): developer_id = uuid7() return developer_id + @fixture(scope="global") async def test_developer(dsn=pg_dsn, developer_id=test_developer_id): pool = await create_db_pool(dsn=dsn) diff --git a/agents-api/tests/test_docs_queries.py b/agents-api/tests/test_docs_queries.py index 71553ee83..82490cb77 100644 --- a/agents-api/tests/test_docs_queries.py +++ b/agents-api/tests/test_docs_queries.py @@ -9,6 +9,7 @@ # If you wish to test text/embedding/hybrid search, import them: from agents_api.queries.docs.search_docs_by_text import search_docs_by_text + # from agents_api.queries.docs.search_docs_by_embedding import search_docs_by_embedding # from agents_api.queries.docs.search_docs_hybrid import search_docs_hybrid # You can rename or remove these imports to match your actual fixtures @@ -81,6 +82,7 @@ async def _(dsn=pg_dsn, developer=test_developer, doc=test_doc): assert doc_test.title == doc.title assert doc_test.content == doc.content + @test("query: list user docs") async def _(dsn=pg_dsn, developer=test_developer, user=test_user): pool = await create_db_pool(dsn=dsn) @@ -212,17 +214,18 @@ async def _(dsn=pg_dsn, developer=test_developer, agent=test_agent): ) assert not any(d.id == doc_agent.id for d in docs_list) + @test("query: search docs by text") async def _(dsn=pg_dsn, agent=test_agent, developer=test_developer): pool = await create_db_pool(dsn=dsn) - + # Create a test document await create_doc( developer_id=developer.id, owner_type="agent", owner_id=agent.id, data=CreateDocRequest( - title="Hello", + title="Hello", content="The world is a funny little thing", metadata={"test": "test"}, embed_instruction="Embed the document", @@ -242,4 +245,4 @@ async def _(dsn=pg_dsn, agent=test_agent, developer=test_developer): ) assert len(result) >= 1 - assert result[0].metadata is not None \ No newline at end of file + assert result[0].metadata is not None