Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow thresholding on vector and fulltext indexes for Hybrid retrievers #239

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/neo4j_graphrag/llm/anthropic_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
from __future__ import annotations

from typing import Any, Iterable, Optional, TYPE_CHECKING, cast
from typing import TYPE_CHECKING, Any, Iterable, Optional, cast

from pydantic import ValidationError

Expand Down
1 change: 1 addition & 0 deletions src/neo4j_graphrag/llm/cohere_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Any, Iterable, Optional, cast

from pydantic import ValidationError

from neo4j_graphrag.exceptions import LLMGenerationError
Expand Down
3 changes: 2 additions & 1 deletion src/neo4j_graphrag/llm/mistralai_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import os
from typing import Any, Iterable, Optional, cast

from pydantic import ValidationError

from neo4j_graphrag.exceptions import LLMGenerationError
Expand All @@ -30,7 +31,7 @@
)

try:
from mistralai import Mistral, Messages
from mistralai import Messages, Mistral
from mistralai.models.sdkerror import SDKError
except ImportError:
Mistral = None # type: ignore
Expand Down
5 changes: 3 additions & 2 deletions src/neo4j_graphrag/llm/ollama_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import Any, Iterable, Optional, Sequence, TYPE_CHECKING, cast

from typing import TYPE_CHECKING, Any, Iterable, Optional, Sequence, cast

from pydantic import ValidationError

Expand All @@ -24,9 +25,9 @@
BaseMessage,
LLMMessage,
LLMResponse,
MessageList,
SystemMessage,
UserMessage,
MessageList,
)

if TYPE_CHECKING:
Expand Down
2 changes: 1 addition & 1 deletion src/neo4j_graphrag/llm/openai_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@
BaseMessage,
LLMMessage,
LLMResponse,
MessageList,
SystemMessage,
UserMessage,
MessageList,
)

if TYPE_CHECKING:
Expand Down
3 changes: 2 additions & 1 deletion src/neo4j_graphrag/llm/types.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from pydantic import BaseModel
from typing import Literal, TypedDict

from pydantic import BaseModel


class LLMResponse(BaseModel):
content: str
Expand Down
4 changes: 2 additions & 2 deletions src/neo4j_graphrag/llm/vertexai_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@

try:
from vertexai.generative_models import (
Content,
GenerativeModel,
ResponseValidationError,
Part,
Content,
ResponseValidationError,
)
except ImportError:
GenerativeModel = None
Expand Down
15 changes: 8 additions & 7 deletions src/neo4j_graphrag/neo4j_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,25 +122,29 @@ def _get_hybrid_query(neo4j_version_is_5_23_or_above: bool) -> str:
f"CALL () {{ {VECTOR_INDEX_QUERY} "
f"WITH collect({{node:node, score:score}}) AS nodes, max(score) AS vector_index_max_score "
f"UNWIND nodes AS n "
f"RETURN n.node AS node, (n.score / vector_index_max_score) AS score "
f"RETURN n.node AS node, CASE WHEN (n.score / vector_index_max_score) >= $threshold_vector_index "
f"THEN (n.score / vector_index_max_score) ELSE 0 END AS score "
f"UNION "
f"{FULL_TEXT_SEARCH_QUERY} "
f"WITH collect({{node:node, score:score}}) AS nodes, max(score) AS ft_index_max_score "
f"UNWIND nodes AS n "
f"RETURN n.node AS node, (n.score / ft_index_max_score) AS score }} "
f"RETURN n.node AS node, CASE WHEN (n.score / ft_index_max_score) >= $threshold_fulltext_index "
f"THEN (n.score / ft_index_max_score) ELSE 0 END AS score }} "
f"WITH node, max(score) AS score ORDER BY score DESC LIMIT $top_k"
)
else:
return (
f"CALL {{ {VECTOR_INDEX_QUERY} "
f"WITH collect({{node:node, score:score}}) AS nodes, max(score) AS vector_index_max_score "
f"UNWIND nodes AS n "
f"RETURN n.node AS node, (n.score / vector_index_max_score) AS score "
f"RETURN n.node AS node, CASE WHEN (n.score / vector_index_max_score) >= $threshold_vector_index "
f"THEN (n.score / vector_index_max_score) ELSE 0 END AS score "
f"UNION "
f"{FULL_TEXT_SEARCH_QUERY} "
f"WITH collect({{node:node, score:score}}) AS nodes, max(score) AS ft_index_max_score "
f"UNWIND nodes AS n "
f"RETURN n.node AS node, (n.score / ft_index_max_score) AS score }} "
f"RETURN n.node AS node, CASE WHEN (n.score / ft_index_max_score) >= $threshold_fulltext_index "
f"THEN (n.score / ft_index_max_score) ELSE 0 END AS score }} "
f"WITH node, max(score) AS score ORDER BY score DESC LIMIT $top_k"
)

Expand Down Expand Up @@ -186,7 +190,6 @@ def get_search_query(
neo4j_version_is_5_23_or_above: bool = False,
) -> tuple[str, dict[str, Any]]:
"""Build the search query, including pre-filtering if needed, and return clause.

Args
search_type: Search type we want to search for:
return_properties (list[str]): list of property names to return.
Expand All @@ -197,10 +200,8 @@ def get_search_query(
embedding_node_property (str): the name of the property holding the embeddings
embedding_dimension (int): the dimension of the embeddings
filters (dict[str, Any]): filters used to pre-filter the nodes before vector search

Returns:
tuple[str, dict[str, Any]]: query and parameters

"""
warnings.warn(
"The default returned 'id' field in the search results will be removed. Please switch to using 'elementId' instead.",
Expand Down
14 changes: 14 additions & 0 deletions src/neo4j_graphrag/retrievers/hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,8 @@ def get_search_results(
query_text: str,
query_vector: Optional[list[float]] = None,
top_k: int = 5,
threshold_vector_index: float = 0.0,
threshold_fulltext_index: float = 0.0,
) -> RawSearchResult:
"""Get the top_k nearest neighbor embeddings for either provided query_vector or query_text.
Both query_vector and query_text can be provided.
Expand All @@ -159,6 +161,8 @@ def get_search_results(
query_text (str): The text to get the closest neighbors of.
query_vector (Optional[list[float]], optional): The vector embeddings to get the closest neighbors of. Defaults to None.
top_k (int, optional): The number of neighbors to return. Defaults to 5.
threshold_vector_index (float, optional): The minimum normalized score from the vector index to include in the top k search.
threshold_fulltext_index (float, optional): The minimum normalized score from the fulltext index to include in the top k search.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor point but users might not know or understand the normalisation process. It could be worth adding something to the docs with an example which explains a bit about what's going on with these parameters as I think on their own these descriptions might not be enough


Raises:
SearchValidationError: If validation of the input arguments fail.
Expand All @@ -180,6 +184,9 @@ def get_search_results(
parameters["vector_index_name"] = self.vector_index_name
parameters["fulltext_index_name"] = self.fulltext_index_name

parameters["threshold_vector_index"] = threshold_vector_index
parameters["threshold_fulltext_index"] = threshold_fulltext_index

if query_text and not query_vector:
if not self.embedder:
raise EmbeddingRequiredError(
Expand Down Expand Up @@ -296,6 +303,8 @@ def get_search_results(
query_vector: Optional[list[float]] = None,
top_k: int = 5,
query_params: Optional[dict[str, Any]] = None,
threshold_vector_index: float = 0.0,
threshold_fulltext_index: float = 0.0,
) -> RawSearchResult:
"""Get the top_k nearest neighbor embeddings for either provided query_vector or query_text.
Both query_vector and query_text can be provided.
Expand All @@ -313,6 +322,8 @@ def get_search_results(
query_vector (Optional[list[float]]): The vector embeddings to get the closest neighbors of. Defaults to None.
top_k (int): The number of neighbors to return. Defaults to 5.
query_params (Optional[dict[str, Any]]): Parameters for the Cypher query. Defaults to None.
threshold_vector_index (float, optional): The minimum normalized score from the vector index to include in the top k search.
threshold_fulltext_index (float, optional): The minimum normalized score from the fulltext index to include in the top k search.

Raises:
SearchValidationError: If validation of the input arguments fail.
Expand All @@ -335,6 +346,9 @@ def get_search_results(
parameters["vector_index_name"] = self.vector_index_name
parameters["fulltext_index_name"] = self.fulltext_index_name

parameters["threshold_vector_index"] = threshold_vector_index
parameters["threshold_fulltext_index"] = threshold_fulltext_index

if query_text and not query_vector:
if not self.embedder:
raise EmbeddingRequiredError(
Expand Down
1 change: 0 additions & 1 deletion tests/unit/llm/test_anthropic_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from unittest.mock import AsyncMock, MagicMock, Mock, patch

import anthropic

import pytest
from neo4j_graphrag.exceptions import LLMGenerationError
from neo4j_graphrag.llm.anthropic_llm import AnthropicLLM
Expand Down
6 changes: 6 additions & 0 deletions tests/unit/retrievers/test_hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,8 @@ def test_hybrid_search_text_happy_path(
"query_text": query_text,
"fulltext_index_name": fulltext_index_name,
"query_vector": embed_query_vector,
"threshold_vector_index": 0.0,
"threshold_fulltext_index": 0.0,
},
database_=None,
routing_=neo4j.RoutingControl.READ,
Expand Down Expand Up @@ -262,6 +264,8 @@ def test_hybrid_search_favors_query_vector_over_embedding_vector(
"query_text": query_text,
"fulltext_index_name": fulltext_index_name,
"query_vector": query_vector,
"threshold_vector_index": 0.0,
"threshold_fulltext_index": 0.0,
},
database_=database,
routing_=neo4j.RoutingControl.READ,
Expand Down Expand Up @@ -345,6 +349,8 @@ def test_hybrid_retriever_return_properties(
"query_text": query_text,
"fulltext_index_name": fulltext_index_name,
"query_vector": embed_query_vector,
"threshold_vector_index": 0.0,
"threshold_fulltext_index": 0.0,
},
database_=None,
routing_=neo4j.RoutingControl.READ,
Expand Down
24 changes: 16 additions & 8 deletions tests/unit/test_neo4j_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,14 @@ def test_hybrid_search_basic() -> None:
"YIELD node, score "
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could it be worth adding a test to check the threshold process is working as expected? i.e. the correct scores are set to zero, etc.

"WITH collect({node:node, score:score}) AS nodes, max(score) AS vector_index_max_score "
"UNWIND nodes AS n "
"RETURN n.node AS node, (n.score / vector_index_max_score) AS score UNION "
"RETURN n.node AS node, CASE WHEN (n.score / vector_index_max_score) >= $threshold_vector_index "
"THEN (n.score / vector_index_max_score) ELSE 0 END AS score UNION "
"CALL db.index.fulltext.queryNodes($fulltext_index_name, $query_text, {limit: $top_k}) "
"YIELD node, score "
"WITH collect({node:node, score:score}) AS nodes, max(score) AS ft_index_max_score "
"UNWIND nodes AS n "
"RETURN n.node AS node, (n.score / ft_index_max_score) AS score "
"RETURN n.node AS node, CASE WHEN (n.score / ft_index_max_score) >= $threshold_fulltext_index "
"THEN (n.score / ft_index_max_score) ELSE 0 END AS score "
"} "
"WITH node, max(score) AS score ORDER BY score DESC LIMIT $top_k "
"RETURN node { .*, `None`: null } AS node, labels(node) AS nodeLabels, elementId(node) AS elementId, elementId(node) AS id, score"
Expand Down Expand Up @@ -129,17 +131,20 @@ def test_hybrid_search_with_retrieval_query() -> None:
"YIELD node, score "
"WITH collect({node:node, score:score}) AS nodes, max(score) AS vector_index_max_score "
"UNWIND nodes AS n "
"RETURN n.node AS node, (n.score / vector_index_max_score) AS score UNION "
"RETURN n.node AS node, CASE WHEN (n.score / vector_index_max_score) >= $threshold_vector_index THEN (n.score / vector_index_max_score) ELSE 0 END AS score UNION "
"CALL db.index.fulltext.queryNodes($fulltext_index_name, $query_text, {limit: $top_k}) "
"YIELD node, score "
"WITH collect({node:node, score:score}) AS nodes, max(score) AS ft_index_max_score "
"UNWIND nodes AS n "
"RETURN n.node AS node, (n.score / ft_index_max_score) AS score "
"RETURN n.node AS node, CASE WHEN (n.score / ft_index_max_score) >= $threshold_fulltext_index THEN (n.score / ft_index_max_score) ELSE 0 END AS score "
"} "
"WITH node, max(score) AS score ORDER BY score DESC LIMIT $top_k "
+ retrieval_query
)
result, _ = get_search_query(SearchType.HYBRID, retrieval_query=retrieval_query)
result, _ = get_search_query(
SearchType.HYBRID,
retrieval_query=retrieval_query,
)
assert result.strip() == expected.strip()


Expand All @@ -151,17 +156,20 @@ def test_hybrid_search_with_properties() -> None:
"YIELD node, score "
"WITH collect({node:node, score:score}) AS nodes, max(score) AS vector_index_max_score "
"UNWIND nodes AS n "
"RETURN n.node AS node, (n.score / vector_index_max_score) AS score UNION "
"RETURN n.node AS node, CASE WHEN (n.score / vector_index_max_score) >= $threshold_vector_index THEN (n.score / vector_index_max_score) ELSE 0 END AS score UNION "
"CALL db.index.fulltext.queryNodes($fulltext_index_name, $query_text, {limit: $top_k}) "
"YIELD node, score "
"WITH collect({node:node, score:score}) AS nodes, max(score) AS ft_index_max_score "
"UNWIND nodes AS n "
"RETURN n.node AS node, (n.score / ft_index_max_score) AS score "
"RETURN n.node AS node, CASE WHEN (n.score / ft_index_max_score) >= $threshold_fulltext_index THEN (n.score / ft_index_max_score) ELSE 0 END AS score "
"} "
"WITH node, max(score) AS score ORDER BY score DESC LIMIT $top_k "
"RETURN node {.name, .age} AS node, labels(node) AS nodeLabels, elementId(node) AS elementId, elementId(node) AS id, score"
)
result, _ = get_search_query(SearchType.HYBRID, return_properties=properties)
result, _ = get_search_query(
SearchType.HYBRID,
return_properties=properties,
)
assert result.strip() == expected.strip()


Expand Down
Loading