Skip to content

Commit

Permalink
feat: add Neo4j graph backend (#61)
Browse files Browse the repository at this point in the history
* refactor: make _storage a folder

* feat: add neo4j backend

* fix: remove test coverage for neo4j

* refactor: dspy extraction

* docs: update neo4j

* fix: neo4j return clusters in node_data

* tests: fix test wrong with clusters node data

* improve coverage of llm
  • Loading branch information
gusye1234 authored Sep 25, 2024
1 parent b33b2b8 commit 9ad71cf
Show file tree
Hide file tree
Showing 26 changed files with 1,111 additions and 360 deletions.
5 changes: 4 additions & 1 deletion .coveragerc
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,7 @@ exclude_lines =

# Don't complain if tests don't hit defensive assertion code:
raise NotImplementedError
logger.
logger.
omit =
# Don't have a nice github action for neo4j now, so skip this file:
nano_graphrag/_storage/gdb_neo4j.py
2 changes: 2 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ jobs:
run: |
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
- name: Build and Test
env:
NANO_GRAPHRAG_TEST_IGNORE_NEO4J: true
run: |
python -m pytest -o log_cli=true -o log_cli_level="INFO" --cov=nano_graphrag --cov-report=xml -v ./
- name: Check codecov file
Expand Down
3 changes: 1 addition & 2 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
# Created by https://www.toptal.com/developers/gitignore/api/python
# Edit at https://www.toptal.com/developers/gitignore?templates=python
test_cache.json
run_test.py
run_test_zh.py
run_test*.py
nano_graphrag_cache*/
*.txt
examples/benchmarks/fixtures/
Expand Down
27 changes: 27 additions & 0 deletions docs/use_neo4j_for_graphrag.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
1. Install [Neo4j](https://neo4j.com/docs/operations-manual/current/installation/)
2. Install Neo4j GDS (graph data science) [plugin](https://neo4j.com/docs/graph-data-science/current/installation/neo4j-server/)
3. Start neo4j server
4. Get the `NEO4J_URL`, `NEO4J_USER` and `NEO4J_PASSWORD`
- By default, `NEO4J_URL` is `neo4j://localhost:7687` , `NEO4J_USER` is `neo4j` and `NEO4J_PASSWORD` is `neo4j`

Pass your neo4j instance to `GraphRAG`:

```python
from nano_graphrag import GraphRAG
from nano_graphrag._storage import Neo4jStorage

neo4j_config = {
"neo4j_url": os.environ.get("NEO4J_URL", "neo4j://localhost:7687"),
"neo4j_auth": (
os.environ.get("NEO4J_USER", "neo4j"),
os.environ.get("NEO4J_PASSWORD", "neo4j"),
)
}
GraphRAG(
graph_storage_cls=Neo4jStorage,
addon_params=neo4j_config,
)
```



1 change: 1 addition & 0 deletions examples/no_openai_key_at_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ async def ollama_model_if_cache(
) -> str:
# remove kwargs that are not supported by ollama
kwargs.pop("max_tokens", None)
kwargs.pop("response_format", None)

ollama_client = ollama.AsyncClient()
messages = []
Expand Down
1 change: 1 addition & 0 deletions examples/using_ollama_as_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ async def ollama_model_if_cache(
) -> str:
# remove kwargs that are not supported by ollama
kwargs.pop("max_tokens", None)
kwargs.pop("response_format", None)

ollama_client = ollama.AsyncClient()
messages = []
Expand Down
17 changes: 10 additions & 7 deletions examples/using_ollama_as_llm_and_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,13 @@
EMBEDDING_MODEL_DIM = 768
EMBEDDING_MODEL_MAX_TOKENS = 8192


async def ollama_model_if_cache(
prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
# remove kwargs that are not supported by ollama
kwargs.pop("max_tokens", None)
kwargs.pop("response_format", None)

ollama_client = ollama.AsyncClient()
messages = []
Expand Down Expand Up @@ -98,20 +100,21 @@ def insert():
# rag = GraphRAG(working_dir=WORKING_DIR, enable_llm_cache=True)
# rag.insert(FAKE_TEXT[half_len:])


# We're using Ollama to generate embeddings for the BGE model
@wrap_embedding_func_with_attrs(
embedding_dim= EMBEDDING_MODEL_DIM,
max_token_size= EMBEDDING_MODEL_MAX_TOKENS,
embedding_dim=EMBEDDING_MODEL_DIM,
max_token_size=EMBEDDING_MODEL_MAX_TOKENS,
)

async def ollama_embedding(texts :list[str]) -> np.ndarray:
async def ollama_embedding(texts: list[str]) -> np.ndarray:
embed_text = []
for text in texts:
data = ollama.embeddings(model=EMBEDDING_MODEL, prompt=text)
embed_text.append(data["embedding"])
data = ollama.embeddings(model=EMBEDDING_MODEL, prompt=text)
embed_text.append(data["embedding"])

return embed_text


if __name__ == "__main__":
insert()
query()
29 changes: 21 additions & 8 deletions nano_graphrag/_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,23 @@
from ._utils import compute_args_hash, wrap_embedding_func_with_attrs
from .base import BaseKVStorage

global_openai_async_client = None
global_azure_openai_async_client = None


def get_openai_async_client_instance():
global global_openai_async_client
if global_openai_async_client is None:
global_openai_async_client = AsyncOpenAI()
return global_openai_async_client


def get_azure_openai_async_client_instance():
global global_azure_openai_async_client
if global_azure_openai_async_client is None:
global_azure_openai_async_client = AsyncAzureOpenAI()
return global_azure_openai_async_client


@retry(
stop=stop_after_attempt(5),
Expand All @@ -22,7 +39,7 @@
async def openai_complete_if_cache(
model, prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
openai_async_client = AsyncOpenAI()
openai_async_client = get_openai_async_client_instance()
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
messages = []
if system_prompt:
Expand Down Expand Up @@ -78,7 +95,7 @@ async def gpt_4o_mini_complete(
retry=retry_if_exception_type((RateLimitError, APIConnectionError)),
)
async def openai_embedding(texts: list[str]) -> np.ndarray:
openai_async_client = AsyncOpenAI()
openai_async_client = get_openai_async_client_instance()
response = await openai_async_client.embeddings.create(
model="text-embedding-3-small", input=texts, encoding_format="float"
)
Expand All @@ -93,7 +110,7 @@ async def openai_embedding(texts: list[str]) -> np.ndarray:
async def azure_openai_complete_if_cache(
deployment_name, prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
azure_openai_client = AsyncAzureOpenAI()
azure_openai_client = get_azure_openai_async_client_instance()
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
messages = []
if system_prompt:
Expand Down Expand Up @@ -154,11 +171,7 @@ async def azure_gpt_4o_mini_complete(
retry=retry_if_exception_type((RateLimitError, APIConnectionError)),
)
async def azure_openai_embedding(texts: list[str]) -> np.ndarray:
azure_openai_client = AsyncAzureOpenAI(
api_key=os.environ.get("API_KEY_EMB"),
api_version=os.environ.get("API_VERSION_EMB"),
azure_endpoint=os.environ.get("AZURE_ENDPOINT_EMB"),
)
azure_openai_client = get_azure_openai_async_client_instance()
response = await azure_openai_client.embeddings.create(
model="text-embedding-3-small", input=texts, encoding_format="float"
)
Expand Down
5 changes: 5 additions & 0 deletions nano_graphrag/_storage/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from .gdb_networkx import NetworkXStorage
from .gdb_neo4j import Neo4jStorage
from .vdb_hnswlib import HNSWVectorStorage
from .vdb_nanovectordb import NanoVectorDBStorage
from .kv_json import JsonKVStorage
Loading

0 comments on commit 9ad71cf

Please sign in to comment.