-
Notifications
You must be signed in to change notification settings - Fork 904
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(agents-api): Hybrid docs search (#444)
Signed-off-by: Diwank Tomer <[email protected]> Co-authored-by: Diwank Tomer <[email protected]>
- Loading branch information
Showing
5 changed files
with
281 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
153 changes: 153 additions & 0 deletions
153
agents-api/agents_api/models/docs/search_docs_by_text.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,153 @@ | ||
"""This module contains functions for searching documents in the CozoDB based on embedding queries.""" | ||
|
||
from typing import Literal | ||
from uuid import UUID | ||
|
||
from beartype import beartype | ||
from fastapi import HTTPException | ||
from pycozo.client import QueryException | ||
from pydantic import ValidationError | ||
|
||
from ...autogen.openapi_model import DocReference | ||
from ..utils import ( | ||
cozo_query, | ||
partialclass, | ||
rewrap_exceptions, | ||
verify_developer_id_query, | ||
verify_developer_owns_resource_query, | ||
wrap_in_class, | ||
) | ||
|
||
|
||
@rewrap_exceptions( | ||
{ | ||
QueryException: partialclass(HTTPException, status_code=400), | ||
ValidationError: partialclass(HTTPException, status_code=400), | ||
TypeError: partialclass(HTTPException, status_code=400), | ||
} | ||
) | ||
@wrap_in_class( | ||
DocReference, | ||
transform=lambda d: { | ||
"owner": { | ||
"id": d["owner_id"], | ||
"role": d["owner_type"], | ||
}, | ||
**d, | ||
}, | ||
) | ||
@cozo_query | ||
@beartype | ||
def search_docs_by_text( | ||
*, | ||
developer_id: UUID, | ||
owner_type: Literal["user", "agent"], | ||
owner_id: UUID, | ||
query: str, | ||
k: int = 3, | ||
) -> tuple[list[str], dict]: | ||
""" | ||
Searches for document snippets in CozoDB by embedding query. | ||
Parameters: | ||
- owner_type (Literal["user", "agent"]): The type of the owner of the documents. | ||
- owner_id (UUID): The unique identifier of the owner. | ||
- query (str): The query string. | ||
- k (int, optional): The number of nearest neighbors to retrieve. Defaults to 3. | ||
""" | ||
|
||
owner_id = str(owner_id) | ||
|
||
# Construct the datalog query for searching document snippets | ||
search_query = f""" | ||
input[ | ||
owner_id, | ||
query, | ||
] <- [[ | ||
to_uuid($owner_id), | ||
$query, | ||
]] | ||
candidate[doc_id] := | ||
input[owner_id, _], | ||
*docs {{ | ||
owner_type: $owner_type, | ||
owner_id, | ||
doc_id | ||
}} | ||
search_result[ | ||
doc_id, | ||
snippet_data, | ||
distance, | ||
] := | ||
input[owner_id, query], | ||
candidate[doc_id], | ||
~snippets:fts {{ | ||
doc_id, | ||
index, | ||
content | ||
| | ||
query: query, | ||
k: {k}, | ||
score_kind: 'tf_idf', | ||
bind_score: score, | ||
}}, | ||
distance = -score, | ||
snippet_data = [index, content] | ||
m[ | ||
doc_id, | ||
collect(snippet), | ||
distance, | ||
title, | ||
] := | ||
candidate[doc_id], | ||
*docs {{ | ||
owner_type: $owner_type, | ||
owner_id, | ||
doc_id, | ||
title, | ||
}}, | ||
search_result [ | ||
doc_id, | ||
snippet_data, | ||
distance, | ||
], | ||
snippet = {{ | ||
"index": snippet_data->0, | ||
"content": snippet_data->1, | ||
}} | ||
?[ | ||
id, | ||
owner_type, | ||
owner_id, | ||
snippets, | ||
distance, | ||
title, | ||
] := m[ | ||
id, | ||
snippets, | ||
distance, | ||
title, | ||
], owner_type = $owner_type, owner_id = $owner_id | ||
# Sort the results by distance to find the closest matches | ||
:sort distance | ||
:limit {k} | ||
""" | ||
|
||
queries = [ | ||
verify_developer_id_query(developer_id), | ||
verify_developer_owns_resource_query( | ||
developer_id, f"{owner_type}s", **{f"{owner_type}_id": owner_id} | ||
), | ||
search_query, | ||
] | ||
|
||
return ( | ||
queries, | ||
{"owner_type": owner_type, "owner_id": owner_id, "query": query}, | ||
) |
121 changes: 121 additions & 0 deletions
121
agents-api/agents_api/models/docs/search_docs_hybrid.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
"""This module contains functions for searching documents in the CozoDB based on embedding queries.""" | ||
|
||
from statistics import mean, stdev | ||
from typing import Literal | ||
from uuid import UUID | ||
|
||
from beartype import beartype | ||
|
||
from ...autogen.openapi_model import DocReference | ||
from .search_docs_by_embedding import search_docs_by_embedding | ||
from .search_docs_by_text import search_docs_by_text | ||
|
||
|
||
# Distribution based score normalization | ||
# https://medium.com/plain-simple-software/distribution-based-score-fusion-dbsf-a-new-approach-to-vector-search-ranking-f87c37488b18 | ||
def dbsf_normalize(scores: list[float]) -> list[float]: | ||
""" | ||
Scores scaled using minmax scaler with our custom feature range | ||
(extremes indicated as 3 standard deviations from the mean) | ||
""" | ||
sd = stdev(scores) | ||
if sd == 0: | ||
return scores | ||
|
||
m = mean(scores) | ||
m3d = 3 * sd + m | ||
m_3d = m - 3 * sd | ||
|
||
return [(s - m_3d) / (m3d - m_3d) for s in scores] | ||
|
||
|
||
def dbsf_fuse( | ||
text_results: list[DocReference], | ||
embedding_results: list[DocReference], | ||
alpha: float = 0.7, # Weight of the embedding search results (this is a good default) | ||
) -> list[DocReference]: | ||
""" | ||
Weighted reciprocal-rank fusion of text and embedding search results | ||
""" | ||
all_docs = {doc.id: doc for doc in text_results + embedding_results} | ||
|
||
text_scores: dict[UUID, float] = {doc.id: -doc.distance for doc in text_results} | ||
|
||
# Because these are cosine distances, we need to invert them | ||
embedding_scores: dict[UUID, float] = { | ||
doc.id: 1.0 - doc.distance for doc in embedding_results | ||
} | ||
|
||
# normalize the scores | ||
text_scores_normalized = dbsf_normalize(list(text_scores.values())) | ||
text_scores = { | ||
doc_id: score | ||
for doc_id, score in zip(text_scores.keys(), text_scores_normalized) | ||
} | ||
|
||
embedding_scores_normalized = dbsf_normalize(list(embedding_scores.values())) | ||
embedding_scores = { | ||
doc_id: score | ||
for doc_id, score in zip(embedding_scores.keys(), embedding_scores_normalized) | ||
} | ||
|
||
# Combine the scores | ||
text_weight: float = 1 - alpha | ||
embedding_weight: float = alpha | ||
|
||
combined_scores = [] | ||
|
||
for id in all_docs.keys(): | ||
text_score = text_weight * text_scores.get(id, 0) | ||
embedding_score = embedding_weight * embedding_scores.get(id, 0) | ||
|
||
combined_scores.append((id, text_score + embedding_score)) | ||
|
||
# Sort by the combined score | ||
combined_scores = sorted(combined_scores, key=lambda x: x[1], reverse=True) | ||
|
||
# Rank the results | ||
ranked_results = [] | ||
for id, score in combined_scores: | ||
doc = all_docs[id].model_copy() | ||
doc.distance = 1.0 - score | ||
ranked_results.append(doc) | ||
|
||
return ranked_results | ||
|
||
|
||
@beartype | ||
def search_docs_hybrid( | ||
*, | ||
developer_id: UUID, | ||
owner_type: Literal["user", "agent"], | ||
owner_id: UUID, | ||
query: str, | ||
query_embedding: list[float], | ||
k: int = 3, | ||
embed_search_options: dict = {}, | ||
text_search_options: dict = {}, | ||
**kwargs, | ||
) -> list[DocReference]: | ||
# TODO: We should probably parallelize these queries | ||
text_results = search_docs_by_text( | ||
developer_id=developer_id, | ||
owner_type=owner_type, | ||
owner_id=owner_id, | ||
query=query, | ||
k=2 * k, | ||
**text_search_options, | ||
**kwargs, | ||
) | ||
|
||
embedding_results = search_docs_by_embedding( | ||
developer_id=developer_id, | ||
owner_type=owner_type, | ||
owner_id=owner_id, | ||
query_embedding=query_embedding, | ||
k=2 * k, | ||
**embed_search_options, | ||
**kwargs, | ||
) | ||
|
||
return dbsf_fuse(text_results, embedding_results)[:k] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters