Skip to content

Commit

Permalink
Merge pull request #393 from partoneplay/main
Browse files Browse the repository at this point in the history
Add Milvus as vector storage
  • Loading branch information
LarFii authored Dec 5, 2024
2 parents dda91a7 + cfb8ca6 commit c352eb6
Show file tree
Hide file tree
Showing 4 changed files with 143 additions and 0 deletions.
51 changes: 51 additions & 0 deletions examples/lightrag_ollama_neo4j_milvus_demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import os
from lightrag import LightRAG, QueryParam
from lightrag.llm import ollama_model_complete, ollama_embed
from lightrag.utils import EmbeddingFunc

# WorkingDir
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
WORKING_DIR = os.path.join(ROOT_DIR, "myKG")
if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR)
print(f"WorkingDir: {WORKING_DIR}")

# neo4j
BATCH_SIZE_NODES = 500
BATCH_SIZE_EDGES = 100
os.environ["NEO4J_URI"] = "bolt://localhost:7687"
os.environ["NEO4J_USERNAME"] = "neo4j"
os.environ["NEO4J_PASSWORD"] = "neo4j"

# milvus
os.environ["MILVUS_URI"] = "http://localhost:19530"
os.environ["MILVUS_USER"] = "root"
os.environ["MILVUS_PASSWORD"] = "root"
os.environ["MILVUS_DB_NAME"] = "lightrag"


rag = LightRAG(
working_dir=WORKING_DIR,
llm_model_func=ollama_model_complete,
llm_model_name="qwen2.5:14b",
llm_model_max_async=4,
llm_model_max_token_size=32768,
llm_model_kwargs={"host": "http://127.0.0.1:11434", "options": {"num_ctx": 32768}},
embedding_func=EmbeddingFunc(
embedding_dim=1024,
max_token_size=8192,
func=lambda texts: ollama_embed(
texts=texts, embed_model="bge-m3:latest", host="http://127.0.0.1:11434"
),
),
graph_storage="Neo4JStorage",
vector_storage="MilvusVectorDBStorge",
)

file = "./book.txt"
with open(file, "r") as f:
rag.insert(f.read())

print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))
)
88 changes: 88 additions & 0 deletions lightrag/kg/milvus_impl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import asyncio
import os
from tqdm.asyncio import tqdm as tqdm_async
from dataclasses import dataclass
import numpy as np
from lightrag.utils import logger
from ..base import BaseVectorStorage

from pymilvus import MilvusClient


@dataclass
class MilvusVectorDBStorge(BaseVectorStorage):
@staticmethod
def create_collection_if_not_exist(
client: MilvusClient, collection_name: str, **kwargs
):
if client.has_collection(collection_name):
return
client.create_collection(
collection_name, max_length=64, id_type="string", **kwargs
)

def __post_init__(self):
self._client = MilvusClient(
uri=os.environ.get(
"MILVUS_URI",
os.path.join(self.global_config["working_dir"], "milvus_lite.db"),
),
user=os.environ.get("MILVUS_USER", ""),
password=os.environ.get("MILVUS_PASSWORD", ""),
token=os.environ.get("MILVUS_TOKEN", ""),
db_name=os.environ.get("MILVUS_DB_NAME", ""),
)
self._max_batch_size = self.global_config["embedding_batch_num"]
MilvusVectorDBStorge.create_collection_if_not_exist(
self._client,
self.namespace,
dimension=self.embedding_func.embedding_dim,
)

async def upsert(self, data: dict[str, dict]):
logger.info(f"Inserting {len(data)} vectors to {self.namespace}")
if not len(data):
logger.warning("You insert an empty data to vector DB")
return []
list_data = [
{
"id": k,
**{k1: v1 for k1, v1 in v.items() if k1 in self.meta_fields},
}
for k, v in data.items()
]
contents = [v["content"] for v in data.values()]
batches = [
contents[i : i + self._max_batch_size]
for i in range(0, len(contents), self._max_batch_size)
]
embedding_tasks = [self.embedding_func(batch) for batch in batches]
embeddings_list = []
for f in tqdm_async(
asyncio.as_completed(embedding_tasks),
total=len(embedding_tasks),
desc="Generating embeddings",
unit="batch",
):
embeddings = await f
embeddings_list.append(embeddings)
embeddings = np.concatenate(embeddings_list)
for i, d in enumerate(list_data):
d["vector"] = embeddings[i]
results = self._client.upsert(collection_name=self.namespace, data=list_data)
return results

async def query(self, query, top_k=5):
embedding = await self.embedding_func([query])
results = self._client.search(
collection_name=self.namespace,
data=embedding,
limit=top_k,
output_fields=list(self.meta_fields),
search_params={"metric_type": "COSINE", "params": {"radius": 0.2}},
)
print(results)
return [
{**dp["entity"], "id": dp["id"], "distance": dp["distance"]}
for dp in results[0]
]
3 changes: 3 additions & 0 deletions lightrag/lightrag.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@

from .kg.oracle_impl import OracleKVStorage, OracleGraphStorage, OracleVectorDBStorage

from .kg.milvus_impl import MilvusVectorDBStorge

# future KG integrations

# from .kg.ArangoDB_impl import (
Expand Down Expand Up @@ -228,6 +230,7 @@ def _get_storage_class(self) -> Type[BaseGraphStorage]:
# vector storage
"NanoVectorDBStorage": NanoVectorDBStorage,
"OracleVectorDBStorage": OracleVectorDBStorage,
"MilvusVectorDBStorge": MilvusVectorDBStorge,
# graph storage
"NetworkXStorage": NetworkXStorage,
"Neo4JStorage": Neo4JStorage,
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ networkx
ollama
openai
oracledb
pymilvus
pyvis
tenacity
# lmdeploy[all]
Expand Down

0 comments on commit c352eb6

Please sign in to comment.