Skip to content

Commit

Permalink
Merge pull request #1000 from julep-ai/x/misc-session-fixes
Browse files Browse the repository at this point in the history
Fix(agents-api): Miscellaneous fixes related to sessions & entries
  • Loading branch information
creatorrr authored Dec 31, 2024
2 parents 81841e4 + 7f76bc3 commit a5f00b1
Show file tree
Hide file tree
Showing 14 changed files with 80 additions and 64 deletions.
4 changes: 2 additions & 2 deletions agents-api/agents_api/queries/chat/gather_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,8 @@ async def gather_messages(
doc_references: list[DocReference] = await search_docs_hybrid(
developer_id=developer.id,
owners=owners,
query=query_text,
query_embedding=query_embedding,
text_query=query_text,
embedding=query_embedding,
connection_pool=connection_pool,
)
case "text":
Expand Down
9 changes: 9 additions & 0 deletions agents-api/agents_api/queries/docs/list_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
It constructs and executes SQL queries to fetch document details based on various filters.
"""

import ast
from typing import Any, Literal
from uuid import UUID

Expand Down Expand Up @@ -57,6 +58,14 @@ def transform_list_docs(d: dict) -> dict:
content = d["content"][0] if len(d["content"]) == 1 else d["content"]

embeddings = d["embeddings"][0] if len(d["embeddings"]) == 1 else d["embeddings"]

# try:
# # Embeddings are retreived as a string, so we need to evaluate it
# embeddings = ast.literal_eval(embeddings)
# except Exception as e:
# msg = f"Error evaluating embeddings: {e}"
# raise ValueError(msg)

if embeddings and all((e is None) for e in embeddings):
embeddings = None

Expand Down
10 changes: 2 additions & 8 deletions agents-api/agents_api/queries/docs/search_docs_by_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from ...autogen.openapi_model import DocReference
from ...common.utils.db_exceptions import common_db_exceptions
from ..utils import pg_query, rewrap_exceptions, wrap_in_class
from .utils import transform_to_doc_reference

# Raw query for vector search
search_docs_by_embedding_query = """
Expand All @@ -25,14 +26,7 @@
@rewrap_exceptions(common_db_exceptions("doc", ["search"]))
@wrap_in_class(
DocReference,
transform=lambda d: {
"owner": {
"id": d["owner_id"],
"role": d["owner_type"],
},
"metadata": d.get("metadata", {}),
**d,
},
transform=transform_to_doc_reference,
)
@pg_query
@beartype
Expand Down
10 changes: 2 additions & 8 deletions agents-api/agents_api/queries/docs/search_docs_by_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from ...autogen.openapi_model import DocReference
from ...common.utils.db_exceptions import common_db_exceptions
from ..utils import pg_query, rewrap_exceptions, wrap_in_class
from .utils import transform_to_doc_reference

# Raw query for text search
search_docs_text_query = """
Expand All @@ -25,14 +26,7 @@
@rewrap_exceptions(common_db_exceptions("doc", ["search"]))
@wrap_in_class(
DocReference,
transform=lambda d: {
"owner": {
"id": d["owner_id"],
"role": d["owner_type"],
},
"metadata": d.get("metadata", {}),
**d,
},
transform=transform_to_doc_reference,
)
@pg_query
@beartype
Expand Down
10 changes: 2 additions & 8 deletions agents-api/agents_api/queries/docs/search_docs_hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
rewrap_exceptions,
wrap_in_class,
)
from .utils import transform_to_doc_reference

# Raw query for hybrid search
search_docs_hybrid_query = """
Expand All @@ -32,14 +33,7 @@
@rewrap_exceptions(common_db_exceptions("doc", ["search"]))
@wrap_in_class(
DocReference,
transform=lambda d: {
"owner": {
"id": d["owner_id"],
"role": d["owner_type"],
},
"metadata": d.get("metadata", {}),
**d,
},
transform=transform_to_doc_reference,
)
@pg_query
@beartype
Expand Down
35 changes: 35 additions & 0 deletions agents-api/agents_api/queries/docs/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import ast


def transform_to_doc_reference(d: dict) -> dict:
id = d.pop("doc_id")
content = d.pop("content")
index = d.pop("index")

embedding = d.pop("embedding")

try:
# Embeddings are retreived as a string, so we need to evaluate it
embedding = ast.literal_eval(embedding)
except Exception as e:
msg = f"Error evaluating embeddings: {e}"
raise ValueError(msg)

owner = {
"id": d.pop("owner_id"),
"role": d.pop("owner_type"),
}
snippet = {
"content": content,
"index": index,
"embedding": embedding,
}
metadata = d.pop("metadata")

return {
"id": id,
"owner": owner,
"snippet": snippet,
"metadata": metadata,
**d,
}
15 changes: 0 additions & 15 deletions agents-api/agents_api/queries/entries/create_entries.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,21 +89,6 @@ async def create_entries(
# Convert the data to a list of dictionaries
data_dicts = [item.model_dump(mode="json") for item in data]

# Prepare the parameters for the query
# $1
# $2
# $3
# $4
# $5
# $6
# $7
# $8
# $9
# $10
# $11
# $12
# $13
# $14
params = [
[
session_id, # $1
Expand Down
23 changes: 16 additions & 7 deletions agents-api/agents_api/queries/entries/get_history.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,14 @@
"""


@rewrap_exceptions(common_db_exceptions("history", ["get"]))
@wrap_in_class(
History,
one=True,
transform=lambda d: {
"entries": json.loads(d.get("entries") or "[]"),
def _transform(d):
return {
"entries": [
{
**entry,
}
for entry in json.loads(d.get("entries") or "[]")
],
"relations": [
{
"head": r["head"],
Expand All @@ -60,7 +62,14 @@
],
"session_id": d.get("session_id"),
"created_at": utcnow(),
},
}


@rewrap_exceptions(common_db_exceptions("history", ["get"]))
@wrap_in_class(
History,
one=True,
transform=_transform,
)
@pg_query
@beartype
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ async def create_or_update_session(
data.token_budget, # $7
data.context_overflow, # $8
data.forward_tool_calls, # $9
data.recall_options or {}, # $10
data.recall_options.model_dump() if data.recall_options else {}, # $10
]

# Prepare lookup parameters
Expand Down
2 changes: 1 addition & 1 deletion agents-api/agents_api/queries/sessions/patch_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ async def patch_session(
data.token_budget, # $7
data.context_overflow, # $8
data.forward_tool_calls, # $9
data.recall_options or {}, # $10
data.recall_options.model_dump() if data.recall_options else {}, # $10
]

return [(session_query, session_params)]
2 changes: 1 addition & 1 deletion agents-api/agents_api/queries/sessions/update_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ async def update_session(
data.token_budget, # $7
data.context_overflow, # $8
data.forward_tool_calls, # $9
data.recall_options or {}, # $10
data.recall_options.model_dump() if data.recall_options else {}, # $10
]

return [
Expand Down
5 changes: 1 addition & 4 deletions agents-api/agents_api/routers/docs/create_doc.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Annotated
from uuid import UUID

from fastapi import BackgroundTasks, Depends
from fastapi import Depends
from starlette.status import HTTP_201_CREATED

from ...autogen.openapi_model import CreateDocRequest, Doc, ResourceCreatedResponse
Expand All @@ -15,7 +15,6 @@ async def create_user_doc(
user_id: UUID,
data: CreateDocRequest,
x_developer_id: Annotated[UUID, Depends(get_developer_id)],
background_tasks: BackgroundTasks,
) -> ResourceCreatedResponse:
"""
Creates a new document for a user.
Expand All @@ -24,7 +23,6 @@ async def create_user_doc(
user_id (UUID): The unique identifier of the user associated with the document.
data (CreateDocRequest): The data to create the document with.
x_developer_id (UUID): The unique identifier of the developer associated with the document.
background_tasks (BackgroundTasks): The background tasks to run.
Returns:
ResourceCreatedResponse: The created document.
Expand All @@ -45,7 +43,6 @@ async def create_agent_doc(
agent_id: UUID,
data: CreateDocRequest,
x_developer_id: Annotated[UUID, Depends(get_developer_id)],
background_tasks: BackgroundTasks,
) -> ResourceCreatedResponse:
doc: Doc = await create_doc_query(
developer_id=x_developer_id,
Expand Down
12 changes: 6 additions & 6 deletions agents-api/agents_api/routers/docs/search_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from .router import router


async def get_search_fn_and_params(
def get_search_fn_and_params(
search_params,
) -> tuple[Any, dict[str, float | int | str | dict[str, float] | list[float]] | None]:
search_fn, params = None, None
Expand Down Expand Up @@ -58,10 +58,10 @@ async def get_search_fn_and_params(
):
search_fn = search_docs_hybrid
params = {
"text_query": query,
"embedding": query_embedding,
"query": query,
"query_embedding": query_embedding,
"k": k * 3 if search_params.mmr_strength > 0 else k,
"confidence": confidence,
"embed_search_options": {"confidence": confidence},
"alpha": alpha,
"metadata_filter": metadata_filter,
}
Expand All @@ -88,7 +88,7 @@ async def search_user_docs(
"""

# MMR here
search_fn, params = await get_search_fn_and_params(search_params)
search_fn, params = get_search_fn_and_params(search_params)

start = time.time()
docs: list[DocReference] = await search_fn(
Expand Down Expand Up @@ -137,7 +137,7 @@ async def search_agent_docs(
DocSearchResponse: The search results.
"""

search_fn, params = await get_search_fn_and_params(search_params)
search_fn, params = get_search_fn_and_params(search_params)

start = time.time()
docs: list[DocReference] = await search_fn(
Expand Down
5 changes: 2 additions & 3 deletions memory-store/migrations/000015_entries.up.sql
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,14 @@ CREATE TABLE IF NOT EXISTS entries (
name TEXT,
content JSONB[] NOT NULL,
tool_call_id TEXT DEFAULT NULL,
tool_calls JSONB[] NOT NULL DEFAULT '{}'::JSONB[],
tool_calls JSONB[] DEFAULT NULL,
model TEXT NOT NULL,
token_count INTEGER DEFAULT NULL,
tokenizer TEXT NOT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
timestamp DOUBLE PRECISION NOT NULL,
CONSTRAINT pk_entries PRIMARY KEY (session_id, entry_id, created_at),
CONSTRAINT ct_content_is_array_of_objects CHECK (all_jsonb_elements_are_objects (content)),
CONSTRAINT ct_tool_calls_is_array_of_objects CHECK (all_jsonb_elements_are_objects (tool_calls))
CONSTRAINT ct_content_is_array_of_objects CHECK (all_jsonb_elements_are_objects (content))
);

-- Convert to hypertable if not already
Expand Down

0 comments on commit a5f00b1

Please sign in to comment.