Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

修改实体和关系提取方式,分为两步进行,首先提取实体,然后根据实体提取关系。 #401

Closed
wants to merge 39 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
284617f
1.修改日志记录方式,改为使用loguru,避免文件写并发报错。
Nov 25, 2024
1397cfd
1.增加prompt_cn.py文件,改为使用中文方式提取实体和关系;
Dec 5, 2024
080204f
Merge branch 'main' into main
bumaple Dec 5, 2024
2103dd5
Merge branch 'HKUDS:main' into main
bumaple Dec 6, 2024
bea6290
1.修改prompt_cn.py、operate.py文件,改为适配fork仓库最新功能。
Dec 6, 2024
285aa7d
Merge remote-tracking branch 'origin/main'
Dec 6, 2024
14a37c9
Merge branch 'HKUDS:main' into main
bumaple Dec 6, 2024
e520be5
Merge branch 'main' of github.com:HKUDS/LightRAG
Dec 9, 2024
4f6eee6
Merge branch 'main' of github.com:HKUDS/LightRAG
Dec 10, 2024
ad91bd5
Merge branch 'main' of github.com:bumaple/LightRAG_Mod
Dec 10, 2024
d613eff
1.完善中文抽取提示词;
Dec 11, 2024
b9e9844
Merge branch 'main' of github.com:HKUDS/LightRAG
Dec 12, 2024
42e446e
1.恢复原实体及关系同步提取的方式,新增分步提取模式,通过参数设置。
Dec 14, 2024
1af5097
Merge branch 'main' of github.com:HKUDS/LightRAG
Dec 14, 2024
dbb8b3e
1.修改BUG,优化功能。
Dec 23, 2024
0127f9c
Merge branch 'main' of github.com:HKUDS/LightRAG
Dec 23, 2024
933d02d
1.优化根据markdown标题分块算法,增加引入调用openai类接口的timeout参数。
Dec 25, 2024
4efecbe
Merge branch 'main' of github.com:HKUDS/LightRAG
Dec 25, 2024
5c75aa2
Merge branch 'main' of github.com:HKUDS/LightRAG
Dec 26, 2024
eefc0cc
Merge branch 'main' of github.com:HKUDS/LightRAG
Jan 3, 2025
3c0a0fd
1.完善text_utils.py中规范化标准编号功能,能够不再将2个字母后回车再接数字识别为标准编号;
Jan 3, 2025
0e7cad8
1.修复MarkdownHeader分块bug。
Jan 4, 2025
25b00d6
1.取消实体和关系分离提取(效果不好)。
Jan 7, 2025
2df1c5a
Merge branch 'main' of github.com:HKUDS/LightRAG
Jan 7, 2025
5f55e8e
1.取消实体和关系分离提取(效果不好)。
Jan 7, 2025
12d09ed
1.修改缩减实体类型文字描述
Jan 9, 2025
8b42ab1
1.处理流式输出。
Jan 9, 2025
4bb69ba
1.处理流式输出。
Jan 9, 2025
d8ac0db
1.openai模式接口增加BadRequestError重试。
Jan 9, 2025
5378b6e
1.回滚到处理流之前。添加openai_complete_if_cache函数BadRequestError异常重试机制。
Jan 10, 2025
2e34d8f
1.添加neo4j_impl.py的默认timeout参数。
Jan 10, 2025
93a72b9
1.增加mix模式回答
Jan 11, 2025
7088cb6
1.增加分块日志对外输出支持。
Jan 13, 2025
175cd02
1.增加分块日志对外输出支持。
Jan 13, 2025
0fa7f64
1.增加分块日志对外输出支持。
Jan 13, 2025
c8f3d55
1.增加neo4j_impl.py默认链接参数。
Jan 14, 2025
54cc4df
1.完善中文查询提示词。
Jan 15, 2025
e29848f
1.完善中文查询提示词。
Jan 16, 2025
0381874
1.完善中文查询提示词。
Jan 16, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/lightrag_oracle_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ async def main():
rag.key_string_value_json_storage_cls.db = oracle_db
rag.vector_db_storage_cls.db = oracle_db
# add embedding_func for graph database, it's deleted in commit 5661d76860436f7bf5aef2e50d9ee4a59660146c
rag.chunk_entity_relation_graph.embedding_func = rag.embedding_func
rag.chunk_entity_relation_graph._embedding_func = rag.embedding_func

# Extract and Insert into LightRAG storage
with open("./dickens/demo.txt", "r", encoding="utf-8") as f:
Expand Down
2 changes: 1 addition & 1 deletion lightrag/kg/milvus_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ async def query(self, query, top_k=5):
output_fields=list(self.meta_fields),
search_params={"metric_type": "COSINE", "params": {"radius": 0.2}},
)
print(results)
print(f"Query VectorDB Results: {len(results[0])}\n * * * * * {results}\n* * * * *")
return [
{**dp["entity"], "id": dp["id"], "distance": dp["distance"]}
for dp in results[0]
Expand Down
38 changes: 27 additions & 11 deletions lightrag/kg/neo4j_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,29 @@ def __init__(self, namespace, global_config, embedding_func):
"NEO4J_DATABASE"
) # If this param is None, the home database will be used. If it is not None, the specified database will be used.
self._DATABASE = DATABASE
# 增加默认参数 by bumaple 2025-01-10
self._timeout = 600
self._check_timeout = 30
self._conn_pool_size = 50

self._driver: AsyncDriver = AsyncGraphDatabase.driver(
URI, auth=(USERNAME, PASSWORD)
URI, auth=(USERNAME, PASSWORD),
connection_acquisition_timeout=self._timeout,
max_connection_lifetime=self._timeout * 6,
max_connection_pool_size=self._conn_pool_size,
connection_timeout=self._check_timeout,
liveness_check_timeout=self._check_timeout,
)
_database_name = "home database" if DATABASE is None else f"database {DATABASE}"
with GraphDatabase.driver(URI, auth=(USERNAME, PASSWORD)) as _sync_driver:
with GraphDatabase.driver(URI, auth=(USERNAME, PASSWORD),
connection_acquisition_timeout=self._timeout,
max_connection_lifetime=self._timeout * 6,
max_connection_pool_size=self._conn_pool_size,
connection_timeout=self._check_timeout,
liveness_check_timeout=self._check_timeout,
) as _sync_driver:
try:
with _sync_driver.session(database=DATABASE) as session:
with _sync_driver.session(database=DATABASE, connection_acquisition_timeout=self._timeout) as session:
try:
session.run("MATCH (n) RETURN n LIMIT 0")
logger.info(f"Connected to {DATABASE} at {URI}")
Expand Down Expand Up @@ -101,7 +117,7 @@ async def index_done_callback(self):
async def has_node(self, node_id: str) -> bool:
entity_name_label = node_id.strip('"')

async with self._driver.session(database=self._DATABASE) as session:
async with self._driver.session(database=self._DATABASE, connection_acquisition_timeout=self._timeout) as session:
query = (
f"MATCH (n:`{entity_name_label}`) RETURN count(n) > 0 AS node_exists"
)
Expand All @@ -116,7 +132,7 @@ async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
entity_name_label_source = source_node_id.strip('"')
entity_name_label_target = target_node_id.strip('"')

async with self._driver.session(database=self._DATABASE) as session:
async with self._driver.session(database=self._DATABASE, connection_acquisition_timeout=self._timeout) as session:
query = (
f"MATCH (a:`{entity_name_label_source}`)-[r]-(b:`{entity_name_label_target}`) "
"RETURN COUNT(r) > 0 AS edgeExists"
Expand All @@ -129,7 +145,7 @@ async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
return single_result["edgeExists"]

async def get_node(self, node_id: str) -> Union[dict, None]:
async with self._driver.session(database=self._DATABASE) as session:
async with self._driver.session(database=self._DATABASE, connection_acquisition_timeout=self._timeout) as session:
entity_name_label = node_id.strip('"')
query = f"MATCH (n:`{entity_name_label}`) RETURN n"
result = await session.run(query)
Expand All @@ -146,7 +162,7 @@ async def get_node(self, node_id: str) -> Union[dict, None]:
async def node_degree(self, node_id: str) -> int:
entity_name_label = node_id.strip('"')

async with self._driver.session(database=self._DATABASE) as session:
async with self._driver.session(database=self._DATABASE, connection_acquisition_timeout=self._timeout) as session:
query = f"""
MATCH (n:`{entity_name_label}`)
RETURN COUNT{{ (n)--() }} AS totalEdgeCount
Expand Down Expand Up @@ -193,7 +209,7 @@ async def get_edge(
Returns:
list: List of all relationships/edges found
"""
async with self._driver.session(database=self._DATABASE) as session:
async with self._driver.session(database=self._DATABASE, connection_acquisition_timeout=self._timeout) as session:
query = f"""
MATCH (start:`{entity_name_label_source}`)-[r]->(end:`{entity_name_label_target}`)
RETURN properties(r) as edge_properties
Expand Down Expand Up @@ -224,7 +240,7 @@ async def get_node_edges(self, source_node_id: str) -> List[Tuple[str, str]]:
query = f"""MATCH (n:`{node_label}`)
OPTIONAL MATCH (n)-[r]-(connected)
RETURN n, r, connected"""
async with self._driver.session(database=self._DATABASE) as session:
async with self._driver.session(database=self._DATABASE, connection_acquisition_timeout=self._timeout) as session:
results = await session.run(query)
edges = []
async for record in results:
Expand Down Expand Up @@ -279,7 +295,7 @@ async def _do_upsert(tx: AsyncManagedTransaction):
)

try:
async with self._driver.session(database=self._DATABASE) as session:
async with self._driver.session(database=self._DATABASE, connection_acquisition_timeout=self._timeout) as session:
await session.execute_write(_do_upsert)
except Exception as e:
logger.error(f"Error during upsert: {str(e)}")
Expand Down Expand Up @@ -326,7 +342,7 @@ async def _do_upsert_edge(tx: AsyncManagedTransaction):
)

try:
async with self._driver.session(database=self._DATABASE) as session:
async with self._driver.session(database=self._DATABASE, connection_acquisition_timeout=self._timeout) as session:
await session.execute_write(_do_upsert_edge)
except Exception as e:
logger.error(f"Error during edge upsert: {str(e)}")
Expand Down
6 changes: 3 additions & 3 deletions lightrag/kg/tidb_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ async def upsert(self, data: dict[str, dict]):
"tokens": item["tokens"],
"chunk_order_index": item["chunk_order_index"],
"full_doc_id": item["full_doc_id"],
"content_vector": f"{item["__vector__"].tolist()}",
"content_vector": f"{item['__vector__'].tolist()}",
"workspace": self.db.workspace,
}
)
Expand Down Expand Up @@ -286,7 +286,7 @@ async def upsert(self, data: dict[str, dict]):
"id": item["id"],
"name": item["entity_name"],
"content": item["content"],
"content_vector": f"{item["content_vector"].tolist()}",
"content_vector": f"{item['content_vector'].tolist()}",
"workspace": self.db.workspace,
}
# update entity_id if node inserted by graph_storage_instance before
Expand All @@ -308,7 +308,7 @@ async def upsert(self, data: dict[str, dict]):
"source_name": item["src_id"],
"target_name": item["tgt_id"],
"content": item["content"],
"content_vector": f"{item["content_vector"].tolist()}",
"content_vector": f"{item['content_vector'].tolist()}",
"workspace": self.db.workspace,
}
# update relation_id if node inserted by graph_storage_instance before
Expand Down
72 changes: 57 additions & 15 deletions lightrag/lightrag.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import asyncio
import os

from lightrag.operate import chunking_by_markdown_header
from tqdm.asyncio import tqdm as tqdm_async
from dataclasses import asdict, dataclass, field
from datetime import datetime
Expand All @@ -12,6 +14,8 @@
)
from .operate import (
chunking_by_token_size,
chunking_by_markdown_header,
chunking_by_markdown_text,
extract_entities,
# local_query,global_query,hybrid_query,
kg_query,
Expand Down Expand Up @@ -43,7 +47,7 @@
JsonDocStatusStorage,
)

from .prompt import GRAPH_FIELD_SEP
from .prompt_cn import GRAPH_FIELD_SEP

# future KG integrations

Expand Down Expand Up @@ -183,13 +187,21 @@ class LightRAG:
addon_params: dict = field(default_factory=dict)
convert_response_to_json_func: callable = convert_response_to_json

# 自定义新增 主实体编号、名称 by bumaple 2024-12-03
extend_entity_title: str = ''
extend_entity_sn: str = ''
# 自定义新增 块类型 by bumaple 2024-12-11
chunk_type: str = 'token_size'
# 自定义新增 块标题层级 by bumaple 2024-12-11
chunk_header_level: int = 2

# Add new field for document status storage type
doc_status_storage: str = field(default="JsonDocStatusStorage")

def __post_init__(self):
log_file = os.path.join("lightrag.log")
set_logger(log_file)
logger.setLevel(self.log_level)
log_file = os.path.join(self.working_dir, "lightrag.log")
set_logger(log_file, self.log_level)
# logger.setLevel(self.log_level)

logger.info(f"Logger initialized for working directory: {self.working_dir}")

Expand Down Expand Up @@ -372,18 +384,48 @@ async def ainsert(self, string_or_strings):
await self.doc_status.upsert({doc_id: doc_status})

# Generate chunks from document
chunks = {
compute_mdhash_id(dp["content"], prefix="chunk-"): {
**dp,
"full_doc_id": doc_id,
if self.chunk_type == "markdown_header":
chunks = {
compute_mdhash_id(dp["content"], prefix="chunk-"): {
**dp,
"full_doc_id": doc_id,
}
for dp in chunking_by_markdown_header(
doc["content"],
overlap_token_size=self.chunk_overlap_token_size,
max_token_size=self.chunk_token_size,
extend_entity_title=self.extend_entity_title,
extend_entity_sn=self.extend_entity_sn,
chunk_header_level=self.chunk_header_level,
)
}
elif self.chunk_type == "markdown_text":
chunks = {
compute_mdhash_id(dp["content"], prefix="chunk-"): {
**dp,
"full_doc_id": doc_id,
}
for dp in chunking_by_markdown_text(
doc["content"],
overlap_token_size=self.chunk_overlap_token_size,
max_token_size=self.chunk_token_size,
extend_entity_title=self.extend_entity_title,
extend_entity_sn=self.extend_entity_sn,
)
}
else:
chunks = {
compute_mdhash_id(dp["content"], prefix="chunk-"): {
**dp,
"full_doc_id": doc_id,
}
for dp in chunking_by_token_size(
doc["content"],
overlap_token_size=self.chunk_overlap_token_size,
max_token_size=self.chunk_token_size,
tiktoken_model=self.tiktoken_model_name,
)
}
for dp in chunking_by_token_size(
doc["content"],
overlap_token_size=self.chunk_overlap_token_size,
max_token_size=self.chunk_token_size,
tiktoken_model=self.tiktoken_model_name,
)
}

# Update status with chunks information
doc_status.update(
Expand Down
6 changes: 4 additions & 2 deletions lightrag/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
RateLimitError,
APITimeoutError,
AsyncAzureOpenAI,
BadRequestError
)
from pydantic import BaseModel, Field
from tenacity import (
Expand Down Expand Up @@ -48,7 +49,7 @@
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(
(RateLimitError, APIConnectionError, APITimeoutError)
(RateLimitError, APIConnectionError, APITimeoutError, BadRequestError)
),
)
async def openai_complete_if_cache(
Expand Down Expand Up @@ -893,6 +894,7 @@ async def openai_embedding(
model: str = "text-embedding-3-small",
base_url: str = None,
api_key: str = None,
timeout: float = 60,
) -> np.ndarray:
if api_key:
os.environ["OPENAI_API_KEY"] = api_key
Expand All @@ -901,7 +903,7 @@ async def openai_embedding(
AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url)
)
response = await openai_async_client.embeddings.create(
model=model, input=texts, encoding_format="float"
model=model, input=texts, encoding_format="float", timeout=timeout
)
return np.array([dp.embedding for dp in response.data])

Expand Down
Loading