diff --git a/README.md b/README.md index c70bfaf1..fcddcc08 100644 --- a/README.md +++ b/README.md @@ -33,7 +33,7 @@ Canopy has two flows: knowledge base creation and chat. In the knowledge base cr 1. **Canopy Core Library** - The library has 3 main classes that are responsible for different parts of the RAG workflow: * **ChatEngine** - Exposes a chat interface to interact with your data. Given the history of chat messages, the `ChatEngine` formulates relevant queries to the `ContextEngine`, then uses the LLM to generate a knowledgeable response. * **ContextEngine** - Performs the “retrieval” part of RAG. The `ContextEngine` utilizes the underlying `KnowledgeBase` to retrieve the most relevant documents, then formulates a coherent textual context to be used as a prompt for the LLM. - * **KnowledgeBase** - Manages your data for the RAG workflow. It automatically chunks and transforms your text data into text embeddings, storing them in a Pinecone vector database. Given a text query - the `KnowledgeBase` will retrieve the most relevant document chunks from the database. + * **KnowledgeBase** - Manages your data for the RAG workflow. It automatically chunks and transforms your text data into text embeddings, storing them in a Pinecone(Default)/Qdrant vector database. Given a text query - the knowledge base will retrieve the most relevant document chunks from the database. > More information about the Core Library usage can be found in the [Library Documentation](docs/library.md) @@ -67,11 +67,12 @@ pip install canopy-sdk ### Extras | Name | Description | -|----------------|----------------------------------------------------------------------------------------------------------------------------------------------------------| +| -------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------- | | `grpc` | To unlock some performance improvements by working with the GRPC version of the [Pinecone Client](https://github.com/pinecone-io/pinecone-python-client) | | `torch` | To enable embeddings provided by [sentence-transformers](https://www.sbert.net/) | | `transformers` | If you are using Anyscale LLMs, it's recommended to use `LLamaTokenizer` tokenizer which requires transformers as dependency | | `cohere` | To use Cohere reranker or/and Cohere LLM | +| `qdrant` | To use [Qdrant](http://qdrant.tech/) as an alternate knowledge base | diff --git a/docs/library.md b/docs/library.md index 179c0180..9429c1af 100644 --- a/docs/library.md +++ b/docs/library.md @@ -1,21 +1,21 @@ # Canopy Library -For most common use cases, users can simply deploy the fully-configurable [Canopy service](../README.md), which provides a REST API backend for your own RAG-infused Chatbot. +For most common use cases, users can simply deploy the fully configurable [Canopy service](../README.md), which provides a REST API backend for their own RAG-infused Chatbot. -For advanced users, this page describes how to use `canopy` core library directly to implement their own custom applications. +For advanced users, this page describes how to use the `canopy` core library directly to implement their custom applications. > **_💡 NOTE:_** You can also follow the quickstart Jupyter [notebook](../examples/canopy-lib-quickstart.ipynb) -The idea behind Canopy library is to provide a framework to build AI applications on top of Pinecone as a long memory storage for you own data. Canopy library designed with the following principles in mind: +The idea behind Canopy is to provide a framework to build AI applications on top of Pinecone as a long-memory storage for your own data. Canopy is designed with the following principles in mind: -- **Easy to use**: Canopy is designed to be easy to use. It is well packaged and can be installed with a single command. -- **Modularity**: Canopy is built as a collection of modules that can be used together or separately. For example, you can use the `chat_engine` module to build a chatbot on top of your data, or you can use the `knowledge_base` module to directly store and search your data. -- **Extensibility**: Canopy is designed to be extensible. You can easily add your own components and extend the functionality. -- **Production ready**: Canopy designed to be production ready, tested, well documented, maintained and supported. -- **Open source**: Canopy is open source and free to use. It built in partnership with the community and for the community. +- **Easy to use**: Canopy is designed to be easy. It is well-packaged and can be installed with a single command. +- **Modularity**: Canopy is built as a collection of modules that can be used together or separately. For example, you can use the `chat_engine` module to build a chatbot on top of your data or the `knowledge_base` module to store and search your data directly. +- **Extensibility**: Canopy is designed to be extensible. You can easily add your components and extend the functionality. +- **Production-ready**: Canopy designed to be production-ready, tested, well-documented, maintained and supported. +- **Open-source**: Canopy is open-source and free to use. It is built in partnership with the community and for the community. -## High level architecture +## High-level architecture ![class architecture](../.readme-content/class_architecture.png) @@ -59,9 +59,9 @@ os.environ["OPENAI_API_KEY"] = "" ### Step 1: Initialize global Tokenizer -The `Tokenizer` object is used for converting text into tokens, which is the basic data represntation that is used for processing. +The `Tokenizer` object is used for converting text into tokens, which is the basic data representation that is used for processing. -Since manny different classes rely on a tokenizer, Canopy uses a singleton `Tokenizer` object which needs to be initialized once. +Since many different classes rely on a tokenizer, Canopy uses a singleton `Tokenizer` object which needs to be initialized once. Before instantiating any other canopy core objects, please initialize the `Tokenizer` singleton: @@ -84,13 +84,13 @@ tokenizer.tokenize("Hello world!") Since the `tokenizer` object created here would be the same instance that you have initialized at the beginning of this subsection. -By default, the global tokenizer is initialized with `OpenAITokenizer` that is based on OpenAI's tiktoken library and aligned with GPT 3 and 4 models tokenization. +By default, the global tokenizer is initialized with `OpenAITokenizer` which is based on OpenAI's Tiktoken library and aligned with GPT 3 and 4 models tokenization.
👉 Click here to understand how you can configure and customize the tokenizer The `Tokenizer` singleton is holding an inner `Tokenizer` object that implements `BaseTokenizer`. -You can create your own customized tokenizer by implementing a new class that derives from `BaseTokenizer`, then passing this class to the `Tokenizer` singleton during initialization. Example: +You can create your own customized tokenizer by implementing a new class that derives from `BaseTokenizer`, and then passing this class to the `Tokenizer` singleton during initialization. Example: ```python from canopy.tokenizer import Tokenizer, BaseTokenizer @@ -114,7 +114,7 @@ Will initialize the global tokenizer with `OpenAITokenizer` and will pass the `m ### Step 2: Create a knowledge base -Knowledge base is an object that is responsible for storing and query your data. It holds a connection to a single Pinecone index and provides a simple API to insert, delete and search textual documents. +Knowledge base is an object that is responsible for storing and querying your data. It holds a connection to a single Pinecone index and provides a simple API to insert, delete and search textual documents. To create a knowledge base, you can use the following command: @@ -140,7 +140,7 @@ To create a new Pinecone index and connect it to the knowledge base, you can use kb.create_canopy_index() ``` -Then, you will be able to mange the index in Pinecone [console](https://app.pinecone.io/). +Then, you will be able to manage the index in Pinecone [console](https://app.pinecone.io/). If you already created a Pinecone index, you can connect it to the knowledge base with the `connect` method: @@ -154,7 +154,39 @@ You can always verify the connection to the Pinecone index with the `verify_inde kb.verify_index_connection() ``` -To learn more about customizing the KnowledgeBase and its inner components, see [understanding knowledgebase workings section](#understanding-knowledgebase-workings). +#### Using Qdrant as a knowledge base + +Canopy supports [Qdrant](https://qdrant.tech) as an alternative knowledge base. To use Qdrant with Canopy, install the `qdrant` extra. + +```bash +pip install canopy-sdk[qdrant] +``` + +The Qdrant knowledge base is accessible via the `QdrantKnowledgeBase` class. + +```python +from canopy.knowledge_base import QdrantKnowledgeBase + +kb = QdrantKnowledgeBase(collection_name="") +``` + +The constructor accepts additional [options](https://github.com/qdrant/qdrant-client/blob/eda201a1dbf1bbc67415f8437a5619f6f83e8ac6/qdrant_client/qdrant_client.py#L36-L61) to customize your connection to Qdrant. + +To create a new Qdrant collection and connect it to the knowledge base, use the `create_canopy_collection` method: + +```python +kb.create_canopy_collection() +``` + +The method accepts additional [options](https://github.com/qdrant/qdrant-client/blob/c63c62e6df9763591622d1921b3dfcc486666481/qdrant_client/qdrant_remote.py#L2137-L2150) to configure the collection to be created. + +You can always verify the connection to the collection with the `verify_index_connection` method: + +```python +kb.verify_index_connection() +``` + +To learn more about customizing the KnowledgeBase and its inner components, see [understanding knowledge ebase workings section](#understanding-knowledgebase-workings). ### Step 3: Upsert and query data @@ -190,9 +222,9 @@ print(f"score - {results[0].documents[0].score:.4f}") ### Step 4: Create a context engine -Context engine is an object that responsible to retrieve the most relevant context for a given query and token budget. +Context engine is an object that is responsible for retrieving the most relevant context for a given query and token budget. The context engine first uses the knowledge base to retrieve the most relevant documents. Then, it formalizes the textual context that will be presented to the LLM. This textual context might be structured or unstructured, depending on the use case and configuration. -The output of the context engine is designed to provide the LLM the most relevant context for a given query. +The output of the context engine is designed to provide the LLM with the most relevant context for a given query. To create a context engine using a knowledge base, you can use the following command: @@ -243,7 +275,7 @@ TBD ### Step 5: Create a chat engine -Chat engine is an object that implements end to end chat API with [RAG](https://www.pinecone.io/learn/retrieval-augmented-generation/). +Chat engine is an object that implements end-to-end chat API with [RAG](https://www.pinecone.io/learn/retrieval-augmented-generation/). Given chat history, the chat engine orchestrates its underlying context engine and LLM to run the following steps: 1. Generate search queries from the chat history @@ -270,8 +302,8 @@ print(response.choices[0].message.content) ``` -Canopy designed to be production ready and handle any conversation length and context length. Therefore, the chat engine uses internal components to handle long conversations and long contexts. -By default, long chat history is truncated to the latest messages that fits the token budget. It orchestrates the context engine to retrieve context that fits the token budget and then use the LLM to generate the next response. +Canopy designed to be production-ready and handle any conversation length and context length. Therefore, the chat engine uses internal components to handle long conversations and long contexts. +By default, long chat history is truncated to the latest messages that fit the token budget. It orchestrates the context engine to retrieve context that fits the token budget and then uses the LLM to generate the next response.
@@ -282,10 +314,10 @@ TBD ## Understanding KnowledgeBase workings -The knowledge base is an object that is responsible for storing and query your data. It holds a connection to a single Pinecone index and provides a simple API to insert, delete and search textual documents. +The knowledge base is an object that is responsible for storing and querying your data. It holds a connection to a single Pinecone index and provides a simple API to insert, delete and search textual documents. ### Upsert workflow -The `upsert` method is used to insert of update textual documents of any size into the knowledge base. For each document, the following steps are performed: +The `upsert` method is used to insert or update textual documents of any size into the knowledge base. For each document, the following steps are performed: 1. The document is chunked into smaller pieces of text, each piece is called a `Chunk`. 2. Each chunk is encoded into a vector representation. @@ -308,7 +340,7 @@ The knowledge base is composed of the following components: - **Chunker**: A `Chunker` object that is used to chunk the documents into smaller pieces of text. - **Encoder**: An `RecordEncoder` object that is used to encode the chunks and queries into vector representations. -By default the knowledge base is initialized with `OpenAIRecordEncoder` which uses OpenAI embedding API to encode the text into vector representations, and `MarkdownChunker` which is based on a cloned version of Langchain's `MarkdownTextSplitter` [chunker](https://github.com/langchain-ai/langchain/blob/95a1b598fefbdb4c28db53e493d5f3242129a5f2/libs/langchain/langchain/text_splitter.py#L1374C7-L1374C27). +By default, the knowledge base is initialized with `OpenAIRecordEncoder` which uses OpenAI embedding API to encode the text into vector representations, and `MarkdownChunker` which is based on a cloned version of Langchain's `MarkdownTextSplitter` [chunker](https://github.com/langchain-ai/langchain/blob/95a1b598fefbdb4c28db53e493d5f3242129a5f2/libs/langchain/langchain/text_splitter.py#L1374C7-L1374C27). You can customize each component by passing any instance of `Chunker` or `RecordEncoder` to the `KnowledgeBase` constructor. diff --git a/pyproject.toml b/pyproject.toml index 1b04eeba..945d4201 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,6 +32,7 @@ transformers = {version = "^4.35.2", optional = true} sentencepiece = "^0.1.99" pandas = "2.0.0" pyarrow = "^14.0.1" +qdrant-client = {version = "^1.8.0", optional = true} cohere = { version = "^4.37", optional = true } @@ -60,6 +61,7 @@ cohere = ["cohere"] torch = ["torch", "sentence-transformers"] transformers = ["transformers"] grpc = ["grpcio", "grpc-gateway-protoc-gen-openapiv2", "googleapis-common-protos", "lz4", "protobuf"] +qdrant = ["qdrant-client"] [tool.poetry.group.dev.dependencies] @@ -96,7 +98,9 @@ module = [ 'tokenizers.*', 'cohere.*', 'pinecone.grpc', - 'huggingface_hub.utils' + 'huggingface_hub.utils', + 'qdrant_client.*', + 'grpc.*' ] ignore_missing_imports = true diff --git a/src/canopy/knowledge_base/__init__.py b/src/canopy/knowledge_base/__init__.py index 44e4aeaa..eacb36ff 100644 --- a/src/canopy/knowledge_base/__init__.py +++ b/src/canopy/knowledge_base/__init__.py @@ -1,2 +1,3 @@ from .knowledge_base import list_canopy_indexes from .knowledge_base import KnowledgeBase +from .qdrant.qdrant_knowledge_base import QdrantKnowledgeBase diff --git a/src/canopy/knowledge_base/qdrant/constants.py b/src/canopy/knowledge_base/qdrant/constants.py new file mode 100644 index 00000000..69df41d7 --- /dev/null +++ b/src/canopy/knowledge_base/qdrant/constants.py @@ -0,0 +1,7 @@ +from canopy.knowledge_base.knowledge_base import INDEX_NAME_PREFIX + +COLLECTION_NAME_PREFIX = INDEX_NAME_PREFIX +DENSE_VECTOR_NAME = "dense" +RESERVED_METADATA_KEYS = {"document_id", "text", "source", "chunk_id"} +SPARSE_VECTOR_NAME = "sparse" +UUID_NAMESPACE = "867603e3-ba69-447d-a8ef-263dff19bda7" diff --git a/src/canopy/knowledge_base/qdrant/converter.py b/src/canopy/knowledge_base/qdrant/converter.py new file mode 100644 index 00000000..e4bc7cd0 --- /dev/null +++ b/src/canopy/knowledge_base/qdrant/converter.py @@ -0,0 +1,102 @@ +from copy import deepcopy +from typing import Dict, List, Any, Union +import uuid +from canopy.knowledge_base.models import ( + KBDocChunkWithScore, + KBEncodedDocChunk, + KBQuery, + VectorValues, +) +from pinecone_text.sparse import SparseVector + +try: + from qdrant_client import models +except ImportError: + pass + +from canopy.knowledge_base.qdrant.constants import ( + DENSE_VECTOR_NAME, + SPARSE_VECTOR_NAME, + UUID_NAMESPACE, +) + + +class QdrantConverter: + @staticmethod + def convert_id(_id: str) -> str: + """ + Converts any string into a UUID string based on a seed. + + Qdrant accepts UUID strings and unsigned integers as point ID. + We use a seed to convert each string into a UUID string deterministically. + This allows us to overwrite the same point with the original ID. + """ + return str(uuid.uuid5(uuid.UUID(UUID_NAMESPACE), _id)) + + @staticmethod + def encoded_docs_to_points( + encoded_docs: List[KBEncodedDocChunk], + ) -> "List[models.PointStruct]": + points = [] + for doc in encoded_docs: + record = doc.to_db_record() + _id: str = record.pop("id") + dense_vector: VectorValues = record.pop("values", None) + sparse_vector: SparseVector = record.pop("sparse_values", None) + + vector: Dict[str, models.Vector] = {} + + if dense_vector: + vector[DENSE_VECTOR_NAME] = dense_vector + + if sparse_vector: + vector[SPARSE_VECTOR_NAME] = models.SparseVector( + indices=sparse_vector["indices"], + values=sparse_vector["values"], + ) + + points.append( + models.PointStruct( + id=QdrantConverter.convert_id(_id), + vector=vector, + payload={**record["metadata"], "chunk_id": _id}, + ) + ) + return points + + @staticmethod + def scored_point_to_scored_doc( + scored_point, + ) -> "KBDocChunkWithScore": + metadata: Dict[str, Any] = deepcopy(scored_point.payload or {}) + _id = metadata.pop("chunk_id") + text = metadata.pop("text", "") + document_id = metadata.pop("document_id") + return KBDocChunkWithScore( + id=_id, + text=text, + document_id=document_id, + score=scored_point.score, + source=metadata.pop("source", ""), + metadata=metadata, + ) + + @staticmethod + def kb_query_to_search_vector( + query: KBQuery, + ) -> "Union[models.NamedVector, models.NamedSparseVector]": + # Use dense vector if available, otherwise use sparse vector + query_vector: Union[models.NamedVector, models.NamedSparseVector] + if query.values: + query_vector = models.NamedVector(name=DENSE_VECTOR_NAME, vector=query.values) # noqa: E501 + elif query.sparse_values: + query_vector = models.NamedSparseVector( + name=SPARSE_VECTOR_NAME, + vector=models.SparseVector( + indices=query.sparse_values["indices"], + values=query.sparse_values["values"], + ), + ) + else: + raise ValueError("Query should have either dense or sparse vector.") + return query_vector diff --git a/src/canopy/knowledge_base/qdrant/qdrant_knowledge_base.py b/src/canopy/knowledge_base/qdrant/qdrant_knowledge_base.py new file mode 100644 index 00000000..43c3915a --- /dev/null +++ b/src/canopy/knowledge_base/qdrant/qdrant_knowledge_base.py @@ -0,0 +1,739 @@ +from copy import deepcopy +from typing import List, Optional, Dict, Any + +from canopy.knowledge_base.base import BaseKnowledgeBase +from canopy.knowledge_base.chunker import Chunker, MarkdownChunker +from canopy.knowledge_base.qdrant.constants import ( + COLLECTION_NAME_PREFIX, + DENSE_VECTOR_NAME, + RESERVED_METADATA_KEYS, + SPARSE_VECTOR_NAME, +) +from canopy.knowledge_base.qdrant.converter import QdrantConverter +from canopy.knowledge_base.qdrant.utils import ( + batched, + generate_clients, + sync_fallback, +) +from canopy.knowledge_base.record_encoder import RecordEncoder, OpenAIRecordEncoder +from canopy.knowledge_base.models import ( + KBEncodedDocChunk, + KBQueryResult, + KBQuery, + QueryResult, + KBDocChunkWithScore, + DocumentWithScore, +) +from canopy.knowledge_base.reranker import Reranker, TransparentReranker +from canopy.models.data_models import Query, Document + +from tqdm import tqdm + +try: + from qdrant_client import models + from qdrant_client.http.exceptions import UnexpectedResponse + from grpc import RpcError + + _qdrant_installed = True +except ImportError: + _qdrant_installed = False + + +class QdrantKnowledgeBase(BaseKnowledgeBase): + """ + `QdrantKnowledgeBase` is used to store/retrieve documents using a Qdrant collection. + Every document is chunked into multiple text snippets based on the text structure. + Then, each chunk is encoded into a vector using an embedding model + The resulting vectors are inserted to the Qdrant collection. + After insertion, the `QdrantKnowledgeBase` can be queried by a textual query. + The query will be encoded to a vector to retrieve the top-k document chunks. + + Note: Since Canopy defines its own data format, + you cannot use a pre-existing Qdrant collection with Canopy's `QdrantKnowledgeBase`. + The collection must be created using `knowledge_base.create_canopy_collection()`. + + Example: + >>> from canopy.knowledge_base.knowledge_base import QdrantKnowledgeBase + >>> kb = QdrantKnowledgeBase(collection_name="my_collection") + >>> # Defaults to a Qdrant instance at localhost:6333 + >>> kb.create_canopy_collection() + + In any future interactions, the same collection name can be used. + Without the need to create the collection again: + + >>> kb = QdrantKnowledgeBase(collection_name="my_collection") + """ + + _DEFAULT_COMPONENTS = { + "record_encoder": OpenAIRecordEncoder, + "chunker": MarkdownChunker, + "reranker": TransparentReranker, + } + + def __init__( + self, + collection_name: str, + *, + record_encoder: Optional[RecordEncoder] = None, + chunker: Optional[Chunker] = None, + reranker: Optional[Reranker] = None, + default_top_k: int = 5, + location: Optional[str] = None, + url: Optional[str] = None, + port: Optional[int] = 6333, + grpc_port: int = 6334, + prefer_grpc: bool = False, + https: Optional[bool] = None, + api_key: Optional[str] = None, + prefix: Optional[str] = None, + timeout: Optional[int] = None, + host: Optional[str] = None, + path: Optional[str] = None, + force_disable_check_same_thread: bool = False, + ): + """ + Instantiates a new `QdrantKnowledgeBase` object. + + If the collection does not exist, + create it by calling `create_canopy_collection()`. + + Note: Canopy will add the prefix 'canopy--' to the collection name. + You can retrieve the full collection name using + `knowledge_base.collection_name`. + + Example: + Create a new collection: + + >>> from canopy.knowledge_base.knowledge_base import QdrantKnowledgeBase + >>> kb = QdrantKnowledgeBase(collection_name="my_collection") + >>> # Defaults to a Qdrant instance at localhost:6333 + >>> kb.create_canopy_collection() + + In any future interactions, the same collection name can be used. + Without having to create it again: + + >>> kb = QdrantKnowledgeBase(collection_name="my_collection") + Args: + collection_name: _description_ + record_encoder: An instance of RecordEncoder to use for encoding documents and queries. + s Defaults to OpenAIRecordEncoder. + chunker: An instance of Chunker to use for chunking documents. Defaults to MarkdownChunker. + reranker: An instance of Reranker to use for reranking query results. Defaults to TransparentReranker. + default_top_k: The default number of document chunks to return per query. Defaults to 5. + location: + If ':memory:' - use in-memory Qdrant instance. + If 'str' - use it as a `url` parameter. + If 'None' - use default values for `host` and `port`. + url: either host or str of "Optional[scheme], host, Optional[port], Optional[prefix]". + Default: `None` + port: Port of the REST API interface. Default: 6333 + grpc_port: Port of the gRPC interface. Default: 6334 + prefer_grpc: If `true` - use gPRC interface whenever possible in custom methods. + https: If `true` - use HTTPS(SSL) protocol. Default: `None` + api_key: API key for authentication in Qdrant Cloud. Default: `None` + prefix: + If not `None` - add `prefix` to the REST URL path. + Example: `service/v1` will result in `http://localhost:6333/service/v1/{qdrant-endpoint}` for REST API. + Default: `None` + timeout: + Timeout for REST and gRPC API requests. + Default: 5.0 seconds for REST and unlimited for gRPC + host: Host name of Qdrant service. If url and host are None, set to 'localhost'. + Default: `None` + path: Persistence path for QdrantLocal. Default: `None` + force_disable_check_same_thread: + For QdrantLocal, force disable check_same_thread. Default: `False` + Only use this if you can guarantee that you can resolve the thread safety outside QdrantClient. + + Raises: + ValueError: If default_top_k is not a positive integer. + TypeError: If record_encoder is not an instance of RecordEncoder. + TypeError: If chunker is not an instance of Chunker. + TypeError: If reranker is not an instance of Reranker. + """ # noqa: E501 + + if not _qdrant_installed: + raise ImportError( + "Failed to import 'qdrant-client'. " + "Try installing the 'qdrant' extra by running: " + "pip install canopy-sdk[qdrant]" + ) + + if default_top_k < 1: + raise ValueError("default_top_k must be greater than 0") + + self._collection_name = self._get_full_collection_name(collection_name) + self._default_top_k = default_top_k + + if record_encoder: + if not isinstance(record_encoder, RecordEncoder): + raise TypeError( + f"record_encoder must be an instance of RecordEncoder, " + f"not {type(record_encoder)}" + ) + self._encoder = record_encoder + else: + self._encoder = self._DEFAULT_COMPONENTS["record_encoder"]() + + if chunker: + if not isinstance(chunker, Chunker): + raise TypeError( + f"chunker must be an instance of Chunker, not {type(chunker)}" + ) + self._chunker = chunker + else: + self._chunker = self._DEFAULT_COMPONENTS["chunker"]() + + if reranker: + if not isinstance(reranker, Reranker): + raise TypeError( + f"reranker must be an instance of Reranker, not {type(reranker)}" + ) + self._reranker = reranker + else: + self._reranker = self._DEFAULT_COMPONENTS["reranker"]() + + self._client, self._async_client = generate_clients( + location=location, + url=url, + port=port, + grpc_port=grpc_port, + prefer_grpc=prefer_grpc, + https=https, + api_key=api_key, + prefix=prefix, + timeout=timeout, + host=host, + path=path, + force_disable_check_same_thread=force_disable_check_same_thread, + ) + + def verify_index_connection(self) -> None: + """ + Verify that the knowledge base is referencing an existing Canopy collection. + + Returns: + None + + Raises: + RuntimeError: If the knowledge base is not referencing an existing Canopy collection. + """ # noqa: E501 + try: + self._client.get_collection(self.collection_name) + except (UnexpectedResponse, RpcError, ValueError) as e: + raise RuntimeError( + f"Collection {self.collection_name} does not exist!" + ) from e + + def query( + self, + queries: List[Query], + global_metadata_filter: Optional[dict] = None, + namespace: Optional[str] = None, + ) -> List[QueryResult]: + """ + Query the knowledge base to retrieve document chunks. + + This operation includes several steps: + 1. Encode the queries to vectors using the underlying encoder. + 2. Query the underlying Qdrant collection to retrieve the top-k chunks for each query. + 3. Rerank the results using the underlying reranker. + 4. Return the results for each query as a list of QueryResult objects. + + Args: + queries: A list of queries to run against the knowledge base. + global_metadata_filter: A payload filter to apply to all queries, in addition to any query-specific filters. + Reference: https://qdrant.tech/documentation/concepts/filtering/ + namespace: This argument is not used by Qdrant. + Returns: + A list of QueryResult objects. + + Examples: + >>> from canopy.knowledge_base.knowledge_base import QdrantKnowledgeBase + >>> kb = QdrantKnowledgeBase(collection_name="my_collection") + >>> queries = [Query(text="How to make a cake"), + Query(text="How to make a pizza", + top_k=10, + metadata_filter={ + "must": [ + {"key": "website", "match": {"value": "wiki"}}, + ] + } + )] + >>> results = kb.query(queries) + """ # noqa: E501 + queries = self._encoder.encode_queries(queries) + results = [self._query_collection(q, global_metadata_filter) for q in queries] + results = self._reranker.rerank(results) + + return [ + QueryResult( + query=r.query, + documents=[ + DocumentWithScore( + **d.dict(exclude={"values", "sparse_values", "document_id"}) + ) + for d in r.documents + ], + ) + for r in results + ] + + @sync_fallback + async def aquery( + self, queries: List[Query], global_metadata_filter: Optional[dict] = None + ) -> List[QueryResult]: + """ + Query the knowledge base to retrieve document chunks asynchronously. + + This operation includes several steps: + 1. Encode the queries to vectors using the underlying encoder. + 2. Query the underlying Qdrant collection to retrieve the top-k chunks for each query. + 3. Rerank the results using the underlying reranker. + 4. Return the results for each query as a list of QueryResult objects. + + Args: + queries: A list of queries to run against the knowledge base. + global_metadata_filter: A payload filter to apply to all queries, in addition to any query-specific filters. + Reference: https://qdrant.tech/documentation/concepts/filtering/ + Returns: + A list of QueryResult objects. + + Examples: + >>> from canopy.knowledge_base.knowledge_base import QdrantKnowledgeBase + >>> kb = QdrantKnowledgeBase(collection_name="my_collection") + >>> queries = [Query(text="How to make a cake"), + Query(text="How to make a pizza", + top_k=10, + metadata_filter={ + "must": [ + {"key": "website", "match": {"value": "wiki"}}, + ] + } + )] + >>> results = await kb.aquery(queries) + """ # noqa: E501 + # TODO: Use aencode_queries() when implemented for the defaults + queries = self._encoder.encode_queries(queries) + results = [ + await self._aquery_collection(q, global_metadata_filter) for q in queries + ] + results = self._reranker.rerank(results) + + return [ + QueryResult( + query=r.query, + documents=[ + DocumentWithScore( + **d.dict(exclude={"values", "sparse_values", "document_id"}) + ) + for d in r.documents + ], + ) + for r in results + ] + + def upsert( + self, + documents: List[Document], + namespace: str = "", + batch_size: int = 200, + show_progress_bar: bool = False, + ): + """ + Add documents into the Qdrant collection. + If a document with the same id already exists in the collection, it will be overwritten with the new document. + Otherwise, a new document will be inserted. + + This operation includes several steps: + 1. Split the documents into smaller chunks. + 2. Encode the chunks to vectors. + 3. Delete any existing chunks belonging to the same documents. + 4. Upsert the chunks to the collection. + + Args: + documents: A list of documents to upsert. + namespace: This argument is not used by Qdrant. + batch_size: The number of chunks (multiple piecies of text per document) to upsert in each batch. + Defaults to 100. + show_progress_bar: Whether to show a progress bar while upserting the documents. + + + Example: + >>> from canopy.knowledge_base.knowledge_base import QdrantKnowledgeBase + >>> kb = QdrantKnowledgeBase(collection_name="my_collection") + >>> documents = [Document(id="doc1", + text="This is a document", + source="my_source", + metadata={"website": "wiki"}), + Document(id="doc2", + text="This is another document", + source="my_source", + metadata={"website": "wiki"})] + >>> kb.upsert(documents) + """ # noqa: E501 + for doc in documents: + metadata_keys = set(doc.metadata.keys()) + forbidden_keys = metadata_keys.intersection(RESERVED_METADATA_KEYS) + if forbidden_keys: + raise ValueError( + f"Document with id {doc.id} contains reserved metadata keys: " + f"{forbidden_keys}. Please remove them and try again." + ) + + # TODO: Use achunk_documents, encode_documents when implemented for the defaults + chunks = self._chunker.chunk_documents(documents) + encoded_chunks = self._encoder.encode_documents(chunks) + + self._upsert_collection(encoded_chunks, batch_size, show_progress_bar) + + @sync_fallback + async def aupsert( + self, + documents: List[Document], + namespace: str = "", + batch_size: int = 200, + show_progress_bar: bool = False, + ): + """ + Add documents into the Qdrant collection asynchronously. + If a document with the same id already exists in the collection, it will be overwritten with the new document. + Otherwise, a new document will be inserted. + + This operation includes several steps: + 1. Split the documents into smaller chunks. + 2. Encode the chunks to vectors. + 3. Delete any existing chunks belonging to the same documents. + 4. Upsert the chunks to the collection. + + Args: + documents: A list of documents to upsert. + namespace: This argument is not used by Qdrant. + batch_size: The number of chunks (multiple piecies of text per document) to upsert in each batch. + Defaults to 100. + show_progress_bar: Whether to show a progress bar while upserting the documents. + + + Example: + >>> from canopy.knowledge_base.knowledge_base import QdrantKnowledgeBase + >>> kb = QdrantKnowledgeBase(collection_name="my_collection") + >>> documents = [Document(id="doc1", + text="This is a document", + source="my_source", + metadata={"website": "wiki"}), + Document(id="doc2", + text="This is another document", + source="my_source", + metadata={"website": "wiki"})] + >>> await kb.aupsert(documents) + """ # noqa: E501 + for doc in documents: + metadata_keys = set(doc.metadata.keys()) + forbidden_keys = metadata_keys.intersection(RESERVED_METADATA_KEYS) + if forbidden_keys: + raise ValueError( + f"Document with id {doc.id} contains reserved metadata keys: " + f"{forbidden_keys}. Please remove them and try again." + ) + + chunks = self._chunker.chunk_documents(documents) + encoded_chunks = self._encoder.encode_documents(chunks) + + await self._aupsert_collection(encoded_chunks, batch_size, show_progress_bar) + + def delete(self, document_ids: List[str], namespace: str = "") -> None: + """ + Delete documents from the Qdrant collection. + Since each document is chunked into multiple chunks, this operation will delete all chunks belonging to the given document ids. + This operation does not raise an exception if the document does not exist. + + Args: + document_ids: A list of document ids to delete from the Qdrant collection. + namespace: This argument is not used by Qdrant. + + Returns: + None + + Example: + >>> from canopy.knowledge_base.knowledge_base import QdrantKnowledgeBase + >>> kb = QdrantKnowledgeBase(collection_name="my_collection") + >>> kb.delete(document_ids=["doc1", "doc2"]) + """ # noqa: E501 + self._client.delete( + self.collection_name, + points_selector=models.Filter( + must=[ + models.FieldCondition( + key="document_id", match=models.MatchAny(any=document_ids) + ) + ] + ), + ) + + @sync_fallback + async def adelete(self, document_ids: List[str], namespace: str = "") -> None: + """ + Delete documents from the Qdrant collection asynchronously. + Since each document is chunked into multiple chunks, this operation will delete all chunks belonging to the given document ids. + This operation does not raise an exception if the document does not exist. + + Args: + document_ids: A list of document ids to delete from the Qdrant collection. + namespace: This argument is not used by Qdrant. + + Returns: + None + + Example: + >>> from canopy.knowledge_base.knowledge_base import QdrantKnowledgeBase + >>> kb = QdrantKnowledgeBase(collection_name="my_collection") + >>> await kb.adelete(document_ids=["doc1", "doc2"]) + """ # noqa: E501 + # @sync_fallback will call the sync method if the async client is None + self._async_client and await self._async_client.delete( + self.collection_name, + points_selector=models.Filter( + must=[ + models.FieldCondition( + key="document_id", match=models.MatchAny(any=document_ids) + ) + ] + ), + ) + + def create_canopy_collection( + self, + dimension: Optional[int] = None, + indexed_keyword_fields: List[str] = ["document_id"], + distance: Any = "Cosine", + vectors_on_disk: Optional[bool] = None, + **kwargs, + ): + """ + Creates a collection with the appropriate config that will be used by the QdrantKnowledgeBase. + This is a one time set-up operation that only needs to be done once for every new Canopy service. + + Since Canopy defines its own data format with some named vectors configurations, + you can not use a pre-existing Qdrant collection with Canopy's QdrantKnowledgeBase. + + Note: Canopy will add the prefix 'canopy--' to the collection name. + You can retrieve the full collection name using `knowledge_base.collection_name`. + + Args: + dimension: The dimension of the dense vectors to be used. + If `dimension` isn't explicitly provided, + Canopy would try to infer the embedding's dimension based on the configured `Encoder` + indexed_keyword_fields: List of metadata fields to create Qdrant keyword payload index for. + Defaults to ["document_id"]. + distance: Distance function to use for the vectors. + Defaults to "Cosine". + Reference: https://qdrant.tech/documentation/concepts/search/#metrics + vectors_on_disk: Whethers to store vectors on disk. Defaults to None. + **kwargs: Additional arguments to pass to the `QdrantClient#create_collection()` method. + Reference: https://qdrant.tech/documentation/concepts/collections/#create-a-collection + + """ # noqa: E501 + if dimension is None: + try: + encoder_dimension = self._encoder.dimension + if encoder_dimension is None: + raise RuntimeError( + f"The selected encoder {self._encoder.__class__.__name__} does " + f"not support inferring the vectors' dimensionality." + ) + dimension = encoder_dimension + except Exception as e: + raise RuntimeError( + f"Canopy has failed to infer vectors' dimensionality using the " + f"selected encoder: {self._encoder.__class__.__name__}. You can " + f"provide the dimension manually, try using a different encoder, or" + f" fix the underlying error:\n{e}" + ) from e + + try: + self._client.get_collection(self.collection_name) + + raise RuntimeError( + f"Collection {self.collection_name} already exists!" + "To delete it call `knowledge_base.delete_canopy_collection()`. " + ) + + except (UnexpectedResponse, RpcError, ValueError): + self._client.create_collection( + collection_name=self.collection_name, + vectors_config={ + DENSE_VECTOR_NAME: models.VectorParams( + size=dimension, distance=distance, on_disk=vectors_on_disk + ) + }, + sparse_vectors_config={ + SPARSE_VECTOR_NAME: models.SparseVectorParams( + index=models.SparseIndexParams( + on_disk=vectors_on_disk, + ) + ) + }, + **kwargs, + ) + + for field in indexed_keyword_fields: + self._client.create_payload_index( + self.collection_name, field_name=field, field_schema="keyword" + ) + + def list_canopy_collections(self) -> List[str]: + collections = [ + collection.name + for collection in self._client.get_collections().collections + if collection.name.startswith(COLLECTION_NAME_PREFIX) + ] + return collections + + def delete_canopy_collection(self): + successful = self._client.delete_collection(self.collection_name) + + if not successful: + raise RuntimeError(f"Failed to delete collection {self.collection_name}") + + @property + def collection_name(self) -> str: + """ + The name of the collection the knowledge base is connected to. + """ + return self._collection_name + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "QdrantKnowledgeBase": + """ + Create a QdrantKnowledgeBase object from a configuration dictionary. + + Args: + config: A dictionary containing the configuration for the Qdrant knowledge base. + Returns: + A QdrantKnowledgeBase object. + """ # noqa: E501 + + config = deepcopy(config) + config["params"] = config.get("params", {}) + # TODO: Add support for collection creation config for use in the CLI + kb = cls._from_config(config) + return kb + + @staticmethod + def _get_full_collection_name(collection_name: str) -> str: + if collection_name.startswith(COLLECTION_NAME_PREFIX): + return collection_name + else: + return COLLECTION_NAME_PREFIX + collection_name + + def _query_collection( + self, query: KBQuery, global_metadata_filter: Optional[dict] + ) -> KBQueryResult: + metadata_filter = deepcopy(query.metadata_filter) + if global_metadata_filter is not None: + if metadata_filter is None: + metadata_filter = {} + metadata_filter.update(global_metadata_filter) + top_k = query.top_k if query.top_k else self._default_top_k + + query_params = deepcopy(query.query_params) + + query_vector = QdrantConverter.kb_query_to_search_vector(query) + + results = self._client.search( + self.collection_name, + query_vector=query_vector, + limit=top_k, + query_filter=metadata_filter, + with_payload=True, + **query_params, + ) + + documents: List[KBDocChunkWithScore] = [] + + for result in results: + documents.append(QdrantConverter.scored_point_to_scored_doc(result)) + return KBQueryResult(query=query.text, documents=documents) + + async def _aquery_collection( + self, query: KBQuery, global_metadata_filter: Optional[dict] + ) -> KBQueryResult: + metadata_filter = deepcopy(query.metadata_filter) + if global_metadata_filter is not None: + if metadata_filter is None: + metadata_filter = {} + metadata_filter.update(global_metadata_filter) + top_k = query.top_k if query.top_k else self._default_top_k + + query_params = deepcopy(query.query_params) + + # Use dense vector if available, otherwise use sparse vector + query_vector = QdrantConverter.kb_query_to_search_vector(query) + + # @sync_fallback will call the sync method if the async client is None + results = ( + await self._async_client.search( + self.collection_name, + query_vector=query_vector, + limit=top_k, + query_filter=metadata_filter, + with_payload=True, + **query_params, + ) + if self._async_client + else [] + ) + documents: List[KBDocChunkWithScore] = [] + for result in results: + documents.append(QdrantConverter.scored_point_to_scored_doc(result)) + return KBQueryResult(query=query.text, documents=documents) + + def _upsert_collection( + self, + encoded_chunks: List[KBEncodedDocChunk], + batch_size: int, + show_progress_bar: bool, + ) -> None: + batched_documents = batched(encoded_chunks, batch_size) + with tqdm( + total=len(encoded_chunks), disable=not show_progress_bar + ) as progress_bar: + for document_batch in batched_documents: + batch = QdrantConverter.encoded_docs_to_points( + document_batch, + ) + + self._client.upsert( + collection_name=self.collection_name, + points=batch, + ) + + progress_bar.update(batch_size) + + async def _aupsert_collection( + self, + encoded_chunks: List[KBEncodedDocChunk], + batch_size: int, + show_progress_bar: bool, + ) -> None: + batched_documents = batched(encoded_chunks, batch_size) + with tqdm( + total=len(encoded_chunks), disable=not show_progress_bar + ) as progress_bar: + for document_batch in batched_documents: + batch = QdrantConverter.encoded_docs_to_points( + document_batch, + ) + + # @sync_fallback will call the sync method if the async client is None + self._async_client and await self._async_client.upsert( + collection_name=self.collection_name, + points=batch, + ) + + progress_bar.update(batch_size) + + async def close(self) -> None: + self._client.close() + if self._async_client: + await self._async_client.close() diff --git a/src/canopy/knowledge_base/qdrant/utils.py b/src/canopy/knowledge_base/qdrant/utils.py new file mode 100644 index 00000000..f96bc76a --- /dev/null +++ b/src/canopy/knowledge_base/qdrant/utils.py @@ -0,0 +1,100 @@ +import asyncio +import functools +from itertools import islice +from typing import Any, Callable, Optional + +import logging + +try: + from qdrant_client import AsyncQdrantClient, QdrantClient + from qdrant_client.local.async_qdrant_local import AsyncQdrantLocal +except ImportError: + pass +logger = logging.getLogger(__name__) + + +def sync_fallback(method: Callable) -> Callable: + @functools.wraps(method) + async def wrapper(self, *args, **kwargs): + if self._async_client is None or isinstance( + self._async_client._client, AsyncQdrantLocal + ): + sync_method_name = method.__name__[1:] + + logger.warning( + f"{method.__name__}() cannot be used for QdrantLocal. " + f"Falling back to {sync_method_name}()" + ) + loop = asyncio.get_event_loop() + + call = functools.partial(getattr(self, sync_method_name), *args, **kwargs) + return await loop.run_in_executor(None, call) + else: + return await method(self, *args, **kwargs) + + return wrapper + + +def generate_clients( + location: Optional[str] = None, + url: Optional[str] = None, + port: Optional[int] = 6333, + grpc_port: int = 6334, + prefer_grpc: bool = False, + https: Optional[bool] = None, + api_key: Optional[str] = None, + prefix: Optional[str] = None, + timeout: Optional[int] = None, + host: Optional[str] = None, + path: Optional[str] = None, + force_disable_check_same_thread: bool = False, + **kwargs: Any, +): + sync_client = QdrantClient( + location=location, + url=url, + port=port, + grpc_port=grpc_port, + prefer_grpc=prefer_grpc, + https=https, + api_key=api_key, + prefix=prefix, + timeout=timeout, + host=host, + path=path, + force_disable_check_same_thread=force_disable_check_same_thread, + **kwargs, + ) + + if location == ":memory:" or path is not None: + # In-memory Qdrant doesn't interoperate with Sync and Async clients + # We fallback to sync operations in this case using @utils.sync_fallback + async_client = None + else: + async_client = AsyncQdrantClient( + location=location, + url=url, + port=port, + grpc_port=grpc_port, + prefer_grpc=prefer_grpc, + https=https, + api_key=api_key, + prefix=prefix, + timeout=timeout, + host=host, + path=path, + force_disable_check_same_thread=force_disable_check_same_thread, + **kwargs, + ) + + return sync_client, async_client + + +def batched(iterable, n): + """ + Batch elements of an iterable into fixed-length chunks or blocks. + Based on itertools.batched() from Python 3.12 + """ + it = iter(iterable) + while batch := tuple(islice(it, n)): + yield batch diff --git a/tests/system/knowledge_base/qdrant/__init__.py b/tests/system/knowledge_base/qdrant/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/system/knowledge_base/qdrant/common.py b/tests/system/knowledge_base/qdrant/common.py new file mode 100644 index 00000000..a0621fb5 --- /dev/null +++ b/tests/system/knowledge_base/qdrant/common.py @@ -0,0 +1,80 @@ +import numpy as np +import requests +from canopy.knowledge_base.qdrant.constants import DENSE_VECTOR_NAME +from canopy.knowledge_base.qdrant.converter import QdrantConverter +from canopy.knowledge_base.qdrant.qdrant_knowledge_base import QdrantKnowledgeBase + +import logging +from typing import List + +logger = logging.getLogger(__name__) + + +def total_vectors_in_collection(knowledge_base: QdrantKnowledgeBase): + return knowledge_base._client.count(knowledge_base.collection_name).count + + +def assert_chunks_in_collection(knowledge_base: QdrantKnowledgeBase, encoded_chunks): + ids = [QdrantConverter.convert_id(c.id) for c in encoded_chunks] + fetch_result = knowledge_base._client.retrieve( + knowledge_base.collection_name, ids=ids, with_payload=True, with_vectors=True + ) + points = {p.id: p for p in fetch_result} + for chunk in encoded_chunks: + id = QdrantConverter.convert_id(chunk.id) + assert id in points + point = points[id] + assert np.allclose( + point.vector[DENSE_VECTOR_NAME], + np.array(chunk.values, dtype=np.float32), + atol=1e-8, + ) + + assert point.payload["text"] == chunk.text + assert point.payload["document_id"] == chunk.document_id + assert point.payload["source"] == chunk.source + for key, value in chunk.metadata.items(): + assert point.payload[key] == value + + +def assert_ids_in_collection(knowledge_base, ids): + fetch_result = knowledge_base._client.retrieve( + knowledge_base.collection_name, + ids=ids, + ) + assert len(fetch_result) == len( + ids + ), f"Expected {len(ids)} ids, got {len(fetch_result)}" + + +def assert_num_points_in_collection(knowledge_base, num_vectors): + points_in_index = total_vectors_in_collection(knowledge_base) + assert ( + points_in_index == num_vectors + ), f"Expected {num_vectors} vectors in index, got {points_in_index}" + + +def assert_ids_not_in_collection(knowledge_base, ids): + fetch_result = knowledge_base._client.retrieve( + knowledge_base.collection_name, + ids=ids, + ) + assert len(fetch_result) == 0, f"Found {len(fetch_result)} unexpected ids" + + +def qdrant_server_running() -> bool: + """Check if Qdrant server is running.""" + + try: + response = requests.get("http://localhost:6333", timeout=10.0) + response_json = response.json() + return response_json.get("title") == "qdrant - vector search engine" + except (requests.exceptions.ConnectionError, requests.exceptions.Timeout): + return False + + +def qdrant_locations() -> List[str]: + if not qdrant_server_running(): + logger.warning("Running Qdrant tests in memory mode only.") + return [":memory:"] + return ["http://localhost:6333", ":memory:"] diff --git a/tests/system/knowledge_base/qdrant/conftest.py b/tests/system/knowledge_base/qdrant/conftest.py new file mode 100644 index 00000000..d446768e --- /dev/null +++ b/tests/system/knowledge_base/qdrant/conftest.py @@ -0,0 +1,118 @@ +import pytest +from canopy.knowledge_base.qdrant.constants import COLLECTION_NAME_PREFIX +from canopy.knowledge_base.qdrant.qdrant_knowledge_base import QdrantKnowledgeBase +from canopy.models.data_models import Document +from tests.system.knowledge_base.qdrant.common import qdrant_locations +from tests.system.knowledge_base.test_knowledge_base import _generate_text +from tests.unit.stubs.stub_chunker import StubChunker +from tests.unit.stubs.stub_dense_encoder import StubDenseEncoder +from tests.unit.stubs.stub_record_encoder import StubRecordEncoder +from tests.util import create_system_tests_index_name + + +@pytest.fixture(scope="module") +def collection_name(testrun_uid): + return create_system_tests_index_name(testrun_uid) + + +@pytest.fixture(scope="module") +def collection_full_name(collection_name): + return COLLECTION_NAME_PREFIX + collection_name + + +@pytest.fixture(scope="module") +def chunker(): + return StubChunker(num_chunks_per_doc=2) + + +@pytest.fixture(scope="module") +def encoder(): + return StubRecordEncoder(StubDenseEncoder()) + + +@pytest.fixture(scope="module", autouse=True, params=qdrant_locations()) +def knowledge_base(collection_name, chunker, encoder, request): + kb = QdrantKnowledgeBase( + collection_name=collection_name, + record_encoder=encoder, + chunker=chunker, + location=request.param, + ) + kb.create_canopy_collection() + + return kb + + +@pytest.fixture +def documents_large(): + return [ + Document( + id=f"doc_{i}_large", + text=f"Sample document {i}", + metadata={"my-key-large": f"value-{i}"}, + ) + for i in range(1000) + ] + + +@pytest.fixture +def encoded_chunks_large(documents_large, chunker, encoder): + chunks = chunker.chunk_documents(documents_large) + return encoder.encode_documents(chunks) + + +@pytest.fixture +def documents_with_datetime_metadata(): + return [ + Document( + id="doc_1_metadata", + text="document with datetime metadata", + source="source_1", + metadata={ + "datetime": "2021-01-01T00:00:00Z", + "datetime_other_format": "January 1, 2021 00:00:00", + "datetime_other_format_2": "2210.03945", + }, + ), + Document(id="2021-01-01T00:00:00Z", text="id is datetime", source="source_1"), + ] + + +@pytest.fixture +def datetime_metadata_encoded_chunks( + documents_with_datetime_metadata, chunker, encoder +): + chunks = chunker.chunk_documents(documents_with_datetime_metadata) + return encoder.encode_documents(chunks) + + +@pytest.fixture +def encoded_chunks(documents, chunker, encoder): + chunks = chunker.chunk_documents(documents) + return encoder.encode_documents(chunks) + + +@pytest.fixture(scope="module", autouse=True) +def teardown_knowledge_base(collection_full_name, knowledge_base): + yield + + knowledge_base._client.delete_collection(collection_full_name) + knowledge_base.close() + + +@pytest.fixture(scope="module") +def random_texts(): + return [_generate_text(10) for _ in range(5)] + + +@pytest.fixture +def documents(random_texts): + return [ + Document( + id=f"doc_{i}", + text=random_texts[i], + source=f"source_{i}", + metadata={"my-key": f"value-{i}"}, + ) + for i in range(5) + ] diff --git a/tests/system/knowledge_base/qdrant/test_async_qdrant_knowledge_base.py b/tests/system/knowledge_base/qdrant/test_async_qdrant_knowledge_base.py new file mode 100644 index 00000000..0aaf7f96 --- /dev/null +++ b/tests/system/knowledge_base/qdrant/test_async_qdrant_knowledge_base.py @@ -0,0 +1,213 @@ +import random + +import pytest + +from canopy.knowledge_base.knowledge_base import KnowledgeBase +from canopy.knowledge_base.qdrant.qdrant_knowledge_base import ( + QdrantConverter, + QdrantKnowledgeBase, +) +from tests.system.knowledge_base.qdrant.common import ( + assert_chunks_in_collection, + assert_ids_in_collection, + assert_ids_not_in_collection, + assert_num_points_in_collection, + total_vectors_in_collection, +) +from canopy.knowledge_base.models import DocumentWithScore +from canopy.models.data_models import Query +from tests.unit import random_words +from tests.unit.stubs.stub_chunker import StubChunker + +qdrant_client = pytest.importorskip( + "qdrant_client", reason="'qdrant_client' is not installed" +) + + +async def execute_and_assert_queries( + knowledge_base: QdrantKnowledgeBase, chunks_to_query +): + queries = [Query(text=chunk.text, top_k=2) for chunk in chunks_to_query] + + query_results = await knowledge_base.aquery(queries) + + assert len(query_results) == len(queries) + + for i, q_res in enumerate(query_results): + assert queries[i].text == q_res.query + assert len(q_res.documents) == 2 + q_res.documents[0].score = round(q_res.documents[0].score, 1) + assert q_res.documents[0] == DocumentWithScore( + id=chunks_to_query[i].id, + text=chunks_to_query[i].text, + metadata=chunks_to_query[i].metadata, + source=chunks_to_query[i].source, + score=1.0, + ), ( + f"query {i} - expected: {chunks_to_query[i]}, " + f"actual: {q_res.documents[0]}" + ) + + +async def assert_query_metadata_filter( + knowledge_base: KnowledgeBase, + metadata_filter: dict, + num_vectors_expected: int, + top_k: int = 100, +): + assert ( + top_k > num_vectors_expected + ), "the test might return false positive if top_k is not > num_vectors_expected" + + query = Query(text="test", top_k=top_k, metadata_filter=metadata_filter) + query_results = await knowledge_base.aquery([query]) + assert len(query_results) == 1 + assert len(query_results[0].documents) == num_vectors_expected + + +def _generate_text(num_words: int): + return " ".join(random.choices(random_words, k=num_words)) + + +@pytest.mark.asyncio +async def test_upsert_happy_path( + knowledge_base: QdrantKnowledgeBase, documents, encoded_chunks +): + await knowledge_base.aupsert(documents) + assert_num_points_in_collection(knowledge_base, len(encoded_chunks)) + assert_chunks_in_collection(knowledge_base, encoded_chunks) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("key", ["document_id", "text", "source"]) +async def test_upsert_forbidden_metadata(knowledge_base, documents, key): + doc = random.choice(documents) + doc.metadata[key] = "bla" + + with pytest.raises(ValueError) as e: + await knowledge_base.aupsert(documents) + + assert "reserved metadata keys" in str(e.value) + assert doc.id in str(e.value) + assert key in str(e.value) + + +@pytest.mark.asyncio +async def test_query(knowledge_base, encoded_chunks): + await execute_and_assert_queries(knowledge_base, encoded_chunks) + + +@pytest.mark.asyncio +async def test_query_with_metadata_filter(knowledge_base): + if knowledge_base._async_client is None or not isinstance( + knowledge_base._async_client._client, + qdrant_client.async_qdrant_remote.AsyncQdrantRemote, # noqa: E501 + ): + pytest.skip( + "Dict filter is not supported for QdrantLocal" + "Use qdrant_client.models.Filter instead" + ) + + await assert_query_metadata_filter( + knowledge_base, + { + "must": [ + {"key": "my-key", "match": {"value": "value-1"}}, + ] + }, + 2, + ) + + +@pytest.mark.asyncio +async def test_delete_documents(knowledge_base: QdrantKnowledgeBase, encoded_chunks): + chunk_ids = [QdrantConverter.convert_id(chunk.id) for chunk in encoded_chunks[-4:]] + doc_ids = set(doc.document_id for doc in encoded_chunks[-4:]) + + assert_ids_in_collection(knowledge_base, chunk_ids) + + before_vector_cnt = total_vectors_in_collection(knowledge_base) + + await knowledge_base.adelete(document_ids=list(doc_ids)) + + assert_num_points_in_collection(knowledge_base, before_vector_cnt - len(chunk_ids)) + assert_ids_not_in_collection(knowledge_base, chunk_ids) + + +@pytest.mark.asyncio +async def test_update_documents(encoder, documents, encoded_chunks, knowledge_base): + # chunker/kb that produces fewer chunks per doc + chunker = StubChunker(num_chunks_per_doc=1) + docs = documents[:2] + doc_ids = [doc.id for doc in docs] + chunk_ids = [ + QdrantConverter.convert_id(chunk.id) + for chunk in encoded_chunks + if chunk.document_id in doc_ids + ] + + assert_ids_in_collection(knowledge_base, chunk_ids) + + docs[0].metadata["new_key"] = "new_value" + await knowledge_base.aupsert(docs) + + updated_chunks = encoder.encode_documents(chunker.chunk_documents(docs)) + expected_chunks = [QdrantConverter.convert_id(chunk.id) for chunk in updated_chunks] + assert_chunks_in_collection(knowledge_base, updated_chunks) + + unexpected_chunks = [ + QdrantConverter.convert_id(c_id) + for c_id in chunk_ids + if c_id not in expected_chunks + ] + assert len(unexpected_chunks) > 0, "bug in the test itself" + + assert_ids_not_in_collection(knowledge_base, unexpected_chunks) + + +@pytest.mark.asyncio +async def test_upsert_large_list_happy_path( + knowledge_base, documents_large, encoded_chunks_large +): + await knowledge_base.aupsert(documents_large) + + chunks_for_validation = encoded_chunks_large[:10] + encoded_chunks_large[-10:] + assert_ids_in_collection( + knowledge_base, + [QdrantConverter.convert_id(chunk.id) for chunk in chunks_for_validation], + ) + + +@pytest.mark.asyncio +async def test_delete_large_df_happy_path( + knowledge_base, documents_large, encoded_chunks_large +): + await knowledge_base.adelete([doc.id for doc in documents_large]) + + chunks_for_validation = encoded_chunks_large[:10] + encoded_chunks_large[-10:] + assert_ids_not_in_collection( + knowledge_base, + [QdrantConverter.convert_id(chunk.id) for chunk in chunks_for_validation], + ) + + +@pytest.mark.asyncio +async def test_upsert_documents_with_datetime_metadata( + knowledge_base, documents_with_datetime_metadata, datetime_metadata_encoded_chunks +): + await knowledge_base.aupsert(documents_with_datetime_metadata) + + assert_ids_in_collection( + knowledge_base, + [ + QdrantConverter.convert_id(chunk.id) + for chunk in datetime_metadata_encoded_chunks + ], + ) + + +@pytest.mark.asyncio +async def test_query_edge_case_documents( + knowledge_base, datetime_metadata_encoded_chunks +): + await execute_and_assert_queries(knowledge_base, datetime_metadata_encoded_chunks) diff --git a/tests/system/knowledge_base/qdrant/test_config.yml b/tests/system/knowledge_base/qdrant/test_config.yml new file mode 100644 index 00000000..fb1a1ac1 --- /dev/null +++ b/tests/system/knowledge_base/qdrant/test_config.yml @@ -0,0 +1,9 @@ +# =========================================================== +# QdrantKnowledgeBase test configuration file +# =========================================================== + +knowledge_base: + params: + default_top_k: 5 + collection_name: test-config-collection + default_top_k: 10 \ No newline at end of file diff --git a/tests/system/knowledge_base/qdrant/test_qdrant_knowledge_base.py b/tests/system/knowledge_base/qdrant/test_qdrant_knowledge_base.py new file mode 100644 index 00000000..7d9c08fe --- /dev/null +++ b/tests/system/knowledge_base/qdrant/test_qdrant_knowledge_base.py @@ -0,0 +1,324 @@ +import random +from copy import copy +from pathlib import Path + +import pytest +from canopy.knowledge_base.chunker.base import Chunker + +from canopy.knowledge_base.knowledge_base import KnowledgeBase +from canopy.knowledge_base.qdrant.qdrant_knowledge_base import ( + DENSE_VECTOR_NAME, + QdrantConverter, + QdrantKnowledgeBase, +) +from canopy.knowledge_base.qdrant.qdrant_knowledge_base import COLLECTION_NAME_PREFIX +from canopy.knowledge_base.models import DocumentWithScore +from canopy.knowledge_base.record_encoder.base import RecordEncoder +from canopy.knowledge_base.reranker.reranker import Reranker +from canopy.models.data_models import Query + +from canopy_cli.cli import _load_kb_config +from tests.system.knowledge_base.qdrant.common import ( + assert_chunks_in_collection, + assert_ids_in_collection, + assert_ids_not_in_collection, + assert_num_points_in_collection, + total_vectors_in_collection, +) +from tests.unit.stubs.stub_chunker import StubChunker +from tests.unit.stubs.stub_dense_encoder import StubDenseEncoder +from tests.unit.stubs.stub_record_encoder import StubRecordEncoder + +qdrant_client = pytest.importorskip( + "qdrant_client", reason="'qdrant_client' is not installed" +) + + +def execute_and_assert_queries(knowledge_base: QdrantKnowledgeBase, chunks_to_query): + queries = [Query(text=chunk.text, top_k=2) for chunk in chunks_to_query] + + query_results = knowledge_base.query(queries) + + assert len(query_results) == len(queries) + + for i, q_res in enumerate(query_results): + assert queries[i].text == q_res.query + assert len(q_res.documents) == 2 + q_res.documents[0].score = round(q_res.documents[0].score, 1) + assert q_res.documents[0] == DocumentWithScore( + id=chunks_to_query[i].id, + text=chunks_to_query[i].text, + metadata=chunks_to_query[i].metadata, + source=chunks_to_query[i].source, + score=1.0, + ), ( + f"query {i} - expected: {chunks_to_query[i]}, " + f"actual: {q_res.documents[0]}" + ) + + +def assert_query_metadata_filter( + knowledge_base: KnowledgeBase, + metadata_filter: dict, + num_vectors_expected: int, + top_k: int = 100, +): + assert ( + top_k > num_vectors_expected + ), "the test might return false positive if top_k is not > num_vectors_expected" + query = Query(text="test", top_k=top_k, metadata_filter=metadata_filter) + query_results = knowledge_base.query([query]) + assert len(query_results) == 1 + assert len(query_results[0].documents) == num_vectors_expected + + +def test_create_collection(collection_full_name, knowledge_base: QdrantKnowledgeBase): + assert knowledge_base.collection_name == collection_full_name + collection_info = knowledge_base._client.get_collection(collection_full_name) + assert ( + collection_info.config.params.vectors[DENSE_VECTOR_NAME].size + == knowledge_base._encoder.dimension + ) + + +def test_list_collections(collection_full_name, knowledge_base: QdrantKnowledgeBase): + collections_list = knowledge_base.list_canopy_collections() + + assert len(collections_list) > 0 + for item in collections_list: + assert COLLECTION_NAME_PREFIX in item + + assert collection_full_name in collections_list + + +def test_is_verify_connection_happy_path(knowledge_base): + knowledge_base.verify_index_connection() + + +def test_init_with_context_engine_prefix(collection_full_name, chunker, encoder): + kb = QdrantKnowledgeBase( + collection_name=collection_full_name, + record_encoder=encoder, + chunker=chunker, + ) + assert kb.collection_name == collection_full_name + + +def test_upsert_happy_path( + knowledge_base: QdrantKnowledgeBase, documents, encoded_chunks +): + knowledge_base.upsert(documents) + + assert_num_points_in_collection(knowledge_base, len(encoded_chunks)) + assert_chunks_in_collection(knowledge_base, encoded_chunks) + + +@pytest.mark.parametrize("key", ["document_id", "text", "source", "chunk_id"]) +def test_upsert_forbidden_metadata(knowledge_base, documents, key): + doc = random.choice(documents) + doc.metadata[key] = "bla" + + with pytest.raises(ValueError) as e: + knowledge_base.upsert(documents) + + assert "reserved metadata keys" in str(e.value) + assert doc.id in str(e.value) + assert key in str(e.value) + + +def test_query(knowledge_base, encoded_chunks): + execute_and_assert_queries(knowledge_base, encoded_chunks) + + +def test_query_with_metadata_filter(knowledge_base): + if not isinstance( + knowledge_base._client._client, qdrant_client.qdrant_remote.QdrantRemote + ): # noqa: E501 + pytest.skip( + "Dict filter is not supported for QdrantLocal" + "Use qdrant_client.models.Filter instead" + ) + + assert_query_metadata_filter( + knowledge_base, + { + "must": [ + {"key": "my-key", "match": {"value": "value-1"}}, + ] + }, + 2, + ) + + +def test_delete_documents(knowledge_base: QdrantKnowledgeBase, encoded_chunks): + chunk_ids = [QdrantConverter.convert_id(chunk.id) for chunk in encoded_chunks[-4:]] + doc_ids = set(doc.document_id for doc in encoded_chunks[-4:]) + + assert_ids_in_collection(knowledge_base, chunk_ids) + + before_vector_cnt = total_vectors_in_collection(knowledge_base) + + knowledge_base.delete(document_ids=list(doc_ids)) + + assert_num_points_in_collection(knowledge_base, before_vector_cnt - len(chunk_ids)) + assert_ids_not_in_collection(knowledge_base, chunk_ids) + + +def test_update_documents(encoder, documents, encoded_chunks, knowledge_base): + # chunker/kb that produces fewer chunks per doc + chunker = StubChunker(num_chunks_per_doc=1) + + docs = documents[:2] + doc_ids = [doc.id for doc in docs] + chunk_ids = [ + QdrantConverter.convert_id(chunk.id) + for chunk in encoded_chunks + if chunk.document_id in doc_ids + ] + + assert_ids_in_collection(knowledge_base, chunk_ids) + + docs[0].metadata["new_key"] = "new_value" + knowledge_base.upsert(docs) + + updated_chunks = encoder.encode_documents(chunker.chunk_documents(docs)) + expected_chunks = [QdrantConverter.convert_id(chunk.id) for chunk in updated_chunks] + assert_chunks_in_collection(knowledge_base, updated_chunks) + + unexpected_chunks = [ + QdrantConverter.convert_id(c_id) + for c_id in chunk_ids + if c_id not in expected_chunks + ] + assert len(unexpected_chunks) > 0, "bug in the test itself" + + assert_ids_not_in_collection(knowledge_base, unexpected_chunks) + + +def test_upsert_large_list_happy_path( + knowledge_base, documents_large, encoded_chunks_large +): + knowledge_base.upsert(documents_large) + + chunks_for_validation = encoded_chunks_large[:10] + encoded_chunks_large[-10:] + assert_ids_in_collection( + knowledge_base, + [QdrantConverter.convert_id(chunk.id) for chunk in chunks_for_validation], + ) + + +def test_delete_large_df_happy_path( + knowledge_base, documents_large, encoded_chunks_large +): + knowledge_base.delete([doc.id for doc in documents_large]) + + chunks_for_validation = encoded_chunks_large[:10] + encoded_chunks_large[-10:] + assert_ids_not_in_collection( + knowledge_base, + [QdrantConverter.convert_id(chunk.id) for chunk in chunks_for_validation], + ) + + +def test_upsert_documents_with_datetime_metadata( + knowledge_base, documents_with_datetime_metadata, datetime_metadata_encoded_chunks +): + knowledge_base.upsert(documents_with_datetime_metadata) + + assert_ids_in_collection( + knowledge_base, + [ + QdrantConverter.convert_id(chunk.id) + for chunk in datetime_metadata_encoded_chunks + ], + ) + + +def test_query_edge_case_documents(knowledge_base, datetime_metadata_encoded_chunks): + execute_and_assert_queries(knowledge_base, datetime_metadata_encoded_chunks) + + +def test_create_existing_collection(collection_full_name, knowledge_base): + with pytest.raises(RuntimeError) as e: + knowledge_base.create_canopy_collection() + + assert f"Collection {collection_full_name} already exists" in str(e.value) + + +def test_kb_non_existing_collection(knowledge_base): + kb = copy(knowledge_base) + + kb._collection_name = f"{COLLECTION_NAME_PREFIX}non-existing-collection" + + with pytest.raises(RuntimeError) as e: + kb.verify_index_connection() + expected_msg = ( + f"Collection {COLLECTION_NAME_PREFIX}non-existing-collection does not exist!" + ) + assert expected_msg in str(e.value) + + +def test_init_defaults(collection_name, collection_full_name): + new_kb = QdrantKnowledgeBase(collection_name) + assert isinstance(new_kb._client, qdrant_client.QdrantClient) + assert new_kb.collection_name == collection_full_name + assert isinstance(new_kb._chunker, Chunker) + assert isinstance( + new_kb._chunker, QdrantKnowledgeBase._DEFAULT_COMPONENTS["chunker"] + ) + assert isinstance(new_kb._encoder, RecordEncoder) + assert isinstance( + new_kb._encoder, QdrantKnowledgeBase._DEFAULT_COMPONENTS["record_encoder"] + ) + assert isinstance(new_kb._reranker, Reranker) + assert isinstance(new_kb._reranker, KnowledgeBase._DEFAULT_COMPONENTS["reranker"]) + + +def test_init_defaults_with_override(knowledge_base, chunker): + collection_name = knowledge_base.collection_name + new_kb = QdrantKnowledgeBase(collection_name=collection_name, chunker=chunker) + assert isinstance(new_kb._client, qdrant_client.QdrantClient) + assert new_kb.collection_name == collection_name + assert isinstance(new_kb._chunker, Chunker) + assert isinstance(new_kb._chunker, StubChunker) + assert new_kb._chunker is chunker + assert isinstance(new_kb._encoder, RecordEncoder) + assert isinstance( + new_kb._encoder, KnowledgeBase._DEFAULT_COMPONENTS["record_encoder"] + ) + assert isinstance(new_kb._reranker, Reranker) + assert isinstance(new_kb._reranker, KnowledgeBase._DEFAULT_COMPONENTS["reranker"]) + + +def test_init_raise_wrong_type(knowledge_base, chunker): + collection_name = knowledge_base.collection_name + with pytest.raises(TypeError) as e: + QdrantKnowledgeBase( + collection_name=collection_name, + record_encoder=chunker, + ) + + assert "record_encoder must be an instance of RecordEncoder" in str(e.value) + + +def test_create_with_collection_encoder_dimension_none(collection_name, chunker): + encoder = StubRecordEncoder(StubDenseEncoder(dimension=3)) + encoder._dense_encoder.dimension = None + with pytest.raises(RuntimeError) as e: + kb = QdrantKnowledgeBase( + collection_name=collection_name, + record_encoder=encoder, + chunker=chunker, + ) + kb.create_canopy_collection() + + assert "failed to infer" in str(e.value) + assert "dimension" in str(e.value) + assert f"{encoder.__class__.__name__} does not support" in str(e.value) + + +def test_knowlege_base_from_config(): + config_path = Path(__file__).with_name("test_config.yml") + kb_config = _load_kb_config(config_path) + kb = QdrantKnowledgeBase.from_config(kb_config) + assert kb.collection_name == COLLECTION_NAME_PREFIX + "test-config-collection" + assert kb._default_top_k == 10