Skip to content

Commit

Permalink
fix: fixed the CRD doc queries + added tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Vedantsahai18 committed Dec 20, 2024
1 parent b427e38 commit 93673b7
Show file tree
Hide file tree
Showing 21 changed files with 454 additions and 326 deletions.
24 changes: 24 additions & 0 deletions agents-api/agents_api/autogen/Docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,30 @@ class Doc(BaseModel):
"""
Embeddings for the document
"""
modality: Annotated[str | None, Field(json_schema_extra={"readOnly": True})] = None
"""
Modality of the document
"""
language: Annotated[str | None, Field(json_schema_extra={"readOnly": True})] = None
"""
Language of the document
"""
index: Annotated[int | None, Field(json_schema_extra={"readOnly": True})] = None
"""
Index of the document
"""
embedding_model: Annotated[
str | None, Field(json_schema_extra={"readOnly": True})
] = None
"""
Embedding model to use for the document
"""
embedding_dimensions: Annotated[
int | None, Field(json_schema_extra={"readOnly": True})
] = None
"""
Dimensions of the embedding model
"""


class DocOwner(BaseModel):
Expand Down
13 changes: 10 additions & 3 deletions agents-api/agents_api/queries/docs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,15 @@

from .create_doc import create_doc
from .delete_doc import delete_doc
from .embed_snippets import embed_snippets
from .get_doc import get_doc
from .list_docs import list_docs
from .search_docs_by_embedding import search_docs_by_embedding
from .search_docs_by_text import search_docs_by_text
# from .search_docs_by_embedding import search_docs_by_embedding
# from .search_docs_by_text import search_docs_by_text

__all__ = [
"create_doc",
"delete_doc",
"get_doc",
"list_docs",
# "search_docs_by_embct",
]
40 changes: 22 additions & 18 deletions agents-api/agents_api/queries/docs/create_doc.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,4 @@
"""
Timescale-based creation of docs.
Mirrors the structure of create_file.py, but uses the docs/doc_owners tables.
"""

import base64
import hashlib
from typing import Any, Literal
from typing import Literal
from uuid import UUID

import asyncpg
Expand All @@ -15,6 +7,9 @@
from sqlglot import parse_one
from uuid_extensions import uuid7

import ast


from ...autogen.openapi_model import CreateDocRequest, Doc
from ...metrics.counters import increase_counter
from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class
Expand Down Expand Up @@ -91,7 +86,7 @@
transform=lambda d: {
**d,
"id": d["doc_id"],
# You could optionally return a computed hash or partial content if desired
"content": ast.literal_eval(d["content"])[0] if len(ast.literal_eval(d["content"])) == 1 else ast.literal_eval(d["content"]),
},
)
@increase_counter("create_doc")
Expand All @@ -102,26 +97,35 @@ async def create_doc(
developer_id: UUID,
doc_id: UUID | None = None,
data: CreateDocRequest,
owner_type: Literal["user", "agent", "org"] | None = None,
owner_type: Literal["user", "agent"] | None = None,
owner_id: UUID | None = None,
) -> list[tuple[str, list]]:
modality: Literal["text", "image", "mixed"] | None = "text",
embedding_model: str | None = "voyage-3",
embedding_dimensions: int | None = 1024,
language: str | None = "english",
index: int | None = 0,
) -> list[tuple[str, list] | tuple[str, list, str]]:
"""
Insert a new doc record into Timescale and optionally associate it with an owner.
"""
# Generate a UUID if not provided
doc_id = doc_id or uuid7()

# check if content is a string
if isinstance(data.content, str):
data.content = [data.content]

# Create the doc record
doc_params = [
developer_id,
doc_id,
data.title,
data.content,
data.index or 0, # fallback if no snippet index
data.modality or "text",
data.embedding_model or "none",
data.embedding_dimensions or 0,
data.language or "english",
str(data.content),
index,
modality,
embedding_model,
embedding_dimensions,
language,
data.metadata or {},
]

Expand Down
6 changes: 1 addition & 5 deletions agents-api/agents_api/queries/docs/delete_doc.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
"""
Timescale-based deletion of a doc record.
"""

from typing import Literal
from uuid import UUID

Expand Down Expand Up @@ -65,7 +61,7 @@ async def delete_doc(
*,
developer_id: UUID,
doc_id: UUID,
owner_type: Literal["user", "agent", "org"] | None = None,
owner_type: Literal["user", "agent"] | None = None,
owner_id: UUID | None = None,
) -> tuple[str, list]:
"""
Expand Down
15 changes: 6 additions & 9 deletions agents-api/agents_api/queries/docs/get_doc.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,22 @@
"""
Timescale-based retrieval of a single doc record.
"""

from typing import Literal
from uuid import UUID

import asyncpg
from beartype import beartype
from fastapi import HTTPException
from sqlglot import parse_one
import ast

from ...autogen.openapi_model import Doc
from ..utils import pg_query, wrap_in_class

doc_query = parse_one("""
SELECT d.*
FROM docs d
LEFT JOIN doc_owners do ON d.developer_id = do.developer_id AND d.doc_id = do.doc_id
LEFT JOIN doc_owners doc_own ON d.developer_id = doc_own.developer_id AND d.doc_id = doc_own.doc_id
WHERE d.developer_id = $1
AND d.doc_id = $2
AND (
($3::text IS NULL AND $4::uuid IS NULL)
OR (do.owner_type = $3 AND do.owner_id = $4)
OR (doc_own.owner_type = $3 AND doc_own.owner_id = $4)
)
LIMIT 1;
""").sql(pretty=True)
Expand All @@ -33,6 +28,8 @@
transform=lambda d: {
**d,
"id": d["doc_id"],
"content": ast.literal_eval(d["content"])[0] if len(ast.literal_eval(d["content"])) == 1 else ast.literal_eval(d["content"]),
# "embeddings": d["embeddings"],
},
)
@pg_query
Expand All @@ -41,7 +38,7 @@ async def get_doc(
*,
developer_id: UUID,
doc_id: UUID,
owner_type: Literal["user", "agent", "org"] | None = None,
owner_type: Literal["user", "agent"] | None = None,
owner_id: UUID | None = None,
) -> tuple[str, list]:
"""
Expand Down
81 changes: 36 additions & 45 deletions agents-api/agents_api/queries/docs/list_docs.py
Original file line number Diff line number Diff line change
@@ -1,52 +1,20 @@
"""
Timescale-based listing of docs with optional owner filter and pagination.
"""

from typing import Literal
from typing import Any, Literal
from uuid import UUID

import asyncpg
from beartype import beartype
from fastapi import HTTPException
from sqlglot import parse_one
import ast

from ...autogen.openapi_model import Doc
from ..utils import pg_query, wrap_in_class

# Basic listing for all docs by developer
developer_docs_query = parse_one("""
# Base query for listing docs
base_docs_query = parse_one("""
SELECT d.*
FROM docs d
LEFT JOIN doc_owners do ON d.developer_id = do.developer_id AND d.doc_id = do.doc_id
LEFT JOIN doc_owners doc_own ON d.developer_id = doc_own.developer_id AND d.doc_id = doc_own.doc_id
WHERE d.developer_id = $1
ORDER BY
CASE
WHEN $4 = 'created_at' AND $5 = 'asc' THEN d.created_at
WHEN $4 = 'created_at' AND $5 = 'desc' THEN d.created_at
WHEN $4 = 'updated_at' AND $5 = 'asc' THEN d.updated_at
WHEN $4 = 'updated_at' AND $5 = 'desc' THEN d.updated_at
END DESC NULLS LAST
LIMIT $2
OFFSET $3;
""").sql(pretty=True)

# Listing for docs associated with a specific owner
owner_docs_query = parse_one("""
SELECT d.*
FROM docs d
JOIN doc_owners do ON d.developer_id = do.developer_id AND d.doc_id = do.doc_id
WHERE do.developer_id = $1
AND do.owner_id = $6
AND do.owner_type = $7
ORDER BY
CASE
WHEN $4 = 'created_at' AND $5 = 'asc' THEN d.created_at
WHEN $4 = 'created_at' AND $5 = 'desc' THEN d.created_at
WHEN $4 = 'updated_at' AND $5 = 'asc' THEN d.updated_at
WHEN $4 = 'updated_at' AND $5 = 'desc' THEN d.updated_at
END DESC NULLS LAST
LIMIT $2
OFFSET $3;
""").sql(pretty=True)


Expand All @@ -56,6 +24,8 @@
transform=lambda d: {
**d,
"id": d["doc_id"],
"content": ast.literal_eval(d["content"])[0] if len(ast.literal_eval(d["content"])) == 1 else ast.literal_eval(d["content"]),
# "embeddings": d["embeddings"],
},
)
@pg_query
Expand All @@ -64,29 +34,50 @@ async def list_docs(
*,
developer_id: UUID,
owner_id: UUID | None = None,
owner_type: Literal["user", "agent", "org"] | None = None,
owner_type: Literal["user", "agent"] | None = None,
limit: int = 100,
offset: int = 0,
sort_by: Literal["created_at", "updated_at"] = "created_at",
direction: Literal["asc", "desc"] = "desc",
metadata_filter: dict[str, Any] = {},
include_without_embeddings: bool = False,
) -> tuple[str, list]:
"""
Lists docs with optional owner filtering, pagination, and sorting.
"""
if direction.lower() not in ["asc", "desc"]:
raise HTTPException(status_code=400, detail="Invalid sort direction")

if sort_by not in ["created_at", "updated_at"]:
raise HTTPException(status_code=400, detail="Invalid sort field")

if limit > 100 or limit < 1:
raise HTTPException(status_code=400, detail="Limit must be between 1 and 100")

if offset < 0:
raise HTTPException(status_code=400, detail="Offset must be >= 0")

params = [developer_id, limit, offset, sort_by, direction]
if owner_id and owner_type:
params.extend([owner_id, owner_type])
query = owner_docs_query
else:
query = developer_docs_query
# Start with the base query
query = base_docs_query
params = [developer_id]

# Add owner filtering
if owner_type and owner_id:
query += " AND doc_own.owner_type = $2 AND doc_own.owner_id = $3"
params.extend([owner_type, owner_id])

# Add metadata filtering
if metadata_filter:
for key, value in metadata_filter.items():
query += f" AND d.metadata->>'{key}' = ${len(params) + 1}"
params.append(value)

# Include or exclude documents without embeddings
# if not include_without_embeddings:
# query += " AND d.embeddings IS NOT NULL"

# Add sorting and pagination
query += f" ORDER BY {sort_by} {direction} LIMIT ${len(params) + 1} OFFSET ${len(params) + 2}"
params.extend([limit, offset])

return (query, params)
return query, params
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from typing import List, Literal
from uuid import UUID

import asyncpg
from beartype import beartype
from fastapi import HTTPException
from sqlglot import parse_one
Expand Down
3 changes: 1 addition & 2 deletions agents-api/agents_api/queries/docs/search_docs_by_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from typing import Literal
from uuid import UUID

import asyncpg
from beartype import beartype
from fastapi import HTTPException
from sqlglot import parse_one
Expand All @@ -22,7 +21,7 @@
AND d.doc_id = do.doc_id
WHERE d.developer_id = $1
AND (
($4::text IS NULL AND $5::uuid IS NULL)
($4 IS NULL AND $5 IS NULL)
OR (do.owner_type = $4 AND do.owner_id = $5)
)
AND d.search_tsv @@ websearch_to_tsquery($3)
Expand Down
2 changes: 0 additions & 2 deletions agents-api/agents_api/queries/docs/search_docs_hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,8 @@
from uuid import UUID

from beartype import beartype
from fastapi import HTTPException

from ...autogen.openapi_model import Doc
from ..utils import run_concurrently
from .search_docs_by_embedding import search_docs_by_embedding
from .search_docs_by_text import search_docs_by_text

Expand Down
1 change: 0 additions & 1 deletion agents-api/agents_api/queries/entries/get_history.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import json
from typing import Any, List, Tuple
from uuid import UUID

import asyncpg
Expand Down
6 changes: 2 additions & 4 deletions agents-api/agents_api/queries/files/get_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,11 @@
from typing import Literal
from uuid import UUID

import asyncpg
from beartype import beartype
from fastapi import HTTPException
from sqlglot import parse_one

from ...autogen.openapi_model import File
from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class
from ..utils import pg_query, wrap_in_class

# Define the raw SQL query
file_query = parse_one("""
Expand Down Expand Up @@ -47,8 +45,8 @@
File,
one=True,
transform=lambda d: {
"id": d["file_id"],
**d,
"id": d["file_id"],
"hash": d["hash"].hex(),
"content": "DUMMY: NEED TO FETCH CONTENT FROM BLOB STORAGE",
},
Expand Down
Loading

0 comments on commit 93673b7

Please sign in to comment.