From cfb8ca6f16d9ef44b7e3d29db2046c53d0bd45e5 Mon Sep 17 00:00:00 2001 From: partoneplay Date: Wed, 4 Dec 2024 17:26:47 +0800 Subject: [PATCH] Add Milvus as vector storage --- examples/lightrag_ollama_neo4j_milvus_demo.py | 51 +++++++++++ lightrag/kg/milvus_impl.py | 88 +++++++++++++++++++ lightrag/lightrag.py | 3 + requirements.txt | 1 + 4 files changed, 143 insertions(+) create mode 100644 examples/lightrag_ollama_neo4j_milvus_demo.py create mode 100644 lightrag/kg/milvus_impl.py diff --git a/examples/lightrag_ollama_neo4j_milvus_demo.py b/examples/lightrag_ollama_neo4j_milvus_demo.py new file mode 100644 index 000000000..6ed6da83a --- /dev/null +++ b/examples/lightrag_ollama_neo4j_milvus_demo.py @@ -0,0 +1,51 @@ +import os +from lightrag import LightRAG, QueryParam +from lightrag.llm import ollama_model_complete, ollama_embed +from lightrag.utils import EmbeddingFunc + +# WorkingDir +ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) +WORKING_DIR = os.path.join(ROOT_DIR, "myKG") +if not os.path.exists(WORKING_DIR): + os.mkdir(WORKING_DIR) +print(f"WorkingDir: {WORKING_DIR}") + +# neo4j +BATCH_SIZE_NODES = 500 +BATCH_SIZE_EDGES = 100 +os.environ["NEO4J_URI"] = "bolt://localhost:7687" +os.environ["NEO4J_USERNAME"] = "neo4j" +os.environ["NEO4J_PASSWORD"] = "neo4j" + +# milvus +os.environ["MILVUS_URI"] = "http://localhost:19530" +os.environ["MILVUS_USER"] = "root" +os.environ["MILVUS_PASSWORD"] = "root" +os.environ["MILVUS_DB_NAME"] = "lightrag" + + +rag = LightRAG( + working_dir=WORKING_DIR, + llm_model_func=ollama_model_complete, + llm_model_name="qwen2.5:14b", + llm_model_max_async=4, + llm_model_max_token_size=32768, + llm_model_kwargs={"host": "http://127.0.0.1:11434", "options": {"num_ctx": 32768}}, + embedding_func=EmbeddingFunc( + embedding_dim=1024, + max_token_size=8192, + func=lambda texts: ollama_embed( + texts=texts, embed_model="bge-m3:latest", host="http://127.0.0.1:11434" + ), + ), + graph_storage="Neo4JStorage", + vector_storage="MilvusVectorDBStorge", +) + +file = "./book.txt" +with open(file, "r") as f: + rag.insert(f.read()) + +print( + rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid")) +) diff --git a/lightrag/kg/milvus_impl.py b/lightrag/kg/milvus_impl.py new file mode 100644 index 000000000..6d2520ce4 --- /dev/null +++ b/lightrag/kg/milvus_impl.py @@ -0,0 +1,88 @@ +import asyncio +import os +from tqdm.asyncio import tqdm as tqdm_async +from dataclasses import dataclass +import numpy as np +from lightrag.utils import logger +from ..base import BaseVectorStorage + +from pymilvus import MilvusClient + + +@dataclass +class MilvusVectorDBStorge(BaseVectorStorage): + @staticmethod + def create_collection_if_not_exist( + client: MilvusClient, collection_name: str, **kwargs + ): + if client.has_collection(collection_name): + return + client.create_collection( + collection_name, max_length=64, id_type="string", **kwargs + ) + + def __post_init__(self): + self._client = MilvusClient( + uri=os.environ.get( + "MILVUS_URI", + os.path.join(self.global_config["working_dir"], "milvus_lite.db"), + ), + user=os.environ.get("MILVUS_USER", ""), + password=os.environ.get("MILVUS_PASSWORD", ""), + token=os.environ.get("MILVUS_TOKEN", ""), + db_name=os.environ.get("MILVUS_DB_NAME", ""), + ) + self._max_batch_size = self.global_config["embedding_batch_num"] + MilvusVectorDBStorge.create_collection_if_not_exist( + self._client, + self.namespace, + dimension=self.embedding_func.embedding_dim, + ) + + async def upsert(self, data: dict[str, dict]): + logger.info(f"Inserting {len(data)} vectors to {self.namespace}") + if not len(data): + logger.warning("You insert an empty data to vector DB") + return [] + list_data = [ + { + "id": k, + **{k1: v1 for k1, v1 in v.items() if k1 in self.meta_fields}, + } + for k, v in data.items() + ] + contents = [v["content"] for v in data.values()] + batches = [ + contents[i : i + self._max_batch_size] + for i in range(0, len(contents), self._max_batch_size) + ] + embedding_tasks = [self.embedding_func(batch) for batch in batches] + embeddings_list = [] + for f in tqdm_async( + asyncio.as_completed(embedding_tasks), + total=len(embedding_tasks), + desc="Generating embeddings", + unit="batch", + ): + embeddings = await f + embeddings_list.append(embeddings) + embeddings = np.concatenate(embeddings_list) + for i, d in enumerate(list_data): + d["vector"] = embeddings[i] + results = self._client.upsert(collection_name=self.namespace, data=list_data) + return results + + async def query(self, query, top_k=5): + embedding = await self.embedding_func([query]) + results = self._client.search( + collection_name=self.namespace, + data=embedding, + limit=top_k, + output_fields=list(self.meta_fields), + search_params={"metric_type": "COSINE", "params": {"radius": 0.2}}, + ) + print(results) + return [ + {**dp["entity"], "id": dp["id"], "distance": dp["distance"]} + for dp in results[0] + ] diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 97b2f2565..48258f99c 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -44,6 +44,8 @@ from .kg.oracle_impl import OracleKVStorage, OracleGraphStorage, OracleVectorDBStorage +from .kg.milvus_impl import MilvusVectorDBStorge + # future KG integrations # from .kg.ArangoDB_impl import ( @@ -228,6 +230,7 @@ def _get_storage_class(self) -> Type[BaseGraphStorage]: # vector storage "NanoVectorDBStorage": NanoVectorDBStorage, "OracleVectorDBStorage": OracleVectorDBStorage, + "MilvusVectorDBStorge": MilvusVectorDBStorge, # graph storage "NetworkXStorage": NetworkXStorage, "Neo4JStorage": Neo4JStorage, diff --git a/requirements.txt b/requirements.txt index 6adb6929b..4ccb2bb39 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,6 +11,7 @@ networkx ollama openai oracledb +pymilvus pyvis tenacity # lmdeploy[all]