From 6c77490b60286343809faa91be80339bee6b6fc1 Mon Sep 17 00:00:00 2001 From: vedantsahai18 Date: Thu, 19 Dec 2024 20:24:28 -0500 Subject: [PATCH 01/10] wip(agents-api): Doc queries --- .../agents_api/queries/docs/__init__.py | 25 +++ .../agents_api/queries/docs/create_doc.py | 135 +++++++++++++++ .../agents_api/queries/docs/delete_doc.py | 77 +++++++++ .../agents_api/queries/docs/embed_snippets.py | 0 agents-api/agents_api/queries/docs/get_doc.py | 52 ++++++ .../agents_api/queries/docs/list_docs.py | 91 ++++++++++ agents-api/agents_api/queries/docs/mmr.py | 109 ++++++++++++ .../queries/docs/search_docs_by_embedding.py | 70 ++++++++ .../queries/docs/search_docs_by_text.py | 65 +++++++ .../queries/docs/search_docs_hybrid.py | 159 ++++++++++++++++++ 10 files changed, 783 insertions(+) create mode 100644 agents-api/agents_api/queries/docs/__init__.py create mode 100644 agents-api/agents_api/queries/docs/create_doc.py create mode 100644 agents-api/agents_api/queries/docs/delete_doc.py create mode 100644 agents-api/agents_api/queries/docs/embed_snippets.py create mode 100644 agents-api/agents_api/queries/docs/get_doc.py create mode 100644 agents-api/agents_api/queries/docs/list_docs.py create mode 100644 agents-api/agents_api/queries/docs/mmr.py create mode 100644 agents-api/agents_api/queries/docs/search_docs_by_embedding.py create mode 100644 agents-api/agents_api/queries/docs/search_docs_by_text.py create mode 100644 agents-api/agents_api/queries/docs/search_docs_hybrid.py diff --git a/agents-api/agents_api/queries/docs/__init__.py b/agents-api/agents_api/queries/docs/__init__.py new file mode 100644 index 000000000..0ba3db0d4 --- /dev/null +++ b/agents-api/agents_api/queries/docs/__init__.py @@ -0,0 +1,25 @@ +""" +Module: agents_api/models/docs + +This module is responsible for managing document-related operations within the application, particularly for agents and possibly other entities. It serves as a core component of the document management system, enabling features such as document creation, listing, deletion, and embedding of snippets for enhanced search and retrieval capabilities. + +Main functionalities include: +- Creating new documents and associating them with agents or users. +- Listing documents based on various criteria, including ownership and metadata filters. +- Deleting documents by their unique identifiers. +- Embedding document snippets for retrieval purposes. + +The module interacts with other parts of the application, such as the agents and users modules, to provide a comprehensive document management system. Its role is crucial in enabling document search, retrieval, and management features within the context of agents and users. + +This documentation aims to provide clear, concise, and sufficient context for new developers or contributors to understand the module's role without needing to dive deep into the code immediately. +""" + +# ruff: noqa: F401, F403, F405 + +from .create_doc import create_doc +from .delete_doc import delete_doc +from .embed_snippets import embed_snippets +from .get_doc import get_doc +from .list_docs import list_docs +from .search_docs_by_embedding import search_docs_by_embedding +from .search_docs_by_text import search_docs_by_text diff --git a/agents-api/agents_api/queries/docs/create_doc.py b/agents-api/agents_api/queries/docs/create_doc.py new file mode 100644 index 000000000..57be43bdf --- /dev/null +++ b/agents-api/agents_api/queries/docs/create_doc.py @@ -0,0 +1,135 @@ +""" +Timescale-based creation of docs. + +Mirrors the structure of create_file.py, but uses the docs/doc_owners tables. +""" + +import base64 +import hashlib +from typing import Any, Literal +from uuid import UUID + +import asyncpg +from beartype import beartype +from fastapi import HTTPException +from sqlglot import parse_one +from uuid_extensions import uuid7 + +from ...autogen.openapi_model import CreateDocRequest, Doc +from ...metrics.counters import increase_counter +from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class + +# Base INSERT for docs +doc_query = parse_one(""" +INSERT INTO docs ( + developer_id, + doc_id, + title, + content, + index, + modality, + embedding_model, + embedding_dimensions, + language, + metadata +) +VALUES ( + $1, -- developer_id + $2, -- doc_id + $3, -- title + $4, -- content + $5, -- index + $6, -- modality + $7, -- embedding_model + $8, -- embedding_dimensions + $9, -- language + $10 -- metadata (JSONB) +) +RETURNING *; +""").sql(pretty=True) + +# Owner association query for doc_owners +doc_owner_query = parse_one(""" +WITH inserted_owner AS ( + INSERT INTO doc_owners ( + developer_id, + doc_id, + owner_type, + owner_id + ) + VALUES ($1, $2, $3, $4) + RETURNING doc_id +) +SELECT d.* +FROM inserted_owner io +JOIN docs d ON d.doc_id = io.doc_id; +""").sql(pretty=True) + + +@rewrap_exceptions( + { + asyncpg.UniqueViolationError: partialclass( + HTTPException, + status_code=409, + detail="A document with this ID already exists for this developer", + ), + asyncpg.NoDataFoundError: partialclass( + HTTPException, + status_code=404, + detail="The specified owner does not exist", + ), + asyncpg.ForeignKeyViolationError: partialclass( + HTTPException, + status_code=404, + detail="Developer or doc owner not found", + ), + } +) +@wrap_in_class( + Doc, + one=True, + transform=lambda d: { + **d, + "id": d["doc_id"], + # You could optionally return a computed hash or partial content if desired + }, +) +@increase_counter("create_doc") +@pg_query +@beartype +async def create_doc( + *, + developer_id: UUID, + doc_id: UUID | None = None, + data: CreateDocRequest, + owner_type: Literal["user", "agent", "org"] | None = None, + owner_id: UUID | None = None, +) -> list[tuple[str, list]]: + """ + Insert a new doc record into Timescale and optionally associate it with an owner. + """ + # Generate a UUID if not provided + doc_id = doc_id or uuid7() + + # Create the doc record + doc_params = [ + developer_id, + doc_id, + data.title, + data.content, + data.index or 0, # fallback if no snippet index + data.modality or "text", + data.embedding_model or "none", + data.embedding_dimensions or 0, + data.language or "english", + data.metadata or {}, + ] + + queries = [(doc_query, doc_params)] + + # If an owner is specified, associate it: + if owner_type and owner_id: + owner_params = [developer_id, doc_id, owner_type, owner_id] + queries.append((doc_owner_query, owner_params)) + + return queries diff --git a/agents-api/agents_api/queries/docs/delete_doc.py b/agents-api/agents_api/queries/docs/delete_doc.py new file mode 100644 index 000000000..d1e02faf1 --- /dev/null +++ b/agents-api/agents_api/queries/docs/delete_doc.py @@ -0,0 +1,77 @@ +""" +Timescale-based deletion of a doc record. +""" +from typing import Literal +from uuid import UUID + +import asyncpg +from beartype import beartype +from fastapi import HTTPException +from sqlglot import parse_one + +from ...autogen.openapi_model import ResourceDeletedResponse +from ...common.utils.datetime import utcnow +from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class + +# Delete doc query + ownership check +delete_doc_query = parse_one(""" +WITH deleted_owners AS ( + DELETE FROM doc_owners + WHERE developer_id = $1 + AND doc_id = $2 + AND ( + ($3::text IS NULL AND $4::uuid IS NULL) + OR (owner_type = $3 AND owner_id = $4) + ) +) +DELETE FROM docs +WHERE developer_id = $1 + AND doc_id = $2 + AND ( + $3::text IS NULL OR EXISTS ( + SELECT 1 FROM doc_owners + WHERE developer_id = $1 + AND doc_id = $2 + AND owner_type = $3 + AND owner_id = $4 + ) + ) +RETURNING doc_id; +""").sql(pretty=True) + + +@rewrap_exceptions( + { + asyncpg.NoDataFoundError: partialclass( + HTTPException, + status_code=404, + detail="Doc not found", + ) + } +) +@wrap_in_class( + ResourceDeletedResponse, + one=True, + transform=lambda d: { + "id": d["doc_id"], + "deleted_at": utcnow(), + "jobs": [], + }, +) +@pg_query +@beartype +async def delete_doc( + *, + developer_id: UUID, + doc_id: UUID, + owner_type: Literal["user", "agent", "org"] | None = None, + owner_id: UUID | None = None, +) -> tuple[str, list]: + """ + Deletes a doc (and associated doc_owners) for the given developer and doc_id. + If owner_type/owner_id is specified, only remove doc if that matches. + """ + return ( + delete_doc_query, + [developer_id, doc_id, owner_type, owner_id], + ) diff --git a/agents-api/agents_api/queries/docs/embed_snippets.py b/agents-api/agents_api/queries/docs/embed_snippets.py new file mode 100644 index 000000000..e69de29bb diff --git a/agents-api/agents_api/queries/docs/get_doc.py b/agents-api/agents_api/queries/docs/get_doc.py new file mode 100644 index 000000000..a0345f5e3 --- /dev/null +++ b/agents-api/agents_api/queries/docs/get_doc.py @@ -0,0 +1,52 @@ +""" +Timescale-based retrieval of a single doc record. +""" +from typing import Literal +from uuid import UUID + +import asyncpg +from beartype import beartype +from fastapi import HTTPException +from sqlglot import parse_one + +from ...autogen.openapi_model import Doc +from ..utils import pg_query, wrap_in_class + +doc_query = parse_one(""" +SELECT d.* +FROM docs d +LEFT JOIN doc_owners do ON d.developer_id = do.developer_id AND d.doc_id = do.doc_id +WHERE d.developer_id = $1 + AND d.doc_id = $2 + AND ( + ($3::text IS NULL AND $4::uuid IS NULL) + OR (do.owner_type = $3 AND do.owner_id = $4) + ) +LIMIT 1; +""").sql(pretty=True) + + +@wrap_in_class( + Doc, + one=True, + transform=lambda d: { + **d, + "id": d["doc_id"], + }, +) +@pg_query +@beartype +async def get_doc( + *, + developer_id: UUID, + doc_id: UUID, + owner_type: Literal["user", "agent", "org"] | None = None, + owner_id: UUID | None = None +) -> tuple[str, list]: + """ + Fetch a single doc, optionally constrained to a given owner. + """ + return ( + doc_query, + [developer_id, doc_id, owner_type, owner_id], + ) diff --git a/agents-api/agents_api/queries/docs/list_docs.py b/agents-api/agents_api/queries/docs/list_docs.py new file mode 100644 index 000000000..b145a1cbc --- /dev/null +++ b/agents-api/agents_api/queries/docs/list_docs.py @@ -0,0 +1,91 @@ +""" +Timescale-based listing of docs with optional owner filter and pagination. +""" +from typing import Literal +from uuid import UUID + +import asyncpg +from beartype import beartype +from fastapi import HTTPException +from sqlglot import parse_one + +from ...autogen.openapi_model import Doc +from ..utils import pg_query, wrap_in_class + +# Basic listing for all docs by developer +developer_docs_query = parse_one(""" +SELECT d.* +FROM docs d +LEFT JOIN doc_owners do ON d.developer_id = do.developer_id AND d.doc_id = do.doc_id +WHERE d.developer_id = $1 +ORDER BY +CASE + WHEN $4 = 'created_at' AND $5 = 'asc' THEN d.created_at + WHEN $4 = 'created_at' AND $5 = 'desc' THEN d.created_at + WHEN $4 = 'updated_at' AND $5 = 'asc' THEN d.updated_at + WHEN $4 = 'updated_at' AND $5 = 'desc' THEN d.updated_at +END DESC NULLS LAST +LIMIT $2 +OFFSET $3; +""").sql(pretty=True) + +# Listing for docs associated with a specific owner +owner_docs_query = parse_one(""" +SELECT d.* +FROM docs d +JOIN doc_owners do ON d.developer_id = do.developer_id AND d.doc_id = do.doc_id +WHERE do.developer_id = $1 + AND do.owner_id = $6 + AND do.owner_type = $7 +ORDER BY +CASE + WHEN $4 = 'created_at' AND $5 = 'asc' THEN d.created_at + WHEN $4 = 'created_at' AND $5 = 'desc' THEN d.created_at + WHEN $4 = 'updated_at' AND $5 = 'asc' THEN d.updated_at + WHEN $4 = 'updated_at' AND $5 = 'desc' THEN d.updated_at +END DESC NULLS LAST +LIMIT $2 +OFFSET $3; +""").sql(pretty=True) + + +@wrap_in_class( + Doc, + one=False, + transform=lambda d: { + **d, + "id": d["doc_id"], + }, +) +@pg_query +@beartype +async def list_docs( + *, + developer_id: UUID, + owner_id: UUID | None = None, + owner_type: Literal["user", "agent", "org"] | None = None, + limit: int = 100, + offset: int = 0, + sort_by: Literal["created_at", "updated_at"] = "created_at", + direction: Literal["asc", "desc"] = "desc", +) -> tuple[str, list]: + """ + Lists docs with optional owner filtering, pagination, and sorting. + """ + if direction.lower() not in ["asc", "desc"]: + raise HTTPException(status_code=400, detail="Invalid sort direction") + + if limit > 100 or limit < 1: + raise HTTPException(status_code=400, detail="Limit must be between 1 and 100") + + if offset < 0: + raise HTTPException(status_code=400, detail="Offset must be >= 0") + + params = [developer_id, limit, offset, sort_by, direction] + if owner_id and owner_type: + params.extend([owner_id, owner_type]) + query = owner_docs_query + else: + query = developer_docs_query + + return (query, params) diff --git a/agents-api/agents_api/queries/docs/mmr.py b/agents-api/agents_api/queries/docs/mmr.py new file mode 100644 index 000000000..d214e8c04 --- /dev/null +++ b/agents-api/agents_api/queries/docs/mmr.py @@ -0,0 +1,109 @@ +from __future__ import annotations + +import logging +from typing import Union + +import numpy as np + +Matrix = Union[list[list[float]], list[np.ndarray], np.ndarray] + +logger = logging.getLogger(__name__) + + +def _cosine_similarity(x: Matrix, y: Matrix) -> np.ndarray: + """Row-wise cosine similarity between two equal-width matrices. + + Args: + x: A matrix of shape (n, m). + y: A matrix of shape (k, m). + + Returns: + A matrix of shape (n, k) where each element (i, j) is the cosine similarity + between the ith row of X and the jth row of Y. + + Raises: + ValueError: If the number of columns in X and Y are not the same. + ImportError: If numpy is not installed. + """ + + if len(x) == 0 or len(y) == 0: + return np.array([]) + + x = [xx for xx in x if xx is not None] + y = [yy for yy in y if yy is not None] + + x = np.array(x) + y = np.array(y) + if x.shape[1] != y.shape[1]: + msg = ( + f"Number of columns in X and Y must be the same. X has shape {x.shape} " + f"and Y has shape {y.shape}." + ) + raise ValueError(msg) + try: + import simsimd as simd # type: ignore + + x = np.array(x, dtype=np.float32) + y = np.array(y, dtype=np.float32) + z = 1 - np.array(simd.cdist(x, y, metric="cosine")) + return z + except ImportError: + logger.debug( + "Unable to import simsimd, defaulting to NumPy implementation. If you want " + "to use simsimd please install with `pip install simsimd`." + ) + x_norm = np.linalg.norm(x, axis=1) + y_norm = np.linalg.norm(y, axis=1) + # Ignore divide by zero errors run time warnings as those are handled below. + with np.errstate(divide="ignore", invalid="ignore"): + similarity = np.dot(x, y.T) / np.outer(x_norm, y_norm) + similarity[np.isnan(similarity) | np.isinf(similarity)] = 0.0 + return similarity + + +def maximal_marginal_relevance( + query_embedding: np.ndarray, + embedding_list: list, + lambda_mult: float = 0.5, + k: int = 4, +) -> list[int]: + """Calculate maximal marginal relevance. + + Args: + query_embedding: The query embedding. + embedding_list: A list of embeddings. + lambda_mult: The lambda parameter for MMR. Default is 0.5. + k: The number of embeddings to return. Default is 4. + + Returns: + A list of indices of the embeddings to return. + + Raises: + ImportError: If numpy is not installed. + """ + + if min(k, len(embedding_list)) <= 0: + return [] + if query_embedding.ndim == 1: + query_embedding = np.expand_dims(query_embedding, axis=0) + similarity_to_query = _cosine_similarity(query_embedding, embedding_list)[0] + most_similar = int(np.argmax(similarity_to_query)) + idxs = [most_similar] + selected = np.array([embedding_list[most_similar]]) + while len(idxs) < min(k, len(embedding_list)): + best_score = -np.inf + idx_to_add = -1 + similarity_to_selected = _cosine_similarity(embedding_list, selected) + for i, query_score in enumerate(similarity_to_query): + if i in idxs: + continue + redundant_score = max(similarity_to_selected[i]) + equation_score = ( + lambda_mult * query_score - (1 - lambda_mult) * redundant_score + ) + if equation_score > best_score: + best_score = equation_score + idx_to_add = i + idxs.append(idx_to_add) + selected = np.append(selected, [embedding_list[idx_to_add]], axis=0) + return idxs diff --git a/agents-api/agents_api/queries/docs/search_docs_by_embedding.py b/agents-api/agents_api/queries/docs/search_docs_by_embedding.py new file mode 100644 index 000000000..c62188b61 --- /dev/null +++ b/agents-api/agents_api/queries/docs/search_docs_by_embedding.py @@ -0,0 +1,70 @@ +""" +Timescale-based doc embedding search using the `embedding` column. +""" + +import asyncpg +from typing import Literal, List +from uuid import UUID + +from beartype import beartype +from fastapi import HTTPException +from sqlglot import parse_one + +from ...autogen.openapi_model import Doc +from ..utils import pg_query, wrap_in_class + +# If you're doing approximate ANN (DiskANN) or IVF, you might use a special function or hint. +# For a basic vector distance search, you can do something like: +search_docs_by_embedding_query = parse_one(""" +SELECT d.*, + (d.embedding <-> $3) AS distance +FROM docs d +LEFT JOIN doc_owners do + ON d.developer_id = do.developer_id + AND d.doc_id = do.doc_id +WHERE d.developer_id = $1 + AND ( + ($4::text IS NULL AND $5::uuid IS NULL) + OR (do.owner_type = $4 AND do.owner_id = $5) + ) + AND d.embedding IS NOT NULL +ORDER BY d.embedding <-> $3 +LIMIT $2; +""").sql(pretty=True) + +@wrap_in_class( + Doc, + one=False, + transform=lambda rec: { + **rec, + "id": rec["doc_id"], + }, +) +@pg_query +@beartype +async def search_docs_by_embedding( + *, + developer_id: UUID, + query_embedding: List[float], + k: int = 10, + owner_type: Literal["user", "agent", "org"] | None = None, + owner_id: UUID | None = None, +) -> tuple[str, list]: + """ + Vector-based doc search: + - developer_id is required + - query_embedding: the vector to query + - k: number of results to return + - owner_type/owner_id: optional doc ownership filter + """ + if k < 1: + raise HTTPException(status_code=400, detail="k must be >= 1") + + # Validate embedding length if needed; e.g. 1024 floats + if not query_embedding: + raise HTTPException(status_code=400, detail="Empty embedding provided") + + return ( + search_docs_by_embedding_query, + [developer_id, k, query_embedding, owner_type, owner_id], + ) diff --git a/agents-api/agents_api/queries/docs/search_docs_by_text.py b/agents-api/agents_api/queries/docs/search_docs_by_text.py new file mode 100644 index 000000000..c9a5a93e2 --- /dev/null +++ b/agents-api/agents_api/queries/docs/search_docs_by_text.py @@ -0,0 +1,65 @@ +""" +Timescale-based doc text search using the `search_tsv` column. +""" + +import asyncpg +from typing import Literal +from uuid import UUID + +from beartype import beartype +from fastapi import HTTPException +from sqlglot import parse_one + +from ...autogen.openapi_model import Doc +from ..utils import pg_query, wrap_in_class + +search_docs_text_query = parse_one(""" +SELECT d.*, + ts_rank_cd(d.search_tsv, websearch_to_tsquery($3)) AS rank +FROM docs d +LEFT JOIN doc_owners do + ON d.developer_id = do.developer_id + AND d.doc_id = do.doc_id +WHERE d.developer_id = $1 + AND ( + ($4::text IS NULL AND $5::uuid IS NULL) + OR (do.owner_type = $4 AND do.owner_id = $5) + ) + AND d.search_tsv @@ websearch_to_tsquery($3) +ORDER BY rank DESC +LIMIT $2; +""").sql(pretty=True) + + +@wrap_in_class( + Doc, + one=False, + transform=lambda rec: { + **rec, + "id": rec["doc_id"], + }, +) +@pg_query +@beartype +async def search_docs_by_text( + *, + developer_id: UUID, + query: str, + k: int = 10, + owner_type: Literal["user", "agent", "org"] | None = None, + owner_id: UUID | None = None, +) -> tuple[str, list]: + """ + Full-text search on docs using the search_tsv column. + - developer_id: required + - query: the text to look for + - k: max results + - owner_type / owner_id: optional doc ownership filter + """ + if k < 1: + raise HTTPException(status_code=400, detail="k must be >= 1") + + return ( + search_docs_text_query, + [developer_id, k, query, owner_type, owner_id], + ) diff --git a/agents-api/agents_api/queries/docs/search_docs_hybrid.py b/agents-api/agents_api/queries/docs/search_docs_hybrid.py new file mode 100644 index 000000000..9e8d84dc7 --- /dev/null +++ b/agents-api/agents_api/queries/docs/search_docs_hybrid.py @@ -0,0 +1,159 @@ +""" +Hybrid doc search that merges text search and embedding search results +via a simple distribution-based score fusion or direct weighting in Python. +""" + +from typing import Literal, List +from uuid import UUID + +from beartype import beartype +from fastapi import HTTPException + +from ...autogen.openapi_model import Doc +from ..utils import run_concurrently +from .search_docs_by_text import search_docs_by_text +from .search_docs_by_embedding import search_docs_by_embedding + +def dbsf_normalize(scores: List[float]) -> List[float]: + """ + Example distribution-based normalization: clamp each score + from (mean - 3*stddev) to (mean + 3*stddev) and scale to 0..1 + """ + import statistics + if len(scores) < 2: + return scores + m = statistics.mean(scores) + sd = statistics.pstdev(scores) # population std + if sd == 0: + return scores + upper = m + 3*sd + lower = m - 3*sd + def clamp_scale(v): + c = min(upper, max(lower, v)) + return (c - lower) / (upper - lower) + return [clamp_scale(s) for s in scores] + +@beartype +def fuse_results( + text_docs: List[Doc], embedding_docs: List[Doc], alpha: float +) -> List[Doc]: + """ + Merges text search results (descending by text rank) with + embedding results (descending by closeness or inverse distance). + alpha ~ how much to weigh the embedding score + """ + # Suppose we stored each doc's "distance" from the embedding query, and + # for text search we store a rank or negative distance. We'll unify them: + # Make up a dictionary of doc_id -> text_score, doc_id -> embed_score + # For example, text_score = -distance if you want bigger = better + text_scores = {} + embed_scores = {} + for doc in text_docs: + # If you had "rank", you might store doc.distance = rank + # For demo, let's assume doc.distance is negative... up to you + text_scores[doc.id] = float(-doc.distance if doc.distance else 0) + + for doc in embedding_docs: + # Lower distance => better, so we do embed_score = -distance + embed_scores[doc.id] = float(-doc.distance if doc.distance else 0) + + # Normalize them + text_vals = list(text_scores.values()) + embed_vals = list(embed_scores.values()) + text_vals_norm = dbsf_normalize(text_vals) + embed_vals_norm = dbsf_normalize(embed_vals) + + # Map them back + t_keys = list(text_scores.keys()) + for i, key in enumerate(t_keys): + text_scores[key] = text_vals_norm[i] + e_keys = list(embed_scores.keys()) + for i, key in enumerate(e_keys): + embed_scores[key] = embed_vals_norm[i] + + # Gather all doc IDs + all_ids = set(text_scores.keys()) | set(embed_scores.keys()) + + # Weighted sum => combined + out = [] + for doc_id in all_ids: + # text and embed might be missing doc_id => 0 + t_score = text_scores.get(doc_id, 0) + e_score = embed_scores.get(doc_id, 0) + combined = alpha * e_score + (1 - alpha) * t_score + # We'll store final "distance" as -(combined) so bigger combined => smaller distance + out.append((doc_id, combined)) + + # Sort descending by combined + out.sort(key=lambda x: x[1], reverse=True) + + # Convert to doc objects. We can pick from text_docs or embedding_docs or whichever is found. + # If present in both, we can merge fields. For simplicity, just pick from text_docs then fallback embedding_docs. + + # Create a quick ID->doc map + text_map = {d.id: d for d in text_docs} + embed_map = {d.id: d for d in embedding_docs} + + final_docs = [] + for doc_id, score in out: + doc = text_map.get(doc_id) or embed_map.get(doc_id) + doc = doc.model_copy() # or copy if you are using Pydantic + doc.distance = float(-score) # so a higher combined => smaller distance + final_docs.append(doc) + return final_docs + + +@beartype +async def search_docs_hybrid( + developer_id: UUID, + text_query: str = "", + embedding: List[float] = None, + k: int = 10, + alpha: float = 0.5, + owner_type: Literal["user", "agent", "org"] | None = None, + owner_id: UUID | None = None, +) -> List[Doc]: + """ + Hybrid text-and-embedding doc search. We get top-K from each approach, + then fuse them client-side. Adjust concurrency or approach as you like. + """ + # We'll dispatch two queries in parallel + # (One full-text, one embedding-based) each limited to K + tasks = [] + if text_query.strip(): + tasks.append( + search_docs_by_text( + developer_id=developer_id, + query=text_query, + k=k, + owner_type=owner_type, + owner_id=owner_id, + ) + ) + else: + tasks.append([]) # no text results if query is empty + + if embedding and any(embedding): + tasks.append( + search_docs_by_embedding( + developer_id=developer_id, + query_embedding=embedding, + k=k, + owner_type=owner_type, + owner_id=owner_id, + ) + ) + else: + tasks.append([]) + + # Run concurrently (or sequentially, if you prefer) + # If you have a 'run_concurrently' from your old code, you can do: + # text_results, embed_results = await run_concurrently([task1, task2]) + # Otherwise just do them in parallel with e.g. asyncio.gather: + from asyncio import gather + text_results, embed_results = await gather(*tasks) + + # fuse them + fused = fuse_results(text_results, embed_results, alpha) + # Then pick top K overall + return fused[:k] From b427e38576eacd709e536cf24d0f65c0ba1a56f0 Mon Sep 17 00:00:00 2001 From: Vedantsahai18 Date: Fri, 20 Dec 2024 01:26:00 +0000 Subject: [PATCH 02/10] refactor: Lint agents-api (CI) --- agents-api/agents_api/queries/docs/delete_doc.py | 1 + agents-api/agents_api/queries/docs/get_doc.py | 3 ++- agents-api/agents_api/queries/docs/list_docs.py | 1 + .../queries/docs/search_docs_by_embedding.py | 5 +++-- .../agents_api/queries/docs/search_docs_by_text.py | 2 +- .../agents_api/queries/docs/search_docs_hybrid.py | 14 ++++++++++---- 6 files changed, 18 insertions(+), 8 deletions(-) diff --git a/agents-api/agents_api/queries/docs/delete_doc.py b/agents-api/agents_api/queries/docs/delete_doc.py index d1e02faf1..9d2075600 100644 --- a/agents-api/agents_api/queries/docs/delete_doc.py +++ b/agents-api/agents_api/queries/docs/delete_doc.py @@ -1,6 +1,7 @@ """ Timescale-based deletion of a doc record. """ + from typing import Literal from uuid import UUID diff --git a/agents-api/agents_api/queries/docs/get_doc.py b/agents-api/agents_api/queries/docs/get_doc.py index a0345f5e3..35d692c84 100644 --- a/agents-api/agents_api/queries/docs/get_doc.py +++ b/agents-api/agents_api/queries/docs/get_doc.py @@ -1,6 +1,7 @@ """ Timescale-based retrieval of a single doc record. """ + from typing import Literal from uuid import UUID @@ -41,7 +42,7 @@ async def get_doc( developer_id: UUID, doc_id: UUID, owner_type: Literal["user", "agent", "org"] | None = None, - owner_id: UUID | None = None + owner_id: UUID | None = None, ) -> tuple[str, list]: """ Fetch a single doc, optionally constrained to a given owner. diff --git a/agents-api/agents_api/queries/docs/list_docs.py b/agents-api/agents_api/queries/docs/list_docs.py index b145a1cbc..678c1a5e6 100644 --- a/agents-api/agents_api/queries/docs/list_docs.py +++ b/agents-api/agents_api/queries/docs/list_docs.py @@ -1,6 +1,7 @@ """ Timescale-based listing of docs with optional owner filter and pagination. """ + from typing import Literal from uuid import UUID diff --git a/agents-api/agents_api/queries/docs/search_docs_by_embedding.py b/agents-api/agents_api/queries/docs/search_docs_by_embedding.py index c62188b61..af89cc1b8 100644 --- a/agents-api/agents_api/queries/docs/search_docs_by_embedding.py +++ b/agents-api/agents_api/queries/docs/search_docs_by_embedding.py @@ -2,10 +2,10 @@ Timescale-based doc embedding search using the `embedding` column. """ -import asyncpg -from typing import Literal, List +from typing import List, Literal from uuid import UUID +import asyncpg from beartype import beartype from fastapi import HTTPException from sqlglot import parse_one @@ -32,6 +32,7 @@ LIMIT $2; """).sql(pretty=True) + @wrap_in_class( Doc, one=False, diff --git a/agents-api/agents_api/queries/docs/search_docs_by_text.py b/agents-api/agents_api/queries/docs/search_docs_by_text.py index c9a5a93e2..eed74e54b 100644 --- a/agents-api/agents_api/queries/docs/search_docs_by_text.py +++ b/agents-api/agents_api/queries/docs/search_docs_by_text.py @@ -2,10 +2,10 @@ Timescale-based doc text search using the `search_tsv` column. """ -import asyncpg from typing import Literal from uuid import UUID +import asyncpg from beartype import beartype from fastapi import HTTPException from sqlglot import parse_one diff --git a/agents-api/agents_api/queries/docs/search_docs_hybrid.py b/agents-api/agents_api/queries/docs/search_docs_hybrid.py index 9e8d84dc7..ae107419d 100644 --- a/agents-api/agents_api/queries/docs/search_docs_hybrid.py +++ b/agents-api/agents_api/queries/docs/search_docs_hybrid.py @@ -3,7 +3,7 @@ via a simple distribution-based score fusion or direct weighting in Python. """ -from typing import Literal, List +from typing import List, Literal from uuid import UUID from beartype import beartype @@ -11,8 +11,9 @@ from ...autogen.openapi_model import Doc from ..utils import run_concurrently -from .search_docs_by_text import search_docs_by_text from .search_docs_by_embedding import search_docs_by_embedding +from .search_docs_by_text import search_docs_by_text + def dbsf_normalize(scores: List[float]) -> List[float]: """ @@ -20,19 +21,23 @@ def dbsf_normalize(scores: List[float]) -> List[float]: from (mean - 3*stddev) to (mean + 3*stddev) and scale to 0..1 """ import statistics + if len(scores) < 2: return scores m = statistics.mean(scores) sd = statistics.pstdev(scores) # population std if sd == 0: return scores - upper = m + 3*sd - lower = m - 3*sd + upper = m + 3 * sd + lower = m - 3 * sd + def clamp_scale(v): c = min(upper, max(lower, v)) return (c - lower) / (upper - lower) + return [clamp_scale(s) for s in scores] + @beartype def fuse_results( text_docs: List[Doc], embedding_docs: List[Doc], alpha: float @@ -151,6 +156,7 @@ async def search_docs_hybrid( # text_results, embed_results = await run_concurrently([task1, task2]) # Otherwise just do them in parallel with e.g. asyncio.gather: from asyncio import gather + text_results, embed_results = await gather(*tasks) # fuse them From 93673b732512199a77df585c6568a42f657c65f4 Mon Sep 17 00:00:00 2001 From: vedantsahai18 Date: Fri, 20 Dec 2024 14:43:12 -0500 Subject: [PATCH 03/10] fix: fixed the CRD doc queries + added tests --- agents-api/agents_api/autogen/Docs.py | 24 ++ .../agents_api/queries/docs/__init__.py | 13 +- .../agents_api/queries/docs/create_doc.py | 40 +- .../agents_api/queries/docs/delete_doc.py | 6 +- agents-api/agents_api/queries/docs/get_doc.py | 15 +- .../agents_api/queries/docs/list_docs.py | 81 ++-- .../queries/docs/search_docs_by_embedding.py | 1 - .../queries/docs/search_docs_by_text.py | 3 +- .../queries/docs/search_docs_hybrid.py | 2 - .../agents_api/queries/entries/get_history.py | 1 - .../agents_api/queries/files/get_file.py | 6 +- .../agents_api/queries/files/list_files.py | 87 +--- .../queries/sessions/create_session.py | 2 - agents-api/tests/fixtures.py | 21 +- agents-api/tests/test_docs_queries.py | 406 +++++++++++------- agents-api/tests/test_entry_queries.py | 1 - agents-api/tests/test_files_queries.py | 4 +- agents-api/tests/test_session_queries.py | 1 - .../integrations/autogen/Docs.py | 24 ++ typespec/docs/models.tsp | 20 + .../@typespec/openapi3/openapi-1.0.0.yaml | 22 + 21 files changed, 454 insertions(+), 326 deletions(-) diff --git a/agents-api/agents_api/autogen/Docs.py b/agents-api/agents_api/autogen/Docs.py index ffed27c1d..af5f60d6a 100644 --- a/agents-api/agents_api/autogen/Docs.py +++ b/agents-api/agents_api/autogen/Docs.py @@ -73,6 +73,30 @@ class Doc(BaseModel): """ Embeddings for the document """ + modality: Annotated[str | None, Field(json_schema_extra={"readOnly": True})] = None + """ + Modality of the document + """ + language: Annotated[str | None, Field(json_schema_extra={"readOnly": True})] = None + """ + Language of the document + """ + index: Annotated[int | None, Field(json_schema_extra={"readOnly": True})] = None + """ + Index of the document + """ + embedding_model: Annotated[ + str | None, Field(json_schema_extra={"readOnly": True}) + ] = None + """ + Embedding model to use for the document + """ + embedding_dimensions: Annotated[ + int | None, Field(json_schema_extra={"readOnly": True}) + ] = None + """ + Dimensions of the embedding model + """ class DocOwner(BaseModel): diff --git a/agents-api/agents_api/queries/docs/__init__.py b/agents-api/agents_api/queries/docs/__init__.py index 0ba3db0d4..f7c207bf2 100644 --- a/agents-api/agents_api/queries/docs/__init__.py +++ b/agents-api/agents_api/queries/docs/__init__.py @@ -18,8 +18,15 @@ from .create_doc import create_doc from .delete_doc import delete_doc -from .embed_snippets import embed_snippets from .get_doc import get_doc from .list_docs import list_docs -from .search_docs_by_embedding import search_docs_by_embedding -from .search_docs_by_text import search_docs_by_text +# from .search_docs_by_embedding import search_docs_by_embedding +# from .search_docs_by_text import search_docs_by_text + +__all__ = [ + "create_doc", + "delete_doc", + "get_doc", + "list_docs", + # "search_docs_by_embct", +] diff --git a/agents-api/agents_api/queries/docs/create_doc.py b/agents-api/agents_api/queries/docs/create_doc.py index 57be43bdf..4528e9fc5 100644 --- a/agents-api/agents_api/queries/docs/create_doc.py +++ b/agents-api/agents_api/queries/docs/create_doc.py @@ -1,12 +1,4 @@ -""" -Timescale-based creation of docs. - -Mirrors the structure of create_file.py, but uses the docs/doc_owners tables. -""" - -import base64 -import hashlib -from typing import Any, Literal +from typing import Literal from uuid import UUID import asyncpg @@ -15,6 +7,9 @@ from sqlglot import parse_one from uuid_extensions import uuid7 +import ast + + from ...autogen.openapi_model import CreateDocRequest, Doc from ...metrics.counters import increase_counter from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class @@ -91,7 +86,7 @@ transform=lambda d: { **d, "id": d["doc_id"], - # You could optionally return a computed hash or partial content if desired + "content": ast.literal_eval(d["content"])[0] if len(ast.literal_eval(d["content"])) == 1 else ast.literal_eval(d["content"]), }, ) @increase_counter("create_doc") @@ -102,26 +97,35 @@ async def create_doc( developer_id: UUID, doc_id: UUID | None = None, data: CreateDocRequest, - owner_type: Literal["user", "agent", "org"] | None = None, + owner_type: Literal["user", "agent"] | None = None, owner_id: UUID | None = None, -) -> list[tuple[str, list]]: + modality: Literal["text", "image", "mixed"] | None = "text", + embedding_model: str | None = "voyage-3", + embedding_dimensions: int | None = 1024, + language: str | None = "english", + index: int | None = 0, +) -> list[tuple[str, list] | tuple[str, list, str]]: """ Insert a new doc record into Timescale and optionally associate it with an owner. """ # Generate a UUID if not provided doc_id = doc_id or uuid7() + # check if content is a string + if isinstance(data.content, str): + data.content = [data.content] + # Create the doc record doc_params = [ developer_id, doc_id, data.title, - data.content, - data.index or 0, # fallback if no snippet index - data.modality or "text", - data.embedding_model or "none", - data.embedding_dimensions or 0, - data.language or "english", + str(data.content), + index, + modality, + embedding_model, + embedding_dimensions, + language, data.metadata or {}, ] diff --git a/agents-api/agents_api/queries/docs/delete_doc.py b/agents-api/agents_api/queries/docs/delete_doc.py index 9d2075600..adeb09bd8 100644 --- a/agents-api/agents_api/queries/docs/delete_doc.py +++ b/agents-api/agents_api/queries/docs/delete_doc.py @@ -1,7 +1,3 @@ -""" -Timescale-based deletion of a doc record. -""" - from typing import Literal from uuid import UUID @@ -65,7 +61,7 @@ async def delete_doc( *, developer_id: UUID, doc_id: UUID, - owner_type: Literal["user", "agent", "org"] | None = None, + owner_type: Literal["user", "agent"] | None = None, owner_id: UUID | None = None, ) -> tuple[str, list]: """ diff --git a/agents-api/agents_api/queries/docs/get_doc.py b/agents-api/agents_api/queries/docs/get_doc.py index 35d692c84..9155f500a 100644 --- a/agents-api/agents_api/queries/docs/get_doc.py +++ b/agents-api/agents_api/queries/docs/get_doc.py @@ -1,14 +1,9 @@ -""" -Timescale-based retrieval of a single doc record. -""" - from typing import Literal from uuid import UUID -import asyncpg from beartype import beartype -from fastapi import HTTPException from sqlglot import parse_one +import ast from ...autogen.openapi_model import Doc from ..utils import pg_query, wrap_in_class @@ -16,12 +11,12 @@ doc_query = parse_one(""" SELECT d.* FROM docs d -LEFT JOIN doc_owners do ON d.developer_id = do.developer_id AND d.doc_id = do.doc_id +LEFT JOIN doc_owners doc_own ON d.developer_id = doc_own.developer_id AND d.doc_id = doc_own.doc_id WHERE d.developer_id = $1 AND d.doc_id = $2 AND ( ($3::text IS NULL AND $4::uuid IS NULL) - OR (do.owner_type = $3 AND do.owner_id = $4) + OR (doc_own.owner_type = $3 AND doc_own.owner_id = $4) ) LIMIT 1; """).sql(pretty=True) @@ -33,6 +28,8 @@ transform=lambda d: { **d, "id": d["doc_id"], + "content": ast.literal_eval(d["content"])[0] if len(ast.literal_eval(d["content"])) == 1 else ast.literal_eval(d["content"]), + # "embeddings": d["embeddings"], }, ) @pg_query @@ -41,7 +38,7 @@ async def get_doc( *, developer_id: UUID, doc_id: UUID, - owner_type: Literal["user", "agent", "org"] | None = None, + owner_type: Literal["user", "agent"] | None = None, owner_id: UUID | None = None, ) -> tuple[str, list]: """ diff --git a/agents-api/agents_api/queries/docs/list_docs.py b/agents-api/agents_api/queries/docs/list_docs.py index 678c1a5e6..a4df08e73 100644 --- a/agents-api/agents_api/queries/docs/list_docs.py +++ b/agents-api/agents_api/queries/docs/list_docs.py @@ -1,52 +1,20 @@ -""" -Timescale-based listing of docs with optional owner filter and pagination. -""" - -from typing import Literal +from typing import Any, Literal from uuid import UUID -import asyncpg from beartype import beartype from fastapi import HTTPException from sqlglot import parse_one +import ast from ...autogen.openapi_model import Doc from ..utils import pg_query, wrap_in_class -# Basic listing for all docs by developer -developer_docs_query = parse_one(""" +# Base query for listing docs +base_docs_query = parse_one(""" SELECT d.* FROM docs d -LEFT JOIN doc_owners do ON d.developer_id = do.developer_id AND d.doc_id = do.doc_id +LEFT JOIN doc_owners doc_own ON d.developer_id = doc_own.developer_id AND d.doc_id = doc_own.doc_id WHERE d.developer_id = $1 -ORDER BY -CASE - WHEN $4 = 'created_at' AND $5 = 'asc' THEN d.created_at - WHEN $4 = 'created_at' AND $5 = 'desc' THEN d.created_at - WHEN $4 = 'updated_at' AND $5 = 'asc' THEN d.updated_at - WHEN $4 = 'updated_at' AND $5 = 'desc' THEN d.updated_at -END DESC NULLS LAST -LIMIT $2 -OFFSET $3; -""").sql(pretty=True) - -# Listing for docs associated with a specific owner -owner_docs_query = parse_one(""" -SELECT d.* -FROM docs d -JOIN doc_owners do ON d.developer_id = do.developer_id AND d.doc_id = do.doc_id -WHERE do.developer_id = $1 - AND do.owner_id = $6 - AND do.owner_type = $7 -ORDER BY -CASE - WHEN $4 = 'created_at' AND $5 = 'asc' THEN d.created_at - WHEN $4 = 'created_at' AND $5 = 'desc' THEN d.created_at - WHEN $4 = 'updated_at' AND $5 = 'asc' THEN d.updated_at - WHEN $4 = 'updated_at' AND $5 = 'desc' THEN d.updated_at -END DESC NULLS LAST -LIMIT $2 -OFFSET $3; """).sql(pretty=True) @@ -56,6 +24,8 @@ transform=lambda d: { **d, "id": d["doc_id"], + "content": ast.literal_eval(d["content"])[0] if len(ast.literal_eval(d["content"])) == 1 else ast.literal_eval(d["content"]), + # "embeddings": d["embeddings"], }, ) @pg_query @@ -64,11 +34,13 @@ async def list_docs( *, developer_id: UUID, owner_id: UUID | None = None, - owner_type: Literal["user", "agent", "org"] | None = None, + owner_type: Literal["user", "agent"] | None = None, limit: int = 100, offset: int = 0, sort_by: Literal["created_at", "updated_at"] = "created_at", direction: Literal["asc", "desc"] = "desc", + metadata_filter: dict[str, Any] = {}, + include_without_embeddings: bool = False, ) -> tuple[str, list]: """ Lists docs with optional owner filtering, pagination, and sorting. @@ -76,17 +48,36 @@ async def list_docs( if direction.lower() not in ["asc", "desc"]: raise HTTPException(status_code=400, detail="Invalid sort direction") + if sort_by not in ["created_at", "updated_at"]: + raise HTTPException(status_code=400, detail="Invalid sort field") + if limit > 100 or limit < 1: raise HTTPException(status_code=400, detail="Limit must be between 1 and 100") if offset < 0: raise HTTPException(status_code=400, detail="Offset must be >= 0") - params = [developer_id, limit, offset, sort_by, direction] - if owner_id and owner_type: - params.extend([owner_id, owner_type]) - query = owner_docs_query - else: - query = developer_docs_query + # Start with the base query + query = base_docs_query + params = [developer_id] + + # Add owner filtering + if owner_type and owner_id: + query += " AND doc_own.owner_type = $2 AND doc_own.owner_id = $3" + params.extend([owner_type, owner_id]) + + # Add metadata filtering + if metadata_filter: + for key, value in metadata_filter.items(): + query += f" AND d.metadata->>'{key}' = ${len(params) + 1}" + params.append(value) + + # Include or exclude documents without embeddings + # if not include_without_embeddings: + # query += " AND d.embeddings IS NOT NULL" + + # Add sorting and pagination + query += f" ORDER BY {sort_by} {direction} LIMIT ${len(params) + 1} OFFSET ${len(params) + 2}" + params.extend([limit, offset]) - return (query, params) + return query, params diff --git a/agents-api/agents_api/queries/docs/search_docs_by_embedding.py b/agents-api/agents_api/queries/docs/search_docs_by_embedding.py index af89cc1b8..e3120bd36 100644 --- a/agents-api/agents_api/queries/docs/search_docs_by_embedding.py +++ b/agents-api/agents_api/queries/docs/search_docs_by_embedding.py @@ -5,7 +5,6 @@ from typing import List, Literal from uuid import UUID -import asyncpg from beartype import beartype from fastapi import HTTPException from sqlglot import parse_one diff --git a/agents-api/agents_api/queries/docs/search_docs_by_text.py b/agents-api/agents_api/queries/docs/search_docs_by_text.py index eed74e54b..9f434d438 100644 --- a/agents-api/agents_api/queries/docs/search_docs_by_text.py +++ b/agents-api/agents_api/queries/docs/search_docs_by_text.py @@ -5,7 +5,6 @@ from typing import Literal from uuid import UUID -import asyncpg from beartype import beartype from fastapi import HTTPException from sqlglot import parse_one @@ -22,7 +21,7 @@ AND d.doc_id = do.doc_id WHERE d.developer_id = $1 AND ( - ($4::text IS NULL AND $5::uuid IS NULL) + ($4 IS NULL AND $5 IS NULL) OR (do.owner_type = $4 AND do.owner_id = $5) ) AND d.search_tsv @@ websearch_to_tsquery($3) diff --git a/agents-api/agents_api/queries/docs/search_docs_hybrid.py b/agents-api/agents_api/queries/docs/search_docs_hybrid.py index ae107419d..a879e3b6b 100644 --- a/agents-api/agents_api/queries/docs/search_docs_hybrid.py +++ b/agents-api/agents_api/queries/docs/search_docs_hybrid.py @@ -7,10 +7,8 @@ from uuid import UUID from beartype import beartype -from fastapi import HTTPException from ...autogen.openapi_model import Doc -from ..utils import run_concurrently from .search_docs_by_embedding import search_docs_by_embedding from .search_docs_by_text import search_docs_by_text diff --git a/agents-api/agents_api/queries/entries/get_history.py b/agents-api/agents_api/queries/entries/get_history.py index e6967a6cc..ffa0746c0 100644 --- a/agents-api/agents_api/queries/entries/get_history.py +++ b/agents-api/agents_api/queries/entries/get_history.py @@ -1,5 +1,4 @@ import json -from typing import Any, List, Tuple from uuid import UUID import asyncpg diff --git a/agents-api/agents_api/queries/files/get_file.py b/agents-api/agents_api/queries/files/get_file.py index 4d5dca4c0..5ccb08d86 100644 --- a/agents-api/agents_api/queries/files/get_file.py +++ b/agents-api/agents_api/queries/files/get_file.py @@ -6,13 +6,11 @@ from typing import Literal from uuid import UUID -import asyncpg from beartype import beartype -from fastapi import HTTPException from sqlglot import parse_one from ...autogen.openapi_model import File -from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class +from ..utils import pg_query, wrap_in_class # Define the raw SQL query file_query = parse_one(""" @@ -47,8 +45,8 @@ File, one=True, transform=lambda d: { - "id": d["file_id"], **d, + "id": d["file_id"], "hash": d["hash"].hex(), "content": "DUMMY: NEED TO FETCH CONTENT FROM BLOB STORAGE", }, diff --git a/agents-api/agents_api/queries/files/list_files.py b/agents-api/agents_api/queries/files/list_files.py index 2bc42f842..7c8b67887 100644 --- a/agents-api/agents_api/queries/files/list_files.py +++ b/agents-api/agents_api/queries/files/list_files.py @@ -3,51 +3,21 @@ It constructs and executes SQL queries to fetch a list of files based on developer ID with pagination. """ -from typing import Any, Literal +from typing import Literal from uuid import UUID -import asyncpg from beartype import beartype from fastapi import HTTPException from sqlglot import parse_one - from ...autogen.openapi_model import File -from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class +from ..utils import pg_query, wrap_in_class -# Query to list all files for a developer (uses developer_id index) -developer_files_query = parse_one(""" +# Base query for listing files +base_files_query = parse_one(""" SELECT f.* FROM files f LEFT JOIN file_owners fo ON f.developer_id = fo.developer_id AND f.file_id = fo.file_id WHERE f.developer_id = $1 -ORDER BY - CASE - WHEN $4 = 'created_at' AND $5 = 'asc' THEN created_at - WHEN $4 = 'created_at' AND $5 = 'desc' THEN created_at - WHEN $4 = 'updated_at' AND $5 = 'asc' THEN updated_at - WHEN $4 = 'updated_at' AND $5 = 'desc' THEN updated_at - END DESC NULLS LAST -LIMIT $2 -OFFSET $3; -""").sql(pretty=True) - -# Query to list files for a specific owner (uses composite indexes) -owner_files_query = parse_one(""" -SELECT f.* -FROM files f -JOIN file_owners fo ON f.developer_id = fo.developer_id AND f.file_id = fo.file_id -WHERE fo.developer_id = $1 -AND fo.owner_id = $6 -AND fo.owner_type = $7 -ORDER BY - CASE - WHEN $4 = 'created_at' AND $5 = 'asc' THEN created_at - WHEN $4 = 'created_at' AND $5 = 'desc' THEN created_at - WHEN $4 = 'updated_at' AND $5 = 'asc' THEN updated_at - WHEN $4 = 'updated_at' AND $5 = 'desc' THEN updated_at - END DESC NULLS LAST -LIMIT $2 -OFFSET $3; """).sql(pretty=True) @@ -74,49 +44,32 @@ async def list_files( direction: Literal["asc", "desc"] = "desc", ) -> tuple[str, list]: """ - Lists files with optimized queries for two cases: - 1. Owner specified: Returns files associated with that owner - 2. No owner: Returns all files for the developer - - Args: - developer_id: UUID of the developer - owner_id: Optional UUID of the owner (user or agent) - owner_type: Optional type of owner ("user" or "agent") - limit: Maximum number of records to return (1-100) - offset: Number of records to skip - sort_by: Field to sort by - direction: Sort direction ('asc' or 'desc') - - Returns: - Tuple of (query, params) - - Raises: - HTTPException: If parameters are invalid + Lists files with optional owner filtering, pagination, and sorting. """ # Validate parameters if direction.lower() not in ["asc", "desc"]: raise HTTPException(status_code=400, detail="Invalid sort direction") + if sort_by not in ["created_at", "updated_at"]: + raise HTTPException(status_code=400, detail="Invalid sort field") + if limit > 100 or limit < 1: raise HTTPException(status_code=400, detail="Limit must be between 1 and 100") if offset < 0: raise HTTPException(status_code=400, detail="Offset must be non-negative") - # Base parameters used in all queries - params = [ - developer_id, - limit, - offset, - sort_by, - direction, - ] + # Start with the base query + query = base_files_query + params = [developer_id] + + # Add owner filtering + if owner_type and owner_id: + query += " AND fo.owner_type = $2 AND fo.owner_id = $3" + params.extend([owner_type, owner_id]) - # Choose appropriate query based on owner details - if owner_id and owner_type: - params.extend([owner_id, owner_type]) # Add owner_id as $6 and owner_type as $7 - query = owner_files_query # Use single query with owner_type parameter - else: - query = developer_files_query + # Add sorting and pagination + query += f" ORDER BY {sort_by} {direction} LIMIT ${len(params) + 1} OFFSET ${len(params) + 2}" + params.extend([limit, offset]) - return (query, params) + return query, params diff --git a/agents-api/agents_api/queries/sessions/create_session.py b/agents-api/agents_api/queries/sessions/create_session.py index 63fbdc940..058462cf8 100644 --- a/agents-api/agents_api/queries/sessions/create_session.py +++ b/agents-api/agents_api/queries/sessions/create_session.py @@ -8,10 +8,8 @@ from ...autogen.openapi_model import ( CreateSessionRequest, - ResourceCreatedResponse, Session, ) -from ...common.utils.datetime import utcnow from ...metrics.counters import increase_counter from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class diff --git a/agents-api/tests/fixtures.py b/agents-api/tests/fixtures.py index 286fd10fb..6689137d7 100644 --- a/agents-api/tests/fixtures.py +++ b/agents-api/tests/fixtures.py @@ -1,6 +1,5 @@ import random import string -import time from uuid import UUID from fastapi.testclient import TestClient @@ -12,6 +11,7 @@ CreateFileRequest, CreateSessionRequest, CreateUserRequest, + CreateDocRequest, ) from agents_api.clients.pg import create_db_pool from agents_api.env import api_key, api_key_header_name, multi_tenant_mode @@ -21,7 +21,8 @@ # from agents_api.queries.agents.delete_agent import delete_agent from agents_api.queries.developers.get_developer import get_developer -# from agents_api.queries.docs.create_doc import create_doc +from agents_api.queries.docs.create_doc import create_doc + # from agents_api.queries.docs.delete_doc import delete_doc # from agents_api.queries.execution.create_execution import create_execution # from agents_api.queries.execution.create_execution_transition import create_execution_transition @@ -149,6 +150,22 @@ async def test_file(dsn=pg_dsn, developer=test_developer, user=test_user): return file +@fixture(scope="test") +async def test_doc(dsn=pg_dsn, developer=test_developer): + pool = await create_db_pool(dsn=dsn) + doc = await create_doc( + developer_id=developer.id, + data=CreateDocRequest( + title="Hello", + content=["World"], + metadata={"test": "test"}, + embed_instruction="Embed the document", + ), + connection_pool=pool, + ) + return doc + + @fixture(scope="test") async def random_email(): return f"{"".join([random.choice(string.ascii_lowercase) for _ in range(10)])}@mail.com" diff --git a/agents-api/tests/test_docs_queries.py b/agents-api/tests/test_docs_queries.py index f2ff2c786..d6af42e57 100644 --- a/agents-api/tests/test_docs_queries.py +++ b/agents-api/tests/test_docs_queries.py @@ -1,163 +1,249 @@ -# # Tests for entry queries +from ward import test -# import asyncio +from agents_api.autogen.openapi_model import CreateDocRequest +from agents_api.clients.pg import create_db_pool +from agents_api.queries.docs.create_doc import create_doc +from agents_api.queries.docs.delete_doc import delete_doc +from agents_api.queries.docs.get_doc import get_doc +from agents_api.queries.docs.list_docs import list_docs -# from ward import test - -# from agents_api.autogen.openapi_model import CreateDocRequest -# from agents_api.queries.docs.create_doc import create_doc -# from agents_api.queries.docs.delete_doc import delete_doc -# from agents_api.queries.docs.embed_snippets import embed_snippets -# from agents_api.queries.docs.get_doc import get_doc -# from agents_api.queries.docs.list_docs import list_docs -# from agents_api.queries.docs.search_docs_by_embedding import search_docs_by_embedding +# If you wish to test text/embedding/hybrid search, import them: # from agents_api.queries.docs.search_docs_by_text import search_docs_by_text -# from tests.fixtures import ( -# EMBEDDING_SIZE, -# cozo_client, -# test_agent, -# test_developer_id, -# test_doc, -# test_user, -# ) - - -# @test("query: create docs") -# def _( -# client=cozo_client, developer_id=test_developer_id, agent=test_agent, user=test_user -# ): -# create_doc( -# developer_id=developer_id, -# owner_type="agent", -# owner_id=agent.id, -# data=CreateDocRequest(title="Hello", content=["World"]), -# client=client, -# ) - -# create_doc( -# developer_id=developer_id, -# owner_type="user", -# owner_id=user.id, -# data=CreateDocRequest(title="Hello", content=["World"]), -# client=client, -# ) - - -# @test("query: get docs") -# def _(client=cozo_client, doc=test_doc, developer_id=test_developer_id): -# get_doc( -# developer_id=developer_id, -# doc_id=doc.id, -# client=client, -# ) - - -# @test("query: delete doc") -# def _(client=cozo_client, developer_id=test_developer_id, agent=test_agent): -# doc = create_doc( -# developer_id=developer_id, -# owner_type="agent", -# owner_id=agent.id, -# data=CreateDocRequest(title="Hello", content=["World"]), -# client=client, -# ) - -# delete_doc( -# developer_id=developer_id, -# doc_id=doc.id, -# owner_type="agent", -# owner_id=agent.id, -# client=client, -# ) - - -# @test("query: list docs") -# def _( -# client=cozo_client, developer_id=test_developer_id, doc=test_doc, agent=test_agent -# ): -# result = list_docs( -# developer_id=developer_id, -# owner_type="agent", -# owner_id=agent.id, -# client=client, -# include_without_embeddings=True, -# ) - -# assert len(result) >= 1 - - -# @test("query: search docs by text") -# async def _(client=cozo_client, agent=test_agent, developer_id=test_developer_id): -# create_doc( -# developer_id=developer_id, -# owner_type="agent", -# owner_id=agent.id, -# data=CreateDocRequest( -# title="Hello", content=["The world is a funny little thing"] -# ), -# client=client, -# ) - -# await asyncio.sleep(1) - -# result = search_docs_by_text( -# developer_id=developer_id, -# owners=[("agent", agent.id)], -# query="funny", -# client=client, -# ) - -# assert len(result) >= 1 -# assert result[0].metadata is not None - - -# @test("query: search docs by embedding") -# async def _(client=cozo_client, agent=test_agent, developer_id=test_developer_id): -# doc = create_doc( -# developer_id=developer_id, -# owner_type="agent", -# owner_id=agent.id, -# data=CreateDocRequest(title="Hello", content=["World"]), -# client=client, -# ) - -# ### Add embedding to the snippet -# embed_snippets( -# developer_id=developer_id, -# doc_id=doc.id, -# snippet_indices=[0], -# embeddings=[[1.0] * EMBEDDING_SIZE], -# client=client, -# ) - -# await asyncio.sleep(1) - -# ### Search -# query_embedding = [0.99] * EMBEDDING_SIZE - -# result = search_docs_by_embedding( -# developer_id=developer_id, -# owners=[("agent", agent.id)], -# query_embedding=query_embedding, -# client=client, -# ) - -# assert len(result) >= 1 -# assert result[0].metadata is not None - - -# @test("query: embed snippets") -# def _(client=cozo_client, developer_id=test_developer_id, doc=test_doc): -# snippet_indices = [0] -# embeddings = [[1.0] * EMBEDDING_SIZE] - -# result = embed_snippets( -# developer_id=developer_id, -# doc_id=doc.id, -# snippet_indices=snippet_indices, -# embeddings=embeddings, -# client=client, -# ) - -# assert result is not None -# assert result.id == doc.id +# from agents_api.queries.docs.search_docs_by_embedding import search_docs_by_embedding +# from agents_api.queries.docs.search_docs_hybrid import search_docs_hybrid + +# You can rename or remove these imports to match your actual fixtures +from tests.fixtures import pg_dsn, test_agent, test_developer, test_user, test_doc + + +@test("query: create doc") +async def _(dsn=pg_dsn, developer=test_developer): + pool = await create_db_pool(dsn=dsn) + doc = await create_doc( + developer_id=developer.id, + data=CreateDocRequest( + title="Hello Doc", + content="This is sample doc content", + embed_instruction="Embed the document", + metadata={"test": "test"}, + ), + connection_pool=pool, + ) + + assert doc.title == "Hello Doc" + assert doc.content == "This is sample doc content" + assert doc.modality == "text" + assert doc.embedding_model == "voyage-3" + assert doc.embedding_dimensions == 1024 + assert doc.language == "english" + assert doc.index == 0 + +@test("query: create user doc") +async def _(dsn=pg_dsn, developer=test_developer, user=test_user): + pool = await create_db_pool(dsn=dsn) + doc = await create_doc( + developer_id=developer.id, + data=CreateDocRequest( + title="User Doc", + content="Docs for user testing", + metadata={"test": "test"}, + embed_instruction="Embed the document", + ), + owner_type="user", + owner_id=user.id, + connection_pool=pool, + ) + assert doc.title == "User Doc" + + # Verify doc appears in user's docs + docs_list = await list_docs( + developer_id=developer.id, + owner_type="user", + owner_id=user.id, + connection_pool=pool, + ) + assert any(d.id == doc.id for d in docs_list) + +@test("query: create agent doc") +async def _(dsn=pg_dsn, developer=test_developer, agent=test_agent): + pool = await create_db_pool(dsn=dsn) + doc = await create_doc( + developer_id=developer.id, + data=CreateDocRequest( + title="Agent Doc", + content="Docs for agent testing", + metadata={"test": "test"}, + embed_instruction="Embed the document", + ), + owner_type="agent", + owner_id=agent.id, + connection_pool=pool, + ) + assert doc.title == "Agent Doc" + + # Verify doc appears in agent's docs + docs_list = await list_docs( + developer_id=developer.id, + owner_type="agent", + owner_id=agent.id, + connection_pool=pool, + ) + assert any(d.id == doc.id for d in docs_list) + +@test("model: get doc") +async def _(dsn=pg_dsn, developer=test_developer, doc=test_doc): + pool = await create_db_pool(dsn=dsn) + doc_test = await get_doc( + developer_id=developer.id, + doc_id=doc.id, + connection_pool=pool, + ) + assert doc_test.id == doc.id + assert doc_test.title == doc.title + +@test("query: list docs") +async def _(dsn=pg_dsn, developer=test_developer, doc=test_doc): + pool = await create_db_pool(dsn=dsn) + docs_list = await list_docs( + developer_id=developer.id, + connection_pool=pool, + ) + assert len(docs_list) >= 1 + assert any(d.id == doc.id for d in docs_list) + +@test("query: list user docs") +async def _(dsn=pg_dsn, developer=test_developer, user=test_user): + pool = await create_db_pool(dsn=dsn) + + # Create a doc owned by the user + doc_user = await create_doc( + developer_id=developer.id, + data=CreateDocRequest( + title="User List Test", + content="Some user doc content", + metadata={"test": "test"}, + embed_instruction="Embed the document", + ), + owner_type="user", + owner_id=user.id, + connection_pool=pool, + ) + + # List user's docs + docs_list = await list_docs( + developer_id=developer.id, + owner_type="user", + owner_id=user.id, + connection_pool=pool, + ) + assert len(docs_list) >= 1 + assert any(d.id == doc_user.id for d in docs_list) + +@test("query: list agent docs") +async def _(dsn=pg_dsn, developer=test_developer, agent=test_agent): + pool = await create_db_pool(dsn=dsn) + + # Create a doc owned by the agent + doc_agent = await create_doc( + developer_id=developer.id, + data=CreateDocRequest( + title="Agent List Test", + content="Some agent doc content", + metadata={"test": "test"}, + embed_instruction="Embed the document", + ), + owner_type="agent", + owner_id=agent.id, + connection_pool=pool, + ) + + # List agent's docs + docs_list = await list_docs( + developer_id=developer.id, + owner_type="agent", + owner_id=agent.id, + connection_pool=pool, + ) + assert len(docs_list) >= 1 + assert any(d.id == doc_agent.id for d in docs_list) + +@test("query: delete user doc") +async def _(dsn=pg_dsn, developer=test_developer, user=test_user): + pool = await create_db_pool(dsn=dsn) + + # Create a doc owned by the user + doc_user = await create_doc( + developer_id=developer.id, + data=CreateDocRequest( + title="User Delete Test", + content="Doc for user deletion test", + metadata={"test": "test"}, + embed_instruction="Embed the document", + ), + owner_type="user", + owner_id=user.id, + connection_pool=pool, + ) + + # Delete the doc + await delete_doc( + developer_id=developer.id, + doc_id=doc_user.id, + owner_type="user", + owner_id=user.id, + connection_pool=pool, + ) + + # Verify doc is no longer in user's docs + docs_list = await list_docs( + developer_id=developer.id, + owner_type="user", + owner_id=user.id, + connection_pool=pool, + ) + assert not any(d.id == doc_user.id for d in docs_list) + +@test("query: delete agent doc") +async def _(dsn=pg_dsn, developer=test_developer, agent=test_agent): + pool = await create_db_pool(dsn=dsn) + + # Create a doc owned by the agent + doc_agent = await create_doc( + developer_id=developer.id, + data=CreateDocRequest( + title="Agent Delete Test", + content="Doc for agent deletion test", + metadata={"test": "test"}, + embed_instruction="Embed the document", + ), + owner_type="agent", + owner_id=agent.id, + connection_pool=pool, + ) + + # Delete the doc + await delete_doc( + developer_id=developer.id, + doc_id=doc_agent.id, + owner_type="agent", + owner_id=agent.id, + connection_pool=pool, + ) + + # Verify doc is no longer in agent's docs + docs_list = await list_docs( + developer_id=developer.id, + owner_type="agent", + owner_id=agent.id, + connection_pool=pool, + ) + assert not any(d.id == doc_agent.id for d in docs_list) + +@test("query: delete doc") +async def _(dsn=pg_dsn, developer=test_developer, doc=test_doc): + pool = await create_db_pool(dsn=dsn) + await delete_doc( + developer_id=developer.id, + doc_id=doc.id, + connection_pool=pool, + ) diff --git a/agents-api/tests/test_entry_queries.py b/agents-api/tests/test_entry_queries.py index 706185c7b..2a9746ef1 100644 --- a/agents-api/tests/test_entry_queries.py +++ b/agents-api/tests/test_entry_queries.py @@ -3,7 +3,6 @@ It verifies the functionality of adding, retrieving, and processing entries as defined in the schema. """ -from uuid import UUID from fastapi import HTTPException from uuid_extensions import uuid7 diff --git a/agents-api/tests/test_files_queries.py b/agents-api/tests/test_files_queries.py index 92b52d733..c83c7a6f6 100644 --- a/agents-api/tests/test_files_queries.py +++ b/agents-api/tests/test_files_queries.py @@ -1,9 +1,7 @@ # # Tests for entry queries -from fastapi import HTTPException -from uuid_extensions import uuid7 -from ward import raises, test +from ward import test from agents_api.autogen.openapi_model import CreateFileRequest from agents_api.clients.pg import create_db_pool diff --git a/agents-api/tests/test_session_queries.py b/agents-api/tests/test_session_queries.py index 171e56aa8..4673d6fc5 100644 --- a/agents-api/tests/test_session_queries.py +++ b/agents-api/tests/test_session_queries.py @@ -10,7 +10,6 @@ CreateOrUpdateSessionRequest, CreateSessionRequest, PatchSessionRequest, - ResourceCreatedResponse, ResourceDeletedResponse, ResourceUpdatedResponse, Session, diff --git a/integrations-service/integrations/autogen/Docs.py b/integrations-service/integrations/autogen/Docs.py index ffed27c1d..af5f60d6a 100644 --- a/integrations-service/integrations/autogen/Docs.py +++ b/integrations-service/integrations/autogen/Docs.py @@ -73,6 +73,30 @@ class Doc(BaseModel): """ Embeddings for the document """ + modality: Annotated[str | None, Field(json_schema_extra={"readOnly": True})] = None + """ + Modality of the document + """ + language: Annotated[str | None, Field(json_schema_extra={"readOnly": True})] = None + """ + Language of the document + """ + index: Annotated[int | None, Field(json_schema_extra={"readOnly": True})] = None + """ + Index of the document + """ + embedding_model: Annotated[ + str | None, Field(json_schema_extra={"readOnly": True}) + ] = None + """ + Embedding model to use for the document + """ + embedding_dimensions: Annotated[ + int | None, Field(json_schema_extra={"readOnly": True}) + ] = None + """ + Dimensions of the embedding model + """ class DocOwner(BaseModel): diff --git a/typespec/docs/models.tsp b/typespec/docs/models.tsp index 055fc2003..f4d16cbd5 100644 --- a/typespec/docs/models.tsp +++ b/typespec/docs/models.tsp @@ -27,6 +27,26 @@ model Doc { /** Embeddings for the document */ @visibility("read") embeddings?: float32[] | float32[][]; + + @visibility("read") + /** Modality of the document */ + modality?: string; + + @visibility("read") + /** Language of the document */ + language?: string; + + @visibility("read") + /** Index of the document */ + index?: uint16; + + @visibility("read") + /** Embedding model to use for the document */ + embedding_model?: string; + + @visibility("read") + /** Dimensions of the embedding model */ + embedding_dimensions?: uint16; } /** Payload for creating a doc */ diff --git a/typespec/tsp-output/@typespec/openapi3/openapi-1.0.0.yaml b/typespec/tsp-output/@typespec/openapi3/openapi-1.0.0.yaml index d4835a695..c19bc4ed2 100644 --- a/typespec/tsp-output/@typespec/openapi3/openapi-1.0.0.yaml +++ b/typespec/tsp-output/@typespec/openapi3/openapi-1.0.0.yaml @@ -2876,6 +2876,28 @@ components: format: float description: Embeddings for the document readOnly: true + modality: + type: string + description: Modality of the document + readOnly: true + language: + type: string + description: Language of the document + readOnly: true + index: + type: integer + format: uint16 + description: Index of the document + readOnly: true + embedding_model: + type: string + description: Embedding model to use for the document + readOnly: true + embedding_dimensions: + type: integer + format: uint16 + description: Dimensions of the embedding model + readOnly: true Docs.DocOwner: type: object required: From 7b0be5c5ae15d7c8b2b6d34689b746278c79fdb4 Mon Sep 17 00:00:00 2001 From: Vedantsahai18 Date: Fri, 20 Dec 2024 19:44:02 +0000 Subject: [PATCH 04/10] refactor: Lint agents-api (CI) --- agents-api/agents_api/queries/docs/__init__.py | 1 + agents-api/agents_api/queries/docs/create_doc.py | 8 ++++---- agents-api/agents_api/queries/docs/get_doc.py | 6 ++++-- agents-api/agents_api/queries/docs/list_docs.py | 6 ++++-- agents-api/agents_api/queries/files/list_files.py | 1 + agents-api/tests/fixtures.py | 3 +-- agents-api/tests/test_docs_queries.py | 14 +++++++++++--- agents-api/tests/test_entry_queries.py | 1 - 8 files changed, 26 insertions(+), 14 deletions(-) diff --git a/agents-api/agents_api/queries/docs/__init__.py b/agents-api/agents_api/queries/docs/__init__.py index f7c207bf2..75f9516a6 100644 --- a/agents-api/agents_api/queries/docs/__init__.py +++ b/agents-api/agents_api/queries/docs/__init__.py @@ -20,6 +20,7 @@ from .delete_doc import delete_doc from .get_doc import get_doc from .list_docs import list_docs + # from .search_docs_by_embedding import search_docs_by_embedding # from .search_docs_by_text import search_docs_by_text diff --git a/agents-api/agents_api/queries/docs/create_doc.py b/agents-api/agents_api/queries/docs/create_doc.py index 4528e9fc5..bf789fad2 100644 --- a/agents-api/agents_api/queries/docs/create_doc.py +++ b/agents-api/agents_api/queries/docs/create_doc.py @@ -1,3 +1,4 @@ +import ast from typing import Literal from uuid import UUID @@ -7,9 +8,6 @@ from sqlglot import parse_one from uuid_extensions import uuid7 -import ast - - from ...autogen.openapi_model import CreateDocRequest, Doc from ...metrics.counters import increase_counter from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class @@ -86,7 +84,9 @@ transform=lambda d: { **d, "id": d["doc_id"], - "content": ast.literal_eval(d["content"])[0] if len(ast.literal_eval(d["content"])) == 1 else ast.literal_eval(d["content"]), + "content": ast.literal_eval(d["content"])[0] + if len(ast.literal_eval(d["content"])) == 1 + else ast.literal_eval(d["content"]), }, ) @increase_counter("create_doc") diff --git a/agents-api/agents_api/queries/docs/get_doc.py b/agents-api/agents_api/queries/docs/get_doc.py index 9155f500a..b46563dbb 100644 --- a/agents-api/agents_api/queries/docs/get_doc.py +++ b/agents-api/agents_api/queries/docs/get_doc.py @@ -1,9 +1,9 @@ +import ast from typing import Literal from uuid import UUID from beartype import beartype from sqlglot import parse_one -import ast from ...autogen.openapi_model import Doc from ..utils import pg_query, wrap_in_class @@ -28,7 +28,9 @@ transform=lambda d: { **d, "id": d["doc_id"], - "content": ast.literal_eval(d["content"])[0] if len(ast.literal_eval(d["content"])) == 1 else ast.literal_eval(d["content"]), + "content": ast.literal_eval(d["content"])[0] + if len(ast.literal_eval(d["content"])) == 1 + else ast.literal_eval(d["content"]), # "embeddings": d["embeddings"], }, ) diff --git a/agents-api/agents_api/queries/docs/list_docs.py b/agents-api/agents_api/queries/docs/list_docs.py index a4df08e73..92cbacf7f 100644 --- a/agents-api/agents_api/queries/docs/list_docs.py +++ b/agents-api/agents_api/queries/docs/list_docs.py @@ -1,10 +1,10 @@ +import ast from typing import Any, Literal from uuid import UUID from beartype import beartype from fastapi import HTTPException from sqlglot import parse_one -import ast from ...autogen.openapi_model import Doc from ..utils import pg_query, wrap_in_class @@ -24,7 +24,9 @@ transform=lambda d: { **d, "id": d["doc_id"], - "content": ast.literal_eval(d["content"])[0] if len(ast.literal_eval(d["content"])) == 1 else ast.literal_eval(d["content"]), + "content": ast.literal_eval(d["content"])[0] + if len(ast.literal_eval(d["content"])) == 1 + else ast.literal_eval(d["content"]), # "embeddings": d["embeddings"], }, ) diff --git a/agents-api/agents_api/queries/files/list_files.py b/agents-api/agents_api/queries/files/list_files.py index 7c8b67887..2f36def4f 100644 --- a/agents-api/agents_api/queries/files/list_files.py +++ b/agents-api/agents_api/queries/files/list_files.py @@ -9,6 +9,7 @@ from beartype import beartype from fastapi import HTTPException from sqlglot import parse_one + from ...autogen.openapi_model import File from ..utils import pg_query, wrap_in_class diff --git a/agents-api/tests/fixtures.py b/agents-api/tests/fixtures.py index 6689137d7..2f7de580e 100644 --- a/agents-api/tests/fixtures.py +++ b/agents-api/tests/fixtures.py @@ -8,10 +8,10 @@ from agents_api.autogen.openapi_model import ( CreateAgentRequest, + CreateDocRequest, CreateFileRequest, CreateSessionRequest, CreateUserRequest, - CreateDocRequest, ) from agents_api.clients.pg import create_db_pool from agents_api.env import api_key, api_key_header_name, multi_tenant_mode @@ -20,7 +20,6 @@ # from agents_api.queries.agents.delete_agent import delete_agent from agents_api.queries.developers.get_developer import get_developer - from agents_api.queries.docs.create_doc import create_doc # from agents_api.queries.docs.delete_doc import delete_doc diff --git a/agents-api/tests/test_docs_queries.py b/agents-api/tests/test_docs_queries.py index d6af42e57..1410c88c9 100644 --- a/agents-api/tests/test_docs_queries.py +++ b/agents-api/tests/test_docs_queries.py @@ -11,9 +11,8 @@ # from agents_api.queries.docs.search_docs_by_text import search_docs_by_text # from agents_api.queries.docs.search_docs_by_embedding import search_docs_by_embedding # from agents_api.queries.docs.search_docs_hybrid import search_docs_hybrid - # You can rename or remove these imports to match your actual fixtures -from tests.fixtures import pg_dsn, test_agent, test_developer, test_user, test_doc +from tests.fixtures import pg_dsn, test_agent, test_developer, test_doc, test_user @test("query: create doc") @@ -29,7 +28,7 @@ async def _(dsn=pg_dsn, developer=test_developer): ), connection_pool=pool, ) - + assert doc.title == "Hello Doc" assert doc.content == "This is sample doc content" assert doc.modality == "text" @@ -38,6 +37,7 @@ async def _(dsn=pg_dsn, developer=test_developer): assert doc.language == "english" assert doc.index == 0 + @test("query: create user doc") async def _(dsn=pg_dsn, developer=test_developer, user=test_user): pool = await create_db_pool(dsn=dsn) @@ -64,6 +64,7 @@ async def _(dsn=pg_dsn, developer=test_developer, user=test_user): ) assert any(d.id == doc.id for d in docs_list) + @test("query: create agent doc") async def _(dsn=pg_dsn, developer=test_developer, agent=test_agent): pool = await create_db_pool(dsn=dsn) @@ -90,6 +91,7 @@ async def _(dsn=pg_dsn, developer=test_developer, agent=test_agent): ) assert any(d.id == doc.id for d in docs_list) + @test("model: get doc") async def _(dsn=pg_dsn, developer=test_developer, doc=test_doc): pool = await create_db_pool(dsn=dsn) @@ -101,6 +103,7 @@ async def _(dsn=pg_dsn, developer=test_developer, doc=test_doc): assert doc_test.id == doc.id assert doc_test.title == doc.title + @test("query: list docs") async def _(dsn=pg_dsn, developer=test_developer, doc=test_doc): pool = await create_db_pool(dsn=dsn) @@ -111,6 +114,7 @@ async def _(dsn=pg_dsn, developer=test_developer, doc=test_doc): assert len(docs_list) >= 1 assert any(d.id == doc.id for d in docs_list) + @test("query: list user docs") async def _(dsn=pg_dsn, developer=test_developer, user=test_user): pool = await create_db_pool(dsn=dsn) @@ -139,6 +143,7 @@ async def _(dsn=pg_dsn, developer=test_developer, user=test_user): assert len(docs_list) >= 1 assert any(d.id == doc_user.id for d in docs_list) + @test("query: list agent docs") async def _(dsn=pg_dsn, developer=test_developer, agent=test_agent): pool = await create_db_pool(dsn=dsn) @@ -167,6 +172,7 @@ async def _(dsn=pg_dsn, developer=test_developer, agent=test_agent): assert len(docs_list) >= 1 assert any(d.id == doc_agent.id for d in docs_list) + @test("query: delete user doc") async def _(dsn=pg_dsn, developer=test_developer, user=test_user): pool = await create_db_pool(dsn=dsn) @@ -203,6 +209,7 @@ async def _(dsn=pg_dsn, developer=test_developer, user=test_user): ) assert not any(d.id == doc_user.id for d in docs_list) + @test("query: delete agent doc") async def _(dsn=pg_dsn, developer=test_developer, agent=test_agent): pool = await create_db_pool(dsn=dsn) @@ -239,6 +246,7 @@ async def _(dsn=pg_dsn, developer=test_developer, agent=test_agent): ) assert not any(d.id == doc_agent.id for d in docs_list) + @test("query: delete doc") async def _(dsn=pg_dsn, developer=test_developer, doc=test_doc): pool = await create_db_pool(dsn=dsn) diff --git a/agents-api/tests/test_entry_queries.py b/agents-api/tests/test_entry_queries.py index 2a9746ef1..ae825ed92 100644 --- a/agents-api/tests/test_entry_queries.py +++ b/agents-api/tests/test_entry_queries.py @@ -3,7 +3,6 @@ It verifies the functionality of adding, retrieving, and processing entries as defined in the schema. """ - from fastapi import HTTPException from uuid_extensions import uuid7 from ward import raises, test From dc0ec364e7a250db8811108953338ffcdc0baf1e Mon Sep 17 00:00:00 2001 From: vedantsahai18 Date: Fri, 20 Dec 2024 15:25:52 -0500 Subject: [PATCH 05/10] wip: initial set of exceptions added --- .../agents_api/queries/agents/create_agent.py | 58 +++++++++---------- .../queries/agents/create_or_update_agent.py | 37 +++++++++--- .../agents_api/queries/agents/delete_agent.py | 39 +++++++++---- .../agents_api/queries/agents/get_agent.py | 28 +++++---- .../agents_api/queries/agents/list_agents.py | 29 ++++++---- .../agents_api/queries/agents/patch_agent.py | 38 ++++++++---- .../agents_api/queries/agents/update_agent.py | 39 +++++++++---- .../queries/developers/create_developer.py | 4 +- .../queries/developers/patch_developer.py | 4 +- .../queries/developers/update_developer.py | 5 ++ .../agents_api/queries/files/create_file.py | 38 ++++++------ .../agents_api/queries/files/delete_file.py | 5 ++ .../agents_api/queries/files/get_file.py | 33 ++++++----- .../agents_api/queries/files/list_files.py | 13 ++++- .../sessions/create_or_update_session.py | 7 ++- .../queries/sessions/create_session.py | 7 ++- .../queries/sessions/delete_session.py | 2 +- .../queries/sessions/get_session.py | 2 +- .../queries/sessions/list_sessions.py | 18 +++--- .../queries/sessions/patch_session.py | 7 ++- .../queries/sessions/update_session.py | 7 ++- .../queries/users/create_or_update_user.py | 4 +- .../agents_api/queries/users/create_user.py | 6 +- .../agents_api/queries/users/delete_user.py | 2 +- .../agents_api/queries/users/get_user.py | 5 -- .../agents_api/queries/users/list_users.py | 5 -- .../agents_api/queries/users/patch_user.py | 4 +- .../agents_api/queries/users/update_user.py | 4 +- 28 files changed, 283 insertions(+), 167 deletions(-) diff --git a/agents-api/agents_api/queries/agents/create_agent.py b/agents-api/agents_api/queries/agents/create_agent.py index 76c96f46b..0b7a7d208 100644 --- a/agents-api/agents_api/queries/agents/create_agent.py +++ b/agents-api/agents_api/queries/agents/create_agent.py @@ -8,13 +8,16 @@ from beartype import beartype from sqlglot import parse_one from uuid_extensions import uuid7 - +import asyncpg +from fastapi import HTTPException from ...autogen.openapi_model import Agent, CreateAgentRequest from ...metrics.counters import increase_counter from ..utils import ( generate_canonical_name, pg_query, wrap_in_class, + rewrap_exceptions, + partialclass, ) # Define the raw SQL query @@ -45,35 +48,30 @@ """).sql(pretty=True) -# @rewrap_exceptions( -# { -# psycopg_errors.ForeignKeyViolation: partialclass( -# HTTPException, -# status_code=404, -# detail="The specified developer does not exist.", -# ), -# psycopg_errors.UniqueViolation: partialclass( -# HTTPException, -# status_code=409, -# detail="An agent with this canonical name already exists for this developer.", -# ), -# psycopg_errors.CheckViolation: partialclass( -# HTTPException, -# status_code=400, -# detail="The provided data violates one or more constraints. Please check the input values.", -# ), -# ValidationError: partialclass( -# HTTPException, -# status_code=400, -# detail="Input validation failed. Please check the provided data.", -# ), -# TypeError: partialclass( -# HTTPException, -# status_code=400, -# detail="A type mismatch occurred. Please review the input.", -# ), -# } -# ) +@rewrap_exceptions( + { + asyncpg.exceptions.ForeignKeyViolationError: partialclass( + HTTPException, + status_code=404, + detail="The specified developer does not exist.", + ), + asyncpg.exceptions.UniqueViolationError: partialclass( + HTTPException, + status_code=409, + detail="An agent with this canonical name already exists for this developer.", + ), + asyncpg.exceptions.CheckViolationError: partialclass( + HTTPException, + status_code=400, + detail="The provided data violates one or more constraints. Please check the input values.", + ), + asyncpg.exceptions.DataError: partialclass( + HTTPException, + status_code=400, + detail="Invalid data provided. Please check the input values.", + ), + } +) @wrap_in_class( Agent, one=True, diff --git a/agents-api/agents_api/queries/agents/create_or_update_agent.py b/agents-api/agents_api/queries/agents/create_or_update_agent.py index ef3a0abe5..fd70e5f8b 100644 --- a/agents-api/agents_api/queries/agents/create_or_update_agent.py +++ b/agents-api/agents_api/queries/agents/create_or_update_agent.py @@ -7,6 +7,8 @@ from beartype import beartype from sqlglot import parse_one +from fastapi import HTTPException +import asyncpg from ...autogen.openapi_model import Agent, CreateOrUpdateAgentRequest from ...metrics.counters import increase_counter @@ -14,6 +16,8 @@ generate_canonical_name, pg_query, wrap_in_class, + rewrap_exceptions, + partialclass, ) # Define the raw SQL query @@ -44,15 +48,30 @@ """).sql(pretty=True) -# @rewrap_exceptions( -# { -# psycopg_errors.ForeignKeyViolation: partialclass( -# HTTPException, -# status_code=404, -# detail="The specified developer does not exist.", -# ) -# } -# ) +@rewrap_exceptions( + { + asyncpg.exceptions.ForeignKeyViolationError: partialclass( + HTTPException, + status_code=404, + detail="The specified developer does not exist.", + ), + asyncpg.exceptions.UniqueViolationError: partialclass( + HTTPException, + status_code=409, + detail="An agent with this canonical name already exists for this developer.", + ), + asyncpg.exceptions.CheckViolationError: partialclass( + HTTPException, + status_code=400, + detail="The provided data violates one or more constraints. Please check the input values.", + ), + asyncpg.exceptions.DataError: partialclass( + HTTPException, + status_code=400, + detail="Invalid data provided. Please check the input values.", + ), + } +) @wrap_in_class( Agent, one=True, diff --git a/agents-api/agents_api/queries/agents/delete_agent.py b/agents-api/agents_api/queries/agents/delete_agent.py index c0ca3919f..64b3e392e 100644 --- a/agents-api/agents_api/queries/agents/delete_agent.py +++ b/agents-api/agents_api/queries/agents/delete_agent.py @@ -7,12 +7,16 @@ from beartype import beartype from sqlglot import parse_one +from fastapi import HTTPException +import asyncpg from ...autogen.openapi_model import ResourceDeletedResponse from ...common.utils.datetime import utcnow from ..utils import ( pg_query, wrap_in_class, + rewrap_exceptions, + partialclass, ) # Define the raw SQL query @@ -59,17 +63,30 @@ """).sql(pretty=True) -# @rewrap_exceptions( -# @rewrap_exceptions( -# { -# psycopg_errors.ForeignKeyViolation: partialclass( -# HTTPException, -# status_code=404, -# detail="The specified developer does not exist.", -# ) -# } -# # TODO: Add more exceptions -# ) +@rewrap_exceptions( + { + asyncpg.exceptions.ForeignKeyViolationError: partialclass( + HTTPException, + status_code=404, + detail="The specified developer does not exist.", + ), + asyncpg.exceptions.UniqueViolationError: partialclass( + HTTPException, + status_code=409, + detail="An agent with this canonical name already exists for this developer.", + ), + asyncpg.exceptions.CheckViolationError: partialclass( + HTTPException, + status_code=400, + detail="The provided data violates one or more constraints. Please check the input values.", + ), + asyncpg.exceptions.DataError: partialclass( + HTTPException, + status_code=400, + detail="Invalid data provided. Please check the input values.", + ), + } +) @wrap_in_class( ResourceDeletedResponse, one=True, diff --git a/agents-api/agents_api/queries/agents/get_agent.py b/agents-api/agents_api/queries/agents/get_agent.py index a731300fa..985937b0d 100644 --- a/agents-api/agents_api/queries/agents/get_agent.py +++ b/agents-api/agents_api/queries/agents/get_agent.py @@ -7,11 +7,15 @@ from beartype import beartype from sqlglot import parse_one +from fastapi import HTTPException +import asyncpg from ...autogen.openapi_model import Agent from ..utils import ( pg_query, wrap_in_class, + rewrap_exceptions, + partialclass, ) # Define the raw SQL query @@ -35,16 +39,20 @@ """).sql(pretty=True) -# @rewrap_exceptions( -# { -# psycopg_errors.ForeignKeyViolation: partialclass( -# HTTPException, -# status_code=404, -# detail="The specified developer does not exist.", -# ) -# } -# # TODO: Add more exceptions -# ) +@rewrap_exceptions( + { + asyncpg.exceptions.ForeignKeyViolationError: partialclass( + HTTPException, + status_code=404, + detail="The specified developer does not exist.", + ), + asyncpg.exceptions.DataError: partialclass( + HTTPException, + status_code=400, + detail="Invalid data provided. Please check the input values.", + ), + } +) @wrap_in_class(Agent, one=True, transform=lambda d: {"id": d["agent_id"], **d}) @pg_query @beartype diff --git a/agents-api/agents_api/queries/agents/list_agents.py b/agents-api/agents_api/queries/agents/list_agents.py index 87a0c942d..68ee3c73a 100644 --- a/agents-api/agents_api/queries/agents/list_agents.py +++ b/agents-api/agents_api/queries/agents/list_agents.py @@ -8,11 +8,13 @@ from beartype import beartype from fastapi import HTTPException - +import asyncpg from ...autogen.openapi_model import Agent from ..utils import ( pg_query, wrap_in_class, + rewrap_exceptions, + partialclass, ) # Define the raw SQL query @@ -39,17 +41,20 @@ LIMIT $2 OFFSET $3; """ - -# @rewrap_exceptions( -# { -# psycopg_errors.ForeignKeyViolation: partialclass( -# HTTPException, -# status_code=404, -# detail="The specified developer does not exist.", -# ) -# } -# # TODO: Add more exceptions -# ) +@rewrap_exceptions( + { + asyncpg.exceptions.ForeignKeyViolationError: partialclass( + HTTPException, + status_code=404, + detail="The specified developer does not exist.", + ), + asyncpg.exceptions.DataError: partialclass( + HTTPException, + status_code=400, + detail="Invalid data provided. Please check the input values.", + ), + } +) @wrap_in_class(Agent, transform=lambda d: {"id": d["agent_id"], **d}) @pg_query @beartype diff --git a/agents-api/agents_api/queries/agents/patch_agent.py b/agents-api/agents_api/queries/agents/patch_agent.py index 69a5a6ca5..fef682858 100644 --- a/agents-api/agents_api/queries/agents/patch_agent.py +++ b/agents-api/agents_api/queries/agents/patch_agent.py @@ -7,12 +7,16 @@ from beartype import beartype from sqlglot import parse_one +from fastapi import HTTPException +import asyncpg from ...autogen.openapi_model import PatchAgentRequest, ResourceUpdatedResponse from ...metrics.counters import increase_counter from ..utils import ( pg_query, wrap_in_class, + rewrap_exceptions, + partialclass, ) # Define the raw SQL query @@ -44,16 +48,30 @@ """).sql(pretty=True) -# @rewrap_exceptions( -# { -# psycopg_errors.ForeignKeyViolation: partialclass( -# HTTPException, -# status_code=404, -# detail="The specified developer does not exist.", -# ) -# } -# # TODO: Add more exceptions -# ) +@rewrap_exceptions( + { + asyncpg.exceptions.ForeignKeyViolationError: partialclass( + HTTPException, + status_code=404, + detail="The specified developer does not exist.", + ), + asyncpg.exceptions.UniqueViolationError: partialclass( + HTTPException, + status_code=409, + detail="An agent with this canonical name already exists for this developer.", + ), + asyncpg.exceptions.CheckViolationError: partialclass( + HTTPException, + status_code=400, + detail="The provided data violates one or more constraints. Please check the input values.", + ), + asyncpg.exceptions.DataError: partialclass( + HTTPException, + status_code=400, + detail="Invalid data provided. Please check the input values.", + ), + } +) @wrap_in_class( ResourceUpdatedResponse, one=True, diff --git a/agents-api/agents_api/queries/agents/update_agent.py b/agents-api/agents_api/queries/agents/update_agent.py index f28e28264..5e33fdddd 100644 --- a/agents-api/agents_api/queries/agents/update_agent.py +++ b/agents-api/agents_api/queries/agents/update_agent.py @@ -7,12 +7,15 @@ from beartype import beartype from sqlglot import parse_one - +from fastapi import HTTPException +import asyncpg from ...autogen.openapi_model import ResourceUpdatedResponse, UpdateAgentRequest from ...metrics.counters import increase_counter from ..utils import ( pg_query, wrap_in_class, + rewrap_exceptions, + partialclass, ) # Define the raw SQL query @@ -29,16 +32,30 @@ """).sql(pretty=True) -# @rewrap_exceptions( -# { -# psycopg_errors.ForeignKeyViolation: partialclass( -# HTTPException, -# status_code=404, -# detail="The specified developer does not exist.", -# ) -# } -# # TODO: Add more exceptions -# ) +@rewrap_exceptions( + { + asyncpg.exceptions.ForeignKeyViolationError: partialclass( + HTTPException, + status_code=404, + detail="The specified developer does not exist.", + ), + asyncpg.exceptions.UniqueViolationError: partialclass( + HTTPException, + status_code=409, + detail="An agent with this canonical name already exists for this developer.", + ), + asyncpg.exceptions.CheckViolationError: partialclass( + HTTPException, + status_code=400, + detail="The provided data violates one or more constraints. Please check the input values.", + ), + asyncpg.exceptions.DataError: partialclass( + HTTPException, + status_code=400, + detail="Invalid data provided. Please check the input values.", + ), + } +) @wrap_in_class( ResourceUpdatedResponse, one=True, diff --git a/agents-api/agents_api/queries/developers/create_developer.py b/agents-api/agents_api/queries/developers/create_developer.py index bed6371c4..51011a63b 100644 --- a/agents-api/agents_api/queries/developers/create_developer.py +++ b/agents-api/agents_api/queries/developers/create_developer.py @@ -38,8 +38,8 @@ { asyncpg.UniqueViolationError: partialclass( HTTPException, - status_code=404, - detail="The specified developer does not exist.", + status_code=409, + detail="A developer with this email already exists.", ) } ) diff --git a/agents-api/agents_api/queries/developers/patch_developer.py b/agents-api/agents_api/queries/developers/patch_developer.py index af2ddb1f8..e14c8bbd0 100644 --- a/agents-api/agents_api/queries/developers/patch_developer.py +++ b/agents-api/agents_api/queries/developers/patch_developer.py @@ -26,8 +26,8 @@ { asyncpg.UniqueViolationError: partialclass( HTTPException, - status_code=404, - detail="The specified developer does not exist.", + status_code=409, + detail="A developer with this email already exists.", ) } ) diff --git a/agents-api/agents_api/queries/developers/update_developer.py b/agents-api/agents_api/queries/developers/update_developer.py index d41b333d5..659dcb111 100644 --- a/agents-api/agents_api/queries/developers/update_developer.py +++ b/agents-api/agents_api/queries/developers/update_developer.py @@ -28,6 +28,11 @@ HTTPException, status_code=404, detail="The specified developer does not exist.", + ), + asyncpg.UniqueViolationError: partialclass( + HTTPException, + status_code=409, + detail="A developer with this email already exists.", ) } ) diff --git a/agents-api/agents_api/queries/files/create_file.py b/agents-api/agents_api/queries/files/create_file.py index 48251fa5e..f2e35a6f4 100644 --- a/agents-api/agents_api/queries/files/create_file.py +++ b/agents-api/agents_api/queries/files/create_file.py @@ -60,25 +60,25 @@ # Add error handling decorator -# @rewrap_exceptions( -# { -# asyncpg.UniqueViolationError: partialclass( -# HTTPException, -# status_code=409, -# detail="A file with this name already exists for this developer", -# ), -# asyncpg.NoDataFoundError: partialclass( -# HTTPException, -# status_code=404, -# detail="The specified owner does not exist", -# ), -# asyncpg.ForeignKeyViolationError: partialclass( -# HTTPException, -# status_code=404, -# detail="The specified developer does not exist", -# ), -# } -# ) +@rewrap_exceptions( + { + asyncpg.UniqueViolationError: partialclass( + HTTPException, + status_code=409, + detail="A file with this name already exists for this developer", + ), + asyncpg.ForeignKeyViolationError: partialclass( + HTTPException, + status_code=404, + detail="The specified developer or owner does not exist", + ), + asyncpg.CheckViolationError: partialclass( + HTTPException, + status_code=400, + detail="File size must be positive and name must be between 1 and 255 characters", + ), + } +) @wrap_in_class( File, one=True, diff --git a/agents-api/agents_api/queries/files/delete_file.py b/agents-api/agents_api/queries/files/delete_file.py index 31cb43404..4cf0142ae 100644 --- a/agents-api/agents_api/queries/files/delete_file.py +++ b/agents-api/agents_api/queries/files/delete_file.py @@ -48,6 +48,11 @@ status_code=404, detail="File not found", ), + asyncpg.ForeignKeyViolationError: partialclass( + HTTPException, + status_code=404, + detail="The specified developer or owner does not exist", + ), } ) @wrap_in_class( diff --git a/agents-api/agents_api/queries/files/get_file.py b/agents-api/agents_api/queries/files/get_file.py index 5ccb08d86..882a93ab7 100644 --- a/agents-api/agents_api/queries/files/get_file.py +++ b/agents-api/agents_api/queries/files/get_file.py @@ -8,9 +8,12 @@ from beartype import beartype from sqlglot import parse_one +import asyncpg +from fastapi import HTTPException from ...autogen.openapi_model import File -from ..utils import pg_query, wrap_in_class +from ..utils import pg_query, wrap_in_class, rewrap_exceptions, partialclass, partialclass + # Define the raw SQL query file_query = parse_one(""" @@ -27,20 +30,20 @@ """).sql(pretty=True) -# @rewrap_exceptions( -# { -# asyncpg.NoDataFoundError: partialclass( -# HTTPException, -# status_code=404, -# detail="File not found", -# ), -# asyncpg.ForeignKeyViolationError: partialclass( -# HTTPException, -# status_code=404, -# detail="Developer not found", -# ), -# } -# ) +@rewrap_exceptions( + { + asyncpg.NoDataFoundError: partialclass( + HTTPException, + status_code=404, + detail="File not found", + ), + asyncpg.ForeignKeyViolationError: partialclass( + HTTPException, + status_code=404, + detail="The specified developer or owner does not exist", + ), + } +) @wrap_in_class( File, one=True, diff --git a/agents-api/agents_api/queries/files/list_files.py b/agents-api/agents_api/queries/files/list_files.py index 2f36def4f..7908bf37d 100644 --- a/agents-api/agents_api/queries/files/list_files.py +++ b/agents-api/agents_api/queries/files/list_files.py @@ -9,9 +9,10 @@ from beartype import beartype from fastapi import HTTPException from sqlglot import parse_one +import asyncpg from ...autogen.openapi_model import File -from ..utils import pg_query, wrap_in_class +from ..utils import pg_query, wrap_in_class, rewrap_exceptions, partialclass # Base query for listing files base_files_query = parse_one(""" @@ -21,7 +22,15 @@ WHERE f.developer_id = $1 """).sql(pretty=True) - +@rewrap_exceptions( + { + asyncpg.ForeignKeyViolationError: partialclass( + HTTPException, + status_code=404, + detail="The specified developer or owner does not exist", + ), + } +) @wrap_in_class( File, one=False, diff --git a/agents-api/agents_api/queries/sessions/create_or_update_session.py b/agents-api/agents_api/queries/sessions/create_or_update_session.py index 3c4dbf66e..b6c280b01 100644 --- a/agents-api/agents_api/queries/sessions/create_or_update_session.py +++ b/agents-api/agents_api/queries/sessions/create_or_update_session.py @@ -70,13 +70,18 @@ asyncpg.ForeignKeyViolationError: partialclass( HTTPException, status_code=404, - detail="The specified developer or participant does not exist.", + detail="The specified developer or session does not exist.", ), asyncpg.UniqueViolationError: partialclass( HTTPException, status_code=409, detail="A session with this ID already exists.", ), + asyncpg.CheckViolationError: partialclass( + HTTPException, + status_code=400, + detail="Invalid session data provided.", + ), } ) @wrap_in_class( diff --git a/agents-api/agents_api/queries/sessions/create_session.py b/agents-api/agents_api/queries/sessions/create_session.py index 058462cf8..0bb967ce5 100644 --- a/agents-api/agents_api/queries/sessions/create_session.py +++ b/agents-api/agents_api/queries/sessions/create_session.py @@ -58,13 +58,18 @@ asyncpg.ForeignKeyViolationError: partialclass( HTTPException, status_code=404, - detail="The specified developer or participant does not exist.", + detail="The specified developer or session does not exist.", ), asyncpg.UniqueViolationError: partialclass( HTTPException, status_code=409, detail="A session with this ID already exists.", ), + asyncpg.CheckViolationError: partialclass( + HTTPException, + status_code=400, + detail="Invalid session data provided.", + ), } ) @wrap_in_class( diff --git a/agents-api/agents_api/queries/sessions/delete_session.py b/agents-api/agents_api/queries/sessions/delete_session.py index 2e3234fe2..ff5317f58 100644 --- a/agents-api/agents_api/queries/sessions/delete_session.py +++ b/agents-api/agents_api/queries/sessions/delete_session.py @@ -30,7 +30,7 @@ asyncpg.ForeignKeyViolationError: partialclass( HTTPException, status_code=404, - detail="The specified developer does not exist.", + detail="The specified developer or session does not exist.", ), } ) diff --git a/agents-api/agents_api/queries/sessions/get_session.py b/agents-api/agents_api/queries/sessions/get_session.py index 1f704539e..cc12d0f88 100644 --- a/agents-api/agents_api/queries/sessions/get_session.py +++ b/agents-api/agents_api/queries/sessions/get_session.py @@ -51,7 +51,7 @@ asyncpg.ForeignKeyViolationError: partialclass( HTTPException, status_code=404, - detail="The specified developer does not exist.", + detail="The specified developer or session does not exist.", ), asyncpg.NoDataFoundError: partialclass( HTTPException, status_code=404, detail="Session not found" diff --git a/agents-api/agents_api/queries/sessions/list_sessions.py b/agents-api/agents_api/queries/sessions/list_sessions.py index 3aabaf32d..c113c0192 100644 --- a/agents-api/agents_api/queries/sessions/list_sessions.py +++ b/agents-api/agents_api/queries/sessions/list_sessions.py @@ -12,7 +12,7 @@ from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class # Define the raw SQL query -raw_query = """ +session_query = """ WITH session_participants AS ( SELECT sl.session_id, @@ -49,11 +49,6 @@ LIMIT $2 OFFSET $6; """ -# Parse and optimize the query -# query = parse_one(raw_query).sql(pretty=True) -query = raw_query - - @rewrap_exceptions( { asyncpg.ForeignKeyViolationError: partialclass( @@ -62,7 +57,14 @@ detail="The specified developer does not exist.", ), asyncpg.NoDataFoundError: partialclass( - HTTPException, status_code=404, detail="No sessions found" + HTTPException, + status_code=404, + detail="No sessions found", + ), + asyncpg.CheckViolationError: partialclass( + HTTPException, + status_code=400, + detail="Invalid session data provided.", ), } ) @@ -94,7 +96,7 @@ async def list_sessions( tuple[str, list]: SQL query and parameters """ return ( - query, + session_query, [ developer_id, # $1 limit, # $2 diff --git a/agents-api/agents_api/queries/sessions/patch_session.py b/agents-api/agents_api/queries/sessions/patch_session.py index 7d526ae1a..d7533e124 100644 --- a/agents-api/agents_api/queries/sessions/patch_session.py +++ b/agents-api/agents_api/queries/sessions/patch_session.py @@ -37,13 +37,18 @@ asyncpg.ForeignKeyViolationError: partialclass( HTTPException, status_code=404, - detail="The specified developer or participant does not exist.", + detail="The specified developer or session does not exist.", ), asyncpg.NoDataFoundError: partialclass( HTTPException, status_code=404, detail="Session not found", ), + asyncpg.CheckViolationError: partialclass( + HTTPException, + status_code=400, + detail="Invalid session data provided.", + ), } ) @wrap_in_class( diff --git a/agents-api/agents_api/queries/sessions/update_session.py b/agents-api/agents_api/queries/sessions/update_session.py index 7c58d10e6..e3f46c0af 100644 --- a/agents-api/agents_api/queries/sessions/update_session.py +++ b/agents-api/agents_api/queries/sessions/update_session.py @@ -33,13 +33,18 @@ asyncpg.ForeignKeyViolationError: partialclass( HTTPException, status_code=404, - detail="The specified developer or participant does not exist.", + detail="The specified developer or session does not exist.", ), asyncpg.NoDataFoundError: partialclass( HTTPException, status_code=404, detail="Session not found", ), + asyncpg.CheckViolationError: partialclass( + HTTPException, + status_code=400, + detail="Invalid session data provided.", + ), } ) @wrap_in_class( diff --git a/agents-api/agents_api/queries/users/create_or_update_user.py b/agents-api/agents_api/queries/users/create_or_update_user.py index 965ae4ce4..0a2936a9b 100644 --- a/agents-api/agents_api/queries/users/create_or_update_user.py +++ b/agents-api/agents_api/queries/users/create_or_update_user.py @@ -40,10 +40,10 @@ status_code=404, detail="The specified developer does not exist.", ), - asyncpg.UniqueViolationError: partialclass( # Add handling for potential race conditions + asyncpg.UniqueViolationError: partialclass( HTTPException, status_code=409, - detail="A user with this ID already exists.", + detail="A user with this ID already exists for the specified developer.", ), } ) diff --git a/agents-api/agents_api/queries/users/create_user.py b/agents-api/agents_api/queries/users/create_user.py index 8f35a646c..e246c7255 100644 --- a/agents-api/agents_api/queries/users/create_user.py +++ b/agents-api/agents_api/queries/users/create_user.py @@ -37,10 +37,10 @@ status_code=404, detail="The specified developer does not exist.", ), - asyncpg.NullValueNoIndicatorParameterError: partialclass( + asyncpg.UniqueViolationError: partialclass( HTTPException, - status_code=404, - detail="The specified developer does not exist.", + status_code=409, + detail="A user with this ID already exists for the specified developer.", ), } ) diff --git a/agents-api/agents_api/queries/users/delete_user.py b/agents-api/agents_api/queries/users/delete_user.py index ad5befd73..6b8497980 100644 --- a/agents-api/agents_api/queries/users/delete_user.py +++ b/agents-api/agents_api/queries/users/delete_user.py @@ -56,7 +56,7 @@ status_code=404, detail="The specified developer does not exist.", ), - asyncpg.UniqueViolationError: partialclass( + asyncpg.DataError: partialclass( HTTPException, status_code=404, detail="The specified user does not exist.", diff --git a/agents-api/agents_api/queries/users/get_user.py b/agents-api/agents_api/queries/users/get_user.py index 2b71f9192..07a840621 100644 --- a/agents-api/agents_api/queries/users/get_user.py +++ b/agents-api/agents_api/queries/users/get_user.py @@ -31,11 +31,6 @@ status_code=404, detail="The specified developer does not exist.", ), - asyncpg.UniqueViolationError: partialclass( - HTTPException, - status_code=404, - detail="The specified user does not exist.", - ), } ) @wrap_in_class(User, one=True) diff --git a/agents-api/agents_api/queries/users/list_users.py b/agents-api/agents_api/queries/users/list_users.py index 0f0818135..75fd62b4b 100644 --- a/agents-api/agents_api/queries/users/list_users.py +++ b/agents-api/agents_api/queries/users/list_users.py @@ -42,11 +42,6 @@ status_code=404, detail="The specified developer does not exist.", ), - asyncpg.UniqueViolationError: partialclass( - HTTPException, - status_code=404, - detail="The specified user does not exist.", - ), } ) @wrap_in_class(User) diff --git a/agents-api/agents_api/queries/users/patch_user.py b/agents-api/agents_api/queries/users/patch_user.py index c55ee31b7..fb2d8bfad 100644 --- a/agents-api/agents_api/queries/users/patch_user.py +++ b/agents-api/agents_api/queries/users/patch_user.py @@ -47,8 +47,8 @@ ), asyncpg.UniqueViolationError: partialclass( HTTPException, - status_code=404, - detail="The specified user does not exist.", + status_code=409, + detail="A user with this ID already exists for the specified developer.", ), } ) diff --git a/agents-api/agents_api/queries/users/update_user.py b/agents-api/agents_api/queries/users/update_user.py index 91572e15d..975dc57c7 100644 --- a/agents-api/agents_api/queries/users/update_user.py +++ b/agents-api/agents_api/queries/users/update_user.py @@ -31,8 +31,8 @@ ), asyncpg.UniqueViolationError: partialclass( HTTPException, - status_code=404, - detail="The specified user does not exist.", + status_code=409, + detail="A user with this ID already exists for the specified developer.", ), } ) From 32d67bc9a5e7f286fb9008a104329e61858aa002 Mon Sep 17 00:00:00 2001 From: Vedantsahai18 Date: Fri, 20 Dec 2024 20:26:41 +0000 Subject: [PATCH 06/10] refactor: Lint agents-api (CI) --- agents-api/agents_api/queries/agents/create_agent.py | 9 +++++---- .../queries/agents/create_or_update_agent.py | 8 ++++---- agents-api/agents_api/queries/agents/delete_agent.py | 8 ++++---- agents-api/agents_api/queries/agents/get_agent.py | 8 ++++---- agents-api/agents_api/queries/agents/list_agents.py | 8 +++++--- agents-api/agents_api/queries/agents/patch_agent.py | 8 ++++---- agents-api/agents_api/queries/agents/update_agent.py | 9 +++++---- .../queries/developers/update_developer.py | 2 +- agents-api/agents_api/queries/files/get_file.py | 12 ++++++++---- agents-api/agents_api/queries/files/list_files.py | 5 +++-- .../agents_api/queries/sessions/list_sessions.py | 1 + 11 files changed, 44 insertions(+), 34 deletions(-) diff --git a/agents-api/agents_api/queries/agents/create_agent.py b/agents-api/agents_api/queries/agents/create_agent.py index 0b7a7d208..5294cfa6d 100644 --- a/agents-api/agents_api/queries/agents/create_agent.py +++ b/agents-api/agents_api/queries/agents/create_agent.py @@ -5,19 +5,20 @@ from uuid import UUID +import asyncpg from beartype import beartype +from fastapi import HTTPException from sqlglot import parse_one from uuid_extensions import uuid7 -import asyncpg -from fastapi import HTTPException + from ...autogen.openapi_model import Agent, CreateAgentRequest from ...metrics.counters import increase_counter from ..utils import ( generate_canonical_name, + partialclass, pg_query, - wrap_in_class, rewrap_exceptions, - partialclass, + wrap_in_class, ) # Define the raw SQL query diff --git a/agents-api/agents_api/queries/agents/create_or_update_agent.py b/agents-api/agents_api/queries/agents/create_or_update_agent.py index fd70e5f8b..fcef53fd6 100644 --- a/agents-api/agents_api/queries/agents/create_or_update_agent.py +++ b/agents-api/agents_api/queries/agents/create_or_update_agent.py @@ -5,19 +5,19 @@ from uuid import UUID +import asyncpg from beartype import beartype -from sqlglot import parse_one from fastapi import HTTPException -import asyncpg +from sqlglot import parse_one from ...autogen.openapi_model import Agent, CreateOrUpdateAgentRequest from ...metrics.counters import increase_counter from ..utils import ( generate_canonical_name, + partialclass, pg_query, - wrap_in_class, rewrap_exceptions, - partialclass, + wrap_in_class, ) # Define the raw SQL query diff --git a/agents-api/agents_api/queries/agents/delete_agent.py b/agents-api/agents_api/queries/agents/delete_agent.py index 64b3e392e..2fd1f1406 100644 --- a/agents-api/agents_api/queries/agents/delete_agent.py +++ b/agents-api/agents_api/queries/agents/delete_agent.py @@ -5,18 +5,18 @@ from uuid import UUID +import asyncpg from beartype import beartype -from sqlglot import parse_one from fastapi import HTTPException -import asyncpg +from sqlglot import parse_one from ...autogen.openapi_model import ResourceDeletedResponse from ...common.utils.datetime import utcnow from ..utils import ( + partialclass, pg_query, - wrap_in_class, rewrap_exceptions, - partialclass, + wrap_in_class, ) # Define the raw SQL query diff --git a/agents-api/agents_api/queries/agents/get_agent.py b/agents-api/agents_api/queries/agents/get_agent.py index 985937b0d..79fa1c4fc 100644 --- a/agents-api/agents_api/queries/agents/get_agent.py +++ b/agents-api/agents_api/queries/agents/get_agent.py @@ -5,17 +5,17 @@ from uuid import UUID +import asyncpg from beartype import beartype -from sqlglot import parse_one from fastapi import HTTPException -import asyncpg +from sqlglot import parse_one from ...autogen.openapi_model import Agent from ..utils import ( + partialclass, pg_query, - wrap_in_class, rewrap_exceptions, - partialclass, + wrap_in_class, ) # Define the raw SQL query diff --git a/agents-api/agents_api/queries/agents/list_agents.py b/agents-api/agents_api/queries/agents/list_agents.py index 68ee3c73a..11b9dc283 100644 --- a/agents-api/agents_api/queries/agents/list_agents.py +++ b/agents-api/agents_api/queries/agents/list_agents.py @@ -6,15 +6,16 @@ from typing import Any, Literal from uuid import UUID +import asyncpg from beartype import beartype from fastapi import HTTPException -import asyncpg + from ...autogen.openapi_model import Agent from ..utils import ( + partialclass, pg_query, - wrap_in_class, rewrap_exceptions, - partialclass, + wrap_in_class, ) # Define the raw SQL query @@ -41,6 +42,7 @@ LIMIT $2 OFFSET $3; """ + @rewrap_exceptions( { asyncpg.exceptions.ForeignKeyViolationError: partialclass( diff --git a/agents-api/agents_api/queries/agents/patch_agent.py b/agents-api/agents_api/queries/agents/patch_agent.py index fef682858..06f0b9253 100644 --- a/agents-api/agents_api/queries/agents/patch_agent.py +++ b/agents-api/agents_api/queries/agents/patch_agent.py @@ -5,18 +5,18 @@ from uuid import UUID +import asyncpg from beartype import beartype -from sqlglot import parse_one from fastapi import HTTPException -import asyncpg +from sqlglot import parse_one from ...autogen.openapi_model import PatchAgentRequest, ResourceUpdatedResponse from ...metrics.counters import increase_counter from ..utils import ( + partialclass, pg_query, - wrap_in_class, rewrap_exceptions, - partialclass, + wrap_in_class, ) # Define the raw SQL query diff --git a/agents-api/agents_api/queries/agents/update_agent.py b/agents-api/agents_api/queries/agents/update_agent.py index 5e33fdddd..4d19229d8 100644 --- a/agents-api/agents_api/queries/agents/update_agent.py +++ b/agents-api/agents_api/queries/agents/update_agent.py @@ -5,17 +5,18 @@ from uuid import UUID +import asyncpg from beartype import beartype -from sqlglot import parse_one from fastapi import HTTPException -import asyncpg +from sqlglot import parse_one + from ...autogen.openapi_model import ResourceUpdatedResponse, UpdateAgentRequest from ...metrics.counters import increase_counter from ..utils import ( + partialclass, pg_query, - wrap_in_class, rewrap_exceptions, - partialclass, + wrap_in_class, ) # Define the raw SQL query diff --git a/agents-api/agents_api/queries/developers/update_developer.py b/agents-api/agents_api/queries/developers/update_developer.py index 659dcb111..8f3e7cd87 100644 --- a/agents-api/agents_api/queries/developers/update_developer.py +++ b/agents-api/agents_api/queries/developers/update_developer.py @@ -33,7 +33,7 @@ HTTPException, status_code=409, detail="A developer with this email already exists.", - ) + ), } ) @wrap_in_class(Developer, one=True, transform=lambda d: {**d, "id": d["developer_id"]}) diff --git a/agents-api/agents_api/queries/files/get_file.py b/agents-api/agents_api/queries/files/get_file.py index 882a93ab7..04ba8ea71 100644 --- a/agents-api/agents_api/queries/files/get_file.py +++ b/agents-api/agents_api/queries/files/get_file.py @@ -6,14 +6,18 @@ from typing import Literal from uuid import UUID -from beartype import beartype -from sqlglot import parse_one import asyncpg +from beartype import beartype from fastapi import HTTPException +from sqlglot import parse_one from ...autogen.openapi_model import File -from ..utils import pg_query, wrap_in_class, rewrap_exceptions, partialclass, partialclass - +from ..utils import ( + partialclass, + pg_query, + rewrap_exceptions, + wrap_in_class, +) # Define the raw SQL query file_query = parse_one(""" diff --git a/agents-api/agents_api/queries/files/list_files.py b/agents-api/agents_api/queries/files/list_files.py index 7908bf37d..d3866dacc 100644 --- a/agents-api/agents_api/queries/files/list_files.py +++ b/agents-api/agents_api/queries/files/list_files.py @@ -6,13 +6,13 @@ from typing import Literal from uuid import UUID +import asyncpg from beartype import beartype from fastapi import HTTPException from sqlglot import parse_one -import asyncpg from ...autogen.openapi_model import File -from ..utils import pg_query, wrap_in_class, rewrap_exceptions, partialclass +from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class # Base query for listing files base_files_query = parse_one(""" @@ -22,6 +22,7 @@ WHERE f.developer_id = $1 """).sql(pretty=True) + @rewrap_exceptions( { asyncpg.ForeignKeyViolationError: partialclass( diff --git a/agents-api/agents_api/queries/sessions/list_sessions.py b/agents-api/agents_api/queries/sessions/list_sessions.py index c113c0192..ac3573e61 100644 --- a/agents-api/agents_api/queries/sessions/list_sessions.py +++ b/agents-api/agents_api/queries/sessions/list_sessions.py @@ -49,6 +49,7 @@ LIMIT $2 OFFSET $6; """ + @rewrap_exceptions( { asyncpg.ForeignKeyViolationError: partialclass( From 831e950ead49c33eaed6972ff47f29067f8dac81 Mon Sep 17 00:00:00 2001 From: vedantsahai18 Date: Fri, 20 Dec 2024 16:40:38 -0500 Subject: [PATCH 07/10] chore: added embedding reading + doctrings updates --- .../agents_api/queries/docs/create_doc.py | 13 +++++++ .../agents_api/queries/docs/delete_doc.py | 9 +++++ .../agents_api/queries/docs/embed_snippets.py | 37 +++++++++++++++++++ agents-api/agents_api/queries/docs/get_doc.py | 26 ++++++++++--- .../agents_api/queries/docs/list_docs.py | 29 ++++++++++----- .../queries/docs/search_docs_by_embedding.py | 29 ++++++++++----- .../queries/docs/search_docs_by_text.py | 29 ++++++++++----- .../queries/entries/create_entries.py | 22 +++++++++++ .../queries/entries/delete_entries.py | 11 +++++- .../agents_api/queries/entries/get_history.py | 10 +++++ .../queries/entries/list_entries.py | 15 ++++++++ 11 files changed, 194 insertions(+), 36 deletions(-) diff --git a/agents-api/agents_api/queries/docs/create_doc.py b/agents-api/agents_api/queries/docs/create_doc.py index bf789fad2..59fd40004 100644 --- a/agents-api/agents_api/queries/docs/create_doc.py +++ b/agents-api/agents_api/queries/docs/create_doc.py @@ -107,6 +107,19 @@ async def create_doc( ) -> list[tuple[str, list] | tuple[str, list, str]]: """ Insert a new doc record into Timescale and optionally associate it with an owner. + + Parameters: + owner_type (Literal["user", "agent"]): The type of the owner of the documents. + owner_id (UUID): The ID of the owner of the documents. + modality (Literal["text", "image", "mixed"]): The modality of the documents. + embedding_model (str): The model used for embedding. + embedding_dimensions (int): The dimensions of the embedding. + language (str): The language of the documents. + index (int): The index of the documents. + data (CreateDocRequest): The data for the document. + + Returns: + list[tuple[str, list] | tuple[str, list, str]]: SQL query and parameters for creating the document. """ # Generate a UUID if not provided doc_id = doc_id or uuid7() diff --git a/agents-api/agents_api/queries/docs/delete_doc.py b/agents-api/agents_api/queries/docs/delete_doc.py index adeb09bd8..5697ca8d6 100644 --- a/agents-api/agents_api/queries/docs/delete_doc.py +++ b/agents-api/agents_api/queries/docs/delete_doc.py @@ -67,6 +67,15 @@ async def delete_doc( """ Deletes a doc (and associated doc_owners) for the given developer and doc_id. If owner_type/owner_id is specified, only remove doc if that matches. + + Parameters: + developer_id (UUID): The ID of the developer. + doc_id (UUID): The ID of the document. + owner_type (Literal["user", "agent"]): The type of the owner of the documents. + owner_id (UUID): The ID of the owner of the documents. + + Returns: + tuple[str, list]: SQL query and parameters for deleting the document. """ return ( delete_doc_query, diff --git a/agents-api/agents_api/queries/docs/embed_snippets.py b/agents-api/agents_api/queries/docs/embed_snippets.py index e69de29bb..1a20d6a34 100644 --- a/agents-api/agents_api/queries/docs/embed_snippets.py +++ b/agents-api/agents_api/queries/docs/embed_snippets.py @@ -0,0 +1,37 @@ +from typing import Literal +from uuid import UUID + +from beartype import beartype +from sqlglot import parse_one + +from ..utils import pg_query + +# TODO: This is a placeholder for the actual query +vectorizer_query = None + + +@pg_query +@beartype +async def embed_snippets( + *, + developer_id: UUID, + doc_id: UUID, + owner_type: Literal["user", "agent"] | None = None, + owner_id: UUID | None = None, +) -> tuple[str, list]: + """ + Trigger the vectorizer to generate embeddings for documents. + + Parameters: + developer_id (UUID): The ID of the developer. + doc_id (UUID): The ID of the document. + owner_type (Literal["user", "agent"]): The type of the owner of the documents. + owner_id (UUID): The ID of the owner of the documents. + + Returns: + tuple[str, list]: SQL query and parameters for embedding the snippets. + """ + return ( + vectorizer_query, + [developer_id, doc_id, owner_type, owner_id], + ) diff --git a/agents-api/agents_api/queries/docs/get_doc.py b/agents-api/agents_api/queries/docs/get_doc.py index b46563dbb..8575f77b0 100644 --- a/agents-api/agents_api/queries/docs/get_doc.py +++ b/agents-api/agents_api/queries/docs/get_doc.py @@ -8,10 +8,15 @@ from ...autogen.openapi_model import Doc from ..utils import pg_query, wrap_in_class -doc_query = parse_one(""" -SELECT d.* +# Combined query to fetch document details and embedding +doc_with_embedding_query = parse_one(""" +SELECT d.*, e.embedding FROM docs d -LEFT JOIN doc_owners doc_own ON d.developer_id = doc_own.developer_id AND d.doc_id = doc_own.doc_id +LEFT JOIN doc_owners doc_own + ON d.developer_id = doc_own.developer_id + AND d.doc_id = doc_own.doc_id +LEFT JOIN docs_embeddings e + ON d.doc_id = e.doc_id WHERE d.developer_id = $1 AND d.doc_id = $2 AND ( @@ -31,7 +36,7 @@ "content": ast.literal_eval(d["content"])[0] if len(ast.literal_eval(d["content"])) == 1 else ast.literal_eval(d["content"]), - # "embeddings": d["embeddings"], + "embedding": d["embedding"], # Add embedding to the transformation }, ) @pg_query @@ -44,9 +49,18 @@ async def get_doc( owner_id: UUID | None = None, ) -> tuple[str, list]: """ - Fetch a single doc, optionally constrained to a given owner. + Fetch a single doc with its embedding, optionally constrained to a given owner. + + Parameters: + developer_id (UUID): The ID of the developer. + doc_id (UUID): The ID of the document. + owner_type (Literal["user", "agent"]): The type of the owner of the documents. + owner_id (UUID): The ID of the owner of the documents. + + Returns: + tuple[str, list]: SQL query and parameters for fetching the document. """ return ( - doc_query, + doc_with_embedding_query, [developer_id, doc_id, owner_type, owner_id], ) diff --git a/agents-api/agents_api/queries/docs/list_docs.py b/agents-api/agents_api/queries/docs/list_docs.py index 92cbacf7f..8ea196958 100644 --- a/agents-api/agents_api/queries/docs/list_docs.py +++ b/agents-api/agents_api/queries/docs/list_docs.py @@ -9,11 +9,12 @@ from ...autogen.openapi_model import Doc from ..utils import pg_query, wrap_in_class -# Base query for listing docs +# Base query for listing docs with optional embeddings base_docs_query = parse_one(""" -SELECT d.* +SELECT d.*, CASE WHEN $2 THEN NULL ELSE e.embedding END AS embedding FROM docs d LEFT JOIN doc_owners doc_own ON d.developer_id = doc_own.developer_id AND d.doc_id = doc_own.doc_id +LEFT JOIN docs_embeddings e ON d.doc_id = e.doc_id WHERE d.developer_id = $1 """).sql(pretty=True) @@ -27,7 +28,7 @@ "content": ast.literal_eval(d["content"])[0] if len(ast.literal_eval(d["content"])) == 1 else ast.literal_eval(d["content"]), - # "embeddings": d["embeddings"], + "embedding": d.get("embedding"), # Add embedding to the transformation }, ) @pg_query @@ -46,6 +47,20 @@ async def list_docs( ) -> tuple[str, list]: """ Lists docs with optional owner filtering, pagination, and sorting. + + Parameters: + developer_id (UUID): The ID of the developer. + owner_id (UUID): The ID of the owner of the documents. + owner_type (Literal["user", "agent"]): The type of the owner of the documents. + limit (int): The number of documents to return. + offset (int): The number of documents to skip. + sort_by (Literal["created_at", "updated_at"]): The field to sort by. + direction (Literal["asc", "desc"]): The direction to sort by. + metadata_filter (dict[str, Any]): The metadata filter to apply. + include_without_embeddings (bool): Whether to include documents without embeddings. + + Returns: + tuple[str, list]: SQL query and parameters for listing the documents. """ if direction.lower() not in ["asc", "desc"]: raise HTTPException(status_code=400, detail="Invalid sort direction") @@ -61,11 +76,11 @@ async def list_docs( # Start with the base query query = base_docs_query - params = [developer_id] + params = [developer_id, include_without_embeddings] # Add owner filtering if owner_type and owner_id: - query += " AND doc_own.owner_type = $2 AND doc_own.owner_id = $3" + query += " AND doc_own.owner_type = $3 AND doc_own.owner_id = $4" params.extend([owner_type, owner_id]) # Add metadata filtering @@ -74,10 +89,6 @@ async def list_docs( query += f" AND d.metadata->>'{key}' = ${len(params) + 1}" params.append(value) - # Include or exclude documents without embeddings - # if not include_without_embeddings: - # query += " AND d.embeddings IS NOT NULL" - # Add sorting and pagination query += f" ORDER BY {sort_by} {direction} LIMIT ${len(params) + 1} OFFSET ${len(params) + 2}" params.extend([limit, offset]) diff --git a/agents-api/agents_api/queries/docs/search_docs_by_embedding.py b/agents-api/agents_api/queries/docs/search_docs_by_embedding.py index e3120bd36..c7b15ee64 100644 --- a/agents-api/agents_api/queries/docs/search_docs_by_embedding.py +++ b/agents-api/agents_api/queries/docs/search_docs_by_embedding.py @@ -9,7 +9,7 @@ from fastapi import HTTPException from sqlglot import parse_one -from ...autogen.openapi_model import Doc +from ...autogen.openapi_model import DocReference from ..utils import pg_query, wrap_in_class # If you're doing approximate ANN (DiskANN) or IVF, you might use a special function or hint. @@ -33,11 +33,14 @@ @wrap_in_class( - Doc, - one=False, - transform=lambda rec: { - **rec, - "id": rec["doc_id"], + DocReference, + transform=lambda d: { + "owner": { + "id": d["owner_id"], + "role": d["owner_type"], + }, + "metadata": d.get("metadata", {}), + **d, }, ) @pg_query @@ -52,10 +55,16 @@ async def search_docs_by_embedding( ) -> tuple[str, list]: """ Vector-based doc search: - - developer_id is required - - query_embedding: the vector to query - - k: number of results to return - - owner_type/owner_id: optional doc ownership filter + + Parameters: + developer_id (UUID): The ID of the developer. + query_embedding (List[float]): The vector to query. + k (int): The number of results to return. + owner_type (Literal["user", "agent", "org"]): The type of the owner of the documents. + owner_id (UUID): The ID of the owner of the documents. + + Returns: + tuple[str, list]: SQL query and parameters for searching the documents. """ if k < 1: raise HTTPException(status_code=400, detail="k must be >= 1") diff --git a/agents-api/agents_api/queries/docs/search_docs_by_text.py b/agents-api/agents_api/queries/docs/search_docs_by_text.py index 9f434d438..0ab309ee8 100644 --- a/agents-api/agents_api/queries/docs/search_docs_by_text.py +++ b/agents-api/agents_api/queries/docs/search_docs_by_text.py @@ -9,7 +9,7 @@ from fastapi import HTTPException from sqlglot import parse_one -from ...autogen.openapi_model import Doc +from ...autogen.openapi_model import DocReference from ..utils import pg_query, wrap_in_class search_docs_text_query = parse_one(""" @@ -31,11 +31,14 @@ @wrap_in_class( - Doc, - one=False, - transform=lambda rec: { - **rec, - "id": rec["doc_id"], + DocReference, + transform=lambda d: { + "owner": { + "id": d["owner_id"], + "role": d["owner_type"], + }, + "metadata": d.get("metadata", {}), + **d, }, ) @pg_query @@ -50,10 +53,16 @@ async def search_docs_by_text( ) -> tuple[str, list]: """ Full-text search on docs using the search_tsv column. - - developer_id: required - - query: the text to look for - - k: max results - - owner_type / owner_id: optional doc ownership filter + + Parameters: + developer_id (UUID): The ID of the developer. + query (str): The text to search for. + k (int): The number of results to return. + owner_type (Literal["user", "agent", "org"]): The type of the owner of the documents. + owner_id (UUID): The ID of the owner of the documents. + + Returns: + tuple[str, list]: SQL query and parameters for searching the documents. """ if k < 1: raise HTTPException(status_code=400, detail="k must be >= 1") diff --git a/agents-api/agents_api/queries/entries/create_entries.py b/agents-api/agents_api/queries/entries/create_entries.py index 95973ad0b..d8439fa21 100644 --- a/agents-api/agents_api/queries/entries/create_entries.py +++ b/agents-api/agents_api/queries/entries/create_entries.py @@ -94,6 +94,17 @@ async def create_entries( session_id: UUID, data: list[CreateEntryRequest], ) -> list[tuple[str, list, Literal["fetch", "fetchmany", "fetchrow"]]]: + """ + Create entries in a session. + + Parameters: + developer_id (UUID): The ID of the developer. + session_id (UUID): The ID of the session. + data (list[CreateEntryRequest]): The list of entries to create. + + Returns: + list[tuple[str, list, Literal["fetch", "fetchmany", "fetchrow"]]]: SQL query and parameters for creating the entries. + """ # Convert the data to a list of dictionaries data_dicts = [item.model_dump(mode="json") for item in data] @@ -163,6 +174,17 @@ async def add_entry_relations( session_id: UUID, data: list[Relation], ) -> list[tuple[str, list, Literal["fetch", "fetchmany", "fetchrow"]]]: + """ + Add relations between entries in a session. + + Parameters: + developer_id (UUID): The ID of the developer. + session_id (UUID): The ID of the session. + data (list[Relation]): The list of relations to add. + + Returns: + list[tuple[str, list, Literal["fetch", "fetchmany", "fetchrow"]]]: SQL query and parameters for adding the relations. + """ # Convert the data to a list of dictionaries data_dicts = [item.model_dump(mode="json") for item in data] diff --git a/agents-api/agents_api/queries/entries/delete_entries.py b/agents-api/agents_api/queries/entries/delete_entries.py index 47b7379a4..14a9648e5 100644 --- a/agents-api/agents_api/queries/entries/delete_entries.py +++ b/agents-api/agents_api/queries/entries/delete_entries.py @@ -134,7 +134,16 @@ async def delete_entries_for_session( async def delete_entries( *, developer_id: UUID, session_id: UUID, entry_ids: list[UUID] ) -> list[tuple[str, list, Literal["fetch", "fetchmany", "fetchrow"]]]: - """Delete specific entries by their IDs.""" + """Delete specific entries by their IDs. + + Parameters: + developer_id (UUID): The ID of the developer. + session_id (UUID): The ID of the session. + entry_ids (list[UUID]): The IDs of the entries to delete. + + Returns: + list[tuple[str, list, Literal["fetch", "fetchmany", "fetchrow"]]]: SQL query and parameters for deleting the entries. + """ return [ ( session_exists_query, diff --git a/agents-api/agents_api/queries/entries/get_history.py b/agents-api/agents_api/queries/entries/get_history.py index ffa0746c0..6a734d4c5 100644 --- a/agents-api/agents_api/queries/entries/get_history.py +++ b/agents-api/agents_api/queries/entries/get_history.py @@ -95,6 +95,16 @@ async def get_history( session_id: UUID, allowed_sources: list[str] = ["api_request", "api_response"], ) -> tuple[str, list] | tuple[str, list, str]: + """Get the history of a session. + + Parameters: + developer_id (UUID): The ID of the developer. + session_id (UUID): The ID of the session. + allowed_sources (list[str]): The sources to include in the history. + + Returns: + tuple[str, list] | tuple[str, list, str]: SQL query and parameters for getting the history. + """ return ( history_query, [session_id, allowed_sources, developer_id], diff --git a/agents-api/agents_api/queries/entries/list_entries.py b/agents-api/agents_api/queries/entries/list_entries.py index 89f432734..0153fe778 100644 --- a/agents-api/agents_api/queries/entries/list_entries.py +++ b/agents-api/agents_api/queries/entries/list_entries.py @@ -88,6 +88,21 @@ async def list_entries( direction: Literal["asc", "desc"] = "asc", exclude_relations: list[str] = [], ) -> list[tuple[str, list] | tuple[str, list, str]]: + """List entries in a session. + + Parameters: + developer_id (UUID): The ID of the developer. + session_id (UUID): The ID of the session. + allowed_sources (list[str]): The sources to include in the history. + limit (int): The number of entries to return. + offset (int): The number of entries to skip. + sort_by (Literal["created_at", "timestamp"]): The field to sort by. + direction (Literal["asc", "desc"]): The direction to sort by. + exclude_relations (list[str]): The relations to exclude. + + Returns: + tuple[str, list] | tuple[str, list, str]: SQL query and parameters for listing the entries. + """ if limit < 1 or limit > 1000: raise HTTPException(status_code=400, detail="Limit must be between 1 and 1000") if offset < 0: From 74add36fd068a2c16942feb74c91d0cf3541489f Mon Sep 17 00:00:00 2001 From: Vedantsahai18 Date: Fri, 20 Dec 2024 21:41:35 +0000 Subject: [PATCH 08/10] refactor: Lint agents-api (CI) --- agents-api/agents_api/queries/docs/list_docs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/agents-api/agents_api/queries/docs/list_docs.py b/agents-api/agents_api/queries/docs/list_docs.py index 8ea196958..bfbc2971e 100644 --- a/agents-api/agents_api/queries/docs/list_docs.py +++ b/agents-api/agents_api/queries/docs/list_docs.py @@ -48,7 +48,7 @@ async def list_docs( """ Lists docs with optional owner filtering, pagination, and sorting. - Parameters: + Parameters: developer_id (UUID): The ID of the developer. owner_id (UUID): The ID of the owner of the documents. owner_type (Literal["user", "agent"]): The type of the owner of the documents. From 249513d6c944f77ff579cb4cd7e51b362483178f Mon Sep 17 00:00:00 2001 From: vedantsahai18 Date: Sat, 21 Dec 2024 03:12:06 -0500 Subject: [PATCH 09/10] chore: updated migrations + added indices support --- .../queries/developers/get_developer.py | 9 +- .../agents_api/queries/docs/__init__.py | 6 +- .../agents_api/queries/docs/create_doc.py | 141 +++++++++++++----- .../agents_api/queries/docs/delete_doc.py | 24 ++- .../agents_api/queries/docs/embed_snippets.py | 37 ----- agents-api/agents_api/queries/docs/get_doc.py | 68 +++++---- .../agents_api/queries/docs/list_docs.py | 96 ++++++++---- .../queries/docs/search_docs_by_embedding.py | 4 - .../queries/docs/search_docs_by_text.py | 76 ++++++---- .../queries/docs/search_docs_hybrid.py | 5 - agents-api/tests/fixtures.py | 23 +-- agents-api/tests/test_docs_queries.py | 72 ++++----- agents-api/tests/test_files_queries.py | 2 +- memory-store/migrations/000006_docs.up.sql | 9 +- .../migrations/000018_doc_search.up.sql | 57 +++---- 15 files changed, 349 insertions(+), 280 deletions(-) delete mode 100644 agents-api/agents_api/queries/docs/embed_snippets.py diff --git a/agents-api/agents_api/queries/developers/get_developer.py b/agents-api/agents_api/queries/developers/get_developer.py index 373a2fb36..79b6e6067 100644 --- a/agents-api/agents_api/queries/developers/get_developer.py +++ b/agents-api/agents_api/queries/developers/get_developer.py @@ -24,9 +24,6 @@ SELECT * FROM developers WHERE developer_id = $1 -- developer_id """).sql(pretty=True) -ModelT = TypeVar("ModelT", bound=Any) -T = TypeVar("T") - @rewrap_exceptions( { @@ -37,7 +34,11 @@ ) } ) -@wrap_in_class(Developer, one=True, transform=lambda d: {**d, "id": d["developer_id"]}) +@wrap_in_class( + Developer, + one=True, + transform=lambda d: {**d, "id": d["developer_id"]}, +) @pg_query @beartype async def get_developer( diff --git a/agents-api/agents_api/queries/docs/__init__.py b/agents-api/agents_api/queries/docs/__init__.py index 75f9516a6..51bab2555 100644 --- a/agents-api/agents_api/queries/docs/__init__.py +++ b/agents-api/agents_api/queries/docs/__init__.py @@ -8,6 +8,7 @@ - Listing documents based on various criteria, including ownership and metadata filters. - Deleting documents by their unique identifiers. - Embedding document snippets for retrieval purposes. +- Searching documents by text. The module interacts with other parts of the application, such as the agents and users modules, to provide a comprehensive document management system. Its role is crucial in enabling document search, retrieval, and management features within the context of agents and users. @@ -22,12 +23,13 @@ from .list_docs import list_docs # from .search_docs_by_embedding import search_docs_by_embedding -# from .search_docs_by_text import search_docs_by_text +from .search_docs_by_text import search_docs_by_text __all__ = [ "create_doc", "delete_doc", "get_doc", "list_docs", - # "search_docs_by_embct", + # "search_docs_by_embedding", + "search_docs_by_text", ] diff --git a/agents-api/agents_api/queries/docs/create_doc.py b/agents-api/agents_api/queries/docs/create_doc.py index 59fd40004..d8bcce7d3 100644 --- a/agents-api/agents_api/queries/docs/create_doc.py +++ b/agents-api/agents_api/queries/docs/create_doc.py @@ -47,15 +47,38 @@ INSERT INTO doc_owners ( developer_id, doc_id, + index, owner_type, owner_id ) - VALUES ($1, $2, $3, $4) + VALUES ($1, $2, $3, $4, $5) RETURNING doc_id ) -SELECT d.* +SELECT DISTINCT ON (docs.doc_id) + docs.doc_id, + docs.developer_id, + docs.title, + array_agg(docs.content ORDER BY docs.index) as content, + array_agg(docs.index ORDER BY docs.index) as indices, + docs.modality, + docs.embedding_model, + docs.embedding_dimensions, + docs.language, + docs.metadata, + docs.created_at + FROM inserted_owner io -JOIN docs d ON d.doc_id = io.doc_id; +JOIN docs ON docs.doc_id = io.doc_id +GROUP BY + docs.doc_id, + docs.developer_id, + docs.title, + docs.modality, + docs.embedding_model, + docs.embedding_dimensions, + docs.language, + docs.metadata, + docs.created_at; """).sql(pretty=True) @@ -82,11 +105,10 @@ Doc, one=True, transform=lambda d: { - **d, "id": d["doc_id"], - "content": ast.literal_eval(d["content"])[0] - if len(ast.literal_eval(d["content"])) == 1 - else ast.literal_eval(d["content"]), + "index": d["indices"][0], + "content": d["content"][0] if len(d["content"]) == 1 else d["content"], + **d, }, ) @increase_counter("create_doc") @@ -97,56 +119,99 @@ async def create_doc( developer_id: UUID, doc_id: UUID | None = None, data: CreateDocRequest, - owner_type: Literal["user", "agent"] | None = None, - owner_id: UUID | None = None, + owner_type: Literal["user", "agent"], + owner_id: UUID, modality: Literal["text", "image", "mixed"] | None = "text", embedding_model: str | None = "voyage-3", embedding_dimensions: int | None = 1024, language: str | None = "english", index: int | None = 0, -) -> list[tuple[str, list] | tuple[str, list, str]]: +) -> list[tuple[str, list, Literal["fetch", "fetchmany", "fetchrow"]]]: """ - Insert a new doc record into Timescale and optionally associate it with an owner. + Insert a new doc record into Timescale and associate it with an owner. Parameters: - owner_type (Literal["user", "agent"]): The type of the owner of the documents. - owner_id (UUID): The ID of the owner of the documents. + developer_id (UUID): The ID of the developer. + doc_id (UUID | None): Optional custom UUID for the document. If not provided, one will be generated. + data (CreateDocRequest): The data for the document. + owner_type (Literal["user", "agent"]): The type of the owner (required). + owner_id (UUID): The ID of the owner (required). modality (Literal["text", "image", "mixed"]): The modality of the documents. embedding_model (str): The model used for embedding. embedding_dimensions (int): The dimensions of the embedding. language (str): The language of the documents. index (int): The index of the documents. - data (CreateDocRequest): The data for the document. Returns: list[tuple[str, list] | tuple[str, list, str]]: SQL query and parameters for creating the document. """ + queries = [] # Generate a UUID if not provided - doc_id = doc_id or uuid7() + current_doc_id = uuid7() if doc_id is None else doc_id - # check if content is a string - if isinstance(data.content, str): - data.content = [data.content] + # Check if content is a list + if isinstance(data.content, list): + final_params_doc = [] + final_params_owner = [] + + for idx, content in enumerate(data.content): + doc_params = [ + developer_id, + current_doc_id, + data.title, + content, + idx, + modality, + embedding_model, + embedding_dimensions, + language, + data.metadata or {}, + ] + final_params_doc.append(doc_params) - # Create the doc record - doc_params = [ - developer_id, - doc_id, - data.title, - str(data.content), - index, - modality, - embedding_model, - embedding_dimensions, - language, - data.metadata or {}, - ] - - queries = [(doc_query, doc_params)] - - # If an owner is specified, associate it: - if owner_type and owner_id: - owner_params = [developer_id, doc_id, owner_type, owner_id] - queries.append((doc_owner_query, owner_params)) + owner_params = [ + developer_id, + current_doc_id, + idx, + owner_type, + owner_id, + ] + final_params_owner.append(owner_params) + + # Add the doc query for each content + queries.append((doc_query, final_params_doc, "fetchmany")) + + # Add the owner query + queries.append((doc_owner_query, final_params_owner, "fetchmany")) + + else: + + # Create the doc record + doc_params = [ + developer_id, + current_doc_id, + data.title, + data.content, + index, + modality, + embedding_model, + embedding_dimensions, + language, + data.metadata or {}, + ] + + owner_params = [ + developer_id, + current_doc_id, + index, + owner_type, + owner_id, + ] + + # Add the doc query for single content + queries.append((doc_query, doc_params, "fetch")) + + # Add the owner query + queries.append((doc_owner_query, owner_params, "fetch")) return queries diff --git a/agents-api/agents_api/queries/docs/delete_doc.py b/agents-api/agents_api/queries/docs/delete_doc.py index 5697ca8d6..b0a9ea1a1 100644 --- a/agents-api/agents_api/queries/docs/delete_doc.py +++ b/agents-api/agents_api/queries/docs/delete_doc.py @@ -16,22 +16,18 @@ DELETE FROM doc_owners WHERE developer_id = $1 AND doc_id = $2 - AND ( - ($3::text IS NULL AND $4::uuid IS NULL) - OR (owner_type = $3 AND owner_id = $4) - ) + AND owner_type = $3 + AND owner_id = $4 ) DELETE FROM docs WHERE developer_id = $1 AND doc_id = $2 - AND ( - $3::text IS NULL OR EXISTS ( - SELECT 1 FROM doc_owners - WHERE developer_id = $1 - AND doc_id = $2 - AND owner_type = $3 - AND owner_id = $4 - ) + AND EXISTS ( + SELECT 1 FROM doc_owners + WHERE developer_id = $1 + AND doc_id = $2 + AND owner_type = $3 + AND owner_id = $4 ) RETURNING doc_id; """).sql(pretty=True) @@ -61,8 +57,8 @@ async def delete_doc( *, developer_id: UUID, doc_id: UUID, - owner_type: Literal["user", "agent"] | None = None, - owner_id: UUID | None = None, + owner_type: Literal["user", "agent"], + owner_id: UUID, ) -> tuple[str, list]: """ Deletes a doc (and associated doc_owners) for the given developer and doc_id. diff --git a/agents-api/agents_api/queries/docs/embed_snippets.py b/agents-api/agents_api/queries/docs/embed_snippets.py deleted file mode 100644 index 1a20d6a34..000000000 --- a/agents-api/agents_api/queries/docs/embed_snippets.py +++ /dev/null @@ -1,37 +0,0 @@ -from typing import Literal -from uuid import UUID - -from beartype import beartype -from sqlglot import parse_one - -from ..utils import pg_query - -# TODO: This is a placeholder for the actual query -vectorizer_query = None - - -@pg_query -@beartype -async def embed_snippets( - *, - developer_id: UUID, - doc_id: UUID, - owner_type: Literal["user", "agent"] | None = None, - owner_id: UUID | None = None, -) -> tuple[str, list]: - """ - Trigger the vectorizer to generate embeddings for documents. - - Parameters: - developer_id (UUID): The ID of the developer. - doc_id (UUID): The ID of the document. - owner_type (Literal["user", "agent"]): The type of the owner of the documents. - owner_id (UUID): The ID of the owner of the documents. - - Returns: - tuple[str, list]: SQL query and parameters for embedding the snippets. - """ - return ( - vectorizer_query, - [developer_id, doc_id, owner_type, owner_id], - ) diff --git a/agents-api/agents_api/queries/docs/get_doc.py b/agents-api/agents_api/queries/docs/get_doc.py index 8575f77b0..3f071cf87 100644 --- a/agents-api/agents_api/queries/docs/get_doc.py +++ b/agents-api/agents_api/queries/docs/get_doc.py @@ -8,35 +8,51 @@ from ...autogen.openapi_model import Doc from ..utils import pg_query, wrap_in_class -# Combined query to fetch document details and embedding +# Update the query to use DISTINCT ON to prevent duplicates doc_with_embedding_query = parse_one(""" -SELECT d.*, e.embedding -FROM docs d -LEFT JOIN doc_owners doc_own - ON d.developer_id = doc_own.developer_id - AND d.doc_id = doc_own.doc_id -LEFT JOIN docs_embeddings e - ON d.doc_id = e.doc_id -WHERE d.developer_id = $1 - AND d.doc_id = $2 - AND ( - ($3::text IS NULL AND $4::uuid IS NULL) - OR (doc_own.owner_type = $3 AND doc_own.owner_id = $4) - ) -LIMIT 1; +WITH doc_data AS ( + SELECT DISTINCT ON (d.doc_id) + d.doc_id, + d.developer_id, + d.title, + array_agg(d.content ORDER BY d.index) as content, + array_agg(d.index ORDER BY d.index) as indices, + array_agg(e.embedding ORDER BY d.index) as embeddings, + d.modality, + d.embedding_model, + d.embedding_dimensions, + d.language, + d.metadata, + d.created_at + FROM docs d + LEFT JOIN docs_embeddings e + ON d.doc_id = e.doc_id + WHERE d.developer_id = $1 + AND d.doc_id = $2 + GROUP BY + d.doc_id, + d.developer_id, + d.title, + d.modality, + d.embedding_model, + d.embedding_dimensions, + d.language, + d.metadata, + d.created_at +) +SELECT * FROM doc_data; """).sql(pretty=True) @wrap_in_class( Doc, - one=True, + one=True, # Changed to True since we're now returning one grouped record transform=lambda d: { - **d, "id": d["doc_id"], - "content": ast.literal_eval(d["content"])[0] - if len(ast.literal_eval(d["content"])) == 1 - else ast.literal_eval(d["content"]), - "embedding": d["embedding"], # Add embedding to the transformation + "index": d["indices"][0], + "content": d["content"][0] if len(d["content"]) == 1 else d["content"], + "embeddings": d["embeddings"][0] if len(d["embeddings"]) == 1 else d["embeddings"], + **d, }, ) @pg_query @@ -45,22 +61,18 @@ async def get_doc( *, developer_id: UUID, doc_id: UUID, - owner_type: Literal["user", "agent"] | None = None, - owner_id: UUID | None = None, ) -> tuple[str, list]: """ - Fetch a single doc with its embedding, optionally constrained to a given owner. - + Fetch a single doc with its embedding, grouping all content chunks and embeddings. + Parameters: developer_id (UUID): The ID of the developer. doc_id (UUID): The ID of the document. - owner_type (Literal["user", "agent"]): The type of the owner of the documents. - owner_id (UUID): The ID of the owner of the documents. Returns: tuple[str, list]: SQL query and parameters for fetching the document. """ return ( doc_with_embedding_query, - [developer_id, doc_id, owner_type, owner_id], + [developer_id, doc_id], ) diff --git a/agents-api/agents_api/queries/docs/list_docs.py b/agents-api/agents_api/queries/docs/list_docs.py index bfbc2971e..2b31df250 100644 --- a/agents-api/agents_api/queries/docs/list_docs.py +++ b/agents-api/agents_api/queries/docs/list_docs.py @@ -1,34 +1,82 @@ -import ast +""" +This module contains the functionality for listing documents from the PostgreSQL database. +It constructs and executes SQL queries to fetch document details based on various filters. +""" + from typing import Any, Literal from uuid import UUID +import asyncpg from beartype import beartype from fastapi import HTTPException from sqlglot import parse_one from ...autogen.openapi_model import Doc -from ..utils import pg_query, wrap_in_class +from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class -# Base query for listing docs with optional embeddings +# Base query for listing docs with aggregated content and embeddings base_docs_query = parse_one(""" -SELECT d.*, CASE WHEN $2 THEN NULL ELSE e.embedding END AS embedding -FROM docs d -LEFT JOIN doc_owners doc_own ON d.developer_id = doc_own.developer_id AND d.doc_id = doc_own.doc_id -LEFT JOIN docs_embeddings e ON d.doc_id = e.doc_id -WHERE d.developer_id = $1 +WITH doc_data AS ( + SELECT DISTINCT ON (d.doc_id) + d.doc_id, + d.developer_id, + d.title, + array_agg(d.content ORDER BY d.index) as content, + array_agg(d.index ORDER BY d.index) as indices, + array_agg(CASE WHEN $2 THEN NULL ELSE e.embedding END ORDER BY d.index) as embeddings, + d.modality, + d.embedding_model, + d.embedding_dimensions, + d.language, + d.metadata, + d.created_at + FROM docs d + JOIN doc_owners doc_own + ON d.developer_id = doc_own.developer_id + AND d.doc_id = doc_own.doc_id + LEFT JOIN docs_embeddings e + ON d.doc_id = e.doc_id + WHERE d.developer_id = $1 + AND doc_own.owner_type = $3 + AND doc_own.owner_id = $4 + GROUP BY + d.doc_id, + d.developer_id, + d.title, + d.modality, + d.embedding_model, + d.embedding_dimensions, + d.language, + d.metadata, + d.created_at +) +SELECT * FROM doc_data """).sql(pretty=True) +@rewrap_exceptions( + { + asyncpg.NoDataFoundError: partialclass( + HTTPException, + status_code=404, + detail="No documents found", + ), + asyncpg.ForeignKeyViolationError: partialclass( + HTTPException, + status_code=404, + detail="The specified developer or owner does not exist", + ), + } +) @wrap_in_class( Doc, one=False, transform=lambda d: { - **d, "id": d["doc_id"], - "content": ast.literal_eval(d["content"])[0] - if len(ast.literal_eval(d["content"])) == 1 - else ast.literal_eval(d["content"]), - "embedding": d.get("embedding"), # Add embedding to the transformation + "index": d["indices"][0], + "content": d["content"][0] if len(d["content"]) == 1 else d["content"], + "embedding": d["embeddings"][0] if d.get("embeddings") and len(d["embeddings"]) == 1 else d.get("embeddings"), + **d, }, ) @pg_query @@ -36,8 +84,8 @@ async def list_docs( *, developer_id: UUID, - owner_id: UUID | None = None, - owner_type: Literal["user", "agent"] | None = None, + owner_id: UUID, + owner_type: Literal["user", "agent"], limit: int = 100, offset: int = 0, sort_by: Literal["created_at", "updated_at"] = "created_at", @@ -46,12 +94,12 @@ async def list_docs( include_without_embeddings: bool = False, ) -> tuple[str, list]: """ - Lists docs with optional owner filtering, pagination, and sorting. + Lists docs with pagination and sorting, aggregating content chunks and embeddings. Parameters: developer_id (UUID): The ID of the developer. - owner_id (UUID): The ID of the owner of the documents. - owner_type (Literal["user", "agent"]): The type of the owner of the documents. + owner_id (UUID): The ID of the owner of the documents (required). + owner_type (Literal["user", "agent"]): The type of the owner of the documents (required). limit (int): The number of documents to return. offset (int): The number of documents to skip. sort_by (Literal["created_at", "updated_at"]): The field to sort by. @@ -61,6 +109,9 @@ async def list_docs( Returns: tuple[str, list]: SQL query and parameters for listing the documents. + + Raises: + HTTPException: If invalid parameters are provided. """ if direction.lower() not in ["asc", "desc"]: raise HTTPException(status_code=400, detail="Invalid sort direction") @@ -76,17 +127,12 @@ async def list_docs( # Start with the base query query = base_docs_query - params = [developer_id, include_without_embeddings] - - # Add owner filtering - if owner_type and owner_id: - query += " AND doc_own.owner_type = $3 AND doc_own.owner_id = $4" - params.extend([owner_type, owner_id]) + params = [developer_id, include_without_embeddings, owner_type, owner_id] # Add metadata filtering if metadata_filter: for key, value in metadata_filter.items(): - query += f" AND d.metadata->>'{key}' = ${len(params) + 1}" + query += f" AND metadata->>'{key}' = ${len(params) + 1}" params.append(value) # Add sorting and pagination diff --git a/agents-api/agents_api/queries/docs/search_docs_by_embedding.py b/agents-api/agents_api/queries/docs/search_docs_by_embedding.py index c7b15ee64..5a89803ee 100644 --- a/agents-api/agents_api/queries/docs/search_docs_by_embedding.py +++ b/agents-api/agents_api/queries/docs/search_docs_by_embedding.py @@ -1,7 +1,3 @@ -""" -Timescale-based doc embedding search using the `embedding` column. -""" - from typing import List, Literal from uuid import UUID diff --git a/agents-api/agents_api/queries/docs/search_docs_by_text.py b/agents-api/agents_api/queries/docs/search_docs_by_text.py index 0ab309ee8..79f9ac305 100644 --- a/agents-api/agents_api/queries/docs/search_docs_by_text.py +++ b/agents-api/agents_api/queries/docs/search_docs_by_text.py @@ -1,35 +1,36 @@ -""" -Timescale-based doc text search using the `search_tsv` column. -""" - -from typing import Literal +from typing import Any, Literal, List from uuid import UUID from beartype import beartype from fastapi import HTTPException from sqlglot import parse_one +import asyncpg +import json from ...autogen.openapi_model import DocReference -from ..utils import pg_query, wrap_in_class +from ..utils import pg_query, wrap_in_class, rewrap_exceptions, partialclass -search_docs_text_query = parse_one(""" -SELECT d.*, - ts_rank_cd(d.search_tsv, websearch_to_tsquery($3)) AS rank -FROM docs d -LEFT JOIN doc_owners do - ON d.developer_id = do.developer_id - AND d.doc_id = do.doc_id -WHERE d.developer_id = $1 - AND ( - ($4 IS NULL AND $5 IS NULL) - OR (do.owner_type = $4 AND do.owner_id = $5) - ) - AND d.search_tsv @@ websearch_to_tsquery($3) -ORDER BY rank DESC -LIMIT $2; -""").sql(pretty=True) +search_docs_text_query = ( + """ + SELECT * FROM search_by_text( + $1, -- developer_id + $2, -- query + $3, -- owner_types + ( SELECT array_agg(*)::UUID[] FROM jsonb_array_elements($4) ) + ) + """ +) +@rewrap_exceptions( + { + asyncpg.UniqueViolationError: partialclass( + HTTPException, + status_code=404, + detail="The specified developer does not exist.", + ) + } +) @wrap_in_class( DocReference, transform=lambda d: { @@ -41,15 +42,16 @@ **d, }, ) -@pg_query +@pg_query(debug=True) @beartype async def search_docs_by_text( *, developer_id: UUID, + owners: list[tuple[Literal["user", "agent"], UUID]], query: str, - k: int = 10, - owner_type: Literal["user", "agent", "org"] | None = None, - owner_id: UUID | None = None, + k: int = 3, + metadata_filter: dict[str, Any] = {}, + search_language: str | None = "english", ) -> tuple[str, list]: """ Full-text search on docs using the search_tsv column. @@ -57,9 +59,11 @@ async def search_docs_by_text( Parameters: developer_id (UUID): The ID of the developer. query (str): The text to search for. - k (int): The number of results to return. - owner_type (Literal["user", "agent", "org"]): The type of the owner of the documents. - owner_id (UUID): The ID of the owner of the documents. + owners (list[tuple[Literal["user", "agent"], UUID]]): List of (owner_type, owner_id) tuples. + k (int): Maximum number of results to return. + search_language (str): Language for text search (default: "english"). + metadata_filter (dict): Metadata filter criteria. + connection_pool (asyncpg.Pool): Database connection pool. Returns: tuple[str, list]: SQL query and parameters for searching the documents. @@ -67,7 +71,19 @@ async def search_docs_by_text( if k < 1: raise HTTPException(status_code=400, detail="k must be >= 1") + # Extract owner types and IDs + owner_types = [owner[0] for owner in owners] + owner_ids = [owner[1] for owner in owners] + return ( search_docs_text_query, - [developer_id, k, query, owner_type, owner_id], + [ + developer_id, + query, + owner_types, + owner_ids, + search_language, + k, + metadata_filter, + ], ) diff --git a/agents-api/agents_api/queries/docs/search_docs_hybrid.py b/agents-api/agents_api/queries/docs/search_docs_hybrid.py index a879e3b6b..184ba7e8e 100644 --- a/agents-api/agents_api/queries/docs/search_docs_hybrid.py +++ b/agents-api/agents_api/queries/docs/search_docs_hybrid.py @@ -1,8 +1,3 @@ -""" -Hybrid doc search that merges text search and embedding search results -via a simple distribution-based score fusion or direct weighting in Python. -""" - from typing import List, Literal from uuid import UUID diff --git a/agents-api/tests/fixtures.py b/agents-api/tests/fixtures.py index 2f7de580e..a34c7e2aa 100644 --- a/agents-api/tests/fixtures.py +++ b/agents-api/tests/fixtures.py @@ -63,23 +63,6 @@ def test_developer_id(): developer_id = uuid7() return developer_id - -# @fixture(scope="global") -# async def test_file(dsn=pg_dsn, developer_id=test_developer_id): -# async with get_pg_client(dsn=dsn) as client: -# file = await create_file( -# developer_id=developer_id, -# data=CreateFileRequest( -# name="Hello", -# description="World", -# mime_type="text/plain", -# content="eyJzYW1wbGUiOiAidGVzdCJ9", -# ), -# client=client, -# ) -# yield file - - @fixture(scope="global") async def test_developer(dsn=pg_dsn, developer_id=test_developer_id): pool = await create_db_pool(dsn=dsn) @@ -150,16 +133,18 @@ async def test_file(dsn=pg_dsn, developer=test_developer, user=test_user): @fixture(scope="test") -async def test_doc(dsn=pg_dsn, developer=test_developer): +async def test_doc(dsn=pg_dsn, developer=test_developer, agent=test_agent): pool = await create_db_pool(dsn=dsn) doc = await create_doc( developer_id=developer.id, data=CreateDocRequest( title="Hello", - content=["World"], + content=["World", "World2", "World3"], metadata={"test": "test"}, embed_instruction="Embed the document", ), + owner_type="agent", + owner_id=agent.id, connection_pool=pool, ) return doc diff --git a/agents-api/tests/test_docs_queries.py b/agents-api/tests/test_docs_queries.py index 1410c88c9..71553ee83 100644 --- a/agents-api/tests/test_docs_queries.py +++ b/agents-api/tests/test_docs_queries.py @@ -8,36 +8,13 @@ from agents_api.queries.docs.list_docs import list_docs # If you wish to test text/embedding/hybrid search, import them: -# from agents_api.queries.docs.search_docs_by_text import search_docs_by_text +from agents_api.queries.docs.search_docs_by_text import search_docs_by_text # from agents_api.queries.docs.search_docs_by_embedding import search_docs_by_embedding # from agents_api.queries.docs.search_docs_hybrid import search_docs_hybrid # You can rename or remove these imports to match your actual fixtures from tests.fixtures import pg_dsn, test_agent, test_developer, test_doc, test_user -@test("query: create doc") -async def _(dsn=pg_dsn, developer=test_developer): - pool = await create_db_pool(dsn=dsn) - doc = await create_doc( - developer_id=developer.id, - data=CreateDocRequest( - title="Hello Doc", - content="This is sample doc content", - embed_instruction="Embed the document", - metadata={"test": "test"}, - ), - connection_pool=pool, - ) - - assert doc.title == "Hello Doc" - assert doc.content == "This is sample doc content" - assert doc.modality == "text" - assert doc.embedding_model == "voyage-3" - assert doc.embedding_dimensions == 1024 - assert doc.language == "english" - assert doc.index == 0 - - @test("query: create user doc") async def _(dsn=pg_dsn, developer=test_developer, user=test_user): pool = await create_db_pool(dsn=dsn) @@ -92,7 +69,7 @@ async def _(dsn=pg_dsn, developer=test_developer, agent=test_agent): assert any(d.id == doc.id for d in docs_list) -@test("model: get doc") +@test("query: get doc") async def _(dsn=pg_dsn, developer=test_developer, doc=test_doc): pool = await create_db_pool(dsn=dsn) doc_test = await get_doc( @@ -102,18 +79,7 @@ async def _(dsn=pg_dsn, developer=test_developer, doc=test_doc): ) assert doc_test.id == doc.id assert doc_test.title == doc.title - - -@test("query: list docs") -async def _(dsn=pg_dsn, developer=test_developer, doc=test_doc): - pool = await create_db_pool(dsn=dsn) - docs_list = await list_docs( - developer_id=developer.id, - connection_pool=pool, - ) - assert len(docs_list) >= 1 - assert any(d.id == doc.id for d in docs_list) - + assert doc_test.content == doc.content @test("query: list user docs") async def _(dsn=pg_dsn, developer=test_developer, user=test_user): @@ -246,12 +212,34 @@ async def _(dsn=pg_dsn, developer=test_developer, agent=test_agent): ) assert not any(d.id == doc_agent.id for d in docs_list) - -@test("query: delete doc") -async def _(dsn=pg_dsn, developer=test_developer, doc=test_doc): +@test("query: search docs by text") +async def _(dsn=pg_dsn, agent=test_agent, developer=test_developer): pool = await create_db_pool(dsn=dsn) - await delete_doc( + + # Create a test document + await create_doc( developer_id=developer.id, - doc_id=doc.id, + owner_type="agent", + owner_id=agent.id, + data=CreateDocRequest( + title="Hello", + content="The world is a funny little thing", + metadata={"test": "test"}, + embed_instruction="Embed the document", + ), connection_pool=pool, ) + + # Search using the correct parameter types + result = await search_docs_by_text( + developer_id=developer.id, + owners=[("agent", agent.id)], + query="funny", + k=3, # Add k parameter + search_language="english", # Add language parameter + metadata_filter={}, # Add metadata filter + connection_pool=pool, + ) + + assert len(result) >= 1 + assert result[0].metadata is not None \ No newline at end of file diff --git a/agents-api/tests/test_files_queries.py b/agents-api/tests/test_files_queries.py index c83c7a6f6..68409ef5c 100644 --- a/agents-api/tests/test_files_queries.py +++ b/agents-api/tests/test_files_queries.py @@ -82,7 +82,7 @@ async def _(dsn=pg_dsn, developer=test_developer, agent=test_agent): assert any(f.id == file.id for f in files) -@test("model: get file") +@test("query: get file") async def _(dsn=pg_dsn, file=test_file, developer=test_developer): pool = await create_db_pool(dsn=dsn) file_test = await get_file( diff --git a/memory-store/migrations/000006_docs.up.sql b/memory-store/migrations/000006_docs.up.sql index 193fae122..97bdad43c 100644 --- a/memory-store/migrations/000006_docs.up.sql +++ b/memory-store/migrations/000006_docs.up.sql @@ -24,8 +24,7 @@ CREATE TABLE IF NOT EXISTS docs ( created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, updated_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, metadata JSONB NOT NULL DEFAULT '{}'::JSONB, - CONSTRAINT pk_docs PRIMARY KEY (developer_id, doc_id), - CONSTRAINT uq_docs_doc_id_index UNIQUE (doc_id, index), + CONSTRAINT pk_docs PRIMARY KEY (developer_id, doc_id, index), CONSTRAINT ct_docs_embedding_dimensions_positive CHECK (embedding_dimensions > 0), CONSTRAINT ct_docs_valid_modality CHECK (modality IN ('text', 'image', 'mixed')), CONSTRAINT ct_docs_index_positive CHECK (index >= 0), @@ -67,10 +66,12 @@ END $$; CREATE TABLE IF NOT EXISTS doc_owners ( developer_id UUID NOT NULL, doc_id UUID NOT NULL, + index INTEGER NOT NULL, owner_type TEXT NOT NULL, -- 'user' or 'agent' owner_id UUID NOT NULL, - CONSTRAINT pk_doc_owners PRIMARY KEY (developer_id, doc_id), - CONSTRAINT fk_doc_owners_doc FOREIGN KEY (developer_id, doc_id) REFERENCES docs (developer_id, doc_id), + CONSTRAINT pk_doc_owners PRIMARY KEY (developer_id, doc_id, index), + -- TODO: Add foreign key constraint + -- CONSTRAINT fk_doc_owners_doc FOREIGN KEY (developer_id, doc_id) REFERENCES docs (developer_id, doc_id), CONSTRAINT ct_doc_owners_owner_type CHECK (owner_type IN ('user', 'agent')) ); diff --git a/memory-store/migrations/000018_doc_search.up.sql b/memory-store/migrations/000018_doc_search.up.sql index 5293cc81a..2f5b2baf1 100644 --- a/memory-store/migrations/000018_doc_search.up.sql +++ b/memory-store/migrations/000018_doc_search.up.sql @@ -101,6 +101,7 @@ END $$; -- Create the search function CREATE OR REPLACE FUNCTION search_by_vector ( + developer_id UUID, query_embedding vector (1024), owner_types TEXT[], owner_ids UUID [], @@ -134,9 +135,7 @@ BEGIN IF owner_types IS NOT NULL AND owner_ids IS NOT NULL THEN owner_filter_sql := ' AND ( - (ud.user_id = ANY($5) AND ''user'' = ANY($4)) - OR - (ad.agent_id = ANY($5) AND ''agent'' = ANY($4)) + doc_owners.owner_id = ANY($5::uuid[]) AND doc_owners.owner_type = ANY($4::text[]) )'; ELSE owner_filter_sql := ''; @@ -153,6 +152,7 @@ BEGIN RETURN QUERY EXECUTE format( 'WITH ranked_docs AS ( SELECT + d.developer_id, d.doc_id, d.index, d.title, @@ -160,15 +160,12 @@ BEGIN (1 - (d.embedding <=> $1)) as distance, d.embedding, d.metadata, - CASE - WHEN ud.user_id IS NOT NULL THEN ''user'' - WHEN ad.agent_id IS NOT NULL THEN ''agent'' - END as owner_type, - COALESCE(ud.user_id, ad.agent_id) as owner_id + doc_owners.owner_type, + doc_owners.owner_id FROM docs_embeddings d - LEFT JOIN user_docs ud ON d.doc_id = ud.doc_id - LEFT JOIN agent_docs ad ON d.doc_id = ad.doc_id - WHERE 1 - (d.embedding <=> $1) >= $2 + LEFT JOIN doc_owners ON d.doc_id = doc_owners.doc_id + WHERE d.developer_id = $7 + AND 1 - (d.embedding <=> $1) >= $2 %s %s ) @@ -185,7 +182,9 @@ BEGIN k, owner_types, owner_ids, - metadata_filter; + metadata_filter, + developer_id; + END; $$; @@ -238,6 +237,7 @@ COMMENT ON FUNCTION embed_and_search_by_vector IS 'Convenience function that com -- Create the text search function CREATE OR REPLACE FUNCTION search_by_text ( + developer_id UUID, query_text text, owner_types TEXT[], owner_ids UUID [], @@ -267,9 +267,7 @@ BEGIN IF owner_types IS NOT NULL AND owner_ids IS NOT NULL THEN owner_filter_sql := ' AND ( - (ud.user_id = ANY($5) AND ''user'' = ANY($4)) - OR - (ad.agent_id = ANY($5) AND ''agent'' = ANY($4)) + doc_owners.owner_id = ANY($5::uuid[]) AND doc_owners.owner_type = ANY($4::text[]) )'; ELSE owner_filter_sql := ''; @@ -286,6 +284,7 @@ BEGIN RETURN QUERY EXECUTE format( 'WITH ranked_docs AS ( SELECT + d.developer_id, d.doc_id, d.index, d.title, @@ -293,15 +292,12 @@ BEGIN ts_rank_cd(d.search_tsv, $1, 32)::double precision as distance, d.embedding, d.metadata, - CASE - WHEN ud.user_id IS NOT NULL THEN ''user'' - WHEN ad.agent_id IS NOT NULL THEN ''agent'' - END as owner_type, - COALESCE(ud.user_id, ad.agent_id) as owner_id + doc_owners.owner_type, + doc_owners.owner_id FROM docs_embeddings d - LEFT JOIN user_docs ud ON d.doc_id = ud.doc_id - LEFT JOIN agent_docs ad ON d.doc_id = ad.doc_id - WHERE d.search_tsv @@ $1 + LEFT JOIN doc_owners ON d.doc_id = doc_owners.doc_id + WHERE d.developer_id = $6 + AND d.search_tsv @@ $1 %s %s ) @@ -314,11 +310,11 @@ BEGIN ) USING ts_query, - search_language, k, owner_types, owner_ids, - metadata_filter; + metadata_filter, + developer_id; END; $$; @@ -372,6 +368,7 @@ $$ LANGUAGE plpgsql; -- Hybrid search function combining text and vector search CREATE OR REPLACE FUNCTION search_hybrid ( + developer_id UUID, query_text text, query_embedding vector (1024), owner_types TEXT[], @@ -397,6 +394,7 @@ BEGIN RETURN QUERY WITH text_results AS ( SELECT * FROM search_by_text( + developer_id, query_text, owner_types, owner_ids, @@ -407,6 +405,7 @@ BEGIN ), embedding_results AS ( SELECT * FROM search_by_vector( + developer_id, query_embedding, owner_types, owner_ids, @@ -426,6 +425,7 @@ BEGIN ), scores AS ( SELECT + r.developer_id, r.doc_id, r.title, r.content, @@ -437,8 +437,8 @@ BEGIN COALESCE(t.distance, 0.0) as text_score, COALESCE(e.distance, 0.0) as embedding_score FROM all_results r - LEFT JOIN text_results t ON r.doc_id = t.doc_id - LEFT JOIN embedding_results e ON r.doc_id = e.doc_id + LEFT JOIN text_results t ON r.doc_id = t.doc_id AND r.developer_id = t.developer_id + LEFT JOIN embedding_results e ON r.doc_id = e.doc_id AND r.developer_id = e.developer_id ), normalized_scores AS ( SELECT @@ -448,6 +448,7 @@ BEGIN FROM scores ) SELECT + developer_id, doc_id, index, title, @@ -468,6 +469,7 @@ COMMENT ON FUNCTION search_hybrid IS 'Hybrid search combining text and vector se -- Convenience function that handles embedding generation CREATE OR REPLACE FUNCTION embed_and_search_hybrid ( + developer_id UUID, query_text text, owner_types TEXT[], owner_ids UUID [], @@ -497,6 +499,7 @@ BEGIN -- Perform hybrid search RETURN QUERY SELECT * FROM search_hybrid( + developer_id, query_text, query_embedding, owner_types, From d7d9cd49f83b6606c0c6bd2aa68cd1c044eae5cb Mon Sep 17 00:00:00 2001 From: Vedantsahai18 Date: Sat, 21 Dec 2024 08:13:04 +0000 Subject: [PATCH 10/10] refactor: Lint agents-api (CI) --- agents-api/agents_api/queries/docs/create_doc.py | 3 +-- agents-api/agents_api/queries/docs/get_doc.py | 6 ++++-- agents-api/agents_api/queries/docs/list_docs.py | 4 +++- .../queries/docs/search_docs_by_text.py | 16 +++++++--------- agents-api/tests/fixtures.py | 1 + agents-api/tests/test_docs_queries.py | 9 ++++++--- 6 files changed, 22 insertions(+), 17 deletions(-) diff --git a/agents-api/agents_api/queries/docs/create_doc.py b/agents-api/agents_api/queries/docs/create_doc.py index d8bcce7d3..d3c2fe3c1 100644 --- a/agents-api/agents_api/queries/docs/create_doc.py +++ b/agents-api/agents_api/queries/docs/create_doc.py @@ -153,7 +153,7 @@ async def create_doc( if isinstance(data.content, list): final_params_doc = [] final_params_owner = [] - + for idx, content in enumerate(data.content): doc_params = [ developer_id, @@ -185,7 +185,6 @@ async def create_doc( queries.append((doc_owner_query, final_params_owner, "fetchmany")) else: - # Create the doc record doc_params = [ developer_id, diff --git a/agents-api/agents_api/queries/docs/get_doc.py b/agents-api/agents_api/queries/docs/get_doc.py index 3f071cf87..1cee8f354 100644 --- a/agents-api/agents_api/queries/docs/get_doc.py +++ b/agents-api/agents_api/queries/docs/get_doc.py @@ -51,7 +51,9 @@ "id": d["doc_id"], "index": d["indices"][0], "content": d["content"][0] if len(d["content"]) == 1 else d["content"], - "embeddings": d["embeddings"][0] if len(d["embeddings"]) == 1 else d["embeddings"], + "embeddings": d["embeddings"][0] + if len(d["embeddings"]) == 1 + else d["embeddings"], **d, }, ) @@ -64,7 +66,7 @@ async def get_doc( ) -> tuple[str, list]: """ Fetch a single doc with its embedding, grouping all content chunks and embeddings. - + Parameters: developer_id (UUID): The ID of the developer. doc_id (UUID): The ID of the document. diff --git a/agents-api/agents_api/queries/docs/list_docs.py b/agents-api/agents_api/queries/docs/list_docs.py index 2b31df250..9788b0daa 100644 --- a/agents-api/agents_api/queries/docs/list_docs.py +++ b/agents-api/agents_api/queries/docs/list_docs.py @@ -75,7 +75,9 @@ "id": d["doc_id"], "index": d["indices"][0], "content": d["content"][0] if len(d["content"]) == 1 else d["content"], - "embedding": d["embeddings"][0] if d.get("embeddings") and len(d["embeddings"]) == 1 else d.get("embeddings"), + "embedding": d["embeddings"][0] + if d.get("embeddings") and len(d["embeddings"]) == 1 + else d.get("embeddings"), **d, }, ) diff --git a/agents-api/agents_api/queries/docs/search_docs_by_text.py b/agents-api/agents_api/queries/docs/search_docs_by_text.py index 79f9ac305..9c22a60ce 100644 --- a/agents-api/agents_api/queries/docs/search_docs_by_text.py +++ b/agents-api/agents_api/queries/docs/search_docs_by_text.py @@ -1,17 +1,16 @@ -from typing import Any, Literal, List +import json +from typing import Any, List, Literal from uuid import UUID +import asyncpg from beartype import beartype from fastapi import HTTPException from sqlglot import parse_one -import asyncpg -import json from ...autogen.openapi_model import DocReference -from ..utils import pg_query, wrap_in_class, rewrap_exceptions, partialclass +from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class -search_docs_text_query = ( - """ +search_docs_text_query = """ SELECT * FROM search_by_text( $1, -- developer_id $2, -- query @@ -19,7 +18,6 @@ ( SELECT array_agg(*)::UUID[] FROM jsonb_array_elements($4) ) ) """ -) @rewrap_exceptions( @@ -74,10 +72,10 @@ async def search_docs_by_text( # Extract owner types and IDs owner_types = [owner[0] for owner in owners] owner_ids = [owner[1] for owner in owners] - + return ( search_docs_text_query, - [ + [ developer_id, query, owner_types, diff --git a/agents-api/tests/fixtures.py b/agents-api/tests/fixtures.py index a34c7e2aa..2ad6bfeeb 100644 --- a/agents-api/tests/fixtures.py +++ b/agents-api/tests/fixtures.py @@ -63,6 +63,7 @@ def test_developer_id(): developer_id = uuid7() return developer_id + @fixture(scope="global") async def test_developer(dsn=pg_dsn, developer_id=test_developer_id): pool = await create_db_pool(dsn=dsn) diff --git a/agents-api/tests/test_docs_queries.py b/agents-api/tests/test_docs_queries.py index 71553ee83..82490cb77 100644 --- a/agents-api/tests/test_docs_queries.py +++ b/agents-api/tests/test_docs_queries.py @@ -9,6 +9,7 @@ # If you wish to test text/embedding/hybrid search, import them: from agents_api.queries.docs.search_docs_by_text import search_docs_by_text + # from agents_api.queries.docs.search_docs_by_embedding import search_docs_by_embedding # from agents_api.queries.docs.search_docs_hybrid import search_docs_hybrid # You can rename or remove these imports to match your actual fixtures @@ -81,6 +82,7 @@ async def _(dsn=pg_dsn, developer=test_developer, doc=test_doc): assert doc_test.title == doc.title assert doc_test.content == doc.content + @test("query: list user docs") async def _(dsn=pg_dsn, developer=test_developer, user=test_user): pool = await create_db_pool(dsn=dsn) @@ -212,17 +214,18 @@ async def _(dsn=pg_dsn, developer=test_developer, agent=test_agent): ) assert not any(d.id == doc_agent.id for d in docs_list) + @test("query: search docs by text") async def _(dsn=pg_dsn, agent=test_agent, developer=test_developer): pool = await create_db_pool(dsn=dsn) - + # Create a test document await create_doc( developer_id=developer.id, owner_type="agent", owner_id=agent.id, data=CreateDocRequest( - title="Hello", + title="Hello", content="The world is a funny little thing", metadata={"test": "test"}, embed_instruction="Embed the document", @@ -242,4 +245,4 @@ async def _(dsn=pg_dsn, agent=test_agent, developer=test_developer): ) assert len(result) >= 1 - assert result[0].metadata is not None \ No newline at end of file + assert result[0].metadata is not None