Skip to content
This repository has been archived by the owner on Nov 13, 2024. It is now read-only.

Commit

Permalink
chore: optional import qdrant_client
Browse files Browse the repository at this point in the history
  • Loading branch information
Anush008 committed Feb 23, 2024
1 parent c10416e commit d9f80e7
Show file tree
Hide file tree
Showing 7 changed files with 61 additions and 77 deletions.
7 changes: 5 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ sentencepiece = "^0.1.99"
pandas = "2.0.0"
pyarrow = "^14.0.1"
cohere = { version = ">=4.37", optional = true }
qdrant-client = "^1.7.2"
qdrant-client = {version = "^1.7.3", optional = true}


pinecone-text = "^0.8.0"
Expand Down Expand Up @@ -61,6 +61,7 @@ cohere = ["cohere"]
torch = ["torch", "sentence-transformers"]
transformers = ["transformers"]
grpc = ["grpcio", "grpc-gateway-protoc-gen-openapiv2", "googleapis-common-protos", "lz4", "protobuf"]
qdrant = ["qdrant-client"]


[tool.poetry.group.dev.dependencies]
Expand Down Expand Up @@ -97,7 +98,9 @@ module = [
'tokenizers.*',
'cohere.*',
'pinecone.grpc',
'huggingface_hub.utils'
'huggingface_hub.utils',
'qdrant_client.*',
'grpc.*'
]
ignore_missing_imports = true

Expand Down
4 changes: 3 additions & 1 deletion src/canopy/knowledge_base/qdrant/constants.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
COLLECTION_NAME_PREFIX = "canopy--"
from canopy.knowledge_base.knowledge_base import INDEX_NAME_PREFIX

COLLECTION_NAME_PREFIX = INDEX_NAME_PREFIX
DENSE_VECTOR_NAME = "dense"
RESERVED_METADATA_KEYS = {"document_id", "text", "source", "chunk_id"}
SPARSE_VECTOR_NAME = "sparse"
Expand Down
8 changes: 6 additions & 2 deletions src/canopy/knowledge_base/qdrant/converter.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from copy import deepcopy
from typing import Dict, List, Any, Union
import uuid
from qdrant_client import models
from canopy.knowledge_base.models import (
KBDocChunkWithScore,
KBEncodedDocChunk,
Expand All @@ -10,6 +9,11 @@
)
from pinecone_text.sparse import SparseVector

try:
from qdrant_client import models
except ImportError:
pass

from canopy.knowledge_base.qdrant.constants import (
DENSE_VECTOR_NAME,
SPARSE_VECTOR_NAME,
Expand Down Expand Up @@ -62,7 +66,7 @@ def encoded_docs_to_points(

@staticmethod
def scored_point_to_scored_doc(
scored_point: models.ScoredPoint,
scored_point,
) -> "KBDocChunkWithScore":
metadata: Dict[str, Any] = deepcopy(scored_point.payload or {})
_id = metadata.pop("chunk_id")
Expand Down
93 changes: 32 additions & 61 deletions src/canopy/knowledge_base/qdrant/qdrant_knowledge_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@
SPARSE_VECTOR_NAME,
)
from canopy.knowledge_base.qdrant.converter import QdrantConverter
from canopy.knowledge_base.qdrant.utils import batched, generate_clients, sync_fallback
from canopy.knowledge_base.qdrant.utils import (
batched,
generate_clients,
sync_fallback,
)
from canopy.knowledge_base.record_encoder import RecordEncoder, OpenAIRecordEncoder
from canopy.knowledge_base.models import (
KBEncodedDocChunk,
Expand All @@ -23,11 +27,17 @@
from canopy.knowledge_base.reranker import Reranker, TransparentReranker
from canopy.models.data_models import Query, Document

from qdrant_client import models as models
from qdrant_client.http.exceptions import UnexpectedResponse
from grpc import RpcError # type: ignore[import-untyped]
from tqdm import tqdm

try:
from qdrant_client import models
from qdrant_client.http.exceptions import UnexpectedResponse
from grpc import RpcError

_qdrant_installed = True
except ImportError:
_qdrant_installed = False


class QdrantKnowledgeBase(BaseKnowledgeBase):
"""
Expand Down Expand Up @@ -141,6 +151,14 @@ def __init__(
TypeError: If chunker is not an instance of Chunker.
TypeError: If reranker is not an instance of Reranker.
""" # noqa: E501

if not _qdrant_installed:
raise ImportError(
"Failed to import 'qdrant-client'. "
"Try installing the 'qdrant' extra by running: "
"pip install canopy-sdk[qdrant]"
)

if default_top_k < 1:
raise ValueError("default_top_k must be greater than 0")

Expand Down Expand Up @@ -487,18 +505,9 @@ def create_canopy_collection(
self,
dimension: Optional[int] = None,
indexed_keyword_fields: List[str] = ["document_id"],
distance: models.Distance = models.Distance.COSINE,
shard_number: Optional[int] = None,
sharding_method: Optional[models.ShardingMethod] = None,
replication_factor: Optional[int] = None,
write_consistency_factor: Optional[int] = None,
on_disk_payload: Optional[bool] = None,
hnsw_config: Optional[models.HnswConfigDiff] = None,
optimizers_config: Optional[models.OptimizersConfigDiff] = None,
wal_config: Optional[models.WalConfigDiff] = None,
quantization_config: Optional[models.QuantizationConfig] = None,
timeout: Optional[int] = None,
on_disk: Optional[bool] = None,
distance: str = "COSINE",
vectors_on_disk: Optional[bool] = None,
**kwargs,
):
"""
Creates a collection with the appropriate config that will be used by the QdrantKnowledgeBase.
Expand All @@ -517,39 +526,10 @@ def create_canopy_collection(
indexed_keyword_fields: List of metadata fields to create Qdrant keyword payload index for.
Defaults to ["document_id"].
distance: Distance function to use for the vectors.
Defaults to COSINE.
shard_number: Number of shards in collection. Default is 1, minimum is 1.
sharding_method:
Defines strategy for shard creation.
Option `auto` (default) creates defined number of shards automatically.
Data will be distributed between shards automatically.
After creation, shards could be additionally replicated, but new shards could not be created.
Option `custom` allows to create shards manually, each shard should be created with assigned
unique `shard_key`. Data will be distributed between based on `shard_key` value.
replication_factor:
Replication factor for collection. Default is 1, minimum is 1.
Defines how many copies of each shard will be created.
Have effect only in distributed mode.
write_consistency_factor:
Write consistency factor for collection. Default is 1, minimum is 1.
Defines how many replicas should apply the operation for us to consider it successful.
Increasing this number will make the collection more resilient to inconsistencies, but will
also make it fail if not enough replicas are available.
Does not have any performance impact.
Has effect only in distributed mode.
on_disk_payload:
If true - point`s payload will not be stored in memory.
It will be read from the disk every time it is requested.
This setting saves RAM by (slightly) increasing the response time.
Note: those payload values that are involved in filtering and are indexed - remain in RAM.
hnsw_config: Params for HNSW index
optimizers_config: Params for optimizer
wal_config: Params for Write-Ahead-Log
quantization_config: Params for quantization, if None - quantization will be disabled
timeout:
Wait for operation commit timeout in seconds.
If timeout is reached - request will return with service error.
on_disk: Whethers to store vectors on disk. Defaults to None.
Defaults to "Cosine".
vectors_on_disk: Whethers to store vectors on disk. Defaults to None.
**kwargs: Additional arguments to pass to the `QdrantClient#create_collection()` method.
Reference: https://qdrant.tech/documentation/concepts/collections/#create-a-collection
""" # noqa: E501
if dimension is None:
Expand Down Expand Up @@ -582,26 +562,17 @@ def create_canopy_collection(
collection_name=self.collection_name,
vectors_config={
DENSE_VECTOR_NAME: models.VectorParams(
size=dimension, distance=distance, on_disk=on_disk
size=dimension, distance=getattr(models.Distance, distance), on_disk=vectors_on_disk # noqa: E501
)
},
sparse_vectors_config={
SPARSE_VECTOR_NAME: models.SparseVectorParams(
index=models.SparseIndexParams(
on_disk=on_disk,
on_disk=vectors_on_disk,
)
)
},
shard_number=shard_number,
sharding_method=sharding_method,
replication_factor=replication_factor,
write_consistency_factor=write_consistency_factor,
on_disk_payload=on_disk_payload,
hnsw_config=hnsw_config,
optimizers_config=optimizers_config,
wal_config=wal_config,
quantization_config=quantization_config,
timeout=timeout,
**kwargs,
)

for field in indexed_keyword_fields:
Expand Down
11 changes: 7 additions & 4 deletions src/canopy/knowledge_base/qdrant/utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import asyncio
import functools
from itertools import islice
from typing import Any, Callable, Optional, Tuple, Union
from typing import Any, Callable, Optional

from qdrant_client import AsyncQdrantClient, QdrantClient
from qdrant_client.local.async_qdrant_local import AsyncQdrantLocal
import logging

try:
from qdrant_client import AsyncQdrantClient, QdrantClient
from qdrant_client.local.async_qdrant_local import AsyncQdrantLocal
except ImportError:
pass
logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -46,7 +49,7 @@ def generate_clients(
path: Optional[str] = None,
force_disable_check_same_thread: bool = False,
**kwargs: Any,
) -> Tuple[QdrantClient, Union[AsyncQdrantClient, None]]:
):
sync_client = QdrantClient(
location=location,
url=url,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,11 @@
)
from canopy.knowledge_base.models import DocumentWithScore
from canopy.models.data_models import Query
from qdrant_client.async_qdrant_remote import AsyncQdrantRemote
from tests.unit import random_words
from tests.unit.stubs.stub_chunker import StubChunker

qdrant_client = pytest.importorskip("qdrant_client")


async def execute_and_assert_queries(
knowledge_base: QdrantKnowledgeBase, chunks_to_query
Expand Down Expand Up @@ -97,7 +98,7 @@ async def test_query(knowledge_base, encoded_chunks):
@pytest.mark.asyncio
async def test_query_with_metadata_filter(knowledge_base):
if knowledge_base._async_client is None or not isinstance(
knowledge_base._async_client._client, AsyncQdrantRemote
knowledge_base._async_client._client, qdrant_client.async_qdrant_remote.AsyncQdrantRemote # noqa: E501
):
pytest.skip(
"Dict filter is not supported for QdrantLocal"
Expand Down
10 changes: 5 additions & 5 deletions tests/system/knowledge_base/qdrant/test_qdrant_knowledge_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from pathlib import Path

import pytest
from qdrant_client import QdrantClient
from canopy.knowledge_base.chunker.base import Chunker

from canopy.knowledge_base.knowledge_base import KnowledgeBase
Expand All @@ -18,7 +17,6 @@
from canopy.knowledge_base.reranker.reranker import Reranker
from canopy.models.data_models import Query

from qdrant_client.qdrant_remote import QdrantRemote
from canopy_cli.cli import _load_kb_config
from tests.system.knowledge_base.qdrant.common import (
assert_chunks_in_collection,
Expand All @@ -31,6 +29,8 @@
from tests.unit.stubs.stub_dense_encoder import StubDenseEncoder
from tests.unit.stubs.stub_record_encoder import StubRecordEncoder

qdrant_client = pytest.importorskip("qdrant_client")


def execute_and_assert_queries(knowledge_base: QdrantKnowledgeBase, chunks_to_query):
queries = [Query(text=chunk.text, top_k=2) for chunk in chunks_to_query]
Expand Down Expand Up @@ -129,7 +129,7 @@ def test_query(knowledge_base, encoded_chunks):


def test_query_with_metadata_filter(knowledge_base):
if not isinstance(knowledge_base._client._client, QdrantRemote):
if not isinstance(knowledge_base._client._client, qdrant_client.qdrant_remote.QdrantRemote): # noqa: E501
pytest.skip(
"Dict filter is not supported for QdrantLocal"
"Use qdrant_client.models.Filter instead"
Expand Down Expand Up @@ -255,7 +255,7 @@ def test_kb_non_existing_collection(knowledge_base):

def test_init_defaults(collection_name, collection_full_name):
new_kb = QdrantKnowledgeBase(collection_name)
assert isinstance(new_kb._client, QdrantClient)
assert isinstance(new_kb._client, qdrant_client.QdrantClient)
assert new_kb.collection_name == collection_full_name
assert isinstance(new_kb._chunker, Chunker)
assert isinstance(
Expand All @@ -272,7 +272,7 @@ def test_init_defaults(collection_name, collection_full_name):
def test_init_defaults_with_override(knowledge_base, chunker):
collection_name = knowledge_base.collection_name
new_kb = QdrantKnowledgeBase(collection_name=collection_name, chunker=chunker)
assert isinstance(new_kb._client, QdrantClient)
assert isinstance(new_kb._client, qdrant_client.QdrantClient)
assert new_kb.collection_name == collection_name
assert isinstance(new_kb._chunker, Chunker)
assert isinstance(new_kb._chunker, StubChunker)
Expand Down

0 comments on commit d9f80e7

Please sign in to comment.