diff --git a/agents-api/agents_api/autogen/Docs.py b/agents-api/agents_api/autogen/Docs.py index ffed27c1d..af5f60d6a 100644 --- a/agents-api/agents_api/autogen/Docs.py +++ b/agents-api/agents_api/autogen/Docs.py @@ -73,6 +73,30 @@ class Doc(BaseModel): """ Embeddings for the document """ + modality: Annotated[str | None, Field(json_schema_extra={"readOnly": True})] = None + """ + Modality of the document + """ + language: Annotated[str | None, Field(json_schema_extra={"readOnly": True})] = None + """ + Language of the document + """ + index: Annotated[int | None, Field(json_schema_extra={"readOnly": True})] = None + """ + Index of the document + """ + embedding_model: Annotated[ + str | None, Field(json_schema_extra={"readOnly": True}) + ] = None + """ + Embedding model to use for the document + """ + embedding_dimensions: Annotated[ + int | None, Field(json_schema_extra={"readOnly": True}) + ] = None + """ + Dimensions of the embedding model + """ class DocOwner(BaseModel): diff --git a/agents-api/agents_api/queries/agents/create_agent.py b/agents-api/agents_api/queries/agents/create_agent.py index 76c96f46b..5294cfa6d 100644 --- a/agents-api/agents_api/queries/agents/create_agent.py +++ b/agents-api/agents_api/queries/agents/create_agent.py @@ -5,7 +5,9 @@ from uuid import UUID +import asyncpg from beartype import beartype +from fastapi import HTTPException from sqlglot import parse_one from uuid_extensions import uuid7 @@ -13,7 +15,9 @@ from ...metrics.counters import increase_counter from ..utils import ( generate_canonical_name, + partialclass, pg_query, + rewrap_exceptions, wrap_in_class, ) @@ -45,35 +49,30 @@ """).sql(pretty=True) -# @rewrap_exceptions( -# { -# psycopg_errors.ForeignKeyViolation: partialclass( -# HTTPException, -# status_code=404, -# detail="The specified developer does not exist.", -# ), -# psycopg_errors.UniqueViolation: partialclass( -# HTTPException, -# status_code=409, -# detail="An agent with this canonical name already exists for this developer.", -# ), -# psycopg_errors.CheckViolation: partialclass( -# HTTPException, -# status_code=400, -# detail="The provided data violates one or more constraints. Please check the input values.", -# ), -# ValidationError: partialclass( -# HTTPException, -# status_code=400, -# detail="Input validation failed. Please check the provided data.", -# ), -# TypeError: partialclass( -# HTTPException, -# status_code=400, -# detail="A type mismatch occurred. Please review the input.", -# ), -# } -# ) +@rewrap_exceptions( + { + asyncpg.exceptions.ForeignKeyViolationError: partialclass( + HTTPException, + status_code=404, + detail="The specified developer does not exist.", + ), + asyncpg.exceptions.UniqueViolationError: partialclass( + HTTPException, + status_code=409, + detail="An agent with this canonical name already exists for this developer.", + ), + asyncpg.exceptions.CheckViolationError: partialclass( + HTTPException, + status_code=400, + detail="The provided data violates one or more constraints. Please check the input values.", + ), + asyncpg.exceptions.DataError: partialclass( + HTTPException, + status_code=400, + detail="Invalid data provided. Please check the input values.", + ), + } +) @wrap_in_class( Agent, one=True, diff --git a/agents-api/agents_api/queries/agents/create_or_update_agent.py b/agents-api/agents_api/queries/agents/create_or_update_agent.py index ef3a0abe5..fcef53fd6 100644 --- a/agents-api/agents_api/queries/agents/create_or_update_agent.py +++ b/agents-api/agents_api/queries/agents/create_or_update_agent.py @@ -5,14 +5,18 @@ from uuid import UUID +import asyncpg from beartype import beartype +from fastapi import HTTPException from sqlglot import parse_one from ...autogen.openapi_model import Agent, CreateOrUpdateAgentRequest from ...metrics.counters import increase_counter from ..utils import ( generate_canonical_name, + partialclass, pg_query, + rewrap_exceptions, wrap_in_class, ) @@ -44,15 +48,30 @@ """).sql(pretty=True) -# @rewrap_exceptions( -# { -# psycopg_errors.ForeignKeyViolation: partialclass( -# HTTPException, -# status_code=404, -# detail="The specified developer does not exist.", -# ) -# } -# ) +@rewrap_exceptions( + { + asyncpg.exceptions.ForeignKeyViolationError: partialclass( + HTTPException, + status_code=404, + detail="The specified developer does not exist.", + ), + asyncpg.exceptions.UniqueViolationError: partialclass( + HTTPException, + status_code=409, + detail="An agent with this canonical name already exists for this developer.", + ), + asyncpg.exceptions.CheckViolationError: partialclass( + HTTPException, + status_code=400, + detail="The provided data violates one or more constraints. Please check the input values.", + ), + asyncpg.exceptions.DataError: partialclass( + HTTPException, + status_code=400, + detail="Invalid data provided. Please check the input values.", + ), + } +) @wrap_in_class( Agent, one=True, diff --git a/agents-api/agents_api/queries/agents/delete_agent.py b/agents-api/agents_api/queries/agents/delete_agent.py index c0ca3919f..2fd1f1406 100644 --- a/agents-api/agents_api/queries/agents/delete_agent.py +++ b/agents-api/agents_api/queries/agents/delete_agent.py @@ -5,13 +5,17 @@ from uuid import UUID +import asyncpg from beartype import beartype +from fastapi import HTTPException from sqlglot import parse_one from ...autogen.openapi_model import ResourceDeletedResponse from ...common.utils.datetime import utcnow from ..utils import ( + partialclass, pg_query, + rewrap_exceptions, wrap_in_class, ) @@ -59,17 +63,30 @@ """).sql(pretty=True) -# @rewrap_exceptions( -# @rewrap_exceptions( -# { -# psycopg_errors.ForeignKeyViolation: partialclass( -# HTTPException, -# status_code=404, -# detail="The specified developer does not exist.", -# ) -# } -# # TODO: Add more exceptions -# ) +@rewrap_exceptions( + { + asyncpg.exceptions.ForeignKeyViolationError: partialclass( + HTTPException, + status_code=404, + detail="The specified developer does not exist.", + ), + asyncpg.exceptions.UniqueViolationError: partialclass( + HTTPException, + status_code=409, + detail="An agent with this canonical name already exists for this developer.", + ), + asyncpg.exceptions.CheckViolationError: partialclass( + HTTPException, + status_code=400, + detail="The provided data violates one or more constraints. Please check the input values.", + ), + asyncpg.exceptions.DataError: partialclass( + HTTPException, + status_code=400, + detail="Invalid data provided. Please check the input values.", + ), + } +) @wrap_in_class( ResourceDeletedResponse, one=True, diff --git a/agents-api/agents_api/queries/agents/get_agent.py b/agents-api/agents_api/queries/agents/get_agent.py index a731300fa..79fa1c4fc 100644 --- a/agents-api/agents_api/queries/agents/get_agent.py +++ b/agents-api/agents_api/queries/agents/get_agent.py @@ -5,12 +5,16 @@ from uuid import UUID +import asyncpg from beartype import beartype +from fastapi import HTTPException from sqlglot import parse_one from ...autogen.openapi_model import Agent from ..utils import ( + partialclass, pg_query, + rewrap_exceptions, wrap_in_class, ) @@ -35,16 +39,20 @@ """).sql(pretty=True) -# @rewrap_exceptions( -# { -# psycopg_errors.ForeignKeyViolation: partialclass( -# HTTPException, -# status_code=404, -# detail="The specified developer does not exist.", -# ) -# } -# # TODO: Add more exceptions -# ) +@rewrap_exceptions( + { + asyncpg.exceptions.ForeignKeyViolationError: partialclass( + HTTPException, + status_code=404, + detail="The specified developer does not exist.", + ), + asyncpg.exceptions.DataError: partialclass( + HTTPException, + status_code=400, + detail="Invalid data provided. Please check the input values.", + ), + } +) @wrap_in_class(Agent, one=True, transform=lambda d: {"id": d["agent_id"], **d}) @pg_query @beartype diff --git a/agents-api/agents_api/queries/agents/list_agents.py b/agents-api/agents_api/queries/agents/list_agents.py index 87a0c942d..11b9dc283 100644 --- a/agents-api/agents_api/queries/agents/list_agents.py +++ b/agents-api/agents_api/queries/agents/list_agents.py @@ -6,12 +6,15 @@ from typing import Any, Literal from uuid import UUID +import asyncpg from beartype import beartype from fastapi import HTTPException from ...autogen.openapi_model import Agent from ..utils import ( + partialclass, pg_query, + rewrap_exceptions, wrap_in_class, ) @@ -40,16 +43,20 @@ """ -# @rewrap_exceptions( -# { -# psycopg_errors.ForeignKeyViolation: partialclass( -# HTTPException, -# status_code=404, -# detail="The specified developer does not exist.", -# ) -# } -# # TODO: Add more exceptions -# ) +@rewrap_exceptions( + { + asyncpg.exceptions.ForeignKeyViolationError: partialclass( + HTTPException, + status_code=404, + detail="The specified developer does not exist.", + ), + asyncpg.exceptions.DataError: partialclass( + HTTPException, + status_code=400, + detail="Invalid data provided. Please check the input values.", + ), + } +) @wrap_in_class(Agent, transform=lambda d: {"id": d["agent_id"], **d}) @pg_query @beartype diff --git a/agents-api/agents_api/queries/agents/patch_agent.py b/agents-api/agents_api/queries/agents/patch_agent.py index 69a5a6ca5..06f0b9253 100644 --- a/agents-api/agents_api/queries/agents/patch_agent.py +++ b/agents-api/agents_api/queries/agents/patch_agent.py @@ -5,13 +5,17 @@ from uuid import UUID +import asyncpg from beartype import beartype +from fastapi import HTTPException from sqlglot import parse_one from ...autogen.openapi_model import PatchAgentRequest, ResourceUpdatedResponse from ...metrics.counters import increase_counter from ..utils import ( + partialclass, pg_query, + rewrap_exceptions, wrap_in_class, ) @@ -44,16 +48,30 @@ """).sql(pretty=True) -# @rewrap_exceptions( -# { -# psycopg_errors.ForeignKeyViolation: partialclass( -# HTTPException, -# status_code=404, -# detail="The specified developer does not exist.", -# ) -# } -# # TODO: Add more exceptions -# ) +@rewrap_exceptions( + { + asyncpg.exceptions.ForeignKeyViolationError: partialclass( + HTTPException, + status_code=404, + detail="The specified developer does not exist.", + ), + asyncpg.exceptions.UniqueViolationError: partialclass( + HTTPException, + status_code=409, + detail="An agent with this canonical name already exists for this developer.", + ), + asyncpg.exceptions.CheckViolationError: partialclass( + HTTPException, + status_code=400, + detail="The provided data violates one or more constraints. Please check the input values.", + ), + asyncpg.exceptions.DataError: partialclass( + HTTPException, + status_code=400, + detail="Invalid data provided. Please check the input values.", + ), + } +) @wrap_in_class( ResourceUpdatedResponse, one=True, diff --git a/agents-api/agents_api/queries/agents/update_agent.py b/agents-api/agents_api/queries/agents/update_agent.py index f28e28264..4d19229d8 100644 --- a/agents-api/agents_api/queries/agents/update_agent.py +++ b/agents-api/agents_api/queries/agents/update_agent.py @@ -5,13 +5,17 @@ from uuid import UUID +import asyncpg from beartype import beartype +from fastapi import HTTPException from sqlglot import parse_one from ...autogen.openapi_model import ResourceUpdatedResponse, UpdateAgentRequest from ...metrics.counters import increase_counter from ..utils import ( + partialclass, pg_query, + rewrap_exceptions, wrap_in_class, ) @@ -29,16 +33,30 @@ """).sql(pretty=True) -# @rewrap_exceptions( -# { -# psycopg_errors.ForeignKeyViolation: partialclass( -# HTTPException, -# status_code=404, -# detail="The specified developer does not exist.", -# ) -# } -# # TODO: Add more exceptions -# ) +@rewrap_exceptions( + { + asyncpg.exceptions.ForeignKeyViolationError: partialclass( + HTTPException, + status_code=404, + detail="The specified developer does not exist.", + ), + asyncpg.exceptions.UniqueViolationError: partialclass( + HTTPException, + status_code=409, + detail="An agent with this canonical name already exists for this developer.", + ), + asyncpg.exceptions.CheckViolationError: partialclass( + HTTPException, + status_code=400, + detail="The provided data violates one or more constraints. Please check the input values.", + ), + asyncpg.exceptions.DataError: partialclass( + HTTPException, + status_code=400, + detail="Invalid data provided. Please check the input values.", + ), + } +) @wrap_in_class( ResourceUpdatedResponse, one=True, diff --git a/agents-api/agents_api/queries/developers/create_developer.py b/agents-api/agents_api/queries/developers/create_developer.py index bed6371c4..51011a63b 100644 --- a/agents-api/agents_api/queries/developers/create_developer.py +++ b/agents-api/agents_api/queries/developers/create_developer.py @@ -38,8 +38,8 @@ { asyncpg.UniqueViolationError: partialclass( HTTPException, - status_code=404, - detail="The specified developer does not exist.", + status_code=409, + detail="A developer with this email already exists.", ) } ) diff --git a/agents-api/agents_api/queries/developers/get_developer.py b/agents-api/agents_api/queries/developers/get_developer.py index 373a2fb36..79b6e6067 100644 --- a/agents-api/agents_api/queries/developers/get_developer.py +++ b/agents-api/agents_api/queries/developers/get_developer.py @@ -24,9 +24,6 @@ SELECT * FROM developers WHERE developer_id = $1 -- developer_id """).sql(pretty=True) -ModelT = TypeVar("ModelT", bound=Any) -T = TypeVar("T") - @rewrap_exceptions( { @@ -37,7 +34,11 @@ ) } ) -@wrap_in_class(Developer, one=True, transform=lambda d: {**d, "id": d["developer_id"]}) +@wrap_in_class( + Developer, + one=True, + transform=lambda d: {**d, "id": d["developer_id"]}, +) @pg_query @beartype async def get_developer( diff --git a/agents-api/agents_api/queries/developers/patch_developer.py b/agents-api/agents_api/queries/developers/patch_developer.py index af2ddb1f8..e14c8bbd0 100644 --- a/agents-api/agents_api/queries/developers/patch_developer.py +++ b/agents-api/agents_api/queries/developers/patch_developer.py @@ -26,8 +26,8 @@ { asyncpg.UniqueViolationError: partialclass( HTTPException, - status_code=404, - detail="The specified developer does not exist.", + status_code=409, + detail="A developer with this email already exists.", ) } ) diff --git a/agents-api/agents_api/queries/developers/update_developer.py b/agents-api/agents_api/queries/developers/update_developer.py index d41b333d5..8f3e7cd87 100644 --- a/agents-api/agents_api/queries/developers/update_developer.py +++ b/agents-api/agents_api/queries/developers/update_developer.py @@ -28,7 +28,12 @@ HTTPException, status_code=404, detail="The specified developer does not exist.", - ) + ), + asyncpg.UniqueViolationError: partialclass( + HTTPException, + status_code=409, + detail="A developer with this email already exists.", + ), } ) @wrap_in_class(Developer, one=True, transform=lambda d: {**d, "id": d["developer_id"]}) diff --git a/agents-api/agents_api/queries/docs/__init__.py b/agents-api/agents_api/queries/docs/__init__.py new file mode 100644 index 000000000..51bab2555 --- /dev/null +++ b/agents-api/agents_api/queries/docs/__init__.py @@ -0,0 +1,35 @@ +""" +Module: agents_api/models/docs + +This module is responsible for managing document-related operations within the application, particularly for agents and possibly other entities. It serves as a core component of the document management system, enabling features such as document creation, listing, deletion, and embedding of snippets for enhanced search and retrieval capabilities. + +Main functionalities include: +- Creating new documents and associating them with agents or users. +- Listing documents based on various criteria, including ownership and metadata filters. +- Deleting documents by their unique identifiers. +- Embedding document snippets for retrieval purposes. +- Searching documents by text. + +The module interacts with other parts of the application, such as the agents and users modules, to provide a comprehensive document management system. Its role is crucial in enabling document search, retrieval, and management features within the context of agents and users. + +This documentation aims to provide clear, concise, and sufficient context for new developers or contributors to understand the module's role without needing to dive deep into the code immediately. +""" + +# ruff: noqa: F401, F403, F405 + +from .create_doc import create_doc +from .delete_doc import delete_doc +from .get_doc import get_doc +from .list_docs import list_docs + +# from .search_docs_by_embedding import search_docs_by_embedding +from .search_docs_by_text import search_docs_by_text + +__all__ = [ + "create_doc", + "delete_doc", + "get_doc", + "list_docs", + # "search_docs_by_embedding", + "search_docs_by_text", +] diff --git a/agents-api/agents_api/queries/docs/create_doc.py b/agents-api/agents_api/queries/docs/create_doc.py new file mode 100644 index 000000000..d3c2fe3c1 --- /dev/null +++ b/agents-api/agents_api/queries/docs/create_doc.py @@ -0,0 +1,216 @@ +import ast +from typing import Literal +from uuid import UUID + +import asyncpg +from beartype import beartype +from fastapi import HTTPException +from sqlglot import parse_one +from uuid_extensions import uuid7 + +from ...autogen.openapi_model import CreateDocRequest, Doc +from ...metrics.counters import increase_counter +from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class + +# Base INSERT for docs +doc_query = parse_one(""" +INSERT INTO docs ( + developer_id, + doc_id, + title, + content, + index, + modality, + embedding_model, + embedding_dimensions, + language, + metadata +) +VALUES ( + $1, -- developer_id + $2, -- doc_id + $3, -- title + $4, -- content + $5, -- index + $6, -- modality + $7, -- embedding_model + $8, -- embedding_dimensions + $9, -- language + $10 -- metadata (JSONB) +) +RETURNING *; +""").sql(pretty=True) + +# Owner association query for doc_owners +doc_owner_query = parse_one(""" +WITH inserted_owner AS ( + INSERT INTO doc_owners ( + developer_id, + doc_id, + index, + owner_type, + owner_id + ) + VALUES ($1, $2, $3, $4, $5) + RETURNING doc_id +) +SELECT DISTINCT ON (docs.doc_id) + docs.doc_id, + docs.developer_id, + docs.title, + array_agg(docs.content ORDER BY docs.index) as content, + array_agg(docs.index ORDER BY docs.index) as indices, + docs.modality, + docs.embedding_model, + docs.embedding_dimensions, + docs.language, + docs.metadata, + docs.created_at + +FROM inserted_owner io +JOIN docs ON docs.doc_id = io.doc_id +GROUP BY + docs.doc_id, + docs.developer_id, + docs.title, + docs.modality, + docs.embedding_model, + docs.embedding_dimensions, + docs.language, + docs.metadata, + docs.created_at; +""").sql(pretty=True) + + +@rewrap_exceptions( + { + asyncpg.UniqueViolationError: partialclass( + HTTPException, + status_code=409, + detail="A document with this ID already exists for this developer", + ), + asyncpg.NoDataFoundError: partialclass( + HTTPException, + status_code=404, + detail="The specified owner does not exist", + ), + asyncpg.ForeignKeyViolationError: partialclass( + HTTPException, + status_code=404, + detail="Developer or doc owner not found", + ), + } +) +@wrap_in_class( + Doc, + one=True, + transform=lambda d: { + "id": d["doc_id"], + "index": d["indices"][0], + "content": d["content"][0] if len(d["content"]) == 1 else d["content"], + **d, + }, +) +@increase_counter("create_doc") +@pg_query +@beartype +async def create_doc( + *, + developer_id: UUID, + doc_id: UUID | None = None, + data: CreateDocRequest, + owner_type: Literal["user", "agent"], + owner_id: UUID, + modality: Literal["text", "image", "mixed"] | None = "text", + embedding_model: str | None = "voyage-3", + embedding_dimensions: int | None = 1024, + language: str | None = "english", + index: int | None = 0, +) -> list[tuple[str, list, Literal["fetch", "fetchmany", "fetchrow"]]]: + """ + Insert a new doc record into Timescale and associate it with an owner. + + Parameters: + developer_id (UUID): The ID of the developer. + doc_id (UUID | None): Optional custom UUID for the document. If not provided, one will be generated. + data (CreateDocRequest): The data for the document. + owner_type (Literal["user", "agent"]): The type of the owner (required). + owner_id (UUID): The ID of the owner (required). + modality (Literal["text", "image", "mixed"]): The modality of the documents. + embedding_model (str): The model used for embedding. + embedding_dimensions (int): The dimensions of the embedding. + language (str): The language of the documents. + index (int): The index of the documents. + + Returns: + list[tuple[str, list] | tuple[str, list, str]]: SQL query and parameters for creating the document. + """ + queries = [] + # Generate a UUID if not provided + current_doc_id = uuid7() if doc_id is None else doc_id + + # Check if content is a list + if isinstance(data.content, list): + final_params_doc = [] + final_params_owner = [] + + for idx, content in enumerate(data.content): + doc_params = [ + developer_id, + current_doc_id, + data.title, + content, + idx, + modality, + embedding_model, + embedding_dimensions, + language, + data.metadata or {}, + ] + final_params_doc.append(doc_params) + + owner_params = [ + developer_id, + current_doc_id, + idx, + owner_type, + owner_id, + ] + final_params_owner.append(owner_params) + + # Add the doc query for each content + queries.append((doc_query, final_params_doc, "fetchmany")) + + # Add the owner query + queries.append((doc_owner_query, final_params_owner, "fetchmany")) + + else: + # Create the doc record + doc_params = [ + developer_id, + current_doc_id, + data.title, + data.content, + index, + modality, + embedding_model, + embedding_dimensions, + language, + data.metadata or {}, + ] + + owner_params = [ + developer_id, + current_doc_id, + index, + owner_type, + owner_id, + ] + + # Add the doc query for single content + queries.append((doc_query, doc_params, "fetch")) + + # Add the owner query + queries.append((doc_owner_query, owner_params, "fetch")) + + return queries diff --git a/agents-api/agents_api/queries/docs/delete_doc.py b/agents-api/agents_api/queries/docs/delete_doc.py new file mode 100644 index 000000000..b0a9ea1a1 --- /dev/null +++ b/agents-api/agents_api/queries/docs/delete_doc.py @@ -0,0 +1,79 @@ +from typing import Literal +from uuid import UUID + +import asyncpg +from beartype import beartype +from fastapi import HTTPException +from sqlglot import parse_one + +from ...autogen.openapi_model import ResourceDeletedResponse +from ...common.utils.datetime import utcnow +from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class + +# Delete doc query + ownership check +delete_doc_query = parse_one(""" +WITH deleted_owners AS ( + DELETE FROM doc_owners + WHERE developer_id = $1 + AND doc_id = $2 + AND owner_type = $3 + AND owner_id = $4 +) +DELETE FROM docs +WHERE developer_id = $1 + AND doc_id = $2 + AND EXISTS ( + SELECT 1 FROM doc_owners + WHERE developer_id = $1 + AND doc_id = $2 + AND owner_type = $3 + AND owner_id = $4 + ) +RETURNING doc_id; +""").sql(pretty=True) + + +@rewrap_exceptions( + { + asyncpg.NoDataFoundError: partialclass( + HTTPException, + status_code=404, + detail="Doc not found", + ) + } +) +@wrap_in_class( + ResourceDeletedResponse, + one=True, + transform=lambda d: { + "id": d["doc_id"], + "deleted_at": utcnow(), + "jobs": [], + }, +) +@pg_query +@beartype +async def delete_doc( + *, + developer_id: UUID, + doc_id: UUID, + owner_type: Literal["user", "agent"], + owner_id: UUID, +) -> tuple[str, list]: + """ + Deletes a doc (and associated doc_owners) for the given developer and doc_id. + If owner_type/owner_id is specified, only remove doc if that matches. + + Parameters: + developer_id (UUID): The ID of the developer. + doc_id (UUID): The ID of the document. + owner_type (Literal["user", "agent"]): The type of the owner of the documents. + owner_id (UUID): The ID of the owner of the documents. + + Returns: + tuple[str, list]: SQL query and parameters for deleting the document. + """ + return ( + delete_doc_query, + [developer_id, doc_id, owner_type, owner_id], + ) diff --git a/agents-api/agents_api/queries/docs/get_doc.py b/agents-api/agents_api/queries/docs/get_doc.py new file mode 100644 index 000000000..1cee8f354 --- /dev/null +++ b/agents-api/agents_api/queries/docs/get_doc.py @@ -0,0 +1,80 @@ +import ast +from typing import Literal +from uuid import UUID + +from beartype import beartype +from sqlglot import parse_one + +from ...autogen.openapi_model import Doc +from ..utils import pg_query, wrap_in_class + +# Update the query to use DISTINCT ON to prevent duplicates +doc_with_embedding_query = parse_one(""" +WITH doc_data AS ( + SELECT DISTINCT ON (d.doc_id) + d.doc_id, + d.developer_id, + d.title, + array_agg(d.content ORDER BY d.index) as content, + array_agg(d.index ORDER BY d.index) as indices, + array_agg(e.embedding ORDER BY d.index) as embeddings, + d.modality, + d.embedding_model, + d.embedding_dimensions, + d.language, + d.metadata, + d.created_at + FROM docs d + LEFT JOIN docs_embeddings e + ON d.doc_id = e.doc_id + WHERE d.developer_id = $1 + AND d.doc_id = $2 + GROUP BY + d.doc_id, + d.developer_id, + d.title, + d.modality, + d.embedding_model, + d.embedding_dimensions, + d.language, + d.metadata, + d.created_at +) +SELECT * FROM doc_data; +""").sql(pretty=True) + + +@wrap_in_class( + Doc, + one=True, # Changed to True since we're now returning one grouped record + transform=lambda d: { + "id": d["doc_id"], + "index": d["indices"][0], + "content": d["content"][0] if len(d["content"]) == 1 else d["content"], + "embeddings": d["embeddings"][0] + if len(d["embeddings"]) == 1 + else d["embeddings"], + **d, + }, +) +@pg_query +@beartype +async def get_doc( + *, + developer_id: UUID, + doc_id: UUID, +) -> tuple[str, list]: + """ + Fetch a single doc with its embedding, grouping all content chunks and embeddings. + + Parameters: + developer_id (UUID): The ID of the developer. + doc_id (UUID): The ID of the document. + + Returns: + tuple[str, list]: SQL query and parameters for fetching the document. + """ + return ( + doc_with_embedding_query, + [developer_id, doc_id], + ) diff --git a/agents-api/agents_api/queries/docs/list_docs.py b/agents-api/agents_api/queries/docs/list_docs.py new file mode 100644 index 000000000..9788b0daa --- /dev/null +++ b/agents-api/agents_api/queries/docs/list_docs.py @@ -0,0 +1,144 @@ +""" +This module contains the functionality for listing documents from the PostgreSQL database. +It constructs and executes SQL queries to fetch document details based on various filters. +""" + +from typing import Any, Literal +from uuid import UUID + +import asyncpg +from beartype import beartype +from fastapi import HTTPException +from sqlglot import parse_one + +from ...autogen.openapi_model import Doc +from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class + +# Base query for listing docs with aggregated content and embeddings +base_docs_query = parse_one(""" +WITH doc_data AS ( + SELECT DISTINCT ON (d.doc_id) + d.doc_id, + d.developer_id, + d.title, + array_agg(d.content ORDER BY d.index) as content, + array_agg(d.index ORDER BY d.index) as indices, + array_agg(CASE WHEN $2 THEN NULL ELSE e.embedding END ORDER BY d.index) as embeddings, + d.modality, + d.embedding_model, + d.embedding_dimensions, + d.language, + d.metadata, + d.created_at + FROM docs d + JOIN doc_owners doc_own + ON d.developer_id = doc_own.developer_id + AND d.doc_id = doc_own.doc_id + LEFT JOIN docs_embeddings e + ON d.doc_id = e.doc_id + WHERE d.developer_id = $1 + AND doc_own.owner_type = $3 + AND doc_own.owner_id = $4 + GROUP BY + d.doc_id, + d.developer_id, + d.title, + d.modality, + d.embedding_model, + d.embedding_dimensions, + d.language, + d.metadata, + d.created_at +) +SELECT * FROM doc_data +""").sql(pretty=True) + + +@rewrap_exceptions( + { + asyncpg.NoDataFoundError: partialclass( + HTTPException, + status_code=404, + detail="No documents found", + ), + asyncpg.ForeignKeyViolationError: partialclass( + HTTPException, + status_code=404, + detail="The specified developer or owner does not exist", + ), + } +) +@wrap_in_class( + Doc, + one=False, + transform=lambda d: { + "id": d["doc_id"], + "index": d["indices"][0], + "content": d["content"][0] if len(d["content"]) == 1 else d["content"], + "embedding": d["embeddings"][0] + if d.get("embeddings") and len(d["embeddings"]) == 1 + else d.get("embeddings"), + **d, + }, +) +@pg_query +@beartype +async def list_docs( + *, + developer_id: UUID, + owner_id: UUID, + owner_type: Literal["user", "agent"], + limit: int = 100, + offset: int = 0, + sort_by: Literal["created_at", "updated_at"] = "created_at", + direction: Literal["asc", "desc"] = "desc", + metadata_filter: dict[str, Any] = {}, + include_without_embeddings: bool = False, +) -> tuple[str, list]: + """ + Lists docs with pagination and sorting, aggregating content chunks and embeddings. + + Parameters: + developer_id (UUID): The ID of the developer. + owner_id (UUID): The ID of the owner of the documents (required). + owner_type (Literal["user", "agent"]): The type of the owner of the documents (required). + limit (int): The number of documents to return. + offset (int): The number of documents to skip. + sort_by (Literal["created_at", "updated_at"]): The field to sort by. + direction (Literal["asc", "desc"]): The direction to sort by. + metadata_filter (dict[str, Any]): The metadata filter to apply. + include_without_embeddings (bool): Whether to include documents without embeddings. + + Returns: + tuple[str, list]: SQL query and parameters for listing the documents. + + Raises: + HTTPException: If invalid parameters are provided. + """ + if direction.lower() not in ["asc", "desc"]: + raise HTTPException(status_code=400, detail="Invalid sort direction") + + if sort_by not in ["created_at", "updated_at"]: + raise HTTPException(status_code=400, detail="Invalid sort field") + + if limit > 100 or limit < 1: + raise HTTPException(status_code=400, detail="Limit must be between 1 and 100") + + if offset < 0: + raise HTTPException(status_code=400, detail="Offset must be >= 0") + + # Start with the base query + query = base_docs_query + params = [developer_id, include_without_embeddings, owner_type, owner_id] + + # Add metadata filtering + if metadata_filter: + for key, value in metadata_filter.items(): + query += f" AND metadata->>'{key}' = ${len(params) + 1}" + params.append(value) + + # Add sorting and pagination + query += f" ORDER BY {sort_by} {direction} LIMIT ${len(params) + 1} OFFSET ${len(params) + 2}" + params.extend([limit, offset]) + + return query, params diff --git a/agents-api/agents_api/queries/docs/mmr.py b/agents-api/agents_api/queries/docs/mmr.py new file mode 100644 index 000000000..d214e8c04 --- /dev/null +++ b/agents-api/agents_api/queries/docs/mmr.py @@ -0,0 +1,109 @@ +from __future__ import annotations + +import logging +from typing import Union + +import numpy as np + +Matrix = Union[list[list[float]], list[np.ndarray], np.ndarray] + +logger = logging.getLogger(__name__) + + +def _cosine_similarity(x: Matrix, y: Matrix) -> np.ndarray: + """Row-wise cosine similarity between two equal-width matrices. + + Args: + x: A matrix of shape (n, m). + y: A matrix of shape (k, m). + + Returns: + A matrix of shape (n, k) where each element (i, j) is the cosine similarity + between the ith row of X and the jth row of Y. + + Raises: + ValueError: If the number of columns in X and Y are not the same. + ImportError: If numpy is not installed. + """ + + if len(x) == 0 or len(y) == 0: + return np.array([]) + + x = [xx for xx in x if xx is not None] + y = [yy for yy in y if yy is not None] + + x = np.array(x) + y = np.array(y) + if x.shape[1] != y.shape[1]: + msg = ( + f"Number of columns in X and Y must be the same. X has shape {x.shape} " + f"and Y has shape {y.shape}." + ) + raise ValueError(msg) + try: + import simsimd as simd # type: ignore + + x = np.array(x, dtype=np.float32) + y = np.array(y, dtype=np.float32) + z = 1 - np.array(simd.cdist(x, y, metric="cosine")) + return z + except ImportError: + logger.debug( + "Unable to import simsimd, defaulting to NumPy implementation. If you want " + "to use simsimd please install with `pip install simsimd`." + ) + x_norm = np.linalg.norm(x, axis=1) + y_norm = np.linalg.norm(y, axis=1) + # Ignore divide by zero errors run time warnings as those are handled below. + with np.errstate(divide="ignore", invalid="ignore"): + similarity = np.dot(x, y.T) / np.outer(x_norm, y_norm) + similarity[np.isnan(similarity) | np.isinf(similarity)] = 0.0 + return similarity + + +def maximal_marginal_relevance( + query_embedding: np.ndarray, + embedding_list: list, + lambda_mult: float = 0.5, + k: int = 4, +) -> list[int]: + """Calculate maximal marginal relevance. + + Args: + query_embedding: The query embedding. + embedding_list: A list of embeddings. + lambda_mult: The lambda parameter for MMR. Default is 0.5. + k: The number of embeddings to return. Default is 4. + + Returns: + A list of indices of the embeddings to return. + + Raises: + ImportError: If numpy is not installed. + """ + + if min(k, len(embedding_list)) <= 0: + return [] + if query_embedding.ndim == 1: + query_embedding = np.expand_dims(query_embedding, axis=0) + similarity_to_query = _cosine_similarity(query_embedding, embedding_list)[0] + most_similar = int(np.argmax(similarity_to_query)) + idxs = [most_similar] + selected = np.array([embedding_list[most_similar]]) + while len(idxs) < min(k, len(embedding_list)): + best_score = -np.inf + idx_to_add = -1 + similarity_to_selected = _cosine_similarity(embedding_list, selected) + for i, query_score in enumerate(similarity_to_query): + if i in idxs: + continue + redundant_score = max(similarity_to_selected[i]) + equation_score = ( + lambda_mult * query_score - (1 - lambda_mult) * redundant_score + ) + if equation_score > best_score: + best_score = equation_score + idx_to_add = i + idxs.append(idx_to_add) + selected = np.append(selected, [embedding_list[idx_to_add]], axis=0) + return idxs diff --git a/agents-api/agents_api/queries/docs/search_docs_by_embedding.py b/agents-api/agents_api/queries/docs/search_docs_by_embedding.py new file mode 100644 index 000000000..5a89803ee --- /dev/null +++ b/agents-api/agents_api/queries/docs/search_docs_by_embedding.py @@ -0,0 +1,75 @@ +from typing import List, Literal +from uuid import UUID + +from beartype import beartype +from fastapi import HTTPException +from sqlglot import parse_one + +from ...autogen.openapi_model import DocReference +from ..utils import pg_query, wrap_in_class + +# If you're doing approximate ANN (DiskANN) or IVF, you might use a special function or hint. +# For a basic vector distance search, you can do something like: +search_docs_by_embedding_query = parse_one(""" +SELECT d.*, + (d.embedding <-> $3) AS distance +FROM docs d +LEFT JOIN doc_owners do + ON d.developer_id = do.developer_id + AND d.doc_id = do.doc_id +WHERE d.developer_id = $1 + AND ( + ($4::text IS NULL AND $5::uuid IS NULL) + OR (do.owner_type = $4 AND do.owner_id = $5) + ) + AND d.embedding IS NOT NULL +ORDER BY d.embedding <-> $3 +LIMIT $2; +""").sql(pretty=True) + + +@wrap_in_class( + DocReference, + transform=lambda d: { + "owner": { + "id": d["owner_id"], + "role": d["owner_type"], + }, + "metadata": d.get("metadata", {}), + **d, + }, +) +@pg_query +@beartype +async def search_docs_by_embedding( + *, + developer_id: UUID, + query_embedding: List[float], + k: int = 10, + owner_type: Literal["user", "agent", "org"] | None = None, + owner_id: UUID | None = None, +) -> tuple[str, list]: + """ + Vector-based doc search: + + Parameters: + developer_id (UUID): The ID of the developer. + query_embedding (List[float]): The vector to query. + k (int): The number of results to return. + owner_type (Literal["user", "agent", "org"]): The type of the owner of the documents. + owner_id (UUID): The ID of the owner of the documents. + + Returns: + tuple[str, list]: SQL query and parameters for searching the documents. + """ + if k < 1: + raise HTTPException(status_code=400, detail="k must be >= 1") + + # Validate embedding length if needed; e.g. 1024 floats + if not query_embedding: + raise HTTPException(status_code=400, detail="Empty embedding provided") + + return ( + search_docs_by_embedding_query, + [developer_id, k, query_embedding, owner_type, owner_id], + ) diff --git a/agents-api/agents_api/queries/docs/search_docs_by_text.py b/agents-api/agents_api/queries/docs/search_docs_by_text.py new file mode 100644 index 000000000..9c22a60ce --- /dev/null +++ b/agents-api/agents_api/queries/docs/search_docs_by_text.py @@ -0,0 +1,87 @@ +import json +from typing import Any, List, Literal +from uuid import UUID + +import asyncpg +from beartype import beartype +from fastapi import HTTPException +from sqlglot import parse_one + +from ...autogen.openapi_model import DocReference +from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class + +search_docs_text_query = """ + SELECT * FROM search_by_text( + $1, -- developer_id + $2, -- query + $3, -- owner_types + ( SELECT array_agg(*)::UUID[] FROM jsonb_array_elements($4) ) + ) + """ + + +@rewrap_exceptions( + { + asyncpg.UniqueViolationError: partialclass( + HTTPException, + status_code=404, + detail="The specified developer does not exist.", + ) + } +) +@wrap_in_class( + DocReference, + transform=lambda d: { + "owner": { + "id": d["owner_id"], + "role": d["owner_type"], + }, + "metadata": d.get("metadata", {}), + **d, + }, +) +@pg_query(debug=True) +@beartype +async def search_docs_by_text( + *, + developer_id: UUID, + owners: list[tuple[Literal["user", "agent"], UUID]], + query: str, + k: int = 3, + metadata_filter: dict[str, Any] = {}, + search_language: str | None = "english", +) -> tuple[str, list]: + """ + Full-text search on docs using the search_tsv column. + + Parameters: + developer_id (UUID): The ID of the developer. + query (str): The text to search for. + owners (list[tuple[Literal["user", "agent"], UUID]]): List of (owner_type, owner_id) tuples. + k (int): Maximum number of results to return. + search_language (str): Language for text search (default: "english"). + metadata_filter (dict): Metadata filter criteria. + connection_pool (asyncpg.Pool): Database connection pool. + + Returns: + tuple[str, list]: SQL query and parameters for searching the documents. + """ + if k < 1: + raise HTTPException(status_code=400, detail="k must be >= 1") + + # Extract owner types and IDs + owner_types = [owner[0] for owner in owners] + owner_ids = [owner[1] for owner in owners] + + return ( + search_docs_text_query, + [ + developer_id, + query, + owner_types, + owner_ids, + search_language, + k, + metadata_filter, + ], + ) diff --git a/agents-api/agents_api/queries/docs/search_docs_hybrid.py b/agents-api/agents_api/queries/docs/search_docs_hybrid.py new file mode 100644 index 000000000..184ba7e8e --- /dev/null +++ b/agents-api/agents_api/queries/docs/search_docs_hybrid.py @@ -0,0 +1,158 @@ +from typing import List, Literal +from uuid import UUID + +from beartype import beartype + +from ...autogen.openapi_model import Doc +from .search_docs_by_embedding import search_docs_by_embedding +from .search_docs_by_text import search_docs_by_text + + +def dbsf_normalize(scores: List[float]) -> List[float]: + """ + Example distribution-based normalization: clamp each score + from (mean - 3*stddev) to (mean + 3*stddev) and scale to 0..1 + """ + import statistics + + if len(scores) < 2: + return scores + m = statistics.mean(scores) + sd = statistics.pstdev(scores) # population std + if sd == 0: + return scores + upper = m + 3 * sd + lower = m - 3 * sd + + def clamp_scale(v): + c = min(upper, max(lower, v)) + return (c - lower) / (upper - lower) + + return [clamp_scale(s) for s in scores] + + +@beartype +def fuse_results( + text_docs: List[Doc], embedding_docs: List[Doc], alpha: float +) -> List[Doc]: + """ + Merges text search results (descending by text rank) with + embedding results (descending by closeness or inverse distance). + alpha ~ how much to weigh the embedding score + """ + # Suppose we stored each doc's "distance" from the embedding query, and + # for text search we store a rank or negative distance. We'll unify them: + # Make up a dictionary of doc_id -> text_score, doc_id -> embed_score + # For example, text_score = -distance if you want bigger = better + text_scores = {} + embed_scores = {} + for doc in text_docs: + # If you had "rank", you might store doc.distance = rank + # For demo, let's assume doc.distance is negative... up to you + text_scores[doc.id] = float(-doc.distance if doc.distance else 0) + + for doc in embedding_docs: + # Lower distance => better, so we do embed_score = -distance + embed_scores[doc.id] = float(-doc.distance if doc.distance else 0) + + # Normalize them + text_vals = list(text_scores.values()) + embed_vals = list(embed_scores.values()) + text_vals_norm = dbsf_normalize(text_vals) + embed_vals_norm = dbsf_normalize(embed_vals) + + # Map them back + t_keys = list(text_scores.keys()) + for i, key in enumerate(t_keys): + text_scores[key] = text_vals_norm[i] + e_keys = list(embed_scores.keys()) + for i, key in enumerate(e_keys): + embed_scores[key] = embed_vals_norm[i] + + # Gather all doc IDs + all_ids = set(text_scores.keys()) | set(embed_scores.keys()) + + # Weighted sum => combined + out = [] + for doc_id in all_ids: + # text and embed might be missing doc_id => 0 + t_score = text_scores.get(doc_id, 0) + e_score = embed_scores.get(doc_id, 0) + combined = alpha * e_score + (1 - alpha) * t_score + # We'll store final "distance" as -(combined) so bigger combined => smaller distance + out.append((doc_id, combined)) + + # Sort descending by combined + out.sort(key=lambda x: x[1], reverse=True) + + # Convert to doc objects. We can pick from text_docs or embedding_docs or whichever is found. + # If present in both, we can merge fields. For simplicity, just pick from text_docs then fallback embedding_docs. + + # Create a quick ID->doc map + text_map = {d.id: d for d in text_docs} + embed_map = {d.id: d for d in embedding_docs} + + final_docs = [] + for doc_id, score in out: + doc = text_map.get(doc_id) or embed_map.get(doc_id) + doc = doc.model_copy() # or copy if you are using Pydantic + doc.distance = float(-score) # so a higher combined => smaller distance + final_docs.append(doc) + return final_docs + + +@beartype +async def search_docs_hybrid( + developer_id: UUID, + text_query: str = "", + embedding: List[float] = None, + k: int = 10, + alpha: float = 0.5, + owner_type: Literal["user", "agent", "org"] | None = None, + owner_id: UUID | None = None, +) -> List[Doc]: + """ + Hybrid text-and-embedding doc search. We get top-K from each approach, + then fuse them client-side. Adjust concurrency or approach as you like. + """ + # We'll dispatch two queries in parallel + # (One full-text, one embedding-based) each limited to K + tasks = [] + if text_query.strip(): + tasks.append( + search_docs_by_text( + developer_id=developer_id, + query=text_query, + k=k, + owner_type=owner_type, + owner_id=owner_id, + ) + ) + else: + tasks.append([]) # no text results if query is empty + + if embedding and any(embedding): + tasks.append( + search_docs_by_embedding( + developer_id=developer_id, + query_embedding=embedding, + k=k, + owner_type=owner_type, + owner_id=owner_id, + ) + ) + else: + tasks.append([]) + + # Run concurrently (or sequentially, if you prefer) + # If you have a 'run_concurrently' from your old code, you can do: + # text_results, embed_results = await run_concurrently([task1, task2]) + # Otherwise just do them in parallel with e.g. asyncio.gather: + from asyncio import gather + + text_results, embed_results = await gather(*tasks) + + # fuse them + fused = fuse_results(text_results, embed_results, alpha) + # Then pick top K overall + return fused[:k] diff --git a/agents-api/agents_api/queries/entries/create_entries.py b/agents-api/agents_api/queries/entries/create_entries.py index 95973ad0b..d8439fa21 100644 --- a/agents-api/agents_api/queries/entries/create_entries.py +++ b/agents-api/agents_api/queries/entries/create_entries.py @@ -94,6 +94,17 @@ async def create_entries( session_id: UUID, data: list[CreateEntryRequest], ) -> list[tuple[str, list, Literal["fetch", "fetchmany", "fetchrow"]]]: + """ + Create entries in a session. + + Parameters: + developer_id (UUID): The ID of the developer. + session_id (UUID): The ID of the session. + data (list[CreateEntryRequest]): The list of entries to create. + + Returns: + list[tuple[str, list, Literal["fetch", "fetchmany", "fetchrow"]]]: SQL query and parameters for creating the entries. + """ # Convert the data to a list of dictionaries data_dicts = [item.model_dump(mode="json") for item in data] @@ -163,6 +174,17 @@ async def add_entry_relations( session_id: UUID, data: list[Relation], ) -> list[tuple[str, list, Literal["fetch", "fetchmany", "fetchrow"]]]: + """ + Add relations between entries in a session. + + Parameters: + developer_id (UUID): The ID of the developer. + session_id (UUID): The ID of the session. + data (list[Relation]): The list of relations to add. + + Returns: + list[tuple[str, list, Literal["fetch", "fetchmany", "fetchrow"]]]: SQL query and parameters for adding the relations. + """ # Convert the data to a list of dictionaries data_dicts = [item.model_dump(mode="json") for item in data] diff --git a/agents-api/agents_api/queries/entries/delete_entries.py b/agents-api/agents_api/queries/entries/delete_entries.py index 47b7379a4..14a9648e5 100644 --- a/agents-api/agents_api/queries/entries/delete_entries.py +++ b/agents-api/agents_api/queries/entries/delete_entries.py @@ -134,7 +134,16 @@ async def delete_entries_for_session( async def delete_entries( *, developer_id: UUID, session_id: UUID, entry_ids: list[UUID] ) -> list[tuple[str, list, Literal["fetch", "fetchmany", "fetchrow"]]]: - """Delete specific entries by their IDs.""" + """Delete specific entries by their IDs. + + Parameters: + developer_id (UUID): The ID of the developer. + session_id (UUID): The ID of the session. + entry_ids (list[UUID]): The IDs of the entries to delete. + + Returns: + list[tuple[str, list, Literal["fetch", "fetchmany", "fetchrow"]]]: SQL query and parameters for deleting the entries. + """ return [ ( session_exists_query, diff --git a/agents-api/agents_api/queries/entries/get_history.py b/agents-api/agents_api/queries/entries/get_history.py index e6967a6cc..6a734d4c5 100644 --- a/agents-api/agents_api/queries/entries/get_history.py +++ b/agents-api/agents_api/queries/entries/get_history.py @@ -1,5 +1,4 @@ import json -from typing import Any, List, Tuple from uuid import UUID import asyncpg @@ -96,6 +95,16 @@ async def get_history( session_id: UUID, allowed_sources: list[str] = ["api_request", "api_response"], ) -> tuple[str, list] | tuple[str, list, str]: + """Get the history of a session. + + Parameters: + developer_id (UUID): The ID of the developer. + session_id (UUID): The ID of the session. + allowed_sources (list[str]): The sources to include in the history. + + Returns: + tuple[str, list] | tuple[str, list, str]: SQL query and parameters for getting the history. + """ return ( history_query, [session_id, allowed_sources, developer_id], diff --git a/agents-api/agents_api/queries/entries/list_entries.py b/agents-api/agents_api/queries/entries/list_entries.py index 89f432734..0153fe778 100644 --- a/agents-api/agents_api/queries/entries/list_entries.py +++ b/agents-api/agents_api/queries/entries/list_entries.py @@ -88,6 +88,21 @@ async def list_entries( direction: Literal["asc", "desc"] = "asc", exclude_relations: list[str] = [], ) -> list[tuple[str, list] | tuple[str, list, str]]: + """List entries in a session. + + Parameters: + developer_id (UUID): The ID of the developer. + session_id (UUID): The ID of the session. + allowed_sources (list[str]): The sources to include in the history. + limit (int): The number of entries to return. + offset (int): The number of entries to skip. + sort_by (Literal["created_at", "timestamp"]): The field to sort by. + direction (Literal["asc", "desc"]): The direction to sort by. + exclude_relations (list[str]): The relations to exclude. + + Returns: + tuple[str, list] | tuple[str, list, str]: SQL query and parameters for listing the entries. + """ if limit < 1 or limit > 1000: raise HTTPException(status_code=400, detail="Limit must be between 1 and 1000") if offset < 0: diff --git a/agents-api/agents_api/queries/files/create_file.py b/agents-api/agents_api/queries/files/create_file.py index 48251fa5e..f2e35a6f4 100644 --- a/agents-api/agents_api/queries/files/create_file.py +++ b/agents-api/agents_api/queries/files/create_file.py @@ -60,25 +60,25 @@ # Add error handling decorator -# @rewrap_exceptions( -# { -# asyncpg.UniqueViolationError: partialclass( -# HTTPException, -# status_code=409, -# detail="A file with this name already exists for this developer", -# ), -# asyncpg.NoDataFoundError: partialclass( -# HTTPException, -# status_code=404, -# detail="The specified owner does not exist", -# ), -# asyncpg.ForeignKeyViolationError: partialclass( -# HTTPException, -# status_code=404, -# detail="The specified developer does not exist", -# ), -# } -# ) +@rewrap_exceptions( + { + asyncpg.UniqueViolationError: partialclass( + HTTPException, + status_code=409, + detail="A file with this name already exists for this developer", + ), + asyncpg.ForeignKeyViolationError: partialclass( + HTTPException, + status_code=404, + detail="The specified developer or owner does not exist", + ), + asyncpg.CheckViolationError: partialclass( + HTTPException, + status_code=400, + detail="File size must be positive and name must be between 1 and 255 characters", + ), + } +) @wrap_in_class( File, one=True, diff --git a/agents-api/agents_api/queries/files/delete_file.py b/agents-api/agents_api/queries/files/delete_file.py index 31cb43404..4cf0142ae 100644 --- a/agents-api/agents_api/queries/files/delete_file.py +++ b/agents-api/agents_api/queries/files/delete_file.py @@ -48,6 +48,11 @@ status_code=404, detail="File not found", ), + asyncpg.ForeignKeyViolationError: partialclass( + HTTPException, + status_code=404, + detail="The specified developer or owner does not exist", + ), } ) @wrap_in_class( diff --git a/agents-api/agents_api/queries/files/get_file.py b/agents-api/agents_api/queries/files/get_file.py index 4d5dca4c0..04ba8ea71 100644 --- a/agents-api/agents_api/queries/files/get_file.py +++ b/agents-api/agents_api/queries/files/get_file.py @@ -12,7 +12,12 @@ from sqlglot import parse_one from ...autogen.openapi_model import File -from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class +from ..utils import ( + partialclass, + pg_query, + rewrap_exceptions, + wrap_in_class, +) # Define the raw SQL query file_query = parse_one(""" @@ -29,26 +34,26 @@ """).sql(pretty=True) -# @rewrap_exceptions( -# { -# asyncpg.NoDataFoundError: partialclass( -# HTTPException, -# status_code=404, -# detail="File not found", -# ), -# asyncpg.ForeignKeyViolationError: partialclass( -# HTTPException, -# status_code=404, -# detail="Developer not found", -# ), -# } -# ) +@rewrap_exceptions( + { + asyncpg.NoDataFoundError: partialclass( + HTTPException, + status_code=404, + detail="File not found", + ), + asyncpg.ForeignKeyViolationError: partialclass( + HTTPException, + status_code=404, + detail="The specified developer or owner does not exist", + ), + } +) @wrap_in_class( File, one=True, transform=lambda d: { - "id": d["file_id"], **d, + "id": d["file_id"], "hash": d["hash"].hex(), "content": "DUMMY: NEED TO FETCH CONTENT FROM BLOB STORAGE", }, diff --git a/agents-api/agents_api/queries/files/list_files.py b/agents-api/agents_api/queries/files/list_files.py index 2bc42f842..d3866dacc 100644 --- a/agents-api/agents_api/queries/files/list_files.py +++ b/agents-api/agents_api/queries/files/list_files.py @@ -3,7 +3,7 @@ It constructs and executes SQL queries to fetch a list of files based on developer ID with pagination. """ -from typing import Any, Literal +from typing import Literal from uuid import UUID import asyncpg @@ -14,43 +14,24 @@ from ...autogen.openapi_model import File from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class -# Query to list all files for a developer (uses developer_id index) -developer_files_query = parse_one(""" +# Base query for listing files +base_files_query = parse_one(""" SELECT f.* FROM files f LEFT JOIN file_owners fo ON f.developer_id = fo.developer_id AND f.file_id = fo.file_id WHERE f.developer_id = $1 -ORDER BY - CASE - WHEN $4 = 'created_at' AND $5 = 'asc' THEN created_at - WHEN $4 = 'created_at' AND $5 = 'desc' THEN created_at - WHEN $4 = 'updated_at' AND $5 = 'asc' THEN updated_at - WHEN $4 = 'updated_at' AND $5 = 'desc' THEN updated_at - END DESC NULLS LAST -LIMIT $2 -OFFSET $3; -""").sql(pretty=True) - -# Query to list files for a specific owner (uses composite indexes) -owner_files_query = parse_one(""" -SELECT f.* -FROM files f -JOIN file_owners fo ON f.developer_id = fo.developer_id AND f.file_id = fo.file_id -WHERE fo.developer_id = $1 -AND fo.owner_id = $6 -AND fo.owner_type = $7 -ORDER BY - CASE - WHEN $4 = 'created_at' AND $5 = 'asc' THEN created_at - WHEN $4 = 'created_at' AND $5 = 'desc' THEN created_at - WHEN $4 = 'updated_at' AND $5 = 'asc' THEN updated_at - WHEN $4 = 'updated_at' AND $5 = 'desc' THEN updated_at - END DESC NULLS LAST -LIMIT $2 -OFFSET $3; """).sql(pretty=True) +@rewrap_exceptions( + { + asyncpg.ForeignKeyViolationError: partialclass( + HTTPException, + status_code=404, + detail="The specified developer or owner does not exist", + ), + } +) @wrap_in_class( File, one=False, @@ -74,49 +55,32 @@ async def list_files( direction: Literal["asc", "desc"] = "desc", ) -> tuple[str, list]: """ - Lists files with optimized queries for two cases: - 1. Owner specified: Returns files associated with that owner - 2. No owner: Returns all files for the developer - - Args: - developer_id: UUID of the developer - owner_id: Optional UUID of the owner (user or agent) - owner_type: Optional type of owner ("user" or "agent") - limit: Maximum number of records to return (1-100) - offset: Number of records to skip - sort_by: Field to sort by - direction: Sort direction ('asc' or 'desc') - - Returns: - Tuple of (query, params) - - Raises: - HTTPException: If parameters are invalid + Lists files with optional owner filtering, pagination, and sorting. """ # Validate parameters if direction.lower() not in ["asc", "desc"]: raise HTTPException(status_code=400, detail="Invalid sort direction") + if sort_by not in ["created_at", "updated_at"]: + raise HTTPException(status_code=400, detail="Invalid sort field") + if limit > 100 or limit < 1: raise HTTPException(status_code=400, detail="Limit must be between 1 and 100") if offset < 0: raise HTTPException(status_code=400, detail="Offset must be non-negative") - # Base parameters used in all queries - params = [ - developer_id, - limit, - offset, - sort_by, - direction, - ] + # Start with the base query + query = base_files_query + params = [developer_id] + + # Add owner filtering + if owner_type and owner_id: + query += " AND fo.owner_type = $2 AND fo.owner_id = $3" + params.extend([owner_type, owner_id]) - # Choose appropriate query based on owner details - if owner_id and owner_type: - params.extend([owner_id, owner_type]) # Add owner_id as $6 and owner_type as $7 - query = owner_files_query # Use single query with owner_type parameter - else: - query = developer_files_query + # Add sorting and pagination + query += f" ORDER BY {sort_by} {direction} LIMIT ${len(params) + 1} OFFSET ${len(params) + 2}" + params.extend([limit, offset]) - return (query, params) + return query, params diff --git a/agents-api/agents_api/queries/sessions/create_or_update_session.py b/agents-api/agents_api/queries/sessions/create_or_update_session.py index 3c4dbf66e..b6c280b01 100644 --- a/agents-api/agents_api/queries/sessions/create_or_update_session.py +++ b/agents-api/agents_api/queries/sessions/create_or_update_session.py @@ -70,13 +70,18 @@ asyncpg.ForeignKeyViolationError: partialclass( HTTPException, status_code=404, - detail="The specified developer or participant does not exist.", + detail="The specified developer or session does not exist.", ), asyncpg.UniqueViolationError: partialclass( HTTPException, status_code=409, detail="A session with this ID already exists.", ), + asyncpg.CheckViolationError: partialclass( + HTTPException, + status_code=400, + detail="Invalid session data provided.", + ), } ) @wrap_in_class( diff --git a/agents-api/agents_api/queries/sessions/create_session.py b/agents-api/agents_api/queries/sessions/create_session.py index 63fbdc940..0bb967ce5 100644 --- a/agents-api/agents_api/queries/sessions/create_session.py +++ b/agents-api/agents_api/queries/sessions/create_session.py @@ -8,10 +8,8 @@ from ...autogen.openapi_model import ( CreateSessionRequest, - ResourceCreatedResponse, Session, ) -from ...common.utils.datetime import utcnow from ...metrics.counters import increase_counter from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class @@ -60,13 +58,18 @@ asyncpg.ForeignKeyViolationError: partialclass( HTTPException, status_code=404, - detail="The specified developer or participant does not exist.", + detail="The specified developer or session does not exist.", ), asyncpg.UniqueViolationError: partialclass( HTTPException, status_code=409, detail="A session with this ID already exists.", ), + asyncpg.CheckViolationError: partialclass( + HTTPException, + status_code=400, + detail="Invalid session data provided.", + ), } ) @wrap_in_class( diff --git a/agents-api/agents_api/queries/sessions/delete_session.py b/agents-api/agents_api/queries/sessions/delete_session.py index 2e3234fe2..ff5317f58 100644 --- a/agents-api/agents_api/queries/sessions/delete_session.py +++ b/agents-api/agents_api/queries/sessions/delete_session.py @@ -30,7 +30,7 @@ asyncpg.ForeignKeyViolationError: partialclass( HTTPException, status_code=404, - detail="The specified developer does not exist.", + detail="The specified developer or session does not exist.", ), } ) diff --git a/agents-api/agents_api/queries/sessions/get_session.py b/agents-api/agents_api/queries/sessions/get_session.py index 1f704539e..cc12d0f88 100644 --- a/agents-api/agents_api/queries/sessions/get_session.py +++ b/agents-api/agents_api/queries/sessions/get_session.py @@ -51,7 +51,7 @@ asyncpg.ForeignKeyViolationError: partialclass( HTTPException, status_code=404, - detail="The specified developer does not exist.", + detail="The specified developer or session does not exist.", ), asyncpg.NoDataFoundError: partialclass( HTTPException, status_code=404, detail="Session not found" diff --git a/agents-api/agents_api/queries/sessions/list_sessions.py b/agents-api/agents_api/queries/sessions/list_sessions.py index 3aabaf32d..ac3573e61 100644 --- a/agents-api/agents_api/queries/sessions/list_sessions.py +++ b/agents-api/agents_api/queries/sessions/list_sessions.py @@ -12,7 +12,7 @@ from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class # Define the raw SQL query -raw_query = """ +session_query = """ WITH session_participants AS ( SELECT sl.session_id, @@ -49,10 +49,6 @@ LIMIT $2 OFFSET $6; """ -# Parse and optimize the query -# query = parse_one(raw_query).sql(pretty=True) -query = raw_query - @rewrap_exceptions( { @@ -62,7 +58,14 @@ detail="The specified developer does not exist.", ), asyncpg.NoDataFoundError: partialclass( - HTTPException, status_code=404, detail="No sessions found" + HTTPException, + status_code=404, + detail="No sessions found", + ), + asyncpg.CheckViolationError: partialclass( + HTTPException, + status_code=400, + detail="Invalid session data provided.", ), } ) @@ -94,7 +97,7 @@ async def list_sessions( tuple[str, list]: SQL query and parameters """ return ( - query, + session_query, [ developer_id, # $1 limit, # $2 diff --git a/agents-api/agents_api/queries/sessions/patch_session.py b/agents-api/agents_api/queries/sessions/patch_session.py index 7d526ae1a..d7533e124 100644 --- a/agents-api/agents_api/queries/sessions/patch_session.py +++ b/agents-api/agents_api/queries/sessions/patch_session.py @@ -37,13 +37,18 @@ asyncpg.ForeignKeyViolationError: partialclass( HTTPException, status_code=404, - detail="The specified developer or participant does not exist.", + detail="The specified developer or session does not exist.", ), asyncpg.NoDataFoundError: partialclass( HTTPException, status_code=404, detail="Session not found", ), + asyncpg.CheckViolationError: partialclass( + HTTPException, + status_code=400, + detail="Invalid session data provided.", + ), } ) @wrap_in_class( diff --git a/agents-api/agents_api/queries/sessions/update_session.py b/agents-api/agents_api/queries/sessions/update_session.py index 7c58d10e6..e3f46c0af 100644 --- a/agents-api/agents_api/queries/sessions/update_session.py +++ b/agents-api/agents_api/queries/sessions/update_session.py @@ -33,13 +33,18 @@ asyncpg.ForeignKeyViolationError: partialclass( HTTPException, status_code=404, - detail="The specified developer or participant does not exist.", + detail="The specified developer or session does not exist.", ), asyncpg.NoDataFoundError: partialclass( HTTPException, status_code=404, detail="Session not found", ), + asyncpg.CheckViolationError: partialclass( + HTTPException, + status_code=400, + detail="Invalid session data provided.", + ), } ) @wrap_in_class( diff --git a/agents-api/agents_api/queries/users/create_or_update_user.py b/agents-api/agents_api/queries/users/create_or_update_user.py index 965ae4ce4..0a2936a9b 100644 --- a/agents-api/agents_api/queries/users/create_or_update_user.py +++ b/agents-api/agents_api/queries/users/create_or_update_user.py @@ -40,10 +40,10 @@ status_code=404, detail="The specified developer does not exist.", ), - asyncpg.UniqueViolationError: partialclass( # Add handling for potential race conditions + asyncpg.UniqueViolationError: partialclass( HTTPException, status_code=409, - detail="A user with this ID already exists.", + detail="A user with this ID already exists for the specified developer.", ), } ) diff --git a/agents-api/agents_api/queries/users/create_user.py b/agents-api/agents_api/queries/users/create_user.py index 8f35a646c..e246c7255 100644 --- a/agents-api/agents_api/queries/users/create_user.py +++ b/agents-api/agents_api/queries/users/create_user.py @@ -37,10 +37,10 @@ status_code=404, detail="The specified developer does not exist.", ), - asyncpg.NullValueNoIndicatorParameterError: partialclass( + asyncpg.UniqueViolationError: partialclass( HTTPException, - status_code=404, - detail="The specified developer does not exist.", + status_code=409, + detail="A user with this ID already exists for the specified developer.", ), } ) diff --git a/agents-api/agents_api/queries/users/delete_user.py b/agents-api/agents_api/queries/users/delete_user.py index ad5befd73..6b8497980 100644 --- a/agents-api/agents_api/queries/users/delete_user.py +++ b/agents-api/agents_api/queries/users/delete_user.py @@ -56,7 +56,7 @@ status_code=404, detail="The specified developer does not exist.", ), - asyncpg.UniqueViolationError: partialclass( + asyncpg.DataError: partialclass( HTTPException, status_code=404, detail="The specified user does not exist.", diff --git a/agents-api/agents_api/queries/users/get_user.py b/agents-api/agents_api/queries/users/get_user.py index 2b71f9192..07a840621 100644 --- a/agents-api/agents_api/queries/users/get_user.py +++ b/agents-api/agents_api/queries/users/get_user.py @@ -31,11 +31,6 @@ status_code=404, detail="The specified developer does not exist.", ), - asyncpg.UniqueViolationError: partialclass( - HTTPException, - status_code=404, - detail="The specified user does not exist.", - ), } ) @wrap_in_class(User, one=True) diff --git a/agents-api/agents_api/queries/users/list_users.py b/agents-api/agents_api/queries/users/list_users.py index 0f0818135..75fd62b4b 100644 --- a/agents-api/agents_api/queries/users/list_users.py +++ b/agents-api/agents_api/queries/users/list_users.py @@ -42,11 +42,6 @@ status_code=404, detail="The specified developer does not exist.", ), - asyncpg.UniqueViolationError: partialclass( - HTTPException, - status_code=404, - detail="The specified user does not exist.", - ), } ) @wrap_in_class(User) diff --git a/agents-api/agents_api/queries/users/patch_user.py b/agents-api/agents_api/queries/users/patch_user.py index c55ee31b7..fb2d8bfad 100644 --- a/agents-api/agents_api/queries/users/patch_user.py +++ b/agents-api/agents_api/queries/users/patch_user.py @@ -47,8 +47,8 @@ ), asyncpg.UniqueViolationError: partialclass( HTTPException, - status_code=404, - detail="The specified user does not exist.", + status_code=409, + detail="A user with this ID already exists for the specified developer.", ), } ) diff --git a/agents-api/agents_api/queries/users/update_user.py b/agents-api/agents_api/queries/users/update_user.py index 91572e15d..975dc57c7 100644 --- a/agents-api/agents_api/queries/users/update_user.py +++ b/agents-api/agents_api/queries/users/update_user.py @@ -31,8 +31,8 @@ ), asyncpg.UniqueViolationError: partialclass( HTTPException, - status_code=404, - detail="The specified user does not exist.", + status_code=409, + detail="A user with this ID already exists for the specified developer.", ), } ) diff --git a/agents-api/tests/fixtures.py b/agents-api/tests/fixtures.py index 286fd10fb..2ad6bfeeb 100644 --- a/agents-api/tests/fixtures.py +++ b/agents-api/tests/fixtures.py @@ -1,6 +1,5 @@ import random import string -import time from uuid import UUID from fastapi.testclient import TestClient @@ -9,6 +8,7 @@ from agents_api.autogen.openapi_model import ( CreateAgentRequest, + CreateDocRequest, CreateFileRequest, CreateSessionRequest, CreateUserRequest, @@ -20,8 +20,8 @@ # from agents_api.queries.agents.delete_agent import delete_agent from agents_api.queries.developers.get_developer import get_developer +from agents_api.queries.docs.create_doc import create_doc -# from agents_api.queries.docs.create_doc import create_doc # from agents_api.queries.docs.delete_doc import delete_doc # from agents_api.queries.execution.create_execution import create_execution # from agents_api.queries.execution.create_execution_transition import create_execution_transition @@ -64,22 +64,6 @@ def test_developer_id(): return developer_id -# @fixture(scope="global") -# async def test_file(dsn=pg_dsn, developer_id=test_developer_id): -# async with get_pg_client(dsn=dsn) as client: -# file = await create_file( -# developer_id=developer_id, -# data=CreateFileRequest( -# name="Hello", -# description="World", -# mime_type="text/plain", -# content="eyJzYW1wbGUiOiAidGVzdCJ9", -# ), -# client=client, -# ) -# yield file - - @fixture(scope="global") async def test_developer(dsn=pg_dsn, developer_id=test_developer_id): pool = await create_db_pool(dsn=dsn) @@ -149,6 +133,24 @@ async def test_file(dsn=pg_dsn, developer=test_developer, user=test_user): return file +@fixture(scope="test") +async def test_doc(dsn=pg_dsn, developer=test_developer, agent=test_agent): + pool = await create_db_pool(dsn=dsn) + doc = await create_doc( + developer_id=developer.id, + data=CreateDocRequest( + title="Hello", + content=["World", "World2", "World3"], + metadata={"test": "test"}, + embed_instruction="Embed the document", + ), + owner_type="agent", + owner_id=agent.id, + connection_pool=pool, + ) + return doc + + @fixture(scope="test") async def random_email(): return f"{"".join([random.choice(string.ascii_lowercase) for _ in range(10)])}@mail.com" diff --git a/agents-api/tests/test_docs_queries.py b/agents-api/tests/test_docs_queries.py index f2ff2c786..82490cb77 100644 --- a/agents-api/tests/test_docs_queries.py +++ b/agents-api/tests/test_docs_queries.py @@ -1,163 +1,248 @@ -# # Tests for entry queries +from ward import test -# import asyncio +from agents_api.autogen.openapi_model import CreateDocRequest +from agents_api.clients.pg import create_db_pool +from agents_api.queries.docs.create_doc import create_doc +from agents_api.queries.docs.delete_doc import delete_doc +from agents_api.queries.docs.get_doc import get_doc +from agents_api.queries.docs.list_docs import list_docs -# from ward import test +# If you wish to test text/embedding/hybrid search, import them: +from agents_api.queries.docs.search_docs_by_text import search_docs_by_text -# from agents_api.autogen.openapi_model import CreateDocRequest -# from agents_api.queries.docs.create_doc import create_doc -# from agents_api.queries.docs.delete_doc import delete_doc -# from agents_api.queries.docs.embed_snippets import embed_snippets -# from agents_api.queries.docs.get_doc import get_doc -# from agents_api.queries.docs.list_docs import list_docs # from agents_api.queries.docs.search_docs_by_embedding import search_docs_by_embedding -# from agents_api.queries.docs.search_docs_by_text import search_docs_by_text -# from tests.fixtures import ( -# EMBEDDING_SIZE, -# cozo_client, -# test_agent, -# test_developer_id, -# test_doc, -# test_user, -# ) - - -# @test("query: create docs") -# def _( -# client=cozo_client, developer_id=test_developer_id, agent=test_agent, user=test_user -# ): -# create_doc( -# developer_id=developer_id, -# owner_type="agent", -# owner_id=agent.id, -# data=CreateDocRequest(title="Hello", content=["World"]), -# client=client, -# ) - -# create_doc( -# developer_id=developer_id, -# owner_type="user", -# owner_id=user.id, -# data=CreateDocRequest(title="Hello", content=["World"]), -# client=client, -# ) - - -# @test("query: get docs") -# def _(client=cozo_client, doc=test_doc, developer_id=test_developer_id): -# get_doc( -# developer_id=developer_id, -# doc_id=doc.id, -# client=client, -# ) - - -# @test("query: delete doc") -# def _(client=cozo_client, developer_id=test_developer_id, agent=test_agent): -# doc = create_doc( -# developer_id=developer_id, -# owner_type="agent", -# owner_id=agent.id, -# data=CreateDocRequest(title="Hello", content=["World"]), -# client=client, -# ) - -# delete_doc( -# developer_id=developer_id, -# doc_id=doc.id, -# owner_type="agent", -# owner_id=agent.id, -# client=client, -# ) - - -# @test("query: list docs") -# def _( -# client=cozo_client, developer_id=test_developer_id, doc=test_doc, agent=test_agent -# ): -# result = list_docs( -# developer_id=developer_id, -# owner_type="agent", -# owner_id=agent.id, -# client=client, -# include_without_embeddings=True, -# ) - -# assert len(result) >= 1 - - -# @test("query: search docs by text") -# async def _(client=cozo_client, agent=test_agent, developer_id=test_developer_id): -# create_doc( -# developer_id=developer_id, -# owner_type="agent", -# owner_id=agent.id, -# data=CreateDocRequest( -# title="Hello", content=["The world is a funny little thing"] -# ), -# client=client, -# ) - -# await asyncio.sleep(1) - -# result = search_docs_by_text( -# developer_id=developer_id, -# owners=[("agent", agent.id)], -# query="funny", -# client=client, -# ) - -# assert len(result) >= 1 -# assert result[0].metadata is not None - - -# @test("query: search docs by embedding") -# async def _(client=cozo_client, agent=test_agent, developer_id=test_developer_id): -# doc = create_doc( -# developer_id=developer_id, -# owner_type="agent", -# owner_id=agent.id, -# data=CreateDocRequest(title="Hello", content=["World"]), -# client=client, -# ) - -# ### Add embedding to the snippet -# embed_snippets( -# developer_id=developer_id, -# doc_id=doc.id, -# snippet_indices=[0], -# embeddings=[[1.0] * EMBEDDING_SIZE], -# client=client, -# ) - -# await asyncio.sleep(1) - -# ### Search -# query_embedding = [0.99] * EMBEDDING_SIZE - -# result = search_docs_by_embedding( -# developer_id=developer_id, -# owners=[("agent", agent.id)], -# query_embedding=query_embedding, -# client=client, -# ) - -# assert len(result) >= 1 -# assert result[0].metadata is not None - - -# @test("query: embed snippets") -# def _(client=cozo_client, developer_id=test_developer_id, doc=test_doc): -# snippet_indices = [0] -# embeddings = [[1.0] * EMBEDDING_SIZE] - -# result = embed_snippets( -# developer_id=developer_id, -# doc_id=doc.id, -# snippet_indices=snippet_indices, -# embeddings=embeddings, -# client=client, -# ) - -# assert result is not None -# assert result.id == doc.id +# from agents_api.queries.docs.search_docs_hybrid import search_docs_hybrid +# You can rename or remove these imports to match your actual fixtures +from tests.fixtures import pg_dsn, test_agent, test_developer, test_doc, test_user + + +@test("query: create user doc") +async def _(dsn=pg_dsn, developer=test_developer, user=test_user): + pool = await create_db_pool(dsn=dsn) + doc = await create_doc( + developer_id=developer.id, + data=CreateDocRequest( + title="User Doc", + content="Docs for user testing", + metadata={"test": "test"}, + embed_instruction="Embed the document", + ), + owner_type="user", + owner_id=user.id, + connection_pool=pool, + ) + assert doc.title == "User Doc" + + # Verify doc appears in user's docs + docs_list = await list_docs( + developer_id=developer.id, + owner_type="user", + owner_id=user.id, + connection_pool=pool, + ) + assert any(d.id == doc.id for d in docs_list) + + +@test("query: create agent doc") +async def _(dsn=pg_dsn, developer=test_developer, agent=test_agent): + pool = await create_db_pool(dsn=dsn) + doc = await create_doc( + developer_id=developer.id, + data=CreateDocRequest( + title="Agent Doc", + content="Docs for agent testing", + metadata={"test": "test"}, + embed_instruction="Embed the document", + ), + owner_type="agent", + owner_id=agent.id, + connection_pool=pool, + ) + assert doc.title == "Agent Doc" + + # Verify doc appears in agent's docs + docs_list = await list_docs( + developer_id=developer.id, + owner_type="agent", + owner_id=agent.id, + connection_pool=pool, + ) + assert any(d.id == doc.id for d in docs_list) + + +@test("query: get doc") +async def _(dsn=pg_dsn, developer=test_developer, doc=test_doc): + pool = await create_db_pool(dsn=dsn) + doc_test = await get_doc( + developer_id=developer.id, + doc_id=doc.id, + connection_pool=pool, + ) + assert doc_test.id == doc.id + assert doc_test.title == doc.title + assert doc_test.content == doc.content + + +@test("query: list user docs") +async def _(dsn=pg_dsn, developer=test_developer, user=test_user): + pool = await create_db_pool(dsn=dsn) + + # Create a doc owned by the user + doc_user = await create_doc( + developer_id=developer.id, + data=CreateDocRequest( + title="User List Test", + content="Some user doc content", + metadata={"test": "test"}, + embed_instruction="Embed the document", + ), + owner_type="user", + owner_id=user.id, + connection_pool=pool, + ) + + # List user's docs + docs_list = await list_docs( + developer_id=developer.id, + owner_type="user", + owner_id=user.id, + connection_pool=pool, + ) + assert len(docs_list) >= 1 + assert any(d.id == doc_user.id for d in docs_list) + + +@test("query: list agent docs") +async def _(dsn=pg_dsn, developer=test_developer, agent=test_agent): + pool = await create_db_pool(dsn=dsn) + + # Create a doc owned by the agent + doc_agent = await create_doc( + developer_id=developer.id, + data=CreateDocRequest( + title="Agent List Test", + content="Some agent doc content", + metadata={"test": "test"}, + embed_instruction="Embed the document", + ), + owner_type="agent", + owner_id=agent.id, + connection_pool=pool, + ) + + # List agent's docs + docs_list = await list_docs( + developer_id=developer.id, + owner_type="agent", + owner_id=agent.id, + connection_pool=pool, + ) + assert len(docs_list) >= 1 + assert any(d.id == doc_agent.id for d in docs_list) + + +@test("query: delete user doc") +async def _(dsn=pg_dsn, developer=test_developer, user=test_user): + pool = await create_db_pool(dsn=dsn) + + # Create a doc owned by the user + doc_user = await create_doc( + developer_id=developer.id, + data=CreateDocRequest( + title="User Delete Test", + content="Doc for user deletion test", + metadata={"test": "test"}, + embed_instruction="Embed the document", + ), + owner_type="user", + owner_id=user.id, + connection_pool=pool, + ) + + # Delete the doc + await delete_doc( + developer_id=developer.id, + doc_id=doc_user.id, + owner_type="user", + owner_id=user.id, + connection_pool=pool, + ) + + # Verify doc is no longer in user's docs + docs_list = await list_docs( + developer_id=developer.id, + owner_type="user", + owner_id=user.id, + connection_pool=pool, + ) + assert not any(d.id == doc_user.id for d in docs_list) + + +@test("query: delete agent doc") +async def _(dsn=pg_dsn, developer=test_developer, agent=test_agent): + pool = await create_db_pool(dsn=dsn) + + # Create a doc owned by the agent + doc_agent = await create_doc( + developer_id=developer.id, + data=CreateDocRequest( + title="Agent Delete Test", + content="Doc for agent deletion test", + metadata={"test": "test"}, + embed_instruction="Embed the document", + ), + owner_type="agent", + owner_id=agent.id, + connection_pool=pool, + ) + + # Delete the doc + await delete_doc( + developer_id=developer.id, + doc_id=doc_agent.id, + owner_type="agent", + owner_id=agent.id, + connection_pool=pool, + ) + + # Verify doc is no longer in agent's docs + docs_list = await list_docs( + developer_id=developer.id, + owner_type="agent", + owner_id=agent.id, + connection_pool=pool, + ) + assert not any(d.id == doc_agent.id for d in docs_list) + + +@test("query: search docs by text") +async def _(dsn=pg_dsn, agent=test_agent, developer=test_developer): + pool = await create_db_pool(dsn=dsn) + + # Create a test document + await create_doc( + developer_id=developer.id, + owner_type="agent", + owner_id=agent.id, + data=CreateDocRequest( + title="Hello", + content="The world is a funny little thing", + metadata={"test": "test"}, + embed_instruction="Embed the document", + ), + connection_pool=pool, + ) + + # Search using the correct parameter types + result = await search_docs_by_text( + developer_id=developer.id, + owners=[("agent", agent.id)], + query="funny", + k=3, # Add k parameter + search_language="english", # Add language parameter + metadata_filter={}, # Add metadata filter + connection_pool=pool, + ) + + assert len(result) >= 1 + assert result[0].metadata is not None diff --git a/agents-api/tests/test_entry_queries.py b/agents-api/tests/test_entry_queries.py index 706185c7b..ae825ed92 100644 --- a/agents-api/tests/test_entry_queries.py +++ b/agents-api/tests/test_entry_queries.py @@ -3,8 +3,6 @@ It verifies the functionality of adding, retrieving, and processing entries as defined in the schema. """ -from uuid import UUID - from fastapi import HTTPException from uuid_extensions import uuid7 from ward import raises, test diff --git a/agents-api/tests/test_files_queries.py b/agents-api/tests/test_files_queries.py index 92b52d733..68409ef5c 100644 --- a/agents-api/tests/test_files_queries.py +++ b/agents-api/tests/test_files_queries.py @@ -1,9 +1,7 @@ # # Tests for entry queries -from fastapi import HTTPException -from uuid_extensions import uuid7 -from ward import raises, test +from ward import test from agents_api.autogen.openapi_model import CreateFileRequest from agents_api.clients.pg import create_db_pool @@ -84,7 +82,7 @@ async def _(dsn=pg_dsn, developer=test_developer, agent=test_agent): assert any(f.id == file.id for f in files) -@test("model: get file") +@test("query: get file") async def _(dsn=pg_dsn, file=test_file, developer=test_developer): pool = await create_db_pool(dsn=dsn) file_test = await get_file( diff --git a/agents-api/tests/test_session_queries.py b/agents-api/tests/test_session_queries.py index 171e56aa8..4673d6fc5 100644 --- a/agents-api/tests/test_session_queries.py +++ b/agents-api/tests/test_session_queries.py @@ -10,7 +10,6 @@ CreateOrUpdateSessionRequest, CreateSessionRequest, PatchSessionRequest, - ResourceCreatedResponse, ResourceDeletedResponse, ResourceUpdatedResponse, Session, diff --git a/integrations-service/integrations/autogen/Docs.py b/integrations-service/integrations/autogen/Docs.py index ffed27c1d..af5f60d6a 100644 --- a/integrations-service/integrations/autogen/Docs.py +++ b/integrations-service/integrations/autogen/Docs.py @@ -73,6 +73,30 @@ class Doc(BaseModel): """ Embeddings for the document """ + modality: Annotated[str | None, Field(json_schema_extra={"readOnly": True})] = None + """ + Modality of the document + """ + language: Annotated[str | None, Field(json_schema_extra={"readOnly": True})] = None + """ + Language of the document + """ + index: Annotated[int | None, Field(json_schema_extra={"readOnly": True})] = None + """ + Index of the document + """ + embedding_model: Annotated[ + str | None, Field(json_schema_extra={"readOnly": True}) + ] = None + """ + Embedding model to use for the document + """ + embedding_dimensions: Annotated[ + int | None, Field(json_schema_extra={"readOnly": True}) + ] = None + """ + Dimensions of the embedding model + """ class DocOwner(BaseModel): diff --git a/memory-store/migrations/000006_docs.up.sql b/memory-store/migrations/000006_docs.up.sql index 193fae122..97bdad43c 100644 --- a/memory-store/migrations/000006_docs.up.sql +++ b/memory-store/migrations/000006_docs.up.sql @@ -24,8 +24,7 @@ CREATE TABLE IF NOT EXISTS docs ( created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, updated_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, metadata JSONB NOT NULL DEFAULT '{}'::JSONB, - CONSTRAINT pk_docs PRIMARY KEY (developer_id, doc_id), - CONSTRAINT uq_docs_doc_id_index UNIQUE (doc_id, index), + CONSTRAINT pk_docs PRIMARY KEY (developer_id, doc_id, index), CONSTRAINT ct_docs_embedding_dimensions_positive CHECK (embedding_dimensions > 0), CONSTRAINT ct_docs_valid_modality CHECK (modality IN ('text', 'image', 'mixed')), CONSTRAINT ct_docs_index_positive CHECK (index >= 0), @@ -67,10 +66,12 @@ END $$; CREATE TABLE IF NOT EXISTS doc_owners ( developer_id UUID NOT NULL, doc_id UUID NOT NULL, + index INTEGER NOT NULL, owner_type TEXT NOT NULL, -- 'user' or 'agent' owner_id UUID NOT NULL, - CONSTRAINT pk_doc_owners PRIMARY KEY (developer_id, doc_id), - CONSTRAINT fk_doc_owners_doc FOREIGN KEY (developer_id, doc_id) REFERENCES docs (developer_id, doc_id), + CONSTRAINT pk_doc_owners PRIMARY KEY (developer_id, doc_id, index), + -- TODO: Add foreign key constraint + -- CONSTRAINT fk_doc_owners_doc FOREIGN KEY (developer_id, doc_id) REFERENCES docs (developer_id, doc_id), CONSTRAINT ct_doc_owners_owner_type CHECK (owner_type IN ('user', 'agent')) ); diff --git a/memory-store/migrations/000018_doc_search.up.sql b/memory-store/migrations/000018_doc_search.up.sql index 5293cc81a..2f5b2baf1 100644 --- a/memory-store/migrations/000018_doc_search.up.sql +++ b/memory-store/migrations/000018_doc_search.up.sql @@ -101,6 +101,7 @@ END $$; -- Create the search function CREATE OR REPLACE FUNCTION search_by_vector ( + developer_id UUID, query_embedding vector (1024), owner_types TEXT[], owner_ids UUID [], @@ -134,9 +135,7 @@ BEGIN IF owner_types IS NOT NULL AND owner_ids IS NOT NULL THEN owner_filter_sql := ' AND ( - (ud.user_id = ANY($5) AND ''user'' = ANY($4)) - OR - (ad.agent_id = ANY($5) AND ''agent'' = ANY($4)) + doc_owners.owner_id = ANY($5::uuid[]) AND doc_owners.owner_type = ANY($4::text[]) )'; ELSE owner_filter_sql := ''; @@ -153,6 +152,7 @@ BEGIN RETURN QUERY EXECUTE format( 'WITH ranked_docs AS ( SELECT + d.developer_id, d.doc_id, d.index, d.title, @@ -160,15 +160,12 @@ BEGIN (1 - (d.embedding <=> $1)) as distance, d.embedding, d.metadata, - CASE - WHEN ud.user_id IS NOT NULL THEN ''user'' - WHEN ad.agent_id IS NOT NULL THEN ''agent'' - END as owner_type, - COALESCE(ud.user_id, ad.agent_id) as owner_id + doc_owners.owner_type, + doc_owners.owner_id FROM docs_embeddings d - LEFT JOIN user_docs ud ON d.doc_id = ud.doc_id - LEFT JOIN agent_docs ad ON d.doc_id = ad.doc_id - WHERE 1 - (d.embedding <=> $1) >= $2 + LEFT JOIN doc_owners ON d.doc_id = doc_owners.doc_id + WHERE d.developer_id = $7 + AND 1 - (d.embedding <=> $1) >= $2 %s %s ) @@ -185,7 +182,9 @@ BEGIN k, owner_types, owner_ids, - metadata_filter; + metadata_filter, + developer_id; + END; $$; @@ -238,6 +237,7 @@ COMMENT ON FUNCTION embed_and_search_by_vector IS 'Convenience function that com -- Create the text search function CREATE OR REPLACE FUNCTION search_by_text ( + developer_id UUID, query_text text, owner_types TEXT[], owner_ids UUID [], @@ -267,9 +267,7 @@ BEGIN IF owner_types IS NOT NULL AND owner_ids IS NOT NULL THEN owner_filter_sql := ' AND ( - (ud.user_id = ANY($5) AND ''user'' = ANY($4)) - OR - (ad.agent_id = ANY($5) AND ''agent'' = ANY($4)) + doc_owners.owner_id = ANY($5::uuid[]) AND doc_owners.owner_type = ANY($4::text[]) )'; ELSE owner_filter_sql := ''; @@ -286,6 +284,7 @@ BEGIN RETURN QUERY EXECUTE format( 'WITH ranked_docs AS ( SELECT + d.developer_id, d.doc_id, d.index, d.title, @@ -293,15 +292,12 @@ BEGIN ts_rank_cd(d.search_tsv, $1, 32)::double precision as distance, d.embedding, d.metadata, - CASE - WHEN ud.user_id IS NOT NULL THEN ''user'' - WHEN ad.agent_id IS NOT NULL THEN ''agent'' - END as owner_type, - COALESCE(ud.user_id, ad.agent_id) as owner_id + doc_owners.owner_type, + doc_owners.owner_id FROM docs_embeddings d - LEFT JOIN user_docs ud ON d.doc_id = ud.doc_id - LEFT JOIN agent_docs ad ON d.doc_id = ad.doc_id - WHERE d.search_tsv @@ $1 + LEFT JOIN doc_owners ON d.doc_id = doc_owners.doc_id + WHERE d.developer_id = $6 + AND d.search_tsv @@ $1 %s %s ) @@ -314,11 +310,11 @@ BEGIN ) USING ts_query, - search_language, k, owner_types, owner_ids, - metadata_filter; + metadata_filter, + developer_id; END; $$; @@ -372,6 +368,7 @@ $$ LANGUAGE plpgsql; -- Hybrid search function combining text and vector search CREATE OR REPLACE FUNCTION search_hybrid ( + developer_id UUID, query_text text, query_embedding vector (1024), owner_types TEXT[], @@ -397,6 +394,7 @@ BEGIN RETURN QUERY WITH text_results AS ( SELECT * FROM search_by_text( + developer_id, query_text, owner_types, owner_ids, @@ -407,6 +405,7 @@ BEGIN ), embedding_results AS ( SELECT * FROM search_by_vector( + developer_id, query_embedding, owner_types, owner_ids, @@ -426,6 +425,7 @@ BEGIN ), scores AS ( SELECT + r.developer_id, r.doc_id, r.title, r.content, @@ -437,8 +437,8 @@ BEGIN COALESCE(t.distance, 0.0) as text_score, COALESCE(e.distance, 0.0) as embedding_score FROM all_results r - LEFT JOIN text_results t ON r.doc_id = t.doc_id - LEFT JOIN embedding_results e ON r.doc_id = e.doc_id + LEFT JOIN text_results t ON r.doc_id = t.doc_id AND r.developer_id = t.developer_id + LEFT JOIN embedding_results e ON r.doc_id = e.doc_id AND r.developer_id = e.developer_id ), normalized_scores AS ( SELECT @@ -448,6 +448,7 @@ BEGIN FROM scores ) SELECT + developer_id, doc_id, index, title, @@ -468,6 +469,7 @@ COMMENT ON FUNCTION search_hybrid IS 'Hybrid search combining text and vector se -- Convenience function that handles embedding generation CREATE OR REPLACE FUNCTION embed_and_search_hybrid ( + developer_id UUID, query_text text, owner_types TEXT[], owner_ids UUID [], @@ -497,6 +499,7 @@ BEGIN -- Perform hybrid search RETURN QUERY SELECT * FROM search_hybrid( + developer_id, query_text, query_embedding, owner_types, diff --git a/typespec/docs/models.tsp b/typespec/docs/models.tsp index 055fc2003..f4d16cbd5 100644 --- a/typespec/docs/models.tsp +++ b/typespec/docs/models.tsp @@ -27,6 +27,26 @@ model Doc { /** Embeddings for the document */ @visibility("read") embeddings?: float32[] | float32[][]; + + @visibility("read") + /** Modality of the document */ + modality?: string; + + @visibility("read") + /** Language of the document */ + language?: string; + + @visibility("read") + /** Index of the document */ + index?: uint16; + + @visibility("read") + /** Embedding model to use for the document */ + embedding_model?: string; + + @visibility("read") + /** Dimensions of the embedding model */ + embedding_dimensions?: uint16; } /** Payload for creating a doc */ diff --git a/typespec/tsp-output/@typespec/openapi3/openapi-1.0.0.yaml b/typespec/tsp-output/@typespec/openapi3/openapi-1.0.0.yaml index d4835a695..c19bc4ed2 100644 --- a/typespec/tsp-output/@typespec/openapi3/openapi-1.0.0.yaml +++ b/typespec/tsp-output/@typespec/openapi3/openapi-1.0.0.yaml @@ -2876,6 +2876,28 @@ components: format: float description: Embeddings for the document readOnly: true + modality: + type: string + description: Modality of the document + readOnly: true + language: + type: string + description: Language of the document + readOnly: true + index: + type: integer + format: uint16 + description: Index of the document + readOnly: true + embedding_model: + type: string + description: Embedding model to use for the document + readOnly: true + embedding_dimensions: + type: integer + format: uint16 + description: Dimensions of the embedding model + readOnly: true Docs.DocOwner: type: object required: