diff --git a/.github/workflows/linting.yaml b/.github/workflows/linting.yaml
new file mode 100644
index 00000000..7c12e0a2
--- /dev/null
+++ b/.github/workflows/linting.yaml
@@ -0,0 +1,30 @@
+name: Linting and Formatting
+
+on:
+ push:
+ branches:
+ - main
+ pull_request:
+ branches:
+ - main
+
+jobs:
+ lint-and-format:
+ runs-on: ubuntu-latest
+
+ steps:
+ - name: Checkout code
+ uses: actions/checkout@v2
+
+ - name: Set up Python
+ uses: actions/setup-python@v2
+ with:
+ python-version: '3.x'
+
+ - name: Install dependencies
+ run: |
+ python -m pip install --upgrade pip
+ pip install pre-commit
+
+ - name: Run pre-commit
+ run: pre-commit run --all-files
diff --git a/.gitignore b/.gitignore
index 39fa6515..def738b2 100644
--- a/.gitignore
+++ b/.gitignore
@@ -8,4 +8,5 @@ dist/
env/
local_neo4jWorkDir/
neo4jWorkDir/
-ignore_this.txt
\ No newline at end of file
+ignore_this.txt
+.venv/
diff --git a/README.md b/README.md
index 33abb13b..96df79ba 100644
--- a/README.md
+++ b/README.md
@@ -8,7 +8,7 @@
-
+
@@ -22,11 +22,17 @@ This repository hosts the code of LightRAG. The structure of this code is based
## π News
-- [x] [2024.10.20]π―π―π’π’Weβve added a new feature to LightRAG: Graph Visualization.
-- [x] [2024.10.18]π―π―π’π’Weβve added a link to a [LightRAG Introduction Video](https://youtu.be/oageL-1I0GE). Thanks to the author!
-- [x] [2024.10.17]π―π―π’π’We have created a [Discord channel](https://discord.gg/mvsfu2Tg)! Welcome to join for sharing and discussions! ππ
-- [x] [2024.10.16]π―π―π’π’LightRAG now supports [Ollama models](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#quick-start)!
-- [x] [2024.10.15]π―π―π’π’LightRAG now supports [Hugging Face models](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#quick-start)!
+- [x] [2024.10.29]π―π’LightRAG now supports multiple file types, including PDF, DOC, PPT, and CSV via `textract`.
+- [x] [2024.10.20]π―π’Weβve added a new feature to LightRAG: Graph Visualization.
+- [x] [2024.10.18]π―π’Weβve added a link to a [LightRAG Introduction Video](https://youtu.be/oageL-1I0GE). Thanks to the author!
+- [x] [2024.10.17]π―π’We have created a [Discord channel](https://discord.gg/mvsfu2Tg)! Welcome to join for sharing and discussions! ππ
+- [x] [2024.10.16]π―π’LightRAG now supports [Ollama models](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#quick-start)!
+- [x] [2024.10.15]π―π’LightRAG now supports [Hugging Face models](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#quick-start)!
+
+## Algorithm Flowchart
+
+![LightRAG_Self excalidraw](https://github.com/user-attachments/assets/aa5c4892-2e44-49e6-a116-2403ed80a1a3)
+
## Install
@@ -58,8 +64,8 @@ from lightrag.llm import gpt_4o_mini_complete, gpt_4o_complete
#########
# Uncomment the below two lines if running in a jupyter notebook to handle the async nature of rag.insert()
-# import nest_asyncio
-# nest_asyncio.apply()
+# import nest_asyncio
+# nest_asyncio.apply()
#########
WORKING_DIR = "./dickens"
@@ -190,8 +196,11 @@ see test_neo4j.py for a working example.
Using Ollama Models
-
-* If you want to use Ollama models, you only need to set LightRAG as follows:
+
+### Overview
+If you want to use Ollama models, you need to pull model you plan to use and embedding model, for example `nomic-embed-text`.
+
+Then you only need to set LightRAG as follows:
```python
from lightrag.llm import ollama_model_complete, ollama_embedding
@@ -213,28 +222,59 @@ rag = LightRAG(
)
```
-* Increasing the `num_ctx` parameter:
+### Increasing context size
+In order for LightRAG to work context should be at least 32k tokens. By default Ollama models have context size of 8k. You can achieve this using one of two ways:
+
+#### Increasing the `num_ctx` parameter in Modelfile.
1. Pull the model:
-```python
+```bash
ollama pull qwen2
```
2. Display the model file:
-```python
+```bash
ollama show --modelfile qwen2 > Modelfile
```
3. Edit the Modelfile by adding the following line:
-```python
+```bash
PARAMETER num_ctx 32768
```
4. Create the modified model:
-```python
+```bash
ollama create -f Modelfile qwen2m
```
+#### Setup `num_ctx` via Ollama API.
+Tiy can use `llm_model_kwargs` param to configure ollama:
+
+```python
+rag = LightRAG(
+ working_dir=WORKING_DIR,
+ llm_model_func=ollama_model_complete, # Use Ollama model for text generation
+ llm_model_name='your_model_name', # Your model name
+ llm_model_kwargs={"options": {"num_ctx": 32768}},
+ # Use Ollama embedding function
+ embedding_func=EmbeddingFunc(
+ embedding_dim=768,
+ max_token_size=8192,
+ func=lambda texts: ollama_embedding(
+ texts,
+ embed_model="nomic-embed-text"
+ )
+ ),
+)
+```
+#### Fully functional example
+
+There fully functional example `examples/lightrag_ollama_demo.py` that utilizes `gemma2:2b` model, runs only 4 requests in parallel and set context size to 32k.
+
+#### Low RAM GPUs
+
+In order to run this experiment on low RAM GPU you should select small model and tune context window (increasing context increase memory consumption). For example, running this ollama example on repurposed mining GPU with 6Gb of RAM required to set context size to 26k while using `gemma2:2b`. It was able to find 197 entities and 19 relations on `book.txt`.
+
### Query Param
@@ -265,12 +305,33 @@ rag.insert(["TEXT1", "TEXT2",...])
```python
# Incremental Insert: Insert new documents into an existing LightRAG instance
-rag = LightRAG(working_dir="./dickens")
+rag = LightRAG(
+ working_dir=WORKING_DIR,
+ llm_model_func=llm_model_func,
+ embedding_func=EmbeddingFunc(
+ embedding_dim=embedding_dimension,
+ max_token_size=8192,
+ func=embedding_func,
+ ),
+)
with open("./newText.txt") as f:
rag.insert(f.read())
```
+### Multi-file Type Support
+
+The `testract` supports reading file types such as TXT, DOCX, PPTX, CSV, and PDF.
+
+```python
+import textract
+
+file_path = 'TEXT.pdf'
+text_content = textract.process(file_path)
+
+rag.insert(text_content.decode('utf-8'))
+```
+
### Graph Visualization
@@ -361,8 +422,8 @@ def main():
SET e.entity_type = node.entity_type,
e.description = node.description,
e.source_id = node.source_id,
- e.displayName = node.id
- REMOVE e:Entity
+ e.displayName = node.id
+ REMOVE e:Entity
WITH e, node
CALL apoc.create.addLabels(e, [node.entity_type]) YIELD node AS labeledNode
RETURN count(*)
@@ -415,7 +476,7 @@ def main():
except Exception as e:
print(f"Error occurred: {e}")
-
+
finally:
driver.close()
@@ -425,6 +486,125 @@ if __name__ == "__main__":
+## API Server Implementation
+
+LightRAG also provides a FastAPI-based server implementation for RESTful API access to RAG operations. This allows you to run LightRAG as a service and interact with it through HTTP requests.
+
+### Setting up the API Server
+
+Click to expand setup instructions
+
+1. First, ensure you have the required dependencies:
+```bash
+pip install fastapi uvicorn pydantic
+```
+
+2. Set up your environment variables:
+```bash
+export RAG_DIR="your_index_directory" # Optional: Defaults to "index_default"
+```
+
+3. Run the API server:
+```bash
+python examples/lightrag_api_openai_compatible_demo.py
+```
+
+The server will start on `http://0.0.0.0:8020`.
+
+
+### API Endpoints
+
+The API server provides the following endpoints:
+
+#### 1. Query Endpoint
+
+Click to view Query endpoint details
+
+- **URL:** `/query`
+- **Method:** POST
+- **Body:**
+```json
+{
+ "query": "Your question here",
+ "mode": "hybrid" // Can be "naive", "local", "global", or "hybrid"
+}
+```
+- **Example:**
+```bash
+curl -X POST "http://127.0.0.1:8020/query" \
+ -H "Content-Type: application/json" \
+ -d '{"query": "What are the main themes?", "mode": "hybrid"}'
+```
+
+
+#### 2. Insert Text Endpoint
+
+Click to view Insert Text endpoint details
+
+- **URL:** `/insert`
+- **Method:** POST
+- **Body:**
+```json
+{
+ "text": "Your text content here"
+}
+```
+- **Example:**
+```bash
+curl -X POST "http://127.0.0.1:8020/insert" \
+ -H "Content-Type: application/json" \
+ -d '{"text": "Content to be inserted into RAG"}'
+```
+
+
+#### 3. Insert File Endpoint
+
+Click to view Insert File endpoint details
+
+- **URL:** `/insert_file`
+- **Method:** POST
+- **Body:**
+```json
+{
+ "file_path": "path/to/your/file.txt"
+}
+```
+- **Example:**
+```bash
+curl -X POST "http://127.0.0.1:8020/insert_file" \
+ -H "Content-Type: application/json" \
+ -d '{"file_path": "./book.txt"}'
+```
+
+
+#### 4. Health Check Endpoint
+
+Click to view Health Check endpoint details
+
+- **URL:** `/health`
+- **Method:** GET
+- **Example:**
+```bash
+curl -X GET "http://127.0.0.1:8020/health"
+```
+
+
+### Configuration
+
+The API server can be configured using environment variables:
+- `RAG_DIR`: Directory for storing the RAG index (default: "index_default")
+- API keys and base URLs should be configured in the code for your specific LLM and embedding model providers
+
+### Error Handling
+
+Click to view error handling details
+
+The API includes comprehensive error handling:
+- File not found errors (404)
+- Processing errors (500)
+- Supports multiple file encodings (UTF-8 and GBK)
+
+
## Evaluation
### Dataset
The dataset used in LightRAG can be downloaded from [TommyChien/UltraDomain](https://huggingface.co/datasets/TommyChien/UltraDomain).
@@ -671,12 +851,14 @@ def extract_queries(file_path):
.
βββ examples
β βββ batch_eval.py
+β βββ generate_query.py
β βββ graph_visual_with_html.py
β βββ graph_visual_with_neo4j.py
-β βββ generate_query.py
+β βββ lightrag_api_openai_compatible_demo.py
β βββ lightrag_azure_openai_demo.py
β βββ lightrag_bedrock_demo.py
β βββ lightrag_hf_demo.py
+β βββ lightrag_lmdeploy_demo.py
β βββ lightrag_ollama_demo.py
β βββ lightrag_openai_compatible_demo.py
β βββ lightrag_openai_demo.py
@@ -693,8 +875,10 @@ def extract_queries(file_path):
β βββ utils.py
βββ reproduce
β βββ Step_0.py
+β βββ Step_1_openai_compatible.py
β βββ Step_1.py
β βββ Step_2.py
+β βββ Step_3_openai_compatible.py
β βββ Step_3.py
βββ .gitignore
βββ .pre-commit-config.yaml
@@ -726,3 +910,6 @@ archivePrefix={arXiv},
primaryClass={cs.IR}
}
```
+
+
+
diff --git a/examples/graph_visual_with_html.py b/examples/graph_visual_with_html.py
index b455e6de..11279b3a 100644
--- a/examples/graph_visual_with_html.py
+++ b/examples/graph_visual_with_html.py
@@ -3,17 +3,17 @@
import random
# Load the GraphML file
-G = nx.read_graphml('./dickens/graph_chunk_entity_relation.graphml')
+G = nx.read_graphml("./dickens/graph_chunk_entity_relation.graphml")
# Create a Pyvis network
-net = Network(notebook=True)
+net = Network(height="100vh", notebook=True)
# Convert NetworkX graph to Pyvis network
net.from_nx(G)
# Add colors to nodes
for node in net.nodes:
- node['color'] = "#{:06x}".format(random.randint(0, 0xFFFFFF))
+ node["color"] = "#{:06x}".format(random.randint(0, 0xFFFFFF))
# Save and display the network
-net.show('knowledge_graph.html')
\ No newline at end of file
+net.show("knowledge_graph.html")
diff --git a/examples/graph_visual_with_neo4j.py b/examples/graph_visual_with_neo4j.py
index 22dde368..7377f21c 100644
--- a/examples/graph_visual_with_neo4j.py
+++ b/examples/graph_visual_with_neo4j.py
@@ -13,6 +13,7 @@
NEO4J_USERNAME = "neo4j"
NEO4J_PASSWORD = "your_password"
+
def convert_xml_to_json(xml_path, output_path):
"""Converts XML file to JSON and saves the output."""
if not os.path.exists(xml_path):
@@ -21,7 +22,7 @@ def convert_xml_to_json(xml_path, output_path):
json_data = xml_to_json(xml_path)
if json_data:
- with open(output_path, 'w', encoding='utf-8') as f:
+ with open(output_path, "w", encoding="utf-8") as f:
json.dump(json_data, f, ensure_ascii=False, indent=2)
print(f"JSON file created: {output_path}")
return json_data
@@ -29,16 +30,18 @@ def convert_xml_to_json(xml_path, output_path):
print("Failed to create JSON data")
return None
+
def process_in_batches(tx, query, data, batch_size):
"""Process data in batches and execute the given query."""
for i in range(0, len(data), batch_size):
- batch = data[i:i + batch_size]
+ batch = data[i : i + batch_size]
tx.run(query, {"nodes": batch} if "nodes" in query else {"edges": batch})
+
def main():
# Paths
- xml_file = os.path.join(WORKING_DIR, 'graph_chunk_entity_relation.graphml')
- json_file = os.path.join(WORKING_DIR, 'graph_data.json')
+ xml_file = os.path.join(WORKING_DIR, "graph_chunk_entity_relation.graphml")
+ json_file = os.path.join(WORKING_DIR, "graph_data.json")
# Convert XML to JSON
json_data = convert_xml_to_json(xml_file, json_file)
@@ -46,8 +49,8 @@ def main():
return
# Load nodes and edges
- nodes = json_data.get('nodes', [])
- edges = json_data.get('edges', [])
+ nodes = json_data.get("nodes", [])
+ edges = json_data.get("edges", [])
# Neo4j queries
create_nodes_query = """
@@ -56,8 +59,8 @@ def main():
SET e.entity_type = node.entity_type,
e.description = node.description,
e.source_id = node.source_id,
- e.displayName = node.id
- REMOVE e:Entity
+ e.displayName = node.id
+ REMOVE e:Entity
WITH e, node
CALL apoc.create.addLabels(e, [node.entity_type]) YIELD node AS labeledNode
RETURN count(*)
@@ -100,19 +103,24 @@ def main():
# Execute queries in batches
with driver.session() as session:
# Insert nodes in batches
- session.execute_write(process_in_batches, create_nodes_query, nodes, BATCH_SIZE_NODES)
+ session.execute_write(
+ process_in_batches, create_nodes_query, nodes, BATCH_SIZE_NODES
+ )
# Insert edges in batches
- session.execute_write(process_in_batches, create_edges_query, edges, BATCH_SIZE_EDGES)
+ session.execute_write(
+ process_in_batches, create_edges_query, edges, BATCH_SIZE_EDGES
+ )
# Set displayName and labels
session.run(set_displayname_and_labels_query)
except Exception as e:
print(f"Error occurred: {e}")
-
+
finally:
driver.close()
+
if __name__ == "__main__":
main()
diff --git a/examples/lightrag_api_openai_compatible_demo.py b/examples/lightrag_api_openai_compatible_demo.py
new file mode 100644
index 00000000..2cd262bb
--- /dev/null
+++ b/examples/lightrag_api_openai_compatible_demo.py
@@ -0,0 +1,164 @@
+from fastapi import FastAPI, HTTPException
+from pydantic import BaseModel
+import os
+from lightrag import LightRAG, QueryParam
+from lightrag.llm import openai_complete_if_cache, openai_embedding
+from lightrag.utils import EmbeddingFunc
+import numpy as np
+from typing import Optional
+import asyncio
+import nest_asyncio
+
+# Apply nest_asyncio to solve event loop issues
+nest_asyncio.apply()
+
+DEFAULT_RAG_DIR = "index_default"
+app = FastAPI(title="LightRAG API", description="API for RAG operations")
+
+# Configure working directory
+WORKING_DIR = os.environ.get("RAG_DIR", f"{DEFAULT_RAG_DIR}")
+print(f"WORKING_DIR: {WORKING_DIR}")
+if not os.path.exists(WORKING_DIR):
+ os.mkdir(WORKING_DIR)
+
+# LLM model function
+
+
+async def llm_model_func(
+ prompt, system_prompt=None, history_messages=[], **kwargs
+) -> str:
+ return await openai_complete_if_cache(
+ "gpt-4o-mini",
+ prompt,
+ system_prompt=system_prompt,
+ history_messages=history_messages,
+ api_key="YOUR_API_KEY",
+ base_url="YourURL/v1",
+ **kwargs,
+ )
+
+
+# Embedding function
+
+
+async def embedding_func(texts: list[str]) -> np.ndarray:
+ return await openai_embedding(
+ texts,
+ model="text-embedding-3-large",
+ api_key="YOUR_API_KEY",
+ base_url="YourURL/v1",
+ )
+
+
+# Initialize RAG instance
+rag = LightRAG(
+ working_dir=WORKING_DIR,
+ llm_model_func=llm_model_func,
+ embedding_func=EmbeddingFunc(
+ embedding_dim=3072, max_token_size=8192, func=embedding_func
+ ),
+)
+
+# Data models
+
+
+class QueryRequest(BaseModel):
+ query: str
+ mode: str = "hybrid"
+
+
+class InsertRequest(BaseModel):
+ text: str
+
+
+class InsertFileRequest(BaseModel):
+ file_path: str
+
+
+class Response(BaseModel):
+ status: str
+ data: Optional[str] = None
+ message: Optional[str] = None
+
+
+# API routes
+
+
+@app.post("/query", response_model=Response)
+async def query_endpoint(request: QueryRequest):
+ try:
+ loop = asyncio.get_event_loop()
+ result = await loop.run_in_executor(
+ None, lambda: rag.query(request.query, param=QueryParam(mode=request.mode))
+ )
+ return Response(status="success", data=result)
+ except Exception as e:
+ raise HTTPException(status_code=500, detail=str(e))
+
+
+@app.post("/insert", response_model=Response)
+async def insert_endpoint(request: InsertRequest):
+ try:
+ loop = asyncio.get_event_loop()
+ await loop.run_in_executor(None, lambda: rag.insert(request.text))
+ return Response(status="success", message="Text inserted successfully")
+ except Exception as e:
+ raise HTTPException(status_code=500, detail=str(e))
+
+
+@app.post("/insert_file", response_model=Response)
+async def insert_file(request: InsertFileRequest):
+ try:
+ # Check if file exists
+ if not os.path.exists(request.file_path):
+ raise HTTPException(
+ status_code=404, detail=f"File not found: {request.file_path}"
+ )
+
+ # Read file content
+ try:
+ with open(request.file_path, "r", encoding="utf-8") as f:
+ content = f.read()
+ except UnicodeDecodeError:
+ # If UTF-8 decoding fails, try other encodings
+ with open(request.file_path, "r", encoding="gbk") as f:
+ content = f.read()
+
+ # Insert file content
+ loop = asyncio.get_event_loop()
+ await loop.run_in_executor(None, lambda: rag.insert(content))
+
+ return Response(
+ status="success",
+ message=f"File content from {request.file_path} inserted successfully",
+ )
+ except Exception as e:
+ raise HTTPException(status_code=500, detail=str(e))
+
+
+@app.get("/health")
+async def health_check():
+ return {"status": "healthy"}
+
+
+if __name__ == "__main__":
+ import uvicorn
+
+ uvicorn.run(app, host="0.0.0.0", port=8020)
+
+# Usage example
+# To run the server, use the following command in your terminal:
+# python lightrag_api_openai_compatible_demo.py
+
+# Example requests:
+# 1. Query:
+# curl -X POST "http://127.0.0.1:8020/query" -H "Content-Type: application/json" -d '{"query": "your query here", "mode": "hybrid"}'
+
+# 2. Insert text:
+# curl -X POST "http://127.0.0.1:8020/insert" -H "Content-Type: application/json" -d '{"text": "your text here"}'
+
+# 3. Insert file:
+# curl -X POST "http://127.0.0.1:8020/insert_file" -H "Content-Type: application/json" -d '{"file_path": "path/to/your/file.txt"}'
+
+# 4. Health check:
+# curl -X GET "http://127.0.0.1:8020/health"
diff --git a/examples/lightrag_lmdeploy_demo.py b/examples/lightrag_lmdeploy_demo.py
new file mode 100644
index 00000000..aeb96f71
--- /dev/null
+++ b/examples/lightrag_lmdeploy_demo.py
@@ -0,0 +1,75 @@
+import os
+
+from lightrag import LightRAG, QueryParam
+from lightrag.llm import lmdeploy_model_if_cache, hf_embedding
+from lightrag.utils import EmbeddingFunc
+from transformers import AutoModel, AutoTokenizer
+
+WORKING_DIR = "./dickens"
+
+if not os.path.exists(WORKING_DIR):
+ os.mkdir(WORKING_DIR)
+
+
+async def lmdeploy_model_complete(
+ prompt=None, system_prompt=None, history_messages=[], **kwargs
+) -> str:
+ model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
+ return await lmdeploy_model_if_cache(
+ model_name,
+ prompt,
+ system_prompt=system_prompt,
+ history_messages=history_messages,
+ ## please specify chat_template if your local path does not follow original HF file name,
+ ## or model_name is a pytorch model on huggingface.co,
+ ## you can refer to https://github.com/InternLM/lmdeploy/blob/main/lmdeploy/model.py
+ ## for a list of chat_template available in lmdeploy.
+ chat_template="llama3",
+ # model_format ='awq', # if you are using awq quantization model.
+ # quant_policy=8, # if you want to use online kv cache, 4=kv int4, 8=kv int8.
+ **kwargs,
+ )
+
+
+rag = LightRAG(
+ working_dir=WORKING_DIR,
+ llm_model_func=lmdeploy_model_complete,
+ llm_model_name="meta-llama/Llama-3.1-8B-Instruct", # please use definite path for local model
+ embedding_func=EmbeddingFunc(
+ embedding_dim=384,
+ max_token_size=5000,
+ func=lambda texts: hf_embedding(
+ texts,
+ tokenizer=AutoTokenizer.from_pretrained(
+ "sentence-transformers/all-MiniLM-L6-v2"
+ ),
+ embed_model=AutoModel.from_pretrained(
+ "sentence-transformers/all-MiniLM-L6-v2"
+ ),
+ ),
+ ),
+)
+
+
+with open("./book.txt", "r", encoding="utf-8") as f:
+ rag.insert(f.read())
+
+# Perform naive search
+print(
+ rag.query("What are the top themes in this story?", param=QueryParam(mode="naive"))
+)
+
+# Perform local search
+print(
+ rag.query("What are the top themes in this story?", param=QueryParam(mode="local"))
+)
+
+# Perform global search
+print(
+ rag.query("What are the top themes in this story?", param=QueryParam(mode="global"))
+)
+
+# Perform hybrid search
+print(
+ rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))
+)
diff --git a/examples/lightrag_ollama_demo.py b/examples/lightrag_ollama_demo.py
index 98f1521c..1a320d13 100644
--- a/examples/lightrag_ollama_demo.py
+++ b/examples/lightrag_ollama_demo.py
@@ -1,26 +1,32 @@
import os
-
+import logging
from lightrag import LightRAG, QueryParam
from lightrag.llm import ollama_model_complete, ollama_embedding
from lightrag.utils import EmbeddingFunc
WORKING_DIR = "./dickens"
+logging.basicConfig(format="%(levelname)s:%(message)s", level=logging.INFO)
+
if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR)
rag = LightRAG(
working_dir=WORKING_DIR,
llm_model_func=ollama_model_complete,
- llm_model_name="your_model_name",
+ llm_model_name="gemma2:2b",
+ llm_model_max_async=4,
+ llm_model_max_token_size=32768,
+ llm_model_kwargs={"host": "http://localhost:11434", "options": {"num_ctx": 32768}},
embedding_func=EmbeddingFunc(
embedding_dim=768,
max_token_size=8192,
- func=lambda texts: ollama_embedding(texts, embed_model="nomic-embed-text"),
+ func=lambda texts: ollama_embedding(
+ texts, embed_model="nomic-embed-text", host="http://localhost:11434"
+ ),
),
)
-
with open("./book.txt", "r", encoding="utf-8") as f:
rag.insert(f.read())
diff --git a/examples/lightrag_openai_compatible_demo.py b/examples/lightrag_openai_compatible_demo.py
index aae56821..1422e2c2 100644
--- a/examples/lightrag_openai_compatible_demo.py
+++ b/examples/lightrag_openai_compatible_demo.py
@@ -34,6 +34,13 @@ async def embedding_func(texts: list[str]) -> np.ndarray:
)
+async def get_embedding_dim():
+ test_text = ["This is a test sentence."]
+ embedding = await embedding_func(test_text)
+ embedding_dim = embedding.shape[1]
+ return embedding_dim
+
+
# function test
async def test_funcs():
result = await llm_model_func("How are you?")
@@ -43,37 +50,59 @@ async def test_funcs():
print("embedding_func: ", result)
-asyncio.run(test_funcs())
-
-
-rag = LightRAG(
- working_dir=WORKING_DIR,
- llm_model_func=llm_model_func,
- embedding_func=EmbeddingFunc(
- embedding_dim=4096, max_token_size=8192, func=embedding_func
- ),
-)
-
-
-with open("./book.txt", "r", encoding="utf-8") as f:
- rag.insert(f.read())
-
-# Perform naive search
-print(
- rag.query("What are the top themes in this story?", param=QueryParam(mode="naive"))
-)
-
-# Perform local search
-print(
- rag.query("What are the top themes in this story?", param=QueryParam(mode="local"))
-)
-
-# Perform global search
-print(
- rag.query("What are the top themes in this story?", param=QueryParam(mode="global"))
-)
-
-# Perform hybrid search
-print(
- rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))
-)
+# asyncio.run(test_funcs())
+
+
+async def main():
+ try:
+ embedding_dimension = await get_embedding_dim()
+ print(f"Detected embedding dimension: {embedding_dimension}")
+
+ rag = LightRAG(
+ working_dir=WORKING_DIR,
+ llm_model_func=llm_model_func,
+ embedding_func=EmbeddingFunc(
+ embedding_dim=embedding_dimension,
+ max_token_size=8192,
+ func=embedding_func,
+ ),
+ )
+
+ with open("./book.txt", "r", encoding="utf-8") as f:
+ await rag.ainsert(f.read())
+
+ # Perform naive search
+ print(
+ await rag.aquery(
+ "What are the top themes in this story?", param=QueryParam(mode="naive")
+ )
+ )
+
+ # Perform local search
+ print(
+ await rag.aquery(
+ "What are the top themes in this story?", param=QueryParam(mode="local")
+ )
+ )
+
+ # Perform global search
+ print(
+ await rag.aquery(
+ "What are the top themes in this story?",
+ param=QueryParam(mode="global"),
+ )
+ )
+
+ # Perform hybrid search
+ print(
+ await rag.aquery(
+ "What are the top themes in this story?",
+ param=QueryParam(mode="hybrid"),
+ )
+ )
+ except Exception as e:
+ print(f"An error occurred: {e}")
+
+
+if __name__ == "__main__":
+ asyncio.run(main())
diff --git a/examples/lightrag_siliconcloud_demo.py b/examples/lightrag_siliconcloud_demo.py
index 82cab228..a73f16c5 100644
--- a/examples/lightrag_siliconcloud_demo.py
+++ b/examples/lightrag_siliconcloud_demo.py
@@ -30,7 +30,7 @@ async def embedding_func(texts: list[str]) -> np.ndarray:
texts,
model="netease-youdao/bce-embedding-base_v1",
api_key=os.getenv("SILICONFLOW_API_KEY"),
- max_token_size=512
+ max_token_size=512,
)
diff --git a/examples/vram_management_demo.py b/examples/vram_management_demo.py
index ec750254..c173b913 100644
--- a/examples/vram_management_demo.py
+++ b/examples/vram_management_demo.py
@@ -27,11 +27,12 @@
# Read all .txt files from the TEXT_FILES_DIR directory
texts = []
for filename in os.listdir(TEXT_FILES_DIR):
- if filename.endswith('.txt'):
+ if filename.endswith(".txt"):
file_path = os.path.join(TEXT_FILES_DIR, filename)
- with open(file_path, 'r', encoding='utf-8') as file:
+ with open(file_path, "r", encoding="utf-8") as file:
texts.append(file.read())
+
# Batch insert texts into LightRAG with a retry mechanism
def insert_texts_with_retry(rag, texts, retries=3, delay=5):
for _ in range(retries):
@@ -39,37 +40,58 @@ def insert_texts_with_retry(rag, texts, retries=3, delay=5):
rag.insert(texts)
return
except Exception as e:
- print(f"Error occurred during insertion: {e}. Retrying in {delay} seconds...")
+ print(
+ f"Error occurred during insertion: {e}. Retrying in {delay} seconds..."
+ )
time.sleep(delay)
raise RuntimeError("Failed to insert texts after multiple retries.")
+
insert_texts_with_retry(rag, texts)
# Perform different types of queries and handle potential errors
try:
- print(rag.query("What are the top themes in this story?", param=QueryParam(mode="naive")))
+ print(
+ rag.query(
+ "What are the top themes in this story?", param=QueryParam(mode="naive")
+ )
+ )
except Exception as e:
print(f"Error performing naive search: {e}")
try:
- print(rag.query("What are the top themes in this story?", param=QueryParam(mode="local")))
+ print(
+ rag.query(
+ "What are the top themes in this story?", param=QueryParam(mode="local")
+ )
+ )
except Exception as e:
print(f"Error performing local search: {e}")
try:
- print(rag.query("What are the top themes in this story?", param=QueryParam(mode="global")))
+ print(
+ rag.query(
+ "What are the top themes in this story?", param=QueryParam(mode="global")
+ )
+ )
except Exception as e:
print(f"Error performing global search: {e}")
try:
- print(rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid")))
+ print(
+ rag.query(
+ "What are the top themes in this story?", param=QueryParam(mode="hybrid")
+ )
+ )
except Exception as e:
print(f"Error performing hybrid search: {e}")
+
# Function to clear VRAM resources
def clear_vram():
os.system("sudo nvidia-smi --gpu-reset")
+
# Regularly clear VRAM to prevent overflow
clear_vram_interval = 3600 # Clear once every hour
start_time = time.time()
diff --git a/lightrag/__init__.py b/lightrag/__init__.py
index db81e005..8e76a260 100644
--- a/lightrag/__init__.py
+++ b/lightrag/__init__.py
@@ -1,5 +1,5 @@
from .lightrag import LightRAG as LightRAG, QueryParam as QueryParam
-__version__ = "0.0.7"
+__version__ = "0.0.8"
__author__ = "Zirui Guo"
__url__ = "https://github.com/HKUDS/LightRAG"
diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py
index e3e7cce1..a42b806e 100644
--- a/lightrag/lightrag.py
+++ b/lightrag/lightrag.py
@@ -109,6 +109,7 @@ class LightRAG:
llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct" #'meta-llama/Llama-3.2-1B'#'google/gemma-2-2b-it'
llm_model_max_token_size: int = 32768
llm_model_max_async: int = 16
+ llm_model_kwargs: dict = field(default_factory=dict)
# storage
key_string_value_json_storage_cls: Type[BaseKVStorage] = JsonKVStorage
@@ -179,7 +180,11 @@ def __post_init__(self):
)
self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
- partial(self.llm_model_func, hashing_kv=self.llm_response_cache)
+ partial(
+ self.llm_model_func,
+ hashing_kv=self.llm_response_cache,
+ **self.llm_model_kwargs,
+ )
)
def _get_storage_class(self) -> Type[BaseGraphStorage]:
return {
@@ -239,7 +244,7 @@ async def ainsert(self, string_or_strings):
logger.info("[Entity Extraction]...")
maybe_new_kg = await extract_entities(
inserting_chunks,
- knwoledge_graph_inst=self.chunk_entity_relation_graph,
+ knowledge_graph_inst=self.chunk_entity_relation_graph,
entity_vdb=self.entities_vdb,
relationships_vdb=self.relationships_vdb,
global_config=asdict(self),
diff --git a/lightrag/llm.py b/lightrag/llm.py
index 4dcf535c..f4045e80 100644
--- a/lightrag/llm.py
+++ b/lightrag/llm.py
@@ -7,7 +7,13 @@
import numpy as np
import ollama
-from openai import AsyncOpenAI, APIConnectionError, RateLimitError, Timeout, AsyncAzureOpenAI
+from openai import (
+ AsyncOpenAI,
+ APIConnectionError,
+ RateLimitError,
+ Timeout,
+ AsyncAzureOpenAI,
+)
import base64
import struct
@@ -70,26 +76,31 @@ async def openai_complete_if_cache(
)
return response.choices[0].message.content
+
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
)
-async def azure_openai_complete_if_cache(model,
+async def azure_openai_complete_if_cache(
+ model,
prompt,
system_prompt=None,
history_messages=[],
base_url=None,
api_key=None,
- **kwargs):
+ **kwargs,
+):
if api_key:
os.environ["AZURE_OPENAI_API_KEY"] = api_key
if base_url:
os.environ["AZURE_OPENAI_ENDPOINT"] = base_url
- openai_async_client = AsyncAzureOpenAI(azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
- api_key=os.getenv("AZURE_OPENAI_API_KEY"),
- api_version=os.getenv("AZURE_OPENAI_API_VERSION"))
+ openai_async_client = AsyncAzureOpenAI(
+ azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
+ api_key=os.getenv("AZURE_OPENAI_API_KEY"),
+ api_version=os.getenv("AZURE_OPENAI_API_VERSION"),
+ )
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
messages = []
@@ -114,6 +125,7 @@ async def azure_openai_complete_if_cache(model,
)
return response.choices[0].message.content
+
class BedrockError(Exception):
"""Generic error for issues related to Amazon Bedrock"""
@@ -205,8 +217,12 @@ async def bedrock_complete_if_cache(
@lru_cache(maxsize=1)
def initialize_hf_model(model_name):
- hf_tokenizer = AutoTokenizer.from_pretrained(model_name, device_map="auto", trust_remote_code=True)
- hf_model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", trust_remote_code=True)
+ hf_tokenizer = AutoTokenizer.from_pretrained(
+ model_name, device_map="auto", trust_remote_code=True
+ )
+ hf_model = AutoModelForCausalLM.from_pretrained(
+ model_name, device_map="auto", trust_remote_code=True
+ )
if hf_tokenizer.pad_token is None:
hf_tokenizer.pad_token = hf_tokenizer.eos_token
@@ -266,10 +282,13 @@ async def hf_model_if_cache(
input_ids = hf_tokenizer(
input_prompt, return_tensors="pt", padding=True, truncation=True
).to("cuda")
+ inputs = {k: v.to(hf_model.device) for k, v in input_ids.items()}
output = hf_model.generate(
- **input_ids, max_new_tokens=200, num_return_sequences=1, early_stopping=True
+ **input_ids, max_new_tokens=512, num_return_sequences=1, early_stopping=True
+ )
+ response_text = hf_tokenizer.decode(
+ output[0][len(inputs["input_ids"][0]) :], skip_special_tokens=True
)
- response_text = hf_tokenizer.decode(output[0], skip_special_tokens=True)
if hashing_kv is not None:
await hashing_kv.upsert({args_hash: {"return": response_text, "model": model}})
return response_text
@@ -280,8 +299,10 @@ async def ollama_model_if_cache(
) -> str:
kwargs.pop("max_tokens", None)
kwargs.pop("response_format", None)
+ host = kwargs.pop("host", None)
+ timeout = kwargs.pop("timeout", None)
- ollama_client = ollama.AsyncClient()
+ ollama_client = ollama.AsyncClient(host=host, timeout=timeout)
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
@@ -305,6 +326,135 @@ async def ollama_model_if_cache(
return result
+@lru_cache(maxsize=1)
+def initialize_lmdeploy_pipeline(
+ model,
+ tp=1,
+ chat_template=None,
+ log_level="WARNING",
+ model_format="hf",
+ quant_policy=0,
+):
+ from lmdeploy import pipeline, ChatTemplateConfig, TurbomindEngineConfig
+
+ lmdeploy_pipe = pipeline(
+ model_path=model,
+ backend_config=TurbomindEngineConfig(
+ tp=tp, model_format=model_format, quant_policy=quant_policy
+ ),
+ chat_template_config=ChatTemplateConfig(model_name=chat_template)
+ if chat_template
+ else None,
+ log_level="WARNING",
+ )
+ return lmdeploy_pipe
+
+
+async def lmdeploy_model_if_cache(
+ model,
+ prompt,
+ system_prompt=None,
+ history_messages=[],
+ chat_template=None,
+ model_format="hf",
+ quant_policy=0,
+ **kwargs,
+) -> str:
+ """
+ Args:
+ model (str): The path to the model.
+ It could be one of the following options:
+ - i) A local directory path of a turbomind model which is
+ converted by `lmdeploy convert` command or download
+ from ii) and iii).
+ - ii) The model_id of a lmdeploy-quantized model hosted
+ inside a model repo on huggingface.co, such as
+ "InternLM/internlm-chat-20b-4bit",
+ "lmdeploy/llama2-chat-70b-4bit", etc.
+ - iii) The model_id of a model hosted inside a model repo
+ on huggingface.co, such as "internlm/internlm-chat-7b",
+ "Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat"
+ and so on.
+ chat_template (str): needed when model is a pytorch model on
+ huggingface.co, such as "internlm-chat-7b",
+ "Qwen-7B-Chat ", "Baichuan2-7B-Chat" and so on,
+ and when the model name of local path did not match the original model name in HF.
+ tp (int): tensor parallel
+ prompt (Union[str, List[str]]): input texts to be completed.
+ do_preprocess (bool): whether pre-process the messages. Default to
+ True, which means chat_template will be applied.
+ skip_special_tokens (bool): Whether or not to remove special tokens
+ in the decoding. Default to be True.
+ do_sample (bool): Whether or not to use sampling, use greedy decoding otherwise.
+ Default to be False, which means greedy decoding will be applied.
+ """
+ try:
+ import lmdeploy
+ from lmdeploy import version_info, GenerationConfig
+ except Exception:
+ raise ImportError("Please install lmdeploy before intialize lmdeploy backend.")
+
+ kwargs.pop("response_format", None)
+ max_new_tokens = kwargs.pop("max_tokens", 512)
+ tp = kwargs.pop("tp", 1)
+ skip_special_tokens = kwargs.pop("skip_special_tokens", True)
+ do_preprocess = kwargs.pop("do_preprocess", True)
+ do_sample = kwargs.pop("do_sample", False)
+ gen_params = kwargs
+
+ version = version_info
+ if do_sample is not None and version < (0, 6, 0):
+ raise RuntimeError(
+ "`do_sample` parameter is not supported by lmdeploy until "
+ f"v0.6.0, but currently using lmdeloy {lmdeploy.__version__}"
+ )
+ else:
+ do_sample = True
+ gen_params.update(do_sample=do_sample)
+
+ lmdeploy_pipe = initialize_lmdeploy_pipeline(
+ model=model,
+ tp=tp,
+ chat_template=chat_template,
+ model_format=model_format,
+ quant_policy=quant_policy,
+ log_level="WARNING",
+ )
+
+ messages = []
+ if system_prompt:
+ messages.append({"role": "system", "content": system_prompt})
+
+ hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
+ messages.extend(history_messages)
+ messages.append({"role": "user", "content": prompt})
+ if hashing_kv is not None:
+ args_hash = compute_args_hash(model, messages)
+ if_cache_return = await hashing_kv.get_by_id(args_hash)
+ if if_cache_return is not None:
+ return if_cache_return["return"]
+
+ gen_config = GenerationConfig(
+ skip_special_tokens=skip_special_tokens,
+ max_new_tokens=max_new_tokens,
+ **gen_params,
+ )
+
+ response = ""
+ async for res in lmdeploy_pipe.generate(
+ messages,
+ gen_config=gen_config,
+ do_preprocess=do_preprocess,
+ stream_response=False,
+ session_id=1,
+ ):
+ response += res.response
+
+ if hashing_kv is not None:
+ await hashing_kv.upsert({args_hash: {"return": response, "model": model}})
+ return response
+
+
async def gpt_4o_complete(
prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
@@ -328,8 +478,9 @@ async def gpt_4o_mini_complete(
**kwargs,
)
+
async def azure_openai_complete(
- prompt, system_prompt=None, history_messages=[], **kwargs
+ prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
return await azure_openai_complete_if_cache(
"conversation-4o-mini",
@@ -339,6 +490,7 @@ async def azure_openai_complete(
**kwargs,
)
+
async def bedrock_complete(
prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
@@ -418,9 +570,11 @@ async def azure_openai_embedding(
if base_url:
os.environ["AZURE_OPENAI_ENDPOINT"] = base_url
- openai_async_client = AsyncAzureOpenAI(azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
- api_key=os.getenv("AZURE_OPENAI_API_KEY"),
- api_version=os.getenv("AZURE_OPENAI_API_VERSION"))
+ openai_async_client = AsyncAzureOpenAI(
+ azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
+ api_key=os.getenv("AZURE_OPENAI_API_KEY"),
+ api_version=os.getenv("AZURE_OPENAI_API_VERSION"),
+ )
response = await openai_async_client.embeddings.create(
model=model, input=texts, encoding_format="float"
@@ -440,35 +594,28 @@ async def siliconcloud_embedding(
max_token_size: int = 512,
api_key: str = None,
) -> np.ndarray:
- if api_key and not api_key.startswith('Bearer '):
- api_key = 'Bearer ' + api_key
+ if api_key and not api_key.startswith("Bearer "):
+ api_key = "Bearer " + api_key
- headers = {
- "Authorization": api_key,
- "Content-Type": "application/json"
- }
+ headers = {"Authorization": api_key, "Content-Type": "application/json"}
truncate_texts = [text[0:max_token_size] for text in texts]
- payload = {
- "model": model,
- "input": truncate_texts,
- "encoding_format": "base64"
- }
+ payload = {"model": model, "input": truncate_texts, "encoding_format": "base64"}
base64_strings = []
async with aiohttp.ClientSession() as session:
async with session.post(base_url, headers=headers, json=payload) as response:
content = await response.json()
- if 'code' in content:
+ if "code" in content:
raise ValueError(content)
- base64_strings = [item['embedding'] for item in content['data']]
-
+ base64_strings = [item["embedding"] for item in content["data"]]
+
embeddings = []
for string in base64_strings:
decode_bytes = base64.b64decode(string)
n = len(decode_bytes) // 4
- float_array = struct.unpack('<' + 'f' * n, decode_bytes)
+ float_array = struct.unpack("<" + "f" * n, decode_bytes)
embeddings.append(float_array)
return np.array(embeddings)
@@ -555,14 +702,16 @@ async def hf_embedding(texts: list[str], tokenizer, embed_model) -> np.ndarray:
return embeddings.detach().numpy()
-async def ollama_embedding(texts: list[str], embed_model) -> np.ndarray:
+async def ollama_embedding(texts: list[str], embed_model, **kwargs) -> np.ndarray:
embed_text = []
+ ollama_client = ollama.Client(**kwargs)
for text in texts:
- data = ollama.embeddings(model=embed_model, prompt=text)
+ data = ollama_client.embeddings(model=embed_model, prompt=text)
embed_text.append(data["embedding"])
return embed_text
+
class Model(BaseModel):
"""
This is a Pydantic model class named 'Model' that is used to define a custom language model.
@@ -580,14 +729,20 @@ class Model(BaseModel):
The 'kwargs' dictionary contains the model name and API key to be passed to the function.
"""
- gen_func: Callable[[Any], str] = Field(..., description="A function that generates the response from the llm. The response must be a string")
- kwargs: Dict[str, Any] = Field(..., description="The arguments to pass to the callable function. Eg. the api key, model name, etc")
+ gen_func: Callable[[Any], str] = Field(
+ ...,
+ description="A function that generates the response from the llm. The response must be a string",
+ )
+ kwargs: Dict[str, Any] = Field(
+ ...,
+ description="The arguments to pass to the callable function. Eg. the api key, model name, etc",
+ )
class Config:
arbitrary_types_allowed = True
-class MultiModel():
+class MultiModel:
"""
Distributes the load across multiple language models. Useful for circumventing low rate limits with certain api providers especially if you are on the free tier.
Could also be used for spliting across diffrent models or providers.
@@ -611,26 +766,31 @@ class MultiModel():
)
```
"""
+
def __init__(self, models: List[Model]):
self._models = models
self._current_model = 0
-
+
def _next_model(self):
self._current_model = (self._current_model + 1) % len(self._models)
return self._models[self._current_model]
async def llm_model_func(
- self,
- prompt, system_prompt=None, history_messages=[], **kwargs
+ self, prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
- kwargs.pop("model", None) # stop from overwriting the custom model name
+ kwargs.pop("model", None) # stop from overwriting the custom model name
next_model = self._next_model()
- args = dict(prompt=prompt, system_prompt=system_prompt, history_messages=history_messages, **kwargs, **next_model.kwargs)
-
- return await next_model.gen_func(
- **args
+ args = dict(
+ prompt=prompt,
+ system_prompt=system_prompt,
+ history_messages=history_messages,
+ **kwargs,
+ **next_model.kwargs,
)
+ return await next_model.gen_func(**args)
+
+
if __name__ == "__main__":
import asyncio
diff --git a/lightrag/operate.py b/lightrag/operate.py
index 14dccaf3..6b6ba563 100644
--- a/lightrag/operate.py
+++ b/lightrag/operate.py
@@ -124,14 +124,14 @@ async def _handle_single_relationship_extraction(
async def _merge_nodes_then_upsert(
entity_name: str,
nodes_data: list[dict],
- knwoledge_graph_inst: BaseGraphStorage,
+ knowledge_graph_inst: BaseGraphStorage,
global_config: dict,
):
already_entitiy_types = []
already_source_ids = []
already_description = []
- already_node = await knwoledge_graph_inst.get_node(entity_name)
+ already_node = await knowledge_graph_inst.get_node(entity_name)
if already_node is not None:
already_entitiy_types.append(already_node["entity_type"])
already_source_ids.extend(
@@ -160,7 +160,7 @@ async def _merge_nodes_then_upsert(
description=description,
source_id=source_id,
)
- await knwoledge_graph_inst.upsert_node(
+ await knowledge_graph_inst.upsert_node(
entity_name,
node_data=node_data,
)
@@ -172,7 +172,7 @@ async def _merge_edges_then_upsert(
src_id: str,
tgt_id: str,
edges_data: list[dict],
- knwoledge_graph_inst: BaseGraphStorage,
+ knowledge_graph_inst: BaseGraphStorage,
global_config: dict,
):
already_weights = []
@@ -180,8 +180,8 @@ async def _merge_edges_then_upsert(
already_description = []
already_keywords = []
- if await knwoledge_graph_inst.has_edge(src_id, tgt_id):
- already_edge = await knwoledge_graph_inst.get_edge(src_id, tgt_id)
+ if await knowledge_graph_inst.has_edge(src_id, tgt_id):
+ already_edge = await knowledge_graph_inst.get_edge(src_id, tgt_id)
already_weights.append(already_edge["weight"])
already_source_ids.extend(
split_string_by_multi_markers(already_edge["source_id"], [GRAPH_FIELD_SEP])
@@ -202,8 +202,8 @@ async def _merge_edges_then_upsert(
set([dp["source_id"] for dp in edges_data] + already_source_ids)
)
for need_insert_id in [src_id, tgt_id]:
- if not (await knwoledge_graph_inst.has_node(need_insert_id)):
- await knwoledge_graph_inst.upsert_node(
+ if not (await knowledge_graph_inst.has_node(need_insert_id)):
+ await knowledge_graph_inst.upsert_node(
need_insert_id,
node_data={
"source_id": source_id,
@@ -214,7 +214,7 @@ async def _merge_edges_then_upsert(
description = await _handle_entity_relation_summary(
(src_id, tgt_id), description, global_config
)
- await knwoledge_graph_inst.upsert_edge(
+ await knowledge_graph_inst.upsert_edge(
src_id,
tgt_id,
edge_data=dict(
@@ -237,7 +237,7 @@ async def _merge_edges_then_upsert(
async def extract_entities(
chunks: dict[str, TextChunkSchema],
- knwoledge_graph_inst: BaseGraphStorage,
+ knowledge_graph_inst: BaseGraphStorage,
entity_vdb: BaseVectorStorage,
relationships_vdb: BaseVectorStorage,
global_config: dict,
@@ -341,13 +341,13 @@ async def _process_single_content(chunk_key_dp: tuple[str, TextChunkSchema]):
maybe_edges[tuple(sorted(k))].extend(v)
all_entities_data = await asyncio.gather(
*[
- _merge_nodes_then_upsert(k, v, knwoledge_graph_inst, global_config)
+ _merge_nodes_then_upsert(k, v, knowledge_graph_inst, global_config)
for k, v in maybe_nodes.items()
]
)
all_relationships_data = await asyncio.gather(
*[
- _merge_edges_then_upsert(k[0], k[1], v, knwoledge_graph_inst, global_config)
+ _merge_edges_then_upsert(k[0], k[1], v, knowledge_graph_inst, global_config)
for k, v in maybe_edges.items()
]
)
@@ -384,7 +384,7 @@ async def _process_single_content(chunk_key_dp: tuple[str, TextChunkSchema]):
}
await relationships_vdb.upsert(data_for_vdb)
- return knwoledge_graph_inst
+ return knowledge_graph_inst
async def local_query(
diff --git a/lightrag/utils.py b/lightrag/utils.py
index 9a68c16b..0da4a51a 100644
--- a/lightrag/utils.py
+++ b/lightrag/utils.py
@@ -185,6 +185,7 @@ def save_data_to_file(data, file_name):
with open(file_name, "w", encoding="utf-8") as f:
json.dump(data, f, ensure_ascii=False, indent=4)
+
def xml_to_json(xml_file):
try:
tree = ET.parse(xml_file)
@@ -194,31 +195,42 @@ def xml_to_json(xml_file):
print(f"Root element: {root.tag}")
print(f"Root attributes: {root.attrib}")
- data = {
- "nodes": [],
- "edges": []
- }
+ data = {"nodes": [], "edges": []}
# Use namespace
- namespace = {'': 'http://graphml.graphdrawing.org/xmlns'}
+ namespace = {"": "http://graphml.graphdrawing.org/xmlns"}
- for node in root.findall('.//node', namespace):
+ for node in root.findall(".//node", namespace):
node_data = {
- "id": node.get('id').strip('"'),
- "entity_type": node.find("./data[@key='d0']", namespace).text.strip('"') if node.find("./data[@key='d0']", namespace) is not None else "",
- "description": node.find("./data[@key='d1']", namespace).text if node.find("./data[@key='d1']", namespace) is not None else "",
- "source_id": node.find("./data[@key='d2']", namespace).text if node.find("./data[@key='d2']", namespace) is not None else ""
+ "id": node.get("id").strip('"'),
+ "entity_type": node.find("./data[@key='d0']", namespace).text.strip('"')
+ if node.find("./data[@key='d0']", namespace) is not None
+ else "",
+ "description": node.find("./data[@key='d1']", namespace).text
+ if node.find("./data[@key='d1']", namespace) is not None
+ else "",
+ "source_id": node.find("./data[@key='d2']", namespace).text
+ if node.find("./data[@key='d2']", namespace) is not None
+ else "",
}
data["nodes"].append(node_data)
- for edge in root.findall('.//edge', namespace):
+ for edge in root.findall(".//edge", namespace):
edge_data = {
- "source": edge.get('source').strip('"'),
- "target": edge.get('target').strip('"'),
- "weight": float(edge.find("./data[@key='d3']", namespace).text) if edge.find("./data[@key='d3']", namespace) is not None else 0.0,
- "description": edge.find("./data[@key='d4']", namespace).text if edge.find("./data[@key='d4']", namespace) is not None else "",
- "keywords": edge.find("./data[@key='d5']", namespace).text if edge.find("./data[@key='d5']", namespace) is not None else "",
- "source_id": edge.find("./data[@key='d6']", namespace).text if edge.find("./data[@key='d6']", namespace) is not None else ""
+ "source": edge.get("source").strip('"'),
+ "target": edge.get("target").strip('"'),
+ "weight": float(edge.find("./data[@key='d3']", namespace).text)
+ if edge.find("./data[@key='d3']", namespace) is not None
+ else 0.0,
+ "description": edge.find("./data[@key='d4']", namespace).text
+ if edge.find("./data[@key='d4']", namespace) is not None
+ else "",
+ "keywords": edge.find("./data[@key='d5']", namespace).text
+ if edge.find("./data[@key='d5']", namespace) is not None
+ else "",
+ "source_id": edge.find("./data[@key='d6']", namespace).text
+ if edge.find("./data[@key='d6']", namespace) is not None
+ else "",
}
data["edges"].append(edge_data)
diff --git a/reproduce/Step_3.py b/reproduce/Step_3.py
index a56190fc..2c5d699c 100644
--- a/reproduce/Step_3.py
+++ b/reproduce/Step_3.py
@@ -18,8 +18,8 @@ def extract_queries(file_path):
async def process_query(query_text, rag_instance, query_param):
try:
- result, context = await rag_instance.aquery(query_text, param=query_param)
- return {"query": query_text, "result": result, "context": context}, None
+ result = await rag_instance.aquery(query_text, param=query_param)
+ return {"query": query_text, "result": result}, None
except Exception as e:
return None, {"query": query_text, "error": str(e)}
diff --git a/reproduce/Step_3_openai_compatible.py b/reproduce/Step_3_openai_compatible.py
index 2be5ea5c..5e2ef778 100644
--- a/reproduce/Step_3_openai_compatible.py
+++ b/reproduce/Step_3_openai_compatible.py
@@ -50,8 +50,8 @@ def extract_queries(file_path):
async def process_query(query_text, rag_instance, query_param):
try:
- result, context = await rag_instance.aquery(query_text, param=query_param)
- return {"query": query_text, "result": result, "context": context}, None
+ result = await rag_instance.aquery(query_text, param=query_param)
+ return {"query": query_text, "result": result}, None
except Exception as e:
return None, {"query": query_text, "error": str(e)}
diff --git a/requirements.txt b/requirements.txt
index 897c53f8..8620fe10 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,16 +1,17 @@
accelerate
aioboto3
+aiohttp
graspologic
hnswlib
nano-vectordb
+neo4j
networkx
ollama
openai
+pyvis
tenacity
tiktoken
torch
transformers
xxhash
-pyvis
-aiohttp
-neo4j
+# lmdeploy[all]
diff --git a/setup.py b/setup.py
index 47222420..1b1f65f0 100644
--- a/setup.py
+++ b/setup.py
@@ -1,39 +1,88 @@
import setuptools
+from pathlib import Path
-with open("README.md", "r", encoding="utf-8") as fh:
- long_description = fh.read()
+# Reading the long description from README.md
+def read_long_description():
+ try:
+ return Path("README.md").read_text(encoding="utf-8")
+ except FileNotFoundError:
+ return "A description of LightRAG is currently unavailable."
-vars2find = ["__author__", "__version__", "__url__"]
-vars2readme = {}
-with open("./lightrag/__init__.py") as f:
- for line in f.readlines():
- for v in vars2find:
- if line.startswith(v):
- line = line.replace(" ", "").replace('"', "").replace("'", "").strip()
- vars2readme[v] = line.split("=")[1]
-deps = []
-with open("./requirements.txt") as f:
- for line in f.readlines():
- if not line.strip():
- continue
- deps.append(line.strip())
+# Retrieving metadata from __init__.py
+def retrieve_metadata():
+ vars2find = ["__author__", "__version__", "__url__"]
+ vars2readme = {}
+ try:
+ with open("./lightrag/__init__.py") as f:
+ for line in f.readlines():
+ for v in vars2find:
+ if line.startswith(v):
+ line = (
+ line.replace(" ", "")
+ .replace('"', "")
+ .replace("'", "")
+ .strip()
+ )
+ vars2readme[v] = line.split("=")[1]
+ except FileNotFoundError:
+ raise FileNotFoundError("Metadata file './lightrag/__init__.py' not found.")
+
+ # Checking if all required variables are found
+ missing_vars = [v for v in vars2find if v not in vars2readme]
+ if missing_vars:
+ raise ValueError(
+ f"Missing required metadata variables in __init__.py: {missing_vars}"
+ )
+
+ return vars2readme
+
+
+# Reading dependencies from requirements.txt
+def read_requirements():
+ deps = []
+ try:
+ with open("./requirements.txt") as f:
+ deps = [line.strip() for line in f if line.strip()]
+ except FileNotFoundError:
+ print(
+ "Warning: 'requirements.txt' not found. No dependencies will be installed."
+ )
+ return deps
+
+
+metadata = retrieve_metadata()
+long_description = read_long_description()
+requirements = read_requirements()
setuptools.setup(
name="lightrag-hku",
- url=vars2readme["__url__"],
- version=vars2readme["__version__"],
- author=vars2readme["__author__"],
+ url=metadata["__url__"],
+ version=metadata["__version__"],
+ author=metadata["__author__"],
description="LightRAG: Simple and Fast Retrieval-Augmented Generation",
long_description=long_description,
long_description_content_type="text/markdown",
- packages=["lightrag"],
+ packages=setuptools.find_packages(
+ exclude=("tests*", "docs*")
+ ), # Automatically find packages
classifiers=[
+ "Development Status :: 4 - Beta",
"Programming Language :: Python :: 3",
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
+ "Intended Audience :: Developers",
+ "Topic :: Software Development :: Libraries :: Python Modules",
],
python_requires=">=3.9",
- install_requires=deps,
+ install_requires=requirements,
+ include_package_data=True, # Includes non-code files from MANIFEST.in
+ project_urls={ # Additional project metadata
+ "Documentation": metadata.get("__url__", ""),
+ "Source": metadata.get("__url__", ""),
+ "Tracker": f"{metadata.get('__url__', '')}/issues"
+ if metadata.get("__url__")
+ else "",
+ },
)