From c00d08d22bda5e484a14bb6626f26424805e8d21 Mon Sep 17 00:00:00 2001 From: vedantsahai18 Date: Fri, 3 Jan 2025 12:11:40 -0500 Subject: [PATCH] feat(agents-api): added mmr to chat --- .../queries/chat/gather_messages.py | 28 ++++++++++++++++++- 1 file changed, 27 insertions(+), 1 deletion(-) diff --git a/agents-api/agents_api/queries/chat/gather_messages.py b/agents-api/agents_api/queries/chat/gather_messages.py index 683eb1cf0..053208426 100644 --- a/agents-api/agents_api/queries/chat/gather_messages.py +++ b/agents-api/agents_api/queries/chat/gather_messages.py @@ -16,6 +16,8 @@ from ..entries.get_history import get_history from ..sessions.get_session import get_session from ..utils import rewrap_exceptions +from ..docs.mmr import maximal_marginal_relevance +import numpy as np T = TypeVar("T") @@ -133,6 +135,30 @@ async def gather_messages( connection_pool=connection_pool, ) - # TODO: Add missing MMR implementation + # Apply MMR if enabled + if ( + # MMR is enabled + recall_options.mmr_strength > 0 + # The number of doc references is greater than the limit + and len(doc_references) > recall_options.limit + # MMR is not applied to text search + and recall_options.mode != "text" + ): + # FIXME: This is a temporary fix to ensure that the MMR algorithm works. + # We shouldn't be having references without embeddings. + doc_references = [ + doc for doc in doc_references if doc.snippet.embedding is not None + ] + + # Apply MMR + indices = maximal_marginal_relevance( + np.asarray(query_embedding), + [doc.snippet.embedding for doc in doc_references], + k=recall_options.limit, + ) + # Apply MMR + doc_references = [ + doc for i, doc in enumerate(doc_references) if i in set(indices) + ] return past_messages, doc_references