diff --git a/agents-api/agents_api/queries/chat/gather_messages.py b/agents-api/agents_api/queries/chat/gather_messages.py index dcbaa36e9..039624762 100644 --- a/agents-api/agents_api/queries/chat/gather_messages.py +++ b/agents-api/agents_api/queries/chat/gather_messages.py @@ -120,8 +120,8 @@ async def gather_messages( doc_references: list[DocReference] = await search_docs_hybrid( developer_id=developer.id, owners=owners, - query=query_text, - query_embedding=query_embedding, + text_query=query_text, + embedding=query_embedding, connection_pool=connection_pool, ) case "text": diff --git a/agents-api/agents_api/queries/docs/list_docs.py b/agents-api/agents_api/queries/docs/list_docs.py index 8ba29a445..e30fc5ed8 100644 --- a/agents-api/agents_api/queries/docs/list_docs.py +++ b/agents-api/agents_api/queries/docs/list_docs.py @@ -3,6 +3,7 @@ It constructs and executes SQL queries to fetch document details based on various filters. """ +import ast from typing import Any, Literal from uuid import UUID @@ -57,6 +58,14 @@ def transform_list_docs(d: dict) -> dict: content = d["content"][0] if len(d["content"]) == 1 else d["content"] embeddings = d["embeddings"][0] if len(d["embeddings"]) == 1 else d["embeddings"] + + # try: + # # Embeddings are retreived as a string, so we need to evaluate it + # embeddings = ast.literal_eval(embeddings) + # except Exception as e: + # msg = f"Error evaluating embeddings: {e}" + # raise ValueError(msg) + if embeddings and all((e is None) for e in embeddings): embeddings = None diff --git a/agents-api/agents_api/queries/docs/search_docs_by_embedding.py b/agents-api/agents_api/queries/docs/search_docs_by_embedding.py index 0f56b9cb7..fb5110a56 100644 --- a/agents-api/agents_api/queries/docs/search_docs_by_embedding.py +++ b/agents-api/agents_api/queries/docs/search_docs_by_embedding.py @@ -7,6 +7,7 @@ from ...autogen.openapi_model import DocReference from ...common.utils.db_exceptions import common_db_exceptions from ..utils import pg_query, rewrap_exceptions, wrap_in_class +from .utils import transform_to_doc_reference # Raw query for vector search search_docs_by_embedding_query = """ @@ -25,14 +26,7 @@ @rewrap_exceptions(common_db_exceptions("doc", ["search"])) @wrap_in_class( DocReference, - transform=lambda d: { - "owner": { - "id": d["owner_id"], - "role": d["owner_type"], - }, - "metadata": d.get("metadata", {}), - **d, - }, + transform=transform_to_doc_reference, ) @pg_query @beartype diff --git a/agents-api/agents_api/queries/docs/search_docs_by_text.py b/agents-api/agents_api/queries/docs/search_docs_by_text.py index 93982b731..77fb3a0e6 100644 --- a/agents-api/agents_api/queries/docs/search_docs_by_text.py +++ b/agents-api/agents_api/queries/docs/search_docs_by_text.py @@ -7,6 +7,7 @@ from ...autogen.openapi_model import DocReference from ...common.utils.db_exceptions import common_db_exceptions from ..utils import pg_query, rewrap_exceptions, wrap_in_class +from .utils import transform_to_doc_reference # Raw query for text search search_docs_text_query = """ @@ -25,14 +26,7 @@ @rewrap_exceptions(common_db_exceptions("doc", ["search"])) @wrap_in_class( DocReference, - transform=lambda d: { - "owner": { - "id": d["owner_id"], - "role": d["owner_type"], - }, - "metadata": d.get("metadata", {}), - **d, - }, + transform=transform_to_doc_reference, ) @pg_query @beartype diff --git a/agents-api/agents_api/queries/docs/search_docs_hybrid.py b/agents-api/agents_api/queries/docs/search_docs_hybrid.py index 4b6cca893..5c09b802c 100644 --- a/agents-api/agents_api/queries/docs/search_docs_hybrid.py +++ b/agents-api/agents_api/queries/docs/search_docs_hybrid.py @@ -11,6 +11,7 @@ rewrap_exceptions, wrap_in_class, ) +from .utils import transform_to_doc_reference # Raw query for hybrid search search_docs_hybrid_query = """ @@ -32,14 +33,7 @@ @rewrap_exceptions(common_db_exceptions("doc", ["search"])) @wrap_in_class( DocReference, - transform=lambda d: { - "owner": { - "id": d["owner_id"], - "role": d["owner_type"], - }, - "metadata": d.get("metadata", {}), - **d, - }, + transform=transform_to_doc_reference, ) @pg_query @beartype diff --git a/agents-api/agents_api/queries/docs/utils.py b/agents-api/agents_api/queries/docs/utils.py new file mode 100644 index 000000000..4d1cbaf45 --- /dev/null +++ b/agents-api/agents_api/queries/docs/utils.py @@ -0,0 +1,35 @@ +import ast + + +def transform_to_doc_reference(d: dict) -> dict: + id = d.pop("doc_id") + content = d.pop("content") + index = d.pop("index") + + embedding = d.pop("embedding") + + try: + # Embeddings are retreived as a string, so we need to evaluate it + embedding = ast.literal_eval(embedding) + except Exception as e: + msg = f"Error evaluating embeddings: {e}" + raise ValueError(msg) + + owner = { + "id": d.pop("owner_id"), + "role": d.pop("owner_type"), + } + snippet = { + "content": content, + "index": index, + "embedding": embedding, + } + metadata = d.pop("metadata") + + return { + "id": id, + "owner": owner, + "snippet": snippet, + "metadata": metadata, + **d, + } diff --git a/agents-api/agents_api/queries/entries/create_entries.py b/agents-api/agents_api/queries/entries/create_entries.py index 6b7fcff26..48e32dafd 100644 --- a/agents-api/agents_api/queries/entries/create_entries.py +++ b/agents-api/agents_api/queries/entries/create_entries.py @@ -89,21 +89,6 @@ async def create_entries( # Convert the data to a list of dictionaries data_dicts = [item.model_dump(mode="json") for item in data] - # Prepare the parameters for the query - # $1 - # $2 - # $3 - # $4 - # $5 - # $6 - # $7 - # $8 - # $9 - # $10 - # $11 - # $12 - # $13 - # $14 params = [ [ session_id, # $1 diff --git a/agents-api/agents_api/queries/entries/get_history.py b/agents-api/agents_api/queries/entries/get_history.py index 3cbbdcd0a..be4eebb5d 100644 --- a/agents-api/agents_api/queries/entries/get_history.py +++ b/agents-api/agents_api/queries/entries/get_history.py @@ -44,12 +44,14 @@ """ -@rewrap_exceptions(common_db_exceptions("history", ["get"])) -@wrap_in_class( - History, - one=True, - transform=lambda d: { - "entries": json.loads(d.get("entries") or "[]"), +def _transform(d): + return { + "entries": [ + { + **entry, + } + for entry in json.loads(d.get("entries") or "[]") + ], "relations": [ { "head": r["head"], @@ -60,7 +62,14 @@ ], "session_id": d.get("session_id"), "created_at": utcnow(), - }, + } + + +@rewrap_exceptions(common_db_exceptions("history", ["get"])) +@wrap_in_class( + History, + one=True, + transform=_transform, ) @pg_query @beartype diff --git a/agents-api/agents_api/queries/sessions/create_or_update_session.py b/agents-api/agents_api/queries/sessions/create_or_update_session.py index c5d278c8c..3da2126f6 100644 --- a/agents-api/agents_api/queries/sessions/create_or_update_session.py +++ b/agents-api/agents_api/queries/sessions/create_or_update_session.py @@ -119,7 +119,7 @@ async def create_or_update_session( data.token_budget, # $7 data.context_overflow, # $8 data.forward_tool_calls, # $9 - data.recall_options or {}, # $10 + data.recall_options.model_dump() if data.recall_options else {}, # $10 ] # Prepare lookup parameters diff --git a/agents-api/agents_api/queries/sessions/patch_session.py b/agents-api/agents_api/queries/sessions/patch_session.py index 033df9e5f..fe6848959 100644 --- a/agents-api/agents_api/queries/sessions/patch_session.py +++ b/agents-api/agents_api/queries/sessions/patch_session.py @@ -65,7 +65,7 @@ async def patch_session( data.token_budget, # $7 data.context_overflow, # $8 data.forward_tool_calls, # $9 - data.recall_options or {}, # $10 + data.recall_options.model_dump() if data.recall_options else {}, # $10 ] return [(session_query, session_params)] diff --git a/agents-api/agents_api/queries/sessions/update_session.py b/agents-api/agents_api/queries/sessions/update_session.py index bb4cc6590..6ad90bef3 100644 --- a/agents-api/agents_api/queries/sessions/update_session.py +++ b/agents-api/agents_api/queries/sessions/update_session.py @@ -64,7 +64,7 @@ async def update_session( data.token_budget, # $7 data.context_overflow, # $8 data.forward_tool_calls, # $9 - data.recall_options or {}, # $10 + data.recall_options.model_dump() if data.recall_options else {}, # $10 ] return [ diff --git a/agents-api/agents_api/routers/docs/create_doc.py b/agents-api/agents_api/routers/docs/create_doc.py index 3ffd81da2..d089f2802 100644 --- a/agents-api/agents_api/routers/docs/create_doc.py +++ b/agents-api/agents_api/routers/docs/create_doc.py @@ -1,7 +1,7 @@ from typing import Annotated from uuid import UUID -from fastapi import BackgroundTasks, Depends +from fastapi import Depends from starlette.status import HTTP_201_CREATED from ...autogen.openapi_model import CreateDocRequest, Doc, ResourceCreatedResponse @@ -15,7 +15,6 @@ async def create_user_doc( user_id: UUID, data: CreateDocRequest, x_developer_id: Annotated[UUID, Depends(get_developer_id)], - background_tasks: BackgroundTasks, ) -> ResourceCreatedResponse: """ Creates a new document for a user. @@ -24,7 +23,6 @@ async def create_user_doc( user_id (UUID): The unique identifier of the user associated with the document. data (CreateDocRequest): The data to create the document with. x_developer_id (UUID): The unique identifier of the developer associated with the document. - background_tasks (BackgroundTasks): The background tasks to run. Returns: ResourceCreatedResponse: The created document. @@ -45,7 +43,6 @@ async def create_agent_doc( agent_id: UUID, data: CreateDocRequest, x_developer_id: Annotated[UUID, Depends(get_developer_id)], - background_tasks: BackgroundTasks, ) -> ResourceCreatedResponse: doc: Doc = await create_doc_query( developer_id=x_developer_id, diff --git a/agents-api/agents_api/routers/docs/search_docs.py b/agents-api/agents_api/routers/docs/search_docs.py index c01f16770..0c463a83a 100644 --- a/agents-api/agents_api/routers/docs/search_docs.py +++ b/agents-api/agents_api/routers/docs/search_docs.py @@ -20,7 +20,7 @@ from .router import router -async def get_search_fn_and_params( +def get_search_fn_and_params( search_params, ) -> tuple[Any, dict[str, float | int | str | dict[str, float] | list[float]] | None]: search_fn, params = None, None @@ -58,10 +58,10 @@ async def get_search_fn_and_params( ): search_fn = search_docs_hybrid params = { - "text_query": query, - "embedding": query_embedding, + "query": query, + "query_embedding": query_embedding, "k": k * 3 if search_params.mmr_strength > 0 else k, - "confidence": confidence, + "embed_search_options": {"confidence": confidence}, "alpha": alpha, "metadata_filter": metadata_filter, } @@ -88,7 +88,7 @@ async def search_user_docs( """ # MMR here - search_fn, params = await get_search_fn_and_params(search_params) + search_fn, params = get_search_fn_and_params(search_params) start = time.time() docs: list[DocReference] = await search_fn( @@ -137,7 +137,7 @@ async def search_agent_docs( DocSearchResponse: The search results. """ - search_fn, params = await get_search_fn_and_params(search_params) + search_fn, params = get_search_fn_and_params(search_params) start = time.time() docs: list[DocReference] = await search_fn( diff --git a/memory-store/migrations/000015_entries.up.sql b/memory-store/migrations/000015_entries.up.sql index 10e7693a4..5b9302f05 100644 --- a/memory-store/migrations/000015_entries.up.sql +++ b/memory-store/migrations/000015_entries.up.sql @@ -37,15 +37,14 @@ CREATE TABLE IF NOT EXISTS entries ( name TEXT, content JSONB[] NOT NULL, tool_call_id TEXT DEFAULT NULL, - tool_calls JSONB[] NOT NULL DEFAULT '{}'::JSONB[], + tool_calls JSONB[] DEFAULT NULL, model TEXT NOT NULL, token_count INTEGER DEFAULT NULL, tokenizer TEXT NOT NULL, created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, timestamp DOUBLE PRECISION NOT NULL, CONSTRAINT pk_entries PRIMARY KEY (session_id, entry_id, created_at), - CONSTRAINT ct_content_is_array_of_objects CHECK (all_jsonb_elements_are_objects (content)), - CONSTRAINT ct_tool_calls_is_array_of_objects CHECK (all_jsonb_elements_are_objects (tool_calls)) + CONSTRAINT ct_content_is_array_of_objects CHECK (all_jsonb_elements_are_objects (content)) ); -- Convert to hypertable if not already