diff --git a/lightrag/kg/neo4j_impl.py b/lightrag/kg/neo4j_impl.py index 9ad725b7..4905caf2 100644 --- a/lightrag/kg/neo4j_impl.py +++ b/lightrag/kg/neo4j_impl.py @@ -2,14 +2,16 @@ import html import os from dataclasses import dataclass -from typing import Any, Union, cast +from typing import Any, Union, cast, Tuple, List, Dict import numpy as np import inspect from lightrag.utils import load_json, logger, write_json from ..base import ( BaseGraphStorage ) -from neo4j import GraphDatabase, exceptions as neo4jExceptions +from neo4j import AsyncGraphDatabase,exceptions as neo4jExceptions,AsyncDriver,AsyncSession, AsyncManagedTransaction + +from contextlib import asynccontextmanager from tenacity import ( @@ -20,126 +22,135 @@ ) - @dataclass -class GraphStorage(BaseGraphStorage): +class Neo4JStorage(BaseGraphStorage): @staticmethod def load_nx_graph(file_name): print ("no preloading of graph with neo4j in production") + def __init__(self, namespace, global_config): + super().__init__(namespace=namespace, global_config=global_config) + self._driver = None + self._driver_lock = asyncio.Lock() + URI = os.environ["NEO4J_URI"] + USERNAME = os.environ["NEO4J_USERNAME"] + PASSWORD = os.environ["NEO4J_PASSWORD"] + self._driver: AsyncDriver = AsyncGraphDatabase.driver(URI, auth=(USERNAME, PASSWORD)) + return None + def __post_init__(self): # self._graph = preloaded_graph or nx.Graph() + print("is this ever run") credetial_parts = ['URI', 'USERNAME','PASSWORD'] credentials_set = all(x in os.environ for x in credetial_parts ) - if credentials_set: - URI = os.environ["NEO4J_URI"] - USERNAME = os.environ["NEO4J_USERNAME"] - PASSWORD = os.environ["NEO4J_PASSWORD"] - else: - raise Exception (f"One or more Neo4J Credentials, {credetial_parts}, not found in the environment") - - self._driver = GraphDatabase.driver(URI, auth=(USERNAME, PASSWORD)) self._node_embed_algorithms = { "node2vec": self._node2vec_embed, } + + async def close(self): + if self._driver: + await self._driver.close() + self._driver = None + + + + async def __aexit__(self, exc_type, exc, tb): + if self._driver: + await self._driver.close() + async def index_done_callback(self): print ("KG successfully indexed.") + + async def has_node(self, node_id: str) -> bool: entity_name_label = node_id.strip('\"') - def _check_node_exists(tx, label): - query = f"MATCH (n:`{label}`) RETURN count(n) > 0 AS node_exists" - result = tx.run(query) - single_result = result.single() + async with self._driver.session() as session: + query = f"MATCH (n:`{entity_name_label}`) RETURN count(n) > 0 AS node_exists" + result = await session.run(query) + single_result = await result.single() logger.debug( f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{single_result["node_exists"]}' ) - return single_result["node_exists"] - - with self._driver.session() as session: - return session.read_transaction(_check_node_exists, entity_name_label) - - + async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: entity_name_label_source = source_node_id.strip('\"') entity_name_label_target = target_node_id.strip('\"') - - def _check_edge_existence(tx, label1, label2): + async with self._driver.session() as session: query = ( - f"MATCH (a:`{label1}`)-[r]-(b:`{label2}`) " + f"MATCH (a:`{entity_name_label_source}`)-[r]-(b:`{entity_name_label_target}`) " "RETURN COUNT(r) > 0 AS edgeExists" ) - result = tx.run(query) - single_result = result.single() + result = await session.run(query) + single_result = await result.single() logger.debug( f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{single_result["edgeExists"]}' ) - return single_result["edgeExists"] + def close(self): self._driver.close() - #hard code relaitionship type, directed. - with self._driver.session() as session: - result = session.read_transaction(_check_edge_existence, entity_name_label_source, entity_name_label_target) - return result + async def get_node(self, node_id: str) -> Union[dict, None]: - entity_name_label = node_id.strip('\"') - with self._driver.session() as session: - query = "MATCH (n:`{entity_name_label}`) RETURN n".format(entity_name_label=entity_name_label) - result = session.run(query) - for record in result: - result = record["n"] + async with self._driver.session() as session: + entity_name_label = node_id.strip('\"') + query = f"MATCH (n:`{entity_name_label}`) RETURN n" + result = await session.run(query) + record = await result.single() + if record: + node = record["n"] + node_dict = dict(node) logger.debug( - f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{result}' - ) - return result - + f'{inspect.currentframe().f_code.co_name}: query: {query}, result: {node_dict}' + ) + return node_dict + return None + async def node_degree(self, node_id: str) -> int: entity_name_label = node_id.strip('\"') - - def _find_node_degree(session, label): - with session.begin_transaction() as tx: - query = f""" - MATCH (n:`{label}`) - RETURN COUNT{{ (n)--() }} AS totalEdgeCount - """ - result = tx.run(query) - record = result.single() - if record: - edge_count = record["totalEdgeCount"] - logger.debug( - f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{edge_count}' - ) - return edge_count - else: - return None + async with self._driver.session() as session: + query = f""" + MATCH (n:`{entity_name_label}`) + RETURN COUNT{{ (n)--() }} AS totalEdgeCount + """ + result = await session.run(query) + record = await result.single() + if record: + edge_count = record["totalEdgeCount"] + logger.debug( + f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{edge_count}' + ) + return edge_count + else: + return None - with self._driver.session() as session: - degree = _find_node_degree(session, entity_name_label) - return degree - async def edge_degree(self, src_id: str, tgt_id: str) -> int: entity_name_label_source = src_id.strip('\"') entity_name_label_target = tgt_id.strip('\"') - with self._driver.session() as session: - query = f"""MATCH (n1:`{entity_name_label_source}`)-[r]-(n2:`{entity_name_label_target}`) - RETURN count(r) AS degree""" - result = session.run(query) - record = result.single() - logger.debug( - f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{record["degree"]}' - ) - return record["degree"] + src_degree = await self.node_degree(entity_name_label_source) + trg_degree = await self.node_degree(entity_name_label_target) + + # Convert None to 0 for addition + src_degree = 0 if src_degree is None else src_degree + trg_degree = 0 if trg_degree is None else trg_degree + + degrees = int(src_degree) + int(trg_degree) + logger.debug( + f'{inspect.currentframe().f_code.co_name}:query:src_Degree+trg_degree:result:{degrees}' + ) + return degrees + + async def get_edge(self, source_node_id: str, target_node_id: str) -> Union[dict, None]: entity_name_label_source = source_node_id.strip('\"') @@ -154,15 +165,15 @@ async def get_edge(self, source_node_id: str, target_node_id: str) -> Union[dict Returns: list: List of all relationships/edges found """ - with self._driver.session() as session: + async with self._driver.session() as session: query = f""" MATCH (start:`{entity_name_label_source}`)-[r]->(end:`{entity_name_label_target}`) RETURN properties(r) as edge_properties LIMIT 1 """.format(entity_name_label_source=entity_name_label_source, entity_name_label_target=entity_name_label_target) - result = session.run(query) - record = result.single() + result = await session.run(query) + record = await result.single() if record: result = dict(record["edge_properties"]) logger.debug( @@ -173,29 +184,20 @@ async def get_edge(self, source_node_id: str, target_node_id: str) -> Union[dict return None - async def get_node_edges(self, source_node_id: str): + async def get_node_edges(self, source_node_id: str)-> List[Tuple[str, str]]: node_label = source_node_id.strip('\"') """ - Retrieves all edges (relationships) for a particular node identified by its label and ID. - - :param uri: Neo4j database URI - :param username: Neo4j username - :param password: Neo4j password - :param node_label: Label of the node - :param node_id: ID property of the node + Retrieves all edges (relationships) for a particular node identified by its label. :return: List of dictionaries containing edge information """ - - def fetch_edges(tx, label): - query = f"""MATCH (n:`{label}`) + query = f"""MATCH (n:`{node_label}`) OPTIONAL MATCH (n)-[r]-(connected) RETURN n, r, connected""" - - results = tx.run(query) - + async with self._driver.session() as session: + results = await session.run(query) edges = [] - for record in results: + async for record in results: source_node = record['n'] connected_node = record['connected'] @@ -207,7 +209,7 @@ def fetch_edges(tx, label): return edges - with self._driver.session() as session: + async with self._driver.session() as session: edges = session.read_transaction(fetch_edges,node_label) return edges @@ -217,86 +219,51 @@ def fetch_edges(tx, label): wait=wait_exponential(multiplier=1, min=4, max=10), retry=retry_if_exception_type((neo4jExceptions.ServiceUnavailable, neo4jExceptions.TransientError, neo4jExceptions.WriteServiceUnavailable)), ) - async def upsert_node(self, node_id: str, node_data: dict[str, str]): - label = node_id.strip('\"') - properties = node_data + async def upsert_node(self, node_id: str, node_data: Dict[str, Any]): """ - Upsert a node with the given label and properties within a transaction. + Upsert a node in the Neo4j database. + Args: - label: The node label to search for and apply - properties: Dictionary of node properties - - Returns: - Dictionary containing the node's properties after upsert, or None if operation fails + node_id: The unique identifier for the node (used as label) + node_data: Dictionary of node properties """ - def _do_upsert(tx, label: str, properties: dict[str, Any]): - - """ - Args: - tx: Neo4j transaction object - label: The node label to search for and apply - properties: Dictionary of node properties - - Returns: - Dictionary containing the node's properties after upsert, or None if operation fails - """ + label = node_id.strip('\"') + properties = node_data + async def _do_upsert(tx: AsyncManagedTransaction): query = f""" MERGE (n:`{label}`) SET n += $properties - RETURN n """ - # Execute the query with properties as parameters - # with session.begin_transaction() as tx: - result = tx.run(query, properties=properties) - record = result.single() - if record: - logger.debug( - f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{dict(record["n"])}' - ) - return dict(record["n"]) - return None - - - with self._driver.session() as session: - with session.begin_transaction() as tx: - try: - result = _do_upsert(tx,label,properties) - tx.commit() - return result - except Exception as e: - raise # roll back - + await tx.run(query, properties=properties) + logger.debug(f"Upserted node with label '{label}' and properties: {properties}") + + try: + async with self._driver.session() as session: + await session.execute_write(_do_upsert) + except Exception as e: + logger.error(f"Error during upsert: {str(e)}") + raise - - async def upsert_edge(self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]) -> None: - source_node_label = source_node_id.strip('\"') - target_node_label = target_node_id.strip('\"') - edge_properties = edge_data + @retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=4, max=10), + retry=retry_if_exception_type((neo4jExceptions.ServiceUnavailable, neo4jExceptions.TransientError, neo4jExceptions.WriteServiceUnavailable)), + ) + async def upsert_edge(self, source_node_id: str, target_node_id: str, edge_data: Dict[str, Any]): """ Upsert an edge and its properties between two nodes identified by their labels. - + Args: - source_node_label (str): Label of the source node (used as identifier) - target_node_label (str): Label of the target node (used as identifier) - edge_properties (dict): Dictionary of properties to set on the edge + source_node_id (str): Label of the source node (used as identifier) + target_node_id (str): Label of the target node (used as identifier) + edge_data (dict): Dictionary of properties to set on the edge """ - - - - def _do_upsert_edge(tx, source_node_label: str, target_node_label: str, edge_properties: dict[str, Any]) -> None: - """ - Static method to perform the edge upsert within a transaction. - - The query will: - 1. Match the source and target nodes by their labels - 2. Merge the DIRECTED relationship - 3. Set all properties on the relationship, updating existing ones and adding new ones - """ - # Convert edge properties to Cypher parameter string - # props_string = ", ".join(f"r.{key} = ${key}" for key in edge_properties.keys()) + source_node_label = source_node_id.strip('\"') + target_node_label = target_node_id.strip('\"') + edge_properties = edge_data - # """.format(props_string) + async def _do_upsert_edge(tx: AsyncManagedTransaction): query = f""" MATCH (source:`{source_node_label}`) WITH source @@ -305,22 +272,15 @@ def _do_upsert_edge(tx, source_node_label: str, target_node_label: str, edge_pro SET r += $properties RETURN r """ - - result = tx.run(query, properties=edge_properties) - logger.debug( - f'{inspect.currentframe().f_code.co_name}:query:{query}:edge_properties:{edge_properties}' - ) - return result.single() - - with self._driver.session() as session: - session.execute_write( - _do_upsert_edge, - source_node_label, - target_node_label, - edge_properties - ) - # return result - + await tx.run(query, properties=edge_properties) + logger.debug(f"Upserted edge from '{source_node_label}' to '{target_node_label}' with properties: {edge_properties}") + + try: + async with self._driver.session() as session: + await session.execute_write(_do_upsert_edge) + except Exception as e: + logger.error(f"Error during edge upsert: {str(e)}") + raise async def _node2vec_embed(self): print ("Implemented but never called.") diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index a42b806e..5d271860 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -26,7 +26,7 @@ ) from .kg.neo4j_impl import ( - GraphStorage as Neo4JStorage + Neo4JStorage ) #future KG integrations @@ -57,9 +57,10 @@ def always_get_an_event_loop() -> asyncio.AbstractEventLoop: try: loop = asyncio.get_running_loop() except RuntimeError: - logger.info("Creating a new event loop in a sub-thread.") - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) + logger.info("Creating a new event loop in main thread.") + # loop = asyncio.new_event_loop() + # asyncio.set_event_loop(loop) + loop = asyncio.get_event_loop() return loop @@ -329,4 +330,4 @@ async def _query_done(self): if storage_inst is None: continue tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback()) - await asyncio.gather(*tasks) + await asyncio.gather(*tasks) \ No newline at end of file diff --git a/lightrag/llm.py b/lightrag/llm.py index f4045e80..e93afa03 100644 --- a/lightrag/llm.py +++ b/lightrag/llm.py @@ -798,4 +798,4 @@ async def main(): result = await gpt_4o_mini_complete("How are you?") print(result) - asyncio.run(main()) + asyncio.run(main()) \ No newline at end of file diff --git a/lightrag/operate.py b/lightrag/operate.py index 6b6ba563..518bd68a 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -1083,4 +1083,4 @@ async def naive_query( .strip() ) - return response + return response \ No newline at end of file diff --git a/test_neo4j.py b/test_neo4j.py index 044c12e9..18195c32 100644 --- a/test_neo4j.py +++ b/test_neo4j.py @@ -2,6 +2,7 @@ from lightrag import LightRAG, QueryParam from lightrag.llm import gpt_4o_mini_complete, gpt_4o_complete + ######### # Uncomment the below two lines if running in a jupyter notebook to handle the async nature of rag.insert() # import nest_asyncio