diff --git a/README.md b/README.md index d2249717..c5f6f955 100644 --- a/README.md +++ b/README.md @@ -14,6 +14,8 @@ Slash Your LLM API Costs by 10x 💰, Boost Speed by 100x ⚡ 📔 This project is undergoing swift development, and as such, the API may be subject to change at any time. For the most up-to-date information, please refer to the latest [documentation]( https://gptcache.readthedocs.io/en/latest/) and [release note](https://github.com/zilliztech/GPTCache/blob/main/docs/release_note.md). +**NOTE:** As the number of large models is growing explosively and their API shape is constantly evolving, we no longer add support for new API or models. We encourage the usage of using the get and set API in gptcache, here is the demo code: https://github.com/zilliztech/GPTCache/blob/main/examples/adapter/api.py + ## Quick Install `pip install gptcache` diff --git a/docs/configure_it.md b/docs/configure_it.md index de7cabc9..e0d16262 100644 --- a/docs/configure_it.md +++ b/docs/configure_it.md @@ -224,6 +224,7 @@ For the similar cache of text, only cache store and vector store are needed. If - docarray - usearch - redis +- lancedb ### object store diff --git a/examples/README.md b/examples/README.md index b3b3cbc4..fc25353c 100644 --- a/examples/README.md +++ b/examples/README.md @@ -274,6 +274,7 @@ Support vector database - Zilliz Cloud - FAISS - ChromaDB +- LanceDB > [Example code](https://github.com/zilliztech/GPTCache/blob/main/examples/data_manager/vector_store.py) diff --git a/examples/data_manager/vector_store.py b/examples/data_manager/vector_store.py index 4d804d38..194010ed 100644 --- a/examples/data_manager/vector_store.py +++ b/examples/data_manager/vector_store.py @@ -20,6 +20,7 @@ def run(): 'docarray', 'redis', 'weaviate', + 'lancedb', ] for vector_store in vector_stores: cache_base = CacheBase('sqlite') diff --git a/gptcache/manager/vector_data/lancedb.py b/gptcache/manager/vector_data/lancedb.py new file mode 100644 index 00000000..708d9aa9 --- /dev/null +++ b/gptcache/manager/vector_data/lancedb.py @@ -0,0 +1,81 @@ +from typing import List, Optional + +import numpy as np +import pyarrow as pa +import lancedb +from gptcache.manager.vector_data.base import VectorBase, VectorData +from gptcache.utils import import_lancedb, import_torch + +import_torch() +import_lancedb() + + +class LanceDB(VectorBase): + """Vector store: LanceDB + :param persist_directory: The directory to persist, defaults to '/tmp/lancedb'. + :type persist_directory: str + :param table_name: The name of the table in LanceDB, defaults to 'gptcache'. + :type table_name: str + :param top_k: The number of the vectors results to return, defaults to 1. + :type top_k: int + """ + + def __init__( + self, + persist_directory: Optional[str] = "/tmp/lancedb", + table_name: str = "gptcache", + top_k: int = 1, + ): + self._persist_directory = persist_directory + self._table_name = table_name + self._top_k = top_k + + # Initialize LanceDB database + self._db = lancedb.connect(self._persist_directory) + + # Initialize or open table + if self._table_name not in self._db.table_names(): + self._table = None # Table will be created with the first insertion + else: + self._table = self._db.open_table(self._table_name) + + def mul_add(self, datas: List[VectorData]): + """Add multiple vectors to the LanceDB table""" + vectors, vector_ids = map(list, zip(*((data.data.tolist(), str(data.id)) for data in datas))) + # Infer the dimension of the vectors + vector_dim = len(vectors[0]) if vectors else 0 + + # Create table with the inferred schema if it doesn't exist + if self._table is None: + schema = pa.schema([ + pa.field("id", pa.string()), + pa.field("vector", pa.list_(pa.float32(), list_size=vector_dim)) + ]) + self._table = self._db.create_table(self._table_name, schema=schema) + + # Prepare and add data to the table + self._table.add(({"id": vector_id, "vector": vector} for vector_id, vector in zip(vector_ids, vectors))) + + def search(self, data: np.ndarray, top_k: int = -1): + """Search for the most similar vectors in the LanceDB table""" + if len(self._table) == 0: + return [] + + if top_k == -1: + top_k = self._top_k + + results = self._table.search(data.tolist()).limit(top_k).to_list() + return [(result["_distance"], int(result["id"])) for result in results] + + def delete(self, ids: List[int]): + """Delete vectors from the LanceDB table based on IDs""" + for vector_id in ids: + self._table.delete(f"id = '{vector_id}'") + + def rebuild(self, ids: Optional[List[int]] = None): + """Rebuild the index, if applicable""" + return True + + def count(self): + """Return the total number of vectors in the table""" + return len(self._table) diff --git a/gptcache/manager/vector_data/manager.py b/gptcache/manager/vector_data/manager.py index fe21b6f9..3ac02c0c 100644 --- a/gptcache/manager/vector_data/manager.py +++ b/gptcache/manager/vector_data/manager.py @@ -42,6 +42,7 @@ class VectorBase: `Chromadb` (with `top_k`, `client_settings`, `persist_directory`, `collection_name` params), `Hnswlib` (with `index_file_path`, `dimension`, `top_k`, `max_elements` params). `pgvector` (with `url`, `collection_name`, `index_params`, `top_k`, `dimension` params). + `lancedb` (with `url`, `collection_name`, `index_param`, `top_k`,). :param name: the name of the vectorbase, it is support 'milvus', 'faiss', 'chromadb', 'hnswlib' now. :type name: str @@ -91,6 +92,14 @@ class VectorBase: :param persist_directory: the directory to persist, defaults to '.chromadb/' in the current directory. :type persist_directory: str + :param client_settings: the setting for LanceDB. + :param persist_directory: The directory to persist, defaults to '/tmp/lancedb'. + :type persist_directory: str + :param table_name: The name of the table in LanceDB, defaults to 'gptcache'. + :type table_name: str + :param top_k: The number of the vectors results to return, defaults to 1. + :type top_k: int + :param index_path: the path to hnswlib index, defaults to 'hnswlib_index.bin'. :type index_path: str :param max_elements: max_elements of hnswlib, defaults 100000. @@ -293,6 +302,20 @@ def get(name, **kwargs): class_schema=class_schema, top_k=top_k, ) + + elif name == "lancedb": + from gptcache.manager.vector_data.lancedb import LanceDB + + persist_directory = kwargs.get("persist_directory", None) + table_name = kwargs.get("table_name", COLLECTION_NAME) + top_k: int = kwargs.get("top_k", TOP_K) + + vector_base = LanceDB( + persist_directory=persist_directory, + table_name=table_name, + top_k=top_k, + ) + else: raise NotFoundError("vector store", name) return vector_base diff --git a/gptcache/utils/__init__.py b/gptcache/utils/__init__.py index 1fec7c56..877e23f6 100644 --- a/gptcache/utils/__init__.py +++ b/gptcache/utils/__init__.py @@ -43,6 +43,7 @@ "import_redis", "import_qdrant", "import_weaviate", + "import_lancedb", ] import importlib.util @@ -152,6 +153,8 @@ def import_duckdb(): _check_library("duckdb", package="duckdb") _check_library("duckdb-engine", package="duckdb-engine") +def import_lancedb(): + _check_library("lancedb", package="lancedb") def import_sql_client(db_name): if db_name == "postgresql": diff --git a/tests/unit_tests/manager/test_lancedb.py b/tests/unit_tests/manager/test_lancedb.py new file mode 100644 index 00000000..f90688b1 --- /dev/null +++ b/tests/unit_tests/manager/test_lancedb.py @@ -0,0 +1,24 @@ +import unittest +import numpy as np +from gptcache.manager import VectorBase +from gptcache.manager.vector_data.base import VectorData + +class TestLanceDB(unittest.TestCase): + def test_normal(self): + + db = VectorBase("lancedb", persist_directory="/tmp/test_lancedb", top_k=3) + + # Add 100 vectors to the LanceDB + db.mul_add([VectorData(id=i, data=np.random.sample(10)) for i in range(100)]) + + # Perform a search with a random query vector + search_res = db.search(np.random.sample(10)) + + # Check that the search returns 3 results + self.assertEqual(len(search_res), 3) + + # Delete vectors with specific IDs + db.delete([1, 3, 5, 7]) + + # Check that the count of vectors in the table is now 96 + self.assertEqual(db.count(), 96) \ No newline at end of file