From a8befc758e35b9193539e2a96e4a1d64134b463f Mon Sep 17 00:00:00 2001 From: Anush008 Date: Wed, 10 Jan 2024 13:18:31 +0530 Subject: [PATCH] chore: added QdrantKnowledgeBase.from_config() --- .../qdrant/qdrant_knowledge_base.py | 21 +++++++++++++++++-- .../knowledge_base/qdrant/test_config.yml | 9 ++++++++ .../qdrant/test_qdrant_knowledge_base.py | 10 +++++++++ 3 files changed, 38 insertions(+), 2 deletions(-) create mode 100644 tests/system/knowledge_base/qdrant/test_config.yml diff --git a/src/canopy/knowledge_base/qdrant/qdrant_knowledge_base.py b/src/canopy/knowledge_base/qdrant/qdrant_knowledge_base.py index 3e4a90ad..f6944895 100644 --- a/src/canopy/knowledge_base/qdrant/qdrant_knowledge_base.py +++ b/src/canopy/knowledge_base/qdrant/qdrant_knowledge_base.py @@ -175,8 +175,6 @@ def __init__( else: self._reranker = self._DEFAULT_COMPONENTS["reranker"]() - self._collection_params: Dict[str, Any] = {} - self._client, self._async_client = generate_clients( location=location, url=url, @@ -293,6 +291,7 @@ async def aquery( )] >>> results = await kb.aquery(queries) """ # noqa: E501 + # TODO: Use aencode_queries() when implemented for the defaults queries = self._encoder.encode_queries(queries) results = [ await self._aquery_collection(q, global_metadata_filter) for q in queries @@ -360,6 +359,7 @@ def upsert( f"{forbidden_keys}. Please remove them and try again." ) + # TODO: Use achunk_documents, encode_documents when implemented for the defaults chunks = self._chunker.chunk_documents(documents) encoded_chunks = self._encoder.encode_documents(chunks) @@ -625,6 +625,23 @@ def collection_name(self) -> str: """ return self._collection_name + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "QdrantKnowledgeBase": + """ + Create a QdrantKnowledgeBase object from a configuration dictionary. + + Args: + config: A dictionary containing the configuration for the Qdrant knowledge base. + Returns: + A QdrantKnowledgeBase object. + """ # noqa: E501 + + config = deepcopy(config) + config["params"] = config.get("params", {}) + # TODO: Add support for collection creation config for use in the CLI + kb = cls._from_config(config) + return kb + @staticmethod def _get_full_collection_name(collection_name: str) -> str: if collection_name.startswith(COLLECTION_NAME_PREFIX): diff --git a/tests/system/knowledge_base/qdrant/test_config.yml b/tests/system/knowledge_base/qdrant/test_config.yml new file mode 100644 index 00000000..fb1a1ac1 --- /dev/null +++ b/tests/system/knowledge_base/qdrant/test_config.yml @@ -0,0 +1,9 @@ +# =========================================================== +# QdrantKnowledgeBase test configuration file +# =========================================================== + +knowledge_base: + params: + default_top_k: 5 + collection_name: test-config-collection + default_top_k: 10 \ No newline at end of file diff --git a/tests/system/knowledge_base/qdrant/test_qdrant_knowledge_base.py b/tests/system/knowledge_base/qdrant/test_qdrant_knowledge_base.py index 80d3ee22..0e02e5d9 100644 --- a/tests/system/knowledge_base/qdrant/test_qdrant_knowledge_base.py +++ b/tests/system/knowledge_base/qdrant/test_qdrant_knowledge_base.py @@ -1,5 +1,6 @@ import copy import random +from pathlib import Path import pytest from dotenv import load_dotenv @@ -19,6 +20,7 @@ from canopy.models.data_models import Query from qdrant_client.qdrant_remote import QdrantRemote +from canopy_cli.cli import _load_kb_config from tests.system.knowledge_base.qdrant.common import ( assert_chunks_in_collection, assert_ids_in_collection, @@ -311,3 +313,11 @@ def test_create_with_collection_encoder_dimension_none(collection_name, chunker) assert "failed to infer" in str(e.value) assert "dimension" in str(e.value) assert f"{encoder.__class__.__name__} does not support" in str(e.value) + + +def test_knowlege_base_from_config(): + config_path = Path(__file__).with_name("test_config.yml") + kb_config = _load_kb_config(config_path) + kb = QdrantKnowledgeBase.from_config(kb_config) + assert kb.collection_name == COLLECTION_NAME_PREFIX + "test-config-collection" + assert kb._default_top_k == 10