From a6fc4209442f795f815e4e86ac637fcdcf2d8e42 Mon Sep 17 00:00:00 2001 From: Adeshen <878600556@qq.com> Date: Fri, 19 Apr 2024 21:59:28 +0800 Subject: [PATCH] the faiss index method change: replace the IndexFlatIP with IndexIVFFlat. IndexFlatIP is brute searching method IndexIVFFlat use aggregation to lower the dimension. So it will be faster, but lower the precision. --- huixiangdou/service/feature_store.py | 3 +- huixiangdou/service/retriever.py | 7 +- huixiangdou/service/vector_store.py | 192 +++++++++++++++++++++++++++ 3 files changed, 200 insertions(+), 2 deletions(-) create mode 100644 huixiangdou/service/vector_store.py diff --git a/huixiangdou/service/feature_store.py b/huixiangdou/service/feature_store.py index 64f038e8..93dafc8a 100644 --- a/huixiangdou/service/feature_store.py +++ b/huixiangdou/service/feature_store.py @@ -14,7 +14,8 @@ from langchain.text_splitter import (MarkdownHeaderTextSplitter, MarkdownTextSplitter, RecursiveCharacterTextSplitter) -from langchain.vectorstores.faiss import FAISS as Vectorstore +# from langchain.vectorstores.faiss import FAISS as Vectorstore +from .vector_store import Vectorstore from langchain_core.documents import Document from loguru import logger from torch.cuda import empty_cache diff --git a/huixiangdou/service/retriever.py b/huixiangdou/service/retriever.py index 24e2249a..9ca59a22 100644 --- a/huixiangdou/service/retriever.py +++ b/huixiangdou/service/retriever.py @@ -8,7 +8,8 @@ from BCEmbedding.tools.langchain import BCERerank from langchain.embeddings import HuggingFaceEmbeddings from langchain.retrievers import ContextualCompressionRetriever -from langchain.vectorstores.faiss import FAISS as Vectorstore +# from langchain.vectorstores.faiss import FAISS as Vectorstore +from .vector_store import Vectorstore from langchain_community.vectorstores.utils import DistanceStrategy from loguru import logger from sklearn.metrics import precision_recall_curve @@ -118,6 +119,9 @@ def query(self, str: The best matching chunk, or None. str: The best matching text, or None """ + + import time + start_time=time.time() if question is None or len(question) < 1: return None, None, [] @@ -183,6 +187,7 @@ def query(self, context = context[0:context_max_length] logger.debug('query:{} top1 file:{}'.format(question, references[0])) + logger.debug(f"query total time:{time.time()-start_time}") return '\n'.join(chunks), context, [ os.path.basename(r) for r in references ] diff --git a/huixiangdou/service/vector_store.py b/huixiangdou/service/vector_store.py new file mode 100644 index 00000000..9889a59a --- /dev/null +++ b/huixiangdou/service/vector_store.py @@ -0,0 +1,192 @@ +import os +from typing import ( + Any, + Callable, + Dict, + Iterable, + List, + Optional, + Sized, + Tuple, + Union, +) +import uuid +import numpy as np +import torch +from langchain.vectorstores.faiss import FAISS +from langchain_core.embeddings import Embeddings +from langchain_community.docstore.in_memory import InMemoryDocstore +from langchain_community.docstore.base import AddableMixin, Docstore +from langchain_core.documents import Document +from langchain_community.vectorstores.utils import ( + DistanceStrategy, + maximal_marginal_relevance, +) + +def dependable_faiss_import(no_avx2: Optional[bool] = None) -> Any: + """ + Import faiss if available, otherwise raise error. + If FAISS_NO_AVX2 environment variable is set, it will be considered + to load FAISS with no AVX2 optimization. + + Args: + no_avx2: Load FAISS strictly with no AVX2 optimization + so that the vectorstore is portable and compatible with other devices. + """ + if no_avx2 is None and "FAISS_NO_AVX2" in os.environ: + no_avx2 = bool(os.getenv("FAISS_NO_AVX2")) + + try: + if no_avx2: + from faiss import swigfaiss as faiss + else: + import faiss + except ImportError: + raise ImportError( + "Could not import faiss python package. " + "Please install it with `pip install faiss-gpu` (for CUDA supported GPU) " + "or `pip install faiss-cpu` (depending on Python version)." + ) + return faiss + +def _len_check_if_sized(x: Any, y: Any, x_name: str, y_name: str) -> None: + if isinstance(x, Sized) and isinstance(y, Sized) and len(x) != len(y): + raise ValueError( + f"{x_name} and {y_name} expected to be equal length but " + f"len({x_name})={len(x)} and len({y_name})={len(y)}" + ) + return + +import faiss +def get_faiss_index(dimension,index_method): + # faiss=dependable_faiss_import() + ''' + dimension: int, the dimension of a embedding + index_method: str, + reference https://github.com/facebookresearch/faiss/wiki/The-index-factory + + return: one of index database instance + ''' + nlist = 100 + if DistanceStrategy.MAX_INNER_PRODUCT: + index=faiss.IndexFlatIP(dimension) + index=faiss.IndexIVFFlat(index,dimension, nlist, faiss.METRIC_INNER_PRODUCT) + else: + index=faiss.IndexFlatL2(dimension) + index=faiss.IndexIVFFlat(index,dimension, nlist, faiss.METRIC_L2) + return index + + +class Vectorstore(FAISS): + @classmethod + def from_texts( + cls, + texts: List[str], + embedding: Embeddings, + metadatas: Optional[List[dict]] = None, + ids: Optional[List[str]] = None, + **kwargs: Any, + ) -> FAISS: + """Construct FAISS wrapper from raw documents. + + This is a user friendly interface that: + 1. Embeds documents. + 2. Creates an in memory docstore + 3. Initializes the FAISS database + + This is intended to be a quick way to get started. + + Example: + .. code-block:: python + + from langchain_community.vectorstores import FAISS + from langchain_community.embeddings import OpenAIEmbeddings + + embeddings = OpenAIEmbeddings() + faiss = FAISS.from_texts(texts, embeddings) + """ + embeddings = embedding.embed_documents(texts) + return cls.__from( + texts, + embeddings, + embedding, + metadatas=metadatas, + ids=ids, + **kwargs, + ) + + @classmethod + def __from( + cls, + texts: Iterable[str], + embeddings: List[List[float]], + embedding: Embeddings, + metadatas: Optional[Iterable[dict]] = None, + ids: Optional[List[str]] = None, + normalize_L2: bool = False, + distance_strategy: DistanceStrategy = DistanceStrategy.EUCLIDEAN_DISTANCE, + **kwargs: Any, + ) -> FAISS: + faiss = dependable_faiss_import() + + index = get_faiss_index(len(embeddings[0]),distance_strategy) + data=np.array(embeddings,dtype="float32") + index.train(data) + + # print("len of the embedding",len(embeddings[0])) + docstore = kwargs.pop("docstore", InMemoryDocstore()) + index_to_docstore_id = kwargs.pop("index_to_docstore_id", {}) + vecstore = cls( + embedding, + index, + docstore, + index_to_docstore_id, + normalize_L2=normalize_L2, + distance_strategy=distance_strategy, + **kwargs, + ) + + vecstore.__add(texts, embeddings, metadatas=metadatas, ids=ids) + + return vecstore + + def __add( + self, + texts: Iterable[str], + embeddings: Iterable[List[float]], + metadatas: Optional[Iterable[dict]] = None, + ids: Optional[List[str]] = None, + ) -> List[str]: + faiss = dependable_faiss_import() + + if not isinstance(self.docstore, AddableMixin): + raise ValueError( + "If trying to add texts, the underlying docstore should support " + f"adding items, which {self.docstore} does not" + ) + + _len_check_if_sized(texts, metadatas, "texts", "metadatas") + _metadatas = metadatas or ({} for _ in texts) + documents = [ + Document(page_content=t, metadata=m) for t, m in zip(texts, _metadatas) + ] + + _len_check_if_sized(documents, embeddings, "documents", "embeddings") + _len_check_if_sized(documents, ids, "documents", "ids") + + if ids and len(ids) != len(set(ids)): + raise ValueError("Duplicate ids found in the ids list.") + + # Add to the index. + vector = np.array(embeddings, dtype=np.float32) + if self._normalize_L2: + faiss.normalize_L2(vector) + self.index.add(vector) + + # Add information to docstore and index. + ids = ids or [str(uuid.uuid4()) for _ in texts] + self.docstore.add({id_: doc for id_, doc in zip(ids, documents)}) + starting_len = len(self.index_to_docstore_id) + index_to_id = {starting_len + j: id_ for j, id_ in enumerate(ids)} + self.index_to_docstore_id.update(index_to_id) + return ids \ No newline at end of file