Skip to content

Commit

Permalink
chore: added embedding reading + doctrings updates
Browse files Browse the repository at this point in the history
  • Loading branch information
Vedantsahai18 committed Dec 20, 2024
1 parent 32d67bc commit 831e950
Show file tree
Hide file tree
Showing 11 changed files with 194 additions and 36 deletions.
13 changes: 13 additions & 0 deletions agents-api/agents_api/queries/docs/create_doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,19 @@ async def create_doc(
) -> list[tuple[str, list] | tuple[str, list, str]]:
"""
Insert a new doc record into Timescale and optionally associate it with an owner.
Parameters:
owner_type (Literal["user", "agent"]): The type of the owner of the documents.
owner_id (UUID): The ID of the owner of the documents.
modality (Literal["text", "image", "mixed"]): The modality of the documents.
embedding_model (str): The model used for embedding.
embedding_dimensions (int): The dimensions of the embedding.
language (str): The language of the documents.
index (int): The index of the documents.
data (CreateDocRequest): The data for the document.
Returns:
list[tuple[str, list] | tuple[str, list, str]]: SQL query and parameters for creating the document.
"""
# Generate a UUID if not provided
doc_id = doc_id or uuid7()
Expand Down
9 changes: 9 additions & 0 deletions agents-api/agents_api/queries/docs/delete_doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,15 @@ async def delete_doc(
"""
Deletes a doc (and associated doc_owners) for the given developer and doc_id.
If owner_type/owner_id is specified, only remove doc if that matches.
Parameters:
developer_id (UUID): The ID of the developer.
doc_id (UUID): The ID of the document.
owner_type (Literal["user", "agent"]): The type of the owner of the documents.
owner_id (UUID): The ID of the owner of the documents.
Returns:
tuple[str, list]: SQL query and parameters for deleting the document.
"""
return (
delete_doc_query,
Expand Down
37 changes: 37 additions & 0 deletions agents-api/agents_api/queries/docs/embed_snippets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from typing import Literal
from uuid import UUID

from beartype import beartype
from sqlglot import parse_one

from ..utils import pg_query

# TODO: This is a placeholder for the actual query
vectorizer_query = None


@pg_query
@beartype
async def embed_snippets(
*,
developer_id: UUID,
doc_id: UUID,
owner_type: Literal["user", "agent"] | None = None,
owner_id: UUID | None = None,
) -> tuple[str, list]:
"""
Trigger the vectorizer to generate embeddings for documents.
Parameters:
developer_id (UUID): The ID of the developer.
doc_id (UUID): The ID of the document.
owner_type (Literal["user", "agent"]): The type of the owner of the documents.
owner_id (UUID): The ID of the owner of the documents.
Returns:
tuple[str, list]: SQL query and parameters for embedding the snippets.
"""
return (
vectorizer_query,
[developer_id, doc_id, owner_type, owner_id],
)
26 changes: 20 additions & 6 deletions agents-api/agents_api/queries/docs/get_doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,15 @@
from ...autogen.openapi_model import Doc
from ..utils import pg_query, wrap_in_class

doc_query = parse_one("""
SELECT d.*
# Combined query to fetch document details and embedding
doc_with_embedding_query = parse_one("""
SELECT d.*, e.embedding
FROM docs d
LEFT JOIN doc_owners doc_own ON d.developer_id = doc_own.developer_id AND d.doc_id = doc_own.doc_id
LEFT JOIN doc_owners doc_own
ON d.developer_id = doc_own.developer_id
AND d.doc_id = doc_own.doc_id
LEFT JOIN docs_embeddings e
ON d.doc_id = e.doc_id
WHERE d.developer_id = $1
AND d.doc_id = $2
AND (
Expand All @@ -31,7 +36,7 @@
"content": ast.literal_eval(d["content"])[0]
if len(ast.literal_eval(d["content"])) == 1
else ast.literal_eval(d["content"]),
# "embeddings": d["embeddings"],
"embedding": d["embedding"], # Add embedding to the transformation
},
)
@pg_query
Expand All @@ -44,9 +49,18 @@ async def get_doc(
owner_id: UUID | None = None,
) -> tuple[str, list]:
"""
Fetch a single doc, optionally constrained to a given owner.
Fetch a single doc with its embedding, optionally constrained to a given owner.
Parameters:
developer_id (UUID): The ID of the developer.
doc_id (UUID): The ID of the document.
owner_type (Literal["user", "agent"]): The type of the owner of the documents.
owner_id (UUID): The ID of the owner of the documents.
Returns:
tuple[str, list]: SQL query and parameters for fetching the document.
"""
return (
doc_query,
doc_with_embedding_query,
[developer_id, doc_id, owner_type, owner_id],
)
29 changes: 20 additions & 9 deletions agents-api/agents_api/queries/docs/list_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@
from ...autogen.openapi_model import Doc
from ..utils import pg_query, wrap_in_class

# Base query for listing docs
# Base query for listing docs with optional embeddings
base_docs_query = parse_one("""
SELECT d.*
SELECT d.*, CASE WHEN $2 THEN NULL ELSE e.embedding END AS embedding
FROM docs d
LEFT JOIN doc_owners doc_own ON d.developer_id = doc_own.developer_id AND d.doc_id = doc_own.doc_id
LEFT JOIN docs_embeddings e ON d.doc_id = e.doc_id
WHERE d.developer_id = $1
""").sql(pretty=True)

Expand All @@ -27,7 +28,7 @@
"content": ast.literal_eval(d["content"])[0]
if len(ast.literal_eval(d["content"])) == 1
else ast.literal_eval(d["content"]),
# "embeddings": d["embeddings"],
"embedding": d.get("embedding"), # Add embedding to the transformation
},
)
@pg_query
Expand All @@ -46,6 +47,20 @@ async def list_docs(
) -> tuple[str, list]:
"""
Lists docs with optional owner filtering, pagination, and sorting.
Parameters:
developer_id (UUID): The ID of the developer.
owner_id (UUID): The ID of the owner of the documents.
owner_type (Literal["user", "agent"]): The type of the owner of the documents.
limit (int): The number of documents to return.
offset (int): The number of documents to skip.
sort_by (Literal["created_at", "updated_at"]): The field to sort by.
direction (Literal["asc", "desc"]): The direction to sort by.
metadata_filter (dict[str, Any]): The metadata filter to apply.
include_without_embeddings (bool): Whether to include documents without embeddings.
Returns:
tuple[str, list]: SQL query and parameters for listing the documents.
"""
if direction.lower() not in ["asc", "desc"]:
raise HTTPException(status_code=400, detail="Invalid sort direction")
Expand All @@ -61,11 +76,11 @@ async def list_docs(

# Start with the base query
query = base_docs_query
params = [developer_id]
params = [developer_id, include_without_embeddings]

# Add owner filtering
if owner_type and owner_id:
query += " AND doc_own.owner_type = $2 AND doc_own.owner_id = $3"
query += " AND doc_own.owner_type = $3 AND doc_own.owner_id = $4"
params.extend([owner_type, owner_id])

# Add metadata filtering
Expand All @@ -74,10 +89,6 @@ async def list_docs(
query += f" AND d.metadata->>'{key}' = ${len(params) + 1}"
params.append(value)

# Include or exclude documents without embeddings
# if not include_without_embeddings:
# query += " AND d.embeddings IS NOT NULL"

# Add sorting and pagination
query += f" ORDER BY {sort_by} {direction} LIMIT ${len(params) + 1} OFFSET ${len(params) + 2}"
params.extend([limit, offset])
Expand Down
29 changes: 19 additions & 10 deletions agents-api/agents_api/queries/docs/search_docs_by_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from fastapi import HTTPException
from sqlglot import parse_one

from ...autogen.openapi_model import Doc
from ...autogen.openapi_model import DocReference
from ..utils import pg_query, wrap_in_class

# If you're doing approximate ANN (DiskANN) or IVF, you might use a special function or hint.
Expand All @@ -33,11 +33,14 @@


@wrap_in_class(
Doc,
one=False,
transform=lambda rec: {
**rec,
"id": rec["doc_id"],
DocReference,
transform=lambda d: {
"owner": {
"id": d["owner_id"],
"role": d["owner_type"],
},
"metadata": d.get("metadata", {}),
**d,
},
)
@pg_query
Expand All @@ -52,10 +55,16 @@ async def search_docs_by_embedding(
) -> tuple[str, list]:
"""
Vector-based doc search:
- developer_id is required
- query_embedding: the vector to query
- k: number of results to return
- owner_type/owner_id: optional doc ownership filter
Parameters:
developer_id (UUID): The ID of the developer.
query_embedding (List[float]): The vector to query.
k (int): The number of results to return.
owner_type (Literal["user", "agent", "org"]): The type of the owner of the documents.
owner_id (UUID): The ID of the owner of the documents.
Returns:
tuple[str, list]: SQL query and parameters for searching the documents.
"""
if k < 1:
raise HTTPException(status_code=400, detail="k must be >= 1")
Expand Down
29 changes: 19 additions & 10 deletions agents-api/agents_api/queries/docs/search_docs_by_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from fastapi import HTTPException
from sqlglot import parse_one

from ...autogen.openapi_model import Doc
from ...autogen.openapi_model import DocReference
from ..utils import pg_query, wrap_in_class

search_docs_text_query = parse_one("""
Expand All @@ -31,11 +31,14 @@


@wrap_in_class(
Doc,
one=False,
transform=lambda rec: {
**rec,
"id": rec["doc_id"],
DocReference,
transform=lambda d: {
"owner": {
"id": d["owner_id"],
"role": d["owner_type"],
},
"metadata": d.get("metadata", {}),
**d,
},
)
@pg_query
Expand All @@ -50,10 +53,16 @@ async def search_docs_by_text(
) -> tuple[str, list]:
"""
Full-text search on docs using the search_tsv column.
- developer_id: required
- query: the text to look for
- k: max results
- owner_type / owner_id: optional doc ownership filter
Parameters:
developer_id (UUID): The ID of the developer.
query (str): The text to search for.
k (int): The number of results to return.
owner_type (Literal["user", "agent", "org"]): The type of the owner of the documents.
owner_id (UUID): The ID of the owner of the documents.
Returns:
tuple[str, list]: SQL query and parameters for searching the documents.
"""
if k < 1:
raise HTTPException(status_code=400, detail="k must be >= 1")
Expand Down
22 changes: 22 additions & 0 deletions agents-api/agents_api/queries/entries/create_entries.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,17 @@ async def create_entries(
session_id: UUID,
data: list[CreateEntryRequest],
) -> list[tuple[str, list, Literal["fetch", "fetchmany", "fetchrow"]]]:
"""
Create entries in a session.
Parameters:
developer_id (UUID): The ID of the developer.
session_id (UUID): The ID of the session.
data (list[CreateEntryRequest]): The list of entries to create.
Returns:
list[tuple[str, list, Literal["fetch", "fetchmany", "fetchrow"]]]: SQL query and parameters for creating the entries.
"""
# Convert the data to a list of dictionaries
data_dicts = [item.model_dump(mode="json") for item in data]

Expand Down Expand Up @@ -163,6 +174,17 @@ async def add_entry_relations(
session_id: UUID,
data: list[Relation],
) -> list[tuple[str, list, Literal["fetch", "fetchmany", "fetchrow"]]]:
"""
Add relations between entries in a session.
Parameters:
developer_id (UUID): The ID of the developer.
session_id (UUID): The ID of the session.
data (list[Relation]): The list of relations to add.
Returns:
list[tuple[str, list, Literal["fetch", "fetchmany", "fetchrow"]]]: SQL query and parameters for adding the relations.
"""
# Convert the data to a list of dictionaries
data_dicts = [item.model_dump(mode="json") for item in data]

Expand Down
11 changes: 10 additions & 1 deletion agents-api/agents_api/queries/entries/delete_entries.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,16 @@ async def delete_entries_for_session(
async def delete_entries(
*, developer_id: UUID, session_id: UUID, entry_ids: list[UUID]
) -> list[tuple[str, list, Literal["fetch", "fetchmany", "fetchrow"]]]:
"""Delete specific entries by their IDs."""
"""Delete specific entries by their IDs.
Parameters:
developer_id (UUID): The ID of the developer.
session_id (UUID): The ID of the session.
entry_ids (list[UUID]): The IDs of the entries to delete.
Returns:
list[tuple[str, list, Literal["fetch", "fetchmany", "fetchrow"]]]: SQL query and parameters for deleting the entries.
"""
return [
(
session_exists_query,
Expand Down
10 changes: 10 additions & 0 deletions agents-api/agents_api/queries/entries/get_history.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,16 @@ async def get_history(
session_id: UUID,
allowed_sources: list[str] = ["api_request", "api_response"],
) -> tuple[str, list] | tuple[str, list, str]:
"""Get the history of a session.
Parameters:
developer_id (UUID): The ID of the developer.
session_id (UUID): The ID of the session.
allowed_sources (list[str]): The sources to include in the history.
Returns:
tuple[str, list] | tuple[str, list, str]: SQL query and parameters for getting the history.
"""
return (
history_query,
[session_id, allowed_sources, developer_id],
Expand Down
15 changes: 15 additions & 0 deletions agents-api/agents_api/queries/entries/list_entries.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,21 @@ async def list_entries(
direction: Literal["asc", "desc"] = "asc",
exclude_relations: list[str] = [],
) -> list[tuple[str, list] | tuple[str, list, str]]:
"""List entries in a session.
Parameters:
developer_id (UUID): The ID of the developer.
session_id (UUID): The ID of the session.
allowed_sources (list[str]): The sources to include in the history.
limit (int): The number of entries to return.
offset (int): The number of entries to skip.
sort_by (Literal["created_at", "timestamp"]): The field to sort by.
direction (Literal["asc", "desc"]): The direction to sort by.
exclude_relations (list[str]): The relations to exclude.
Returns:
tuple[str, list] | tuple[str, list, str]: SQL query and parameters for listing the entries.
"""
if limit < 1 or limit > 1000:
raise HTTPException(status_code=400, detail="Limit must be between 1 and 1000")
if offset < 0:
Expand Down

0 comments on commit 831e950

Please sign in to comment.