Skip to content

Commit

Permalink
feat(agents-api): Hybrid docs search (#444)
Browse files Browse the repository at this point in the history
Signed-off-by: Diwank Tomer <[email protected]>
Co-authored-by: Diwank Tomer <[email protected]>
  • Loading branch information
creatorrr and Diwank Tomer authored Aug 5, 2024
1 parent 50c898e commit 33e86a2
Show file tree
Hide file tree
Showing 5 changed files with 281 additions and 13 deletions.
3 changes: 2 additions & 1 deletion agents-api/agents_api/models/docs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,5 @@
from .embed_snippets import embed_snippets
from .get_doc import get_doc
from .list_docs import list_docs
from .search_docs import search_docs_by_embedding
from .search_docs_by_embedding import search_docs_by_embedding
from .search_docs_by_text import search_docs_by_text
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def search_docs_by_embedding(
confidence: float = 0.7,
ef: int = 128,
mmr_lambda: float = 0.25,
embedding_size: int = 1024,
) -> tuple[list[str], dict]:
"""
Searches for document snippets in CozoDB by embedding query.
Expand All @@ -61,6 +62,9 @@ def search_docs_by_embedding(
- mmr_lambda (float, optional): The lambda parameter for MMR. Defaults to 0.25.
"""

assert len(query_embedding) == embedding_size
assert sum(query_embedding)

owner_id = str(owner_id)

# Calculate the search radius based on confidence level
Expand Down
153 changes: 153 additions & 0 deletions agents-api/agents_api/models/docs/search_docs_by_text.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
"""This module contains functions for searching documents in the CozoDB based on embedding queries."""

from typing import Literal
from uuid import UUID

from beartype import beartype
from fastapi import HTTPException
from pycozo.client import QueryException
from pydantic import ValidationError

from ...autogen.openapi_model import DocReference
from ..utils import (
cozo_query,
partialclass,
rewrap_exceptions,
verify_developer_id_query,
verify_developer_owns_resource_query,
wrap_in_class,
)


@rewrap_exceptions(
{
QueryException: partialclass(HTTPException, status_code=400),
ValidationError: partialclass(HTTPException, status_code=400),
TypeError: partialclass(HTTPException, status_code=400),
}
)
@wrap_in_class(
DocReference,
transform=lambda d: {
"owner": {
"id": d["owner_id"],
"role": d["owner_type"],
},
**d,
},
)
@cozo_query
@beartype
def search_docs_by_text(
*,
developer_id: UUID,
owner_type: Literal["user", "agent"],
owner_id: UUID,
query: str,
k: int = 3,
) -> tuple[list[str], dict]:
"""
Searches for document snippets in CozoDB by embedding query.
Parameters:
- owner_type (Literal["user", "agent"]): The type of the owner of the documents.
- owner_id (UUID): The unique identifier of the owner.
- query (str): The query string.
- k (int, optional): The number of nearest neighbors to retrieve. Defaults to 3.
"""

owner_id = str(owner_id)

# Construct the datalog query for searching document snippets
search_query = f"""
input[
owner_id,
query,
] <- [[
to_uuid($owner_id),
$query,
]]
candidate[doc_id] :=
input[owner_id, _],
*docs {{
owner_type: $owner_type,
owner_id,
doc_id
}}
search_result[
doc_id,
snippet_data,
distance,
] :=
input[owner_id, query],
candidate[doc_id],
~snippets:fts {{
doc_id,
index,
content
|
query: query,
k: {k},
score_kind: 'tf_idf',
bind_score: score,
}},
distance = -score,
snippet_data = [index, content]
m[
doc_id,
collect(snippet),
distance,
title,
] :=
candidate[doc_id],
*docs {{
owner_type: $owner_type,
owner_id,
doc_id,
title,
}},
search_result [
doc_id,
snippet_data,
distance,
],
snippet = {{
"index": snippet_data->0,
"content": snippet_data->1,
}}
?[
id,
owner_type,
owner_id,
snippets,
distance,
title,
] := m[
id,
snippets,
distance,
title,
], owner_type = $owner_type, owner_id = $owner_id
# Sort the results by distance to find the closest matches
:sort distance
:limit {k}
"""

queries = [
verify_developer_id_query(developer_id),
verify_developer_owns_resource_query(
developer_id, f"{owner_type}s", **{f"{owner_type}_id": owner_id}
),
search_query,
]

return (
queries,
{"owner_type": owner_type, "owner_id": owner_id, "query": query},
)
121 changes: 121 additions & 0 deletions agents-api/agents_api/models/docs/search_docs_hybrid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
"""This module contains functions for searching documents in the CozoDB based on embedding queries."""

from statistics import mean, stdev
from typing import Literal
from uuid import UUID

from beartype import beartype

from ...autogen.openapi_model import DocReference
from .search_docs_by_embedding import search_docs_by_embedding
from .search_docs_by_text import search_docs_by_text


# Distribution based score normalization
# https://medium.com/plain-simple-software/distribution-based-score-fusion-dbsf-a-new-approach-to-vector-search-ranking-f87c37488b18
def dbsf_normalize(scores: list[float]) -> list[float]:
"""
Scores scaled using minmax scaler with our custom feature range
(extremes indicated as 3 standard deviations from the mean)
"""
sd = stdev(scores)
if sd == 0:
return scores

m = mean(scores)
m3d = 3 * sd + m
m_3d = m - 3 * sd

return [(s - m_3d) / (m3d - m_3d) for s in scores]


def dbsf_fuse(
text_results: list[DocReference],
embedding_results: list[DocReference],
alpha: float = 0.7, # Weight of the embedding search results (this is a good default)
) -> list[DocReference]:
"""
Weighted reciprocal-rank fusion of text and embedding search results
"""
all_docs = {doc.id: doc for doc in text_results + embedding_results}

text_scores: dict[UUID, float] = {doc.id: -doc.distance for doc in text_results}

# Because these are cosine distances, we need to invert them
embedding_scores: dict[UUID, float] = {
doc.id: 1.0 - doc.distance for doc in embedding_results
}

# normalize the scores
text_scores_normalized = dbsf_normalize(list(text_scores.values()))
text_scores = {
doc_id: score
for doc_id, score in zip(text_scores.keys(), text_scores_normalized)
}

embedding_scores_normalized = dbsf_normalize(list(embedding_scores.values()))
embedding_scores = {
doc_id: score
for doc_id, score in zip(embedding_scores.keys(), embedding_scores_normalized)
}

# Combine the scores
text_weight: float = 1 - alpha
embedding_weight: float = alpha

combined_scores = []

for id in all_docs.keys():
text_score = text_weight * text_scores.get(id, 0)
embedding_score = embedding_weight * embedding_scores.get(id, 0)

combined_scores.append((id, text_score + embedding_score))

# Sort by the combined score
combined_scores = sorted(combined_scores, key=lambda x: x[1], reverse=True)

# Rank the results
ranked_results = []
for id, score in combined_scores:
doc = all_docs[id].model_copy()
doc.distance = 1.0 - score
ranked_results.append(doc)

return ranked_results


@beartype
def search_docs_hybrid(
*,
developer_id: UUID,
owner_type: Literal["user", "agent"],
owner_id: UUID,
query: str,
query_embedding: list[float],
k: int = 3,
embed_search_options: dict = {},
text_search_options: dict = {},
**kwargs,
) -> list[DocReference]:
# TODO: We should probably parallelize these queries
text_results = search_docs_by_text(
developer_id=developer_id,
owner_type=owner_type,
owner_id=owner_id,
query=query,
k=2 * k,
**text_search_options,
**kwargs,
)

embedding_results = search_docs_by_embedding(
developer_id=developer_id,
owner_type=owner_type,
owner_id=owner_id,
query_embedding=query_embedding,
k=2 * k,
**embed_search_options,
**kwargs,
)

return dbsf_fuse(text_results, embedding_results)[:k]
13 changes: 1 addition & 12 deletions agents-api/agents_api/models/session/prepare_chat_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,15 +47,10 @@ def prepare_chat_context(
developer_id: UUID,
agent_id: UUID,
session_id: UUID,
# doc_query_embedding: list[float],
# docs_confidence: float = 0.4,
# k_docs: int = 3,
) -> tuple[list[str], dict]:
"""
Executes a complex query to retrieve memory context based on session ID, tool and document embeddings.
Executes a complex query to retrieve memory context based on session ID.
"""
# VECTOR_SIZE = 1024
# docs_radius: float = 1.0 - docs_confidence

session_data_query, sd_vars = prepare_session_data.__wrapped__(
developer_id=developer_id, session_id=session_id
Expand Down Expand Up @@ -89,9 +84,6 @@ def prepare_chat_context(
}}
"""

# TODO: Implement the following queries
# docs_query = ...

entries_query, e_vars = list_entries.__wrapped__(
developer_id=developer_id,
session_id=session_id,
Expand Down Expand Up @@ -143,8 +135,5 @@ def prepare_chat_context(
**sd_vars,
**t_vars,
**e_vars,
# "doc_query_embedding": doc_query_embedding,
# "k_docs": k_docs,
# "docs_radius": round(docs_radius, 2),
},
)

0 comments on commit 33e86a2

Please sign in to comment.