diff --git a/agents-api/agents_api/queries/docs/__init__.py b/agents-api/agents_api/queries/docs/__init__.py index f7c207bf2..75f9516a6 100644 --- a/agents-api/agents_api/queries/docs/__init__.py +++ b/agents-api/agents_api/queries/docs/__init__.py @@ -20,6 +20,7 @@ from .delete_doc import delete_doc from .get_doc import get_doc from .list_docs import list_docs + # from .search_docs_by_embedding import search_docs_by_embedding # from .search_docs_by_text import search_docs_by_text diff --git a/agents-api/agents_api/queries/docs/create_doc.py b/agents-api/agents_api/queries/docs/create_doc.py index 4528e9fc5..bf789fad2 100644 --- a/agents-api/agents_api/queries/docs/create_doc.py +++ b/agents-api/agents_api/queries/docs/create_doc.py @@ -1,3 +1,4 @@ +import ast from typing import Literal from uuid import UUID @@ -7,9 +8,6 @@ from sqlglot import parse_one from uuid_extensions import uuid7 -import ast - - from ...autogen.openapi_model import CreateDocRequest, Doc from ...metrics.counters import increase_counter from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class @@ -86,7 +84,9 @@ transform=lambda d: { **d, "id": d["doc_id"], - "content": ast.literal_eval(d["content"])[0] if len(ast.literal_eval(d["content"])) == 1 else ast.literal_eval(d["content"]), + "content": ast.literal_eval(d["content"])[0] + if len(ast.literal_eval(d["content"])) == 1 + else ast.literal_eval(d["content"]), }, ) @increase_counter("create_doc") diff --git a/agents-api/agents_api/queries/docs/get_doc.py b/agents-api/agents_api/queries/docs/get_doc.py index 9155f500a..b46563dbb 100644 --- a/agents-api/agents_api/queries/docs/get_doc.py +++ b/agents-api/agents_api/queries/docs/get_doc.py @@ -1,9 +1,9 @@ +import ast from typing import Literal from uuid import UUID from beartype import beartype from sqlglot import parse_one -import ast from ...autogen.openapi_model import Doc from ..utils import pg_query, wrap_in_class @@ -28,7 +28,9 @@ transform=lambda d: { **d, "id": d["doc_id"], - "content": ast.literal_eval(d["content"])[0] if len(ast.literal_eval(d["content"])) == 1 else ast.literal_eval(d["content"]), + "content": ast.literal_eval(d["content"])[0] + if len(ast.literal_eval(d["content"])) == 1 + else ast.literal_eval(d["content"]), # "embeddings": d["embeddings"], }, ) diff --git a/agents-api/agents_api/queries/docs/list_docs.py b/agents-api/agents_api/queries/docs/list_docs.py index a4df08e73..92cbacf7f 100644 --- a/agents-api/agents_api/queries/docs/list_docs.py +++ b/agents-api/agents_api/queries/docs/list_docs.py @@ -1,10 +1,10 @@ +import ast from typing import Any, Literal from uuid import UUID from beartype import beartype from fastapi import HTTPException from sqlglot import parse_one -import ast from ...autogen.openapi_model import Doc from ..utils import pg_query, wrap_in_class @@ -24,7 +24,9 @@ transform=lambda d: { **d, "id": d["doc_id"], - "content": ast.literal_eval(d["content"])[0] if len(ast.literal_eval(d["content"])) == 1 else ast.literal_eval(d["content"]), + "content": ast.literal_eval(d["content"])[0] + if len(ast.literal_eval(d["content"])) == 1 + else ast.literal_eval(d["content"]), # "embeddings": d["embeddings"], }, ) diff --git a/agents-api/agents_api/queries/files/list_files.py b/agents-api/agents_api/queries/files/list_files.py index 7c8b67887..2f36def4f 100644 --- a/agents-api/agents_api/queries/files/list_files.py +++ b/agents-api/agents_api/queries/files/list_files.py @@ -9,6 +9,7 @@ from beartype import beartype from fastapi import HTTPException from sqlglot import parse_one + from ...autogen.openapi_model import File from ..utils import pg_query, wrap_in_class diff --git a/agents-api/tests/fixtures.py b/agents-api/tests/fixtures.py index 6689137d7..2f7de580e 100644 --- a/agents-api/tests/fixtures.py +++ b/agents-api/tests/fixtures.py @@ -8,10 +8,10 @@ from agents_api.autogen.openapi_model import ( CreateAgentRequest, + CreateDocRequest, CreateFileRequest, CreateSessionRequest, CreateUserRequest, - CreateDocRequest, ) from agents_api.clients.pg import create_db_pool from agents_api.env import api_key, api_key_header_name, multi_tenant_mode @@ -20,7 +20,6 @@ # from agents_api.queries.agents.delete_agent import delete_agent from agents_api.queries.developers.get_developer import get_developer - from agents_api.queries.docs.create_doc import create_doc # from agents_api.queries.docs.delete_doc import delete_doc diff --git a/agents-api/tests/test_docs_queries.py b/agents-api/tests/test_docs_queries.py index d6af42e57..1410c88c9 100644 --- a/agents-api/tests/test_docs_queries.py +++ b/agents-api/tests/test_docs_queries.py @@ -11,9 +11,8 @@ # from agents_api.queries.docs.search_docs_by_text import search_docs_by_text # from agents_api.queries.docs.search_docs_by_embedding import search_docs_by_embedding # from agents_api.queries.docs.search_docs_hybrid import search_docs_hybrid - # You can rename or remove these imports to match your actual fixtures -from tests.fixtures import pg_dsn, test_agent, test_developer, test_user, test_doc +from tests.fixtures import pg_dsn, test_agent, test_developer, test_doc, test_user @test("query: create doc") @@ -29,7 +28,7 @@ async def _(dsn=pg_dsn, developer=test_developer): ), connection_pool=pool, ) - + assert doc.title == "Hello Doc" assert doc.content == "This is sample doc content" assert doc.modality == "text" @@ -38,6 +37,7 @@ async def _(dsn=pg_dsn, developer=test_developer): assert doc.language == "english" assert doc.index == 0 + @test("query: create user doc") async def _(dsn=pg_dsn, developer=test_developer, user=test_user): pool = await create_db_pool(dsn=dsn) @@ -64,6 +64,7 @@ async def _(dsn=pg_dsn, developer=test_developer, user=test_user): ) assert any(d.id == doc.id for d in docs_list) + @test("query: create agent doc") async def _(dsn=pg_dsn, developer=test_developer, agent=test_agent): pool = await create_db_pool(dsn=dsn) @@ -90,6 +91,7 @@ async def _(dsn=pg_dsn, developer=test_developer, agent=test_agent): ) assert any(d.id == doc.id for d in docs_list) + @test("model: get doc") async def _(dsn=pg_dsn, developer=test_developer, doc=test_doc): pool = await create_db_pool(dsn=dsn) @@ -101,6 +103,7 @@ async def _(dsn=pg_dsn, developer=test_developer, doc=test_doc): assert doc_test.id == doc.id assert doc_test.title == doc.title + @test("query: list docs") async def _(dsn=pg_dsn, developer=test_developer, doc=test_doc): pool = await create_db_pool(dsn=dsn) @@ -111,6 +114,7 @@ async def _(dsn=pg_dsn, developer=test_developer, doc=test_doc): assert len(docs_list) >= 1 assert any(d.id == doc.id for d in docs_list) + @test("query: list user docs") async def _(dsn=pg_dsn, developer=test_developer, user=test_user): pool = await create_db_pool(dsn=dsn) @@ -139,6 +143,7 @@ async def _(dsn=pg_dsn, developer=test_developer, user=test_user): assert len(docs_list) >= 1 assert any(d.id == doc_user.id for d in docs_list) + @test("query: list agent docs") async def _(dsn=pg_dsn, developer=test_developer, agent=test_agent): pool = await create_db_pool(dsn=dsn) @@ -167,6 +172,7 @@ async def _(dsn=pg_dsn, developer=test_developer, agent=test_agent): assert len(docs_list) >= 1 assert any(d.id == doc_agent.id for d in docs_list) + @test("query: delete user doc") async def _(dsn=pg_dsn, developer=test_developer, user=test_user): pool = await create_db_pool(dsn=dsn) @@ -203,6 +209,7 @@ async def _(dsn=pg_dsn, developer=test_developer, user=test_user): ) assert not any(d.id == doc_user.id for d in docs_list) + @test("query: delete agent doc") async def _(dsn=pg_dsn, developer=test_developer, agent=test_agent): pool = await create_db_pool(dsn=dsn) @@ -239,6 +246,7 @@ async def _(dsn=pg_dsn, developer=test_developer, agent=test_agent): ) assert not any(d.id == doc_agent.id for d in docs_list) + @test("query: delete doc") async def _(dsn=pg_dsn, developer=test_developer, doc=test_doc): pool = await create_db_pool(dsn=dsn) diff --git a/agents-api/tests/test_entry_queries.py b/agents-api/tests/test_entry_queries.py index 2a9746ef1..ae825ed92 100644 --- a/agents-api/tests/test_entry_queries.py +++ b/agents-api/tests/test_entry_queries.py @@ -3,7 +3,6 @@ It verifies the functionality of adding, retrieving, and processing entries as defined in the schema. """ - from fastapi import HTTPException from uuid_extensions import uuid7 from ward import raises, test