Skip to content

Commit

Permalink
Refactor id variables in neo4j queries (#207)
Browse files Browse the repository at this point in the history
* Refactor id variables in neo4j queries

* Add warning in get_search_query

* Add duplicate elementId(node) as nodeId

* Rename to elementId

* Rename node_id and rel_id to element id

* Revert start_node_id and end_node_id

* Rename to node_element_id and rel_element_id

* Update docstring
  • Loading branch information
willtai authored Nov 8, 2024
1 parent 18e8e2a commit 0eb888f
Show file tree
Hide file tree
Showing 6 changed files with 40 additions and 26 deletions.
2 changes: 1 addition & 1 deletion src/neo4j_graphrag/experimental/components/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class Neo4jNode(BaseModel):
"""Represents a Neo4j node.
Attributes:
id (str): The ID of the node.
id (str): The element ID of the node.
label (str): The label of the node.
properties (dict[str, Any]): A dictionary of properties attached to the node.
embedding_properties (Optional[dict[str, list[float]]]): A list of embedding properties attached to the node.
Expand Down
16 changes: 8 additions & 8 deletions src/neo4j_graphrag/indexes.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ def upsert_vector(
Args:
driver (neo4j.Driver): Neo4j Python driver instance.
node_id (int): The id of the node.
node_id (int): The element id of the node.
embedding_property (str): The name of the property to store the vector in.
vector (list[float]): The vector to store.
neo4j_database (Optional[str]): The name of the Neo4j database. If not provided, this defaults to "neo4j" in the database (`see reference to documentation <https://neo4j.com/docs/operations-manual/current/database-administration/#manage-databases-default>`_).
Expand All @@ -288,7 +288,7 @@ def upsert_vector(
"""
try:
parameters = {
"id": node_id,
"node_element_id": node_id,
"embedding_property": embedding_property,
"vector": vector,
}
Expand Down Expand Up @@ -334,7 +334,7 @@ def upsert_vector_on_relationship(
Args:
driver (neo4j.Driver): Neo4j Python driver instance.
rel_id (int): The id of the relationship.
rel_id (int): The element id of the relationship.
embedding_property (str): The name of the property to store the vector in.
vector (list[float]): The vector to store.
neo4j_database (Optional[str]): The name of the Neo4j database. If not provided, this defaults to "neo4j" in the database (`see reference to documentation <https://neo4j.com/docs/operations-manual/current/database-administration/#manage-databases-default>`_).
Expand All @@ -344,7 +344,7 @@ def upsert_vector_on_relationship(
"""
try:
parameters = {
"id": rel_id,
"rel_element_id": rel_id,
"embedding_property": embedding_property,
"vector": vector,
}
Expand Down Expand Up @@ -391,7 +391,7 @@ async def async_upsert_vector(
Args:
driver (neo4j.AsyncDriver): Neo4j Python asynchronous driver instance.
node_id (int): The id of the node.
node_id (int): The element id of the node.
embedding_property (str): The name of the property to store the vector in.
vector (list[float]): The vector to store.
neo4j_database (Optional[str]): The name of the Neo4j database. If not provided, this defaults to "neo4j" in the database (`see reference to documentation <https://neo4j.com/docs/operations-manual/current/database-administration/#manage-databases-default>`_).
Expand All @@ -401,7 +401,7 @@ async def async_upsert_vector(
"""
try:
parameters = {
"id": node_id,
"node_id": node_id,
"embedding_property": embedding_property,
"vector": vector,
}
Expand Down Expand Up @@ -448,7 +448,7 @@ async def async_upsert_vector_on_relationship(
Args:
driver (neo4j.AsyncDriver): Neo4j Python asynchronous driver instance.
rel_id (int): The id of the relationship.
rel_id (int): The element id of the relationship.
embedding_property (str): The name of the property to store the vector in.
vector (list[float]): The vector to store.
neo4j_database (Optional[str]): The name of the Neo4j database. If not provided, this defaults to "neo4j" in the database (`see reference to documentation <https://neo4j.com/docs/operations-manual/current/database-administration/#manage-databases-default>`_).
Expand All @@ -458,7 +458,7 @@ async def async_upsert_vector_on_relationship(
"""
try:
parameters = {
"id": rel_id,
"rel_id": rel_id,
"embedding_property": embedding_property,
"vector": vector,
}
Expand Down
14 changes: 10 additions & 4 deletions src/neo4j_graphrag/neo4j_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.
from __future__ import annotations

import warnings
from typing import Any, Optional

from neo4j_graphrag.filters import get_metadata_filter
Expand Down Expand Up @@ -100,15 +101,15 @@

UPSERT_VECTOR_ON_NODE_QUERY = (
"MATCH (n) "
"WHERE elementId(n) = $id "
"WHERE elementId(n) = $node_element_id "
"WITH n "
"CALL db.create.setNodeVectorProperty(n, $embedding_property, $vector) "
"RETURN n"
)

UPSERT_VECTOR_ON_RELATIONSHIP_QUERY = (
"MATCH ()-[r]->() "
"WHERE elementId(r) = $id "
"WHERE elementId(r) = $rel_element_id "
"WITH r "
"CALL db.create.setRelationshipVectorProperty(r, $embedding_property, $vector) "
"RETURN r"
Expand Down Expand Up @@ -201,6 +202,11 @@ def get_search_query(
tuple[str, dict[str, Any]]: query and parameters
"""
warnings.warn(
"The default returned 'id' field in the search results will be removed. Please switch to using 'elementId' instead.",
DeprecationWarning,
stacklevel=2,
)
if search_type == SearchType.HYBRID:
if filters:
raise Exception("Filters are not supported with Hybrid Search")
Expand All @@ -227,7 +233,7 @@ def get_search_query(
query_tail = get_query_tail(
retrieval_query,
return_properties,
fallback_return=f"RETURN node {{ .*, `{embedding_node_property}`: null }} AS node, labels(node) AS nodeLabels, elementId(node) AS id, score",
fallback_return=f"RETURN node {{ .*, `{embedding_node_property}`: null }} AS node, labels(node) AS nodeLabels, elementId(node) AS elementId, elementId(node) AS id, score",
)
return f"{query} {query_tail}", params

Expand All @@ -253,5 +259,5 @@ def get_query_tail(
return retrieval_query
if return_properties:
return_properties_cypher = ", ".join([f".{prop}" for prop in return_properties])
return f"RETURN node {{{return_properties_cypher}}} AS node, labels(node) AS nodeLabels, elementId(node) AS id, score"
return f"RETURN node {{{return_properties_cypher}}} AS node, labels(node) AS nodeLabels, elementId(node) AS elementId, elementId(node) AS id, score"
return fallback_return if fallback_return else ""
2 changes: 1 addition & 1 deletion tests/unit/retrievers/external/test_weaviate.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def test_match_query_with_return_properties() -> None:
"WITH match_param[0] AS match_id_value, match_param[1] AS score "
"MATCH (node) "
"WHERE node[$id_property] = match_id_value "
"RETURN node {.name, .age} AS node, labels(node) AS nodeLabels, elementId(node) AS id, score"
"RETURN node {.name, .age} AS node, labels(node) AS nodeLabels, elementId(node) AS elementId, elementId(node) AS id, score"
)
assert match_query.strip() == expected.strip()

Expand Down
16 changes: 12 additions & 4 deletions tests/unit/test_indexes.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,15 +219,19 @@ def test_upsert_vector_happy_path(driver: MagicMock) -> None:

upsert_query = (
"MATCH (n) "
"WHERE elementId(n) = $id "
"WHERE elementId(n) = $node_element_id "
"WITH n "
"CALL db.create.setNodeVectorProperty(n, $embedding_property, $vector) "
"RETURN n"
)

driver.execute_query.assert_called_once_with(
upsert_query,
{"id": id, "embedding_property": embedding_property, "vector": vector},
{
"node_element_id": id,
"embedding_property": embedding_property,
"vector": vector,
},
database_=None,
)

Expand All @@ -241,15 +245,19 @@ def test_upsert_vector_on_relationship_happy_path(driver: MagicMock) -> None:

upsert_query = (
"MATCH ()-[r]->() "
"WHERE elementId(r) = $id "
"WHERE elementId(r) = $rel_element_id "
"WITH r "
"CALL db.create.setRelationshipVectorProperty(r, $embedding_property, $vector) "
"RETURN r"
)

driver.execute_query.assert_called_once_with(
upsert_query,
{"id": id, "embedding_property": embedding_property, "vector": vector},
{
"rel_element_id": id,
"embedding_property": embedding_property,
"vector": vector,
},
database_=None,
)

Expand Down
16 changes: 8 additions & 8 deletions tests/unit/test_neo4j_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def test_vector_search_basic() -> None:
expected = (
"CALL db.index.vector.queryNodes($vector_index_name, $top_k, $query_vector) "
"YIELD node, score "
"RETURN node { .*, `None`: null } AS node, labels(node) AS nodeLabels, elementId(node) AS id, score"
"RETURN node { .*, `None`: null } AS node, labels(node) AS nodeLabels, elementId(node) AS elementId, elementId(node) AS id, score"
)
result, params = get_search_query(SearchType.VECTOR)
assert result.strip() == expected.strip()
Expand All @@ -45,7 +45,7 @@ def test_hybrid_search_basic() -> None:
"RETURN n.node AS node, (n.score / ft_index_max_score) AS score "
"} "
"WITH node, max(score) AS score ORDER BY score DESC LIMIT $top_k "
"RETURN node { .*, `None`: null } AS node, labels(node) AS nodeLabels, elementId(node) AS id, score"
"RETURN node { .*, `None`: null } AS node, labels(node) AS nodeLabels, elementId(node) AS elementId, elementId(node) AS id, score"
)
result, _ = get_search_query(SearchType.HYBRID)
assert result.strip() == expected.strip()
Expand All @@ -56,7 +56,7 @@ def test_vector_search_with_properties() -> None:
expected = (
"CALL db.index.vector.queryNodes($vector_index_name, $top_k, $query_vector) "
"YIELD node, score "
"RETURN node {.name, .age} AS node, labels(node) AS nodeLabels, elementId(node) AS id, score"
"RETURN node {.name, .age} AS node, labels(node) AS nodeLabels, elementId(node) AS elementId, elementId(node) AS id, score"
)
result, _ = get_search_query(SearchType.VECTOR, return_properties=properties)
assert result.strip() == expected.strip()
Expand All @@ -82,7 +82,7 @@ def test_vector_search_with_filters(_mock: Any) -> None:
"WITH node, "
"vector.similarity.cosine(node.`vector`, $query_vector) AS score "
"ORDER BY score DESC LIMIT $top_k "
"RETURN node { .*, `vector`: null } AS node, labels(node) AS nodeLabels, elementId(node) AS id, score"
"RETURN node { .*, `vector`: null } AS node, labels(node) AS nodeLabels, elementId(node) AS elementId, elementId(node) AS id, score"
)
result, params = get_search_query(
SearchType.VECTOR,
Expand All @@ -108,7 +108,7 @@ def test_vector_search_with_params_from_filters(_mock: Any) -> None:
"WITH node, "
"vector.similarity.cosine(node.`vector`, $query_vector) AS score "
"ORDER BY score DESC LIMIT $top_k "
"RETURN node { .*, `vector`: null } AS node, labels(node) AS nodeLabels, elementId(node) AS id, score"
"RETURN node { .*, `vector`: null } AS node, labels(node) AS nodeLabels, elementId(node) AS elementId, elementId(node) AS id, score"
)
result, params = get_search_query(
SearchType.VECTOR,
Expand Down Expand Up @@ -159,7 +159,7 @@ def test_hybrid_search_with_properties() -> None:
"RETURN n.node AS node, (n.score / ft_index_max_score) AS score "
"} "
"WITH node, max(score) AS score ORDER BY score DESC LIMIT $top_k "
"RETURN node {.name, .age} AS node, labels(node) AS nodeLabels, elementId(node) AS id, score"
"RETURN node {.name, .age} AS node, labels(node) AS nodeLabels, elementId(node) AS elementId, elementId(node) AS id, score"
)
result, _ = get_search_query(SearchType.HYBRID, return_properties=properties)
assert result.strip() == expected.strip()
Expand All @@ -174,7 +174,7 @@ def test_get_query_tail_with_retrieval_query() -> None:

def test_get_query_tail_with_properties() -> None:
properties = ["name", "age"]
expected = "RETURN node {.name, .age} AS node, labels(node) AS nodeLabels, elementId(node) AS id, score"
expected = "RETURN node {.name, .age} AS node, labels(node) AS nodeLabels, elementId(node) AS elementId, elementId(node) AS id, score"
result = get_query_tail(return_properties=properties)
assert result.strip() == expected.strip()

Expand Down Expand Up @@ -204,7 +204,7 @@ def test_get_query_tail_ordering_no_retrieval_query() -> None:
properties = ["name", "age"]
fallback = "HELLO"

expected = "RETURN node {.name, .age} AS node, labels(node) AS nodeLabels, elementId(node) AS id, score"
expected = "RETURN node {.name, .age} AS node, labels(node) AS nodeLabels, elementId(node) AS elementId, elementId(node) AS id, score"
result = get_query_tail(
return_properties=properties,
fallback_return=fallback,
Expand Down

0 comments on commit 0eb888f

Please sign in to comment.