From 48d27a920f231a8b0cf18efecf04acd2c64b4829 Mon Sep 17 00:00:00 2001 From: Diwank Singh Tomer Date: Wed, 18 Sep 2024 13:08:44 -0400 Subject: [PATCH] fix(agents-api): Fix doc recall using search by text Signed-off-by: Diwank Singh Tomer --- agents-api/agents_api/models/docs/search_docs_by_text.py | 5 +++++ agents-api/agents_api/models/docs/search_docs_hybrid.py | 2 -- agents-api/agents_api/routers/sessions/chat.py | 9 ++++++++- 3 files changed, 13 insertions(+), 3 deletions(-) diff --git a/agents-api/agents_api/models/docs/search_docs_by_text.py b/agents-api/agents_api/models/docs/search_docs_by_text.py index 0662aa84d..eeae8362c 100644 --- a/agents-api/agents_api/models/docs/search_docs_by_text.py +++ b/agents-api/agents_api/models/docs/search_docs_by_text.py @@ -1,5 +1,6 @@ """This module contains functions for searching documents in the CozoDB based on embedding queries.""" +import json from typing import Any, Literal, TypeVar from uuid import UUID @@ -61,6 +62,10 @@ def search_docs_by_text( [owner_type, str(owner_id)] for owner_type, owner_id in owners ] + # Need to use NEAR/3($query) to search for arbitrary text within 3 words of each other + # See: https://docs.cozodb.org/en/latest/vector.html#full-text-search-fts + query = f"NEAR/3({json.dumps(query)})" + # Construct the datalog query for searching document snippets search_query = f""" owners[owner_type, owner_id] <- $owners diff --git a/agents-api/agents_api/models/docs/search_docs_hybrid.py b/agents-api/agents_api/models/docs/search_docs_hybrid.py index 03fb44037..598600511 100644 --- a/agents-api/agents_api/models/docs/search_docs_hybrid.py +++ b/agents-api/agents_api/models/docs/search_docs_hybrid.py @@ -42,8 +42,6 @@ def dbsf_fuse( """ all_docs = {doc.id: doc for doc in text_results + embedding_results} - assert all(doc.distance is not None in all_docs for doc in text_results) - text_scores: dict[UUID, float] = { doc.id: -(doc.distance or 0.0) for doc in text_results } diff --git a/agents-api/agents_api/routers/sessions/chat.py b/agents-api/agents_api/routers/sessions/chat.py index b7bf96d3a..7af5d02c0 100644 --- a/agents-api/agents_api/routers/sessions/chat.py +++ b/agents-api/agents_api/routers/sessions/chat.py @@ -58,7 +58,13 @@ async def chat( # Prepare the environment env: dict = chat_context.get_chat_environment() - env["docs"] = doc_references + env["docs"] = [ + dict( + title=ref.title, + content=[snippet.content for snippet in ref.snippets], + ) + for ref in doc_references + ] # Render the system message if situation := chat_context.session.situation: @@ -120,6 +126,7 @@ async def chat( ] # Add the response to the new entries + # FIXME: We need to save all the choices new_entries.append( CreateEntryRequest.from_model_input( model=settings["model"],